In [1]:
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from collections import OrderedDict

In [2]:
class Model(nn.Module):
    def __init__(self, num_param=2):
        super().__init__()
        self.linear = nn.Linear(num_param, 1, bias=False)
#         self.relu = nn.ReLU()
    def forward(self, x):
        x = self.linear(x)
#         x = self.relu(x)
        return x

class ComplexModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 2, 3, bias=False)
        self.bn = nn.BatchNorm2d(2)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

In [3]:
def clone_module(module, memo=None):

    if memo is None:
        memo = {}

    # Create a copy of the module.
    # https://github.com/pytorch/pytorch/blob/65bad41cbec096aa767b3752843eddebf845726f/torch/nn/modules/module.py#L1171
    if not isinstance(module, torch.nn.Module):
        return module
    clone = module.__new__(type(module))
    clone.__dict__ = module.__dict__.copy()
    clone._parameters = clone._parameters.copy()
    clone._buffers = clone._buffers.copy()
    clone._modules = clone._modules.copy()

    # Re-write all parameters
    if hasattr(clone, '_parameters'):
        for param_key in clone._parameters:
            if module._parameters[param_key] is not None:
                param = module._parameters[param_key]
                param_ptr = param.data_ptr
                if param_ptr in memo:
                    clone._parameters[param_key] = memo[param_ptr]
                else:
                    cloned = param.clone()
                    clone._parameters[param_key] = cloned
                    memo[param_ptr] = cloned

    # Then, recurse for each submodule
    if hasattr(clone, '_modules'):
        for module_key in clone._modules:
            clone._modules[module_key] = clone_module(
                module._modules[module_key],
                memo=memo,
            )

    # Finally, rebuild the flattened parameters for RNNs
    if hasattr(clone, 'flatten_parameters'):
        clone = clone._apply(lambda x: x)
    return clone

In [4]:
def update(module, memo=None):
    if memo is None:
        memo = {}

    # Update the params
    for param_key in module._parameters:
        p = module._parameters[param_key]
        if p is not None and hasattr(p, 'update') and p.update is not None:
            if p in memo:
                module._parameters[param_key] = memo[p]
            else:
                updated = p + p.update
                memo[p] = updated
                module._parameters[param_key] = updated

    # Then, recurse for each submodule
    for module_key in module._modules:
        module._modules[module_key] = update(
            module._modules[module_key],
            memo=memo,
        )

    # Rebuild the flattened parameters for RNNs
    if hasattr(module, 'flatten_parameters'):
        module._apply(lambda x: x)
    return module

In [5]:
def backprop(model, out, lr=0.1, debug=False):
    ok_params = (p for p in model.parameters() if p.requires_grad)
    grad = torch.autograd.grad(out, ok_params, create_graph=True)
    
    ok_params = (p for p in model.parameters() if p.requires_grad)
    for p, g in zip(ok_params, grad):
        if g is not None:
            if debug:
                print(f"Gradient : {g}")
            p.update = -lr * g
    return update(model)

In [6]:
# copy_state_dict = {}
# for p1, p2 in zip(m1.parameters(), m2.parameters()):
#     p2[:] = p1.clone()
#     p2._copy(p1.clone())

# Test 1: Simple Linear Model

In [6]:
m1 = Model()

In [7]:
print(m1.linear.weight)

Parameter containing:
tensor([[0.1400, 0.6083]], requires_grad=True)


In [8]:
outer_optim = optim.SGD(m1.parameters(), lr=0.1)

In [9]:
outer_optim.zero_grad()
for t in range(3):
    total = 0.0
    count = 0
    m2 = clone_module(m1)
    
    x = torch.ones(2) * (t+1)
    out = m2(x)
    
    backprop(m2, out)
    
    print(f"Model orig: {m1.linear.weight}")
    print(f"Model clone: {m2.linear.weight}")
    
    x2 = torch.tensor([0.0, (t+1)**2])
    out2 = m2(x2)
    out2.backward()
    print(m1.linear.weight.grad)
#     print(torch.autograd.grad(out2, m1.parameters()))
    
    total += out2
    count += 1
    print()

Model orig: Parameter containing:
tensor([[0.1400, 0.6083]], requires_grad=True)
Model clone: tensor([[0.0400, 0.5083]], grad_fn=<AddBackward0>)
tensor([[0., 1.]])

