Skip to content

Commit

Permalink
FIX ParameterList issue
Browse files Browse the repository at this point in the history
  • Loading branch information
JeanKossaifi committed Mar 4, 2021
1 parent b6cc797 commit f26c79e
Show file tree
Hide file tree
Showing 15 changed files with 105 additions and 25 deletions.
10 changes: 8 additions & 2 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,19 @@
'sphinx.ext.todo',
'sphinx.ext.viewcode',
'sphinx.ext.githubpages',
"myst_nb",
# "nbsphinx",
# "myst_nb",
# "sphinx_nbexamples",
# 'jupyter_sphinx',
# 'matplotlib.sphinxext.plot_directive',
'sphinx.ext.imgmath', #'sphinx.ext.mathjax',
'sphinx.ext.mathjax', #'sphinx.ext.imgmath',
'numpydoc.numpydoc',
]

# # # Sphinx-nbexamples
# process_examples = False
# example_gallery_config = dict(pattern='+/+.ipynb')

# Remove the permalinks ("¶" symbols)
html_add_permalinks = ""

Expand Down
2 changes: 1 addition & 1 deletion doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ tensorized linear layers, tensor dropout and more!
modules/api
dev_guide/index
about
/tensor_regression_layers


.. only:: html

Expand Down
1 change: 1 addition & 0 deletions doc/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,4 @@ You are now ready to build the doc (here in html)::
make html

The results will be in `_build/html`

10 changes: 5 additions & 5 deletions tltorch/_factorized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tensorly import (validate_tt_rank, validate_cp_rank,
validate_tucker_rank, validate_tt_matrix_rank)

from .base import TensorModule
from .base import TensorModule, ParameterList
from . import init


Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(self, in_features, out_features, tensorized_shape, rank, bias=True)
self.rank = validate_tucker_rank(tensorized_shape, rank=rank)

