Permalink
Browse files

ENH: extensions/autoreload: move methods out of the Plugin class, and…

… rewrite some code to be cleaner
  • Loading branch information...
1 parent 0e4733c commit ca85fb01db782f22ea594e936cb8b54095e9ca63 @pv committed Sep 18, 2011
Showing with 83 additions and 48 deletions.
  1. +83 −48 IPython/extensions/autoreload.py
@@ -30,6 +30,9 @@ def _get_compiled_ext():
PY_COMPILED_EXT = _get_compiled_ext()
class ModuleReloader(object):
+ enabled = False
+ """Whether this reloader is enabled"""
+
failed = {}
"""Modules that failed to reload: {module: mtime-on-failed-reload, ...}"""
@@ -45,9 +48,46 @@ class ModuleReloader(object):
old_objects = {}
"""(module-name, name) -> weakref, for replacing old code objects"""
+ def mark_module_skipped(self, module_name):
+ """Skip reloading the named module in the future"""
+ try:
+ del self.modules[module_name]
+ except KeyError:
+ pass
+ self.skip_modules[module_name] = True
+
+ def mark_module_reloadable(self, module_name):
+ """Reload the named module in the future (if it is imported)"""
+ try:
+ del self.skip_modules[module_name]
+ except KeyError:
+ pass
+ self.modules[module_name] = True
+
+ def aimport_module(self, module_name):
+ """Import a module, and mark it reloadable
+
+ Returns
+ -------
+ top_module : module
+ The imported module if it is top-level, or the top-level
+ top_name : module
+ Name of top_module
+
+ """
+ self.mark_module_reloadable(module_name)
+
+ __import__(module_name)
+ top_name = module_name.split('.')[0]
+ top_module = sys.modules[top_name]
+ return top_module, top_name
+
def check(self, check_all=False):
"""Check whether some modules need to be reloaded."""
+ if not self.enabled and not check_all:
+ return
+
if check_all or self.check_all:
modules = sys.modules.keys()
else:
@@ -67,33 +107,36 @@ def check(self, check_all=False):
continue
filename = m.__file__
- dirname = os.path.dirname(filename)
path, ext = os.path.splitext(filename)
if ext.lower() == '.py':
ext = PY_COMPILED_EXT
- filename = os.path.join(dirname, path + PY_COMPILED_EXT)
+ pyc_filename = path + PY_COMPILED_EXT
+ py_filename = filename
+ else:
+ pyc_filename = filename
+ py_filename = filename[:-1]
if ext != PY_COMPILED_EXT:
continue
try:
- pymtime = os.stat(filename[:-1]).st_mtime
- if pymtime <= os.stat(filename).st_mtime:
+ pymtime = os.stat(py_filename).st_mtime
+ if pymtime <= os.stat(pyc_filename).st_mtime:
continue
- if self.failed.get(filename[:-1], None) == pymtime:
+ if self.failed.get(py_filename, None) == pymtime:
continue
except OSError:
continue
try:
superreload(m, reload, self.old_objects)
- if filename[:-1] in self.failed:
- del self.failed[filename[:-1]]
+ if py_filename in self.failed:
+ del self.failed[py_filename]
except:
print >> sys.stderr, "[autoreload of %s failed: %s]" % (
modname, traceback.format_exc(1))
- self.failed[filename[:-1]] = pymtime
+ self.failed[py_filename] = pymtime
#------------------------------------------------------------------------------
# superreload
@@ -226,26 +269,12 @@ def superreload(module, reload=reload, old_objects={}):
from IPython.core.plugin import Plugin
from IPython.core.hooks import TryNext
-class Autoreload(Plugin):
- def __init__(self, shell=None, config=None):
- super(Autoreload, self).__init__(shell=shell, config=config)
-
- self.shell.define_magic('autoreload', self.magic_autoreload)
- self.shell.define_magic('aimport', self.magic_aimport)
- self.shell.set_hook('pre_run_code_hook', self.pre_run_code_hook)
-
- self._enabled = False
+class AutoreloadInterface(object):
+ def __init__(self, *a, **kw):
+ super(AutoreloadInterface, self).__init__(*a, **kw)
self._reloader = ModuleReloader()
self._reloader.check_all = False
- def pre_run_code_hook(self, ipself):
- if not self._enabled:
- raise TryNext
- try:
- self._reloader.check()
- except:
- pass
-
def magic_autoreload(self, ipself, parameter_s=''):
r"""%autoreload => Reload modules automatically
@@ -293,15 +322,15 @@ def magic_autoreload(self, ipself, parameter_s=''):
if parameter_s == '':
self._reloader.check(True)
elif parameter_s == '0':
- self._enabled = False
+ self._reloader.enabled = False
elif parameter_s == '1':
self._reloader.check_all = False
- self._enabled = True
+ self._reloader.enabled = True
elif parameter_s == '2':
self._reloader.check_all = True
- self._enabled = True
+ self._reloader.enabled = True
- def magic_aimport(self, ipself, parameter_s=''):
+ def magic_aimport(self, ipself, parameter_s='', stream=None):
"""%aimport => Import modules for automatic reloading.
%aimport
@@ -321,38 +350,44 @@ def magic_aimport(self, ipself, parameter_s=''):
to_reload.sort()
to_skip = self._reloader.skip_modules.keys()
to_skip.sort()
+ if stream is None:
+ stream = sys.stdout
if self._reloader.check_all:
- print "Modules to reload:\nall-expect-skipped"
+ stream.write("Modules to reload:\nall-except-skipped\n")
else:
- print "Modules to reload:\n%s" % ' '.join(to_reload)
- print "\nModules to skip:\n%s" % ' '.join(to_skip)
+ stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
+ stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
elif modname.startswith('-'):
modname = modname[1:]
- try:
- del self._reloader.modules[modname]
- except KeyError:
- pass
- self._reloader.skip_modules[modname] = True
+ self._reloader.mark_module_skipped(modname)
else:
- try:
- del self._reloader.skip_modules[modname]
- except KeyError:
- pass
- self._reloader.modules[modname] = True
+ top_module, top_name = self._reloader.aimport_module(modname)
- # Inject module to user namespace; handle also submodules properly
- __import__(modname)
- basename = modname.split('.')[0]
- mod = sys.modules[basename]
- ipself.push({basename: mod})
+ # Inject module to user namespace
+ ipself.push({top_name: top_module})
+ def pre_run_code_hook(self, ipself):
+ if not self._reloader.enabled:
+ raise TryNext
+ try:
+ self._reloader.check()
+ except:
+ pass
+
+class AutoreloadPlugin(AutoreloadInterface, Plugin):
+ def __init__(self, shell=None, config=None):
+ super(AutoreloadPlugin, self).__init__(shell=shell, config=config)
+
+ self.shell.define_magic('autoreload', self.magic_autoreload)
+ self.shell.define_magic('aimport', self.magic_aimport)
+ self.shell.set_hook('pre_run_code_hook', self.pre_run_code_hook)
_loaded = False
def load_ipython_extension(ip):
"""Load the extension in IPython."""
global _loaded
if not _loaded:
- plugin = Autoreload(shell=ip, config=ip.config)
+ plugin = AutoreloadPlugin(shell=ip, config=ip.config)
ip.plugin_manager.register_plugin('autoreload', plugin)
_loaded = True

0 comments on commit ca85fb0

Please sign in to comment.