Model orig: Parameter containing:
tensor([[0.1400, 0.6083]], requires_grad=True)
Model clone: tensor([[-0.0600,  0.4083]], grad_fn=<AddBackward0>)
tensor([[0., 5.]])

Model orig: Parameter containing:
tensor([[0.1400, 0.6083]], requires_grad=True)
Model clone: tensor([[-0.1600,  0.3083]], grad_fn=<AddBackward0>)
tensor([[ 0., 14.]])



In [10]:
total /= count
# outer_optim.zero_grad()
# total.backward()
print(m1.linear.weight.grad)
outer_optim.step()

tensor([[ 0., 14.]])


In [11]:
print(m1.linear.weight)
print(m2.linear.weight)

Parameter containing:
tensor([[ 0.1400, -0.7917]], requires_grad=True)
tensor([[-0.1600,  0.3083]], grad_fn=<AddBackward0>)


# Test 2: Conv Model

In [13]:
m1 = ComplexModel()

In [14]:
print(m1.conv.weight)

Parameter containing:
tensor([[[[-0.2575, -0.1216, -0.2419],
          [ 0.3088, -0.2674,  0.1279],
          [ 0.0775, -0.1218,  0.0791]]],


        [[[ 0.2890, -0.0700,  0.1894],
          [-0.1374,  0.2413, -0.1832],
          [ 0.3039,  0.2421, -0.1715]]]], requires_grad=True)


In [15]:
outer_optim = optim.SGD(m1.parameters(), lr=0.001, momentum=0.9)

In [27]:
outer_optim.zero_grad()
for t in range(3):
    total = 0.0
    count = 0
    m2 = clone_module(m1)
    
    x = torch.randn(1, 1, 5, 5)
    out = m2(x).sum()
    gt = torch.sin(x.sum())
    loss = (gt - out) ** 2
    
#     grads = torch.autograd.grad(loss, (p for p in m2.parameters() if p.requires_grad), create_graph=True)
    backprop(m2, loss, lr=0.001)
    
    print(f"No change: {torch.allclose(m1.conv.weight, m2.conv.weight)}")
    print(f"Support loss {loss.item()}")
#     print(f"Model orig: {m1.conv.weight}")
#     print(f"Model clone: {m2.conv.weight}")
    
    x2 = torch.randn(1, 1, 5, 5)
    out2 = m2(x2).sum()
    gt2 = torch.sin(x2.sum())
    loss2 = (gt2 - out2) ** 2
    loss2.backward()
    print(f"Query loss {loss2.item()}")
    print(m1.conv.weight.grad)
#     print(torch.autograd.grad(out2, m1.parameters()))
    
    total += loss2
    count += 1
    print()

No change: True
Support loss 0.01822948455810547
Query loss 0.23389102518558502
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]])

No change: True
Support loss 0.9629865288734436
Query loss 0.3391243815422058
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]])

No change: True
Support loss 0.1363360732793808
Query loss 0.6685811877250671
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]])



In [28]:
print(total)
total /= count
# outer_optim.zero_grad()
# total.backward()
print(m1.conv.weight.grad)
outer_optim.step()

tensor(0.6686, grad_fn=<AddBackward0>)
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]])


In [18]:
print(m1.conv.weight)
print(m2.conv.weight)

Parameter containing:
tensor([[[[ 0.1967, -0.3055,  0.0550],
          [ 0.1962,  0.1963,  0.1426],
          [ 0.1129,  0.2057,  0.2414]]],


        [[[ 0.2165,  0.1041, -0.0909],
          [ 0.1635, -0.0960, -0.2761],
          [ 0.3091, -0.0802,  0.0284]]]], requires_grad=True)
tensor([[[[ 0.2164, -0.2719,  0.0461],
          [ 0.1779,  0.2398,  0.1853],
          [ 0.1554,  0.1954,  0.1766]]],


        [[[ 0.1904,  0.3132, -0.1621],
          [ 0.2632, -0.1344, -0.1651],
          [ 0.2386, -0.0665,  0.1100]]]], grad_fn=<AddBackward0>)


# Test 3: SimpleNet

(Note: This requires specific models from the repo)