self.core = nn.Parameter(torch.Tensor(*self.rank))
self.factors = nn.ParameterList(nn.Parameter(torch.Tensor(s, r))\
self.factors = ParameterList(nn.Parameter(torch.Tensor(s, r))\
for (s, r) in zip(tensorized_shape, self.rank))

self.init_from_random(False)
Expand Down Expand Up @@ -215,7 +215,7 @@ def __init__(self, in_features, out_features, tensorized_shape, rank, bias=True)
self.rank = validate_cp_rank(tensorized_shape, rank=rank)

self.weights = nn.Parameter(torch.Tensor(self.rank))
self.factors = nn.ParameterList(nn.Parameter(torch.Tensor(s, self.rank)) for s in tensorized_shape)
self.factors = ParameterList(nn.Parameter(torch.Tensor(s, self.rank)) for s in tensorized_shape)

self.init_from_random(decompose_full_weight=False)

Expand Down Expand Up @@ -323,7 +323,7 @@ class TTLinear(BaseFactorizedLinear):
def __init__(self, in_features, out_features, tensorized_shape, rank, bias=True):
super().__init__(in_features, out_features, tensorized_shape, rank, bias=bias)
self.rank = validate_tt_rank(tensorized_shape, rank=rank)
self.factors = nn.ParameterList()
self.factors = ParameterList()
for i, s in enumerate(self.tensorized_shape):
self.factors.append(nn.Parameter(torch.Tensor(self.rank[i], s, self.rank[i+1])))

Expand Down Expand Up @@ -439,7 +439,7 @@ class TTMLinear(BaseFactorizedLinear):
def __init__(self, in_features, out_features, tensorized_shape, rank='same', bias=True):
super().__init__(in_features, out_features, tensorized_shape, rank, bias=bias)
self.rank = validate_tt_matrix_rank(tensorized_shape, rank=rank)
self.factors = nn.ParameterList()
self.factors = ParameterList()
self.ndim = len(tensorized_shape) // 2
self.out_shape = tensorized_shape[:self.ndim]
self.in_shape = tensorized_shape[self.ndim:]
Expand Down
4 changes: 3 additions & 1 deletion tltorch/_tcl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import tensorly as tl
tl.set_backend('pytorch')

from .base import ParameterList


class TCL(nn.Module):
"""Tensor Contraction Layer [1]_
Expand Down Expand Up @@ -55,7 +57,7 @@ def __init__(self, input_shape, rank, verbose=0, bias=False, **kwargs):
self.contraction_modes = list(range(1, self.n_input + 1))
factors = [nn.Parameter(torch.Tensor(r, s))
for (s, r) in zip(self.input_shape, self.rank)]
self.factors = nn.ParameterList(parameters=factors)
self.factors = ParameterList(parameters=factors)
if bias:
self.bias = nn.Parameter(
tl.tensor(self.output_shape), requires_grad=True)
Expand Down
5 changes: 3 additions & 2 deletions tltorch/_tensor_lasso.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.nn import functional as F

import tltorch as tltorch
from .base import ParameterList

# Author: Jean Kossaifi
# License: BSD 3 clause
Expand Down Expand Up @@ -149,7 +150,7 @@ def apply(self, module):

rank = module.rank
context = tl.context(module.core)
lasso_weights = nn.ParameterList([nn.Parameter(torch.ones(r, **context)) for r in rank])
lasso_weights = ParameterList([nn.Parameter(torch.ones(r, **context)) for r in rank])
setattr(module, 'lasso_weights', lasso_weights)
handle = module.register_decomposition_forward_pre_hook(self, 'L1Regularizer')
return module
Expand Down Expand Up @@ -288,7 +289,7 @@ def apply(self, module):
TensorModule (with Regularization hook)
"""
rank = module.rank[1:-1]
lasso_weights = nn.ParameterList([nn.Parameter(torch.ones(1, 1, r)) for r in rank])
lasso_weights = ParameterList([nn.Parameter(torch.ones(1, 1, r)) for r in rank])
setattr(module, 'lasso_weights', lasso_weights)
handle = module.register_decomposition_forward_pre_hook(self, 'L1Regularizer')
return module
Expand Down
8 changes: 4 additions & 4 deletions tltorch/_trl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from tensorly.decomposition import parafac, tucker, tensor_train
from tensorly import validate_tt_rank, validate_cp_rank, validate_tucker_rank

from .base import TensorModule
from .base import TensorModule, ParameterList
from . import init

class BaseTRL(TensorModule):
Expand Down Expand Up @@ -145,7 +145,7 @@ def __init__(self, input_shape, output_shape, rank, project_input=False,
self.project_input = project_input

self.core = nn.Parameter(torch.Tensor(*self.rank))
self.factors = nn.ParameterList(nn.Parameter(torch.Tensor(s, r))\
self.factors = ParameterList(nn.Parameter(torch.Tensor(s, r))\
for (s, r) in zip(self.weight_shape, self.rank))

self.n_factor = len(self.factors)
Expand Down Expand Up @@ -289,7 +289,7 @@ def __init__(self, input_shape, output_shape, rank, bias=False, verbose=0, **kwa
self.rank = validate_cp_rank(self.weight_shape, rank=rank)

self.weights = nn.Parameter(torch.Tensor(self.rank))
self.factors = nn.ParameterList(nn.Parameter(torch.Tensor(s, self.rank)) for s in self.weight_shape)
self.factors = ParameterList(nn.Parameter(torch.Tensor(s, self.rank)) for s in self.weight_shape)

self.init_from_random(decompose_full_weight=False)

Expand Down Expand Up @@ -386,7 +386,7 @@ def __init__(self, input_shape, output_shape, rank, bias=False, verbose=0, **kwa

self.rank = validate_tt_rank(self.weight_shape, rank=rank)

self.factors = nn.ParameterList()
self.factors = ParameterList()
for i, s in enumerate(self.weight_shape):
self.factors.append(nn.Parameter(torch.Tensor(self.rank[i], s, self.rank[i+1])))

Expand Down
67 changes: 67 additions & 0 deletions tltorch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,70 @@ def _process_decomposition(self):

return decomposition


class ParameterList(nn.Module):
def __init__(self, parameters=None):
super().__init__()
self.keys = []
self.counter = 0
if parameters is not None:
self.extend(parameters)

def _unique_key(self):
"""Creates a new unique key"""
key = f'param_{self.counter}'
self.counter += 1
return key

def append(self, element):
# p = nn.Parameter(element)
key = self._unique_key()
self.register_parameter(key, element)
self.keys.append(key)

def insert(self, index, element):
# p = nn.Parameter(element)
key = self._unique_key()
self.register_parameter(key, element)
self.keys.insert(index, key)

def pop(self, index=-1):
item = self[index]
self.__delitem__(index)
return item

def __getitem__(self, index):
keys = self.keys[index]
if isinstance(keys, list):
#return self.__class__([getattr(self, key) for key in keys])
params = [getattr(self, key) for key in keys]
return self.__class__(params)
return getattr(self, keys)

This comment has been minimized.

Copy link
@merajhashemi

merajhashemi Mar 5, 2021

Member

How does this solve the issue with DataParallel?

This comment has been minimized.

Copy link
@JeanKossaifi

JeanKossaifi Mar 5, 2021

Author Member

This ParameterList is no different to any PyTorch module and doesn't directly access _parameters. There are potential issues with this class in general but we shouldn't face any of these as we are only using it internally for holding a list of decomposition factors.

See the original issue is at pytorch/pytorch#36035 (comment)


def __setitem__(self, index, value):
self.register_parameter(self.keys[index], value)

def __delitem__(self, index):
delattr(self, self.keys[index])
self.keys.__delitem__(index)

def __len__(self):
return len(self.keys)

def extend(self, parameters):
for param in parameters:
self.append(param)

def __iadd__(self, parameters):
return self.extend(parameters)

def extra_repr(self) -> str:
child_lines = []
for k, p in self._parameters.items():
size_str = 'x'.join(str(size) for size in p.size())
device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
parastr = 'Parameter containing: [{} of size {}{}]'.format(
torch.typename(p), size_str, device_str)
child_lines.append(' (' + str(k) + '): ' + parastr)
tmpstr = '\n'.join(child_lines)
return tmpstr
3 changes: 2 additions & 1 deletion tltorch/factorized_conv/_cp_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# License: BSD 3 clause

from ._base_conv import Conv1D, BaseFactorizedConv
from ..base import ParameterList
from .. import init
from tensorly import validate_cp_rank
import torch
Expand Down Expand Up @@ -76,7 +77,7 @@ def __init__(self, in_channels, out_channels, kernel_size, rank, order=None,
self.rank = validate_cp_rank(self.kernel_shape, rank=self.rank)

self.weights = nn.Parameter(torch.Tensor(self.rank))
self.factors = nn.ParameterList([nn.Parameter(torch.Tensor(s, self.rank)) for s in self.kernel_shape])
self.factors = ParameterList([nn.Parameter(torch.Tensor(s, self.rank)) for s in self.kernel_shape])

self.init_from_random(decompose_full_weight=False)

Expand Down
3 changes: 2 additions & 1 deletion tltorch/factorized_conv/_tt_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ._base_conv import Conv1D, BaseFactorizedConv
from .. import init
from ..base import ParameterList

from tensorly import validate_tt_rank
import torch
Expand Down Expand Up @@ -76,7 +77,7 @@ def __init__(self, in_channels, out_channels, kernel_size, rank, order=None,
self.tt_shape = tuple(tt_shape)
self.rank = tl.tt_tensor.validate_tt_rank(self.tt_shape, rank)

self.factors = nn.ParameterList()
self.factors = ParameterList()
for i, s in enumerate(self.tt_shape):
self.factors.append(nn.Parameter(torch.Tensor(self.rank[i], s, self.rank[i+1])))

Expand Down
3 changes: 2 additions & 1 deletion tltorch/factorized_conv/_tucker_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# License: BSD 3 clause

from ._base_conv import Conv1D, BaseFactorizedConv
from ..base import ParameterList
from .. import init


Expand Down Expand Up @@ -85,7 +86,7 @@ def __init__(self, in_channels, out_channels, kernel_size, rank, modes_fixed_ran
self.modes_fixed_rank = modes_fixed_rank

self.core = nn.Parameter(torch.Tensor(*self.rank))
self.factors = nn.ParameterList(nn.Parameter(torch.Tensor(s, r))\
self.factors = ParameterList(nn.Parameter(torch.Tensor(s, r))\
for (s, r) in zip(self.kernel_shape, self.rank))

self.init_from_random(decompose_full_weight=False)
Expand Down
2 changes: 1 addition & 1 deletion tltorch/factorized_conv/tests/test_factorized_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
def single_conv_test(FactorizedConv, implementation, random_tensor_generator, reconstruction_fun,
order=2, rank=None, rng=None, input_channels=2, output_channels=4,
kernel_size=3, batch_size=1, activation_size=8, device='cpu'):
rng = random.check_random_state(rng)
rng = tl.check_random_state(rng)
input_shape = (batch_size, input_channels) + (activation_size, )*order
kernel_shape = (output_channels, input_channels) + (kernel_size, )*order

Expand Down
2 changes: 1 addition & 1 deletion tltorch/tests/test_factorized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
@pytest.mark.parametrize('FactorizedLinear', [TuckerLinear, CPLinear, TTLinear, TTMLinear])
def test_FactorizedLinear(FactorizedLinear):
random_state = 12345
rng = random.check_random_state(random_state)
rng = tl.check_random_state(random_state)
batch_size = 2
in_features = 9
out_features = 16
Expand Down
2 changes: 1 addition & 1 deletion tltorch/tests/test_tcl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def test_tcl():
random_state = 12345
rng = random.check_random_state(random_state)
rng = tl.check_random_state(random_state)
batch_size = 2
in_shape = (4, 5, 6)
out_shape = (2, 3, 5)
Expand Down
8 changes: 4 additions & 4 deletions tltorch/tests/test_trl.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_trl(TRL, random_tensor, true_rank, rank):
# fix the random seed for reproducibility
random_state = 12345

rng = random.check_random_state(random_state)
rng = tl.check_random_state(random_state)
tol = 0.08

# Generate a random tensor
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_TuckerTRL(order, project_input, learn_pool):

# fix the random seed for reproducibility and create random input
random_state = 12345
rng = random.check_random_state(random_state)
rng = tl.check_random_state(random_state)
data = tl.tensor(rng.random_sample((batch_size, in_features) + (spatial_size, )*order))

# Build a simple net with avg-pool, flatten + fully-connected
Expand Down Expand Up @@ -154,7 +154,7 @@ def net(data):
res_trl = trl(data)

testing.assert_array_almost_equal(res_fc, res_trl)

@pytest.mark.parametrize('TRL', [TuckerTRL, CPTRL, TensorTrainTRL])
@pytest.mark.parametrize('bias', [True, False])
def test_TRL_from_linear(TRL, bias):
Expand All @@ -168,7 +168,7 @@ def test_TRL_from_linear(TRL, bias):

# fix the random seed for reproducibility and create random input
random_state = 12345
rng = random.check_random_state(random_state)
rng = tl.check_random_state(random_state)
data = tl.tensor(rng.random_sample((batch_size, in_features)))
fc = nn.Linear(in_features, out_features, bias=bias)
res_fc = fc(tl.copy(data))
Expand Down

0 comments on commit f26c79e

Please sign in to comment.