Skip to content

Commit

Permalink
Fix load_module wrapping (#3948)
Browse files Browse the repository at this point in the history
For old type (g3) module loaders, the resolution wasn't working, breaking internal par tests. 

This makes `DeprecatedModuleLoader.load_module` work similarly to `exec_module` that it ensures that when the old module is loaded, both the new and old names are pointing to the same module in the module cache.

cc @cduck - this should be straightforward. The reason it didn't come up earlier, as I haven't tested on g3 since my change in #3917 (comment) - I think that more principled rewiring surfaced this bug :)
  • Loading branch information
balopat committed Mar 23, 2021
1 parent c27dcff commit 51dc096
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 12 deletions.
31 changes: 26 additions & 5 deletions cirq/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Workarounds for compatibility issues between versions and libraries."""
import functools
import importlib
from importlib.machinery import ModuleSpec
import os
import re
import sys
Expand Down Expand Up @@ -325,16 +324,38 @@ def __init__(self, loader: Any, old_module_name: str, new_module_name: str):
# in older environments this line makes them work as well
if hasattr(loader, 'load_module'):
# mypy#2427
self.load_module = loader.load_module # type: ignore
self.load_module = self._wrap_load_module(loader.load_module) # type: ignore
if hasattr(loader, 'create_module'):
# mypy#2427
self.create_module = loader.create_module # type: ignore
self.old_module_name = old_module_name
self.new_module_name = new_module_name

def create_module(self, spec: ModuleSpec) -> ModuleType:
return self.loader.create_module(spec)

def module_repr(self, module: ModuleType) -> str:
return self.loader.module_repr(module)

def _wrap_load_module(self, method: Any) -> Any:
def load_module(fullname: str) -> ModuleType:
assert fullname == self.old_module_name, (
f"DeprecatedModuleLoader for {self.old_module_name} was asked to "
f"load {fullname}"
)
if self.new_module_name in sys.modules:
sys.modules[self.old_module_name] = sys.modules[self.new_module_name]
return sys.modules[self.old_module_name]
method(self.new_module_name)
# https://docs.python.org/3.5/library/importlib.html#importlib.abc.Loader.load_module
assert self.new_module_name in sys.modules, (
f"Wrapped loader {self.loader} was "
f"expected to insert "
f"{self.new_module_name} in sys.modules "
f"but it did not."
)
sys.modules[self.old_module_name] = sys.modules[self.new_module_name]
return sys.modules[self.old_module_name]

return load_module

def _wrap_exec_module(self, method: Any) -> Any:
def exec_module(module: ModuleType) -> None:
assert module.__name__ == self.old_module_name, (
Expand Down
65 changes: 58 additions & 7 deletions cirq/_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import types
import warnings
from types import ModuleType
from typing import Callable
from typing import Callable, Optional
from importlib.machinery import ModuleSpec

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -679,19 +680,69 @@ def exec_module(self, module: ModuleType) -> None:
assert 'new' not in sys.modules


def test_loader_wrappers():
def test_loader_create_module():
class EmptyLoader(importlib.abc.Loader):
pass

dml = DeprecatedModuleLoader(EmptyLoader(), 'old', 'new')
# the default implementation is from the abstract class, which is just pass
assert dml.create_module('test') is None

fake_mod = ModuleType('hello')

class CreateModuleLoader(importlib.abc.Loader):
def create_module(self, spec: ModuleSpec) -> Optional[ModuleType]:
return fake_mod

assert (
DeprecatedModuleLoader(CreateModuleLoader(), 'old', 'new').create_module(None) == fake_mod
)


def test_deprecated_module_loader_load_module_wrapper():
hello_module = types.ModuleType('hello')

class StubLoader(importlib.abc.Loader):
def module_repr(self, module: ModuleType) -> str:
return 'hello'

def load_module(self, fullname: str) -> ModuleType:
# we simulate loader behavior - it is assumed that loaders will set the
# module cache with the loaded module
sys.modules[fullname] = hello_module
return hello_module

with pytest.raises(AssertionError, match="for old was asked to load something_else"):
DeprecatedModuleLoader(StubLoader(), 'old', 'new').load_module('something_else')

# new module already loaded
sys.modules['new_hello'] = hello_module
assert (
DeprecatedModuleLoader(StubLoader(), 'old_hello', 'new_hello').load_module('old_hello')
== hello_module
)
assert 'old_hello' in sys.modules and sys.modules['old_hello'] == sys.modules['new_hello']
del sys.modules['new_hello']
del sys.modules['old_hello']

# new module is not loaded
assert (
DeprecatedModuleLoader(StubLoader(), 'old_hello', 'new_hello').load_module('old_hello')
== hello_module
)
assert 'new_hello' in sys.modules
assert 'old_hello' in sys.modules and sys.modules['old_hello'] == sys.modules['new_hello']
del sys.modules['new_hello']
del sys.modules['old_hello']


def test_deprecated_module_loader_repr():
class StubLoader(importlib.abc.Loader):
def module_repr(self, module: ModuleType) -> str:
return 'hello'

module = types.ModuleType('old')
assert DeprecatedModuleLoader(StubLoader(), 'old', 'new').module_repr(module) == 'hello'
assert DeprecatedModuleLoader(StubLoader(), 'old', 'new').load_module('test') == hello_module
assert (
DeprecatedModuleLoader(StubLoader(), 'old_hello', 'new_hello').module_repr(module)
== 'hello'
)


def test_invalidate_caches():
Expand Down

0 comments on commit 51dc096

Please sign in to comment.