In [1]:
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
from time import time
from torchvision import datasets, transforms
from torch import nn, optim
from tqdm.notebook import trange, tqdm
from torch.autograd import Variable

In [2]:
class scramble(object):
    def __init__(self):
        self.seed = np.random.randint(10 ** 8)

    def __call__(self, pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
        Returns:
            np.array: Image scrambled and converted to numpy array.
        """
        np.random.seed(self.seed)
        pic = np.array(pic).flatten()
        np.random.shuffle(pic)
        pic = (pic.reshape(28, 28) / 255).astype(float)
        return pic
        
    def __repr__(self):
        return self.__class__.__name__ + '()'

In [3]:
s1 = scramble()
s2 = scramble()


trainset1 = datasets.MNIST('./files', train=True, transform=transforms.Compose([
    s1,
    transforms.ToTensor()
]))
trainset2 = datasets.MNIST('./files', train=True, transform=transforms.Compose([
    s2,
    transforms.ToTensor()
]))
testset1 = datasets.MNIST('./files', train=False, transform=transforms.Compose([
    s1,
    transforms.ToTensor()
]))
testset2 = datasets.MNIST('./files', train=False, transform=transforms.Compose([
    s2,
    transforms.ToTensor()
]))

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 400),
            nn.ReLU(),
            nn.Linear(400, 400),
            nn.ReLU(),
            nn.Linear(400, 10)
        )
        
    def forward(self, x):
        return self.model(x.float())
        


In [5]:
train_loader1 = torch.utils.data.DataLoader(
    trainset1,
    batch_size=1,
    num_workers=2,
    drop_last=False,
    shuffle=True)

test_loader1 = torch.utils.data.DataLoader(
    testset1,
    batch_size=1,
    num_workers=2,
    drop_last=False)


train_loader2 = torch.utils.data.DataLoader(
    trainset2,
    batch_size=1,
    num_workers=2,
    drop_last=False,
    shuffle=True)

test_loader2 = torch.utils.data.DataLoader(
    testset2,
    batch_size=1,
    num_workers=2,
    drop_last=False)

In [6]:
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
net = Net()
net.cuda()

optimizer = torch.optim.SGD(net.parameters(), lr=10 ** -3)
criterion = nn.CrossEntropyLoss()

trainer = create_supervised_trainer(net, optimizer, criterion, device='cuda:0')

evaluator = create_supervised_evaluator(net, device='cuda:0')

@trainer.on(Events.ITERATION_COMPLETED(every=10000))
def log_training_loss(trainer):
    print(f"Epoch[{trainer.state.epoch}] Loss: {trainer.state.output:.2f}")


@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
    net.eval()
    tot = 0
    correct = 0
    with torch.no_grad():
        for item in test_loader1:
            ims = item[0].cuda()
            labs = item[1].cuda()
            preds = net(ims)
            preds = torch.sigmoid(preds).cpu().detach().numpy()
            right = preds.argmax(axis=1) == labs.cpu().detach().numpy()
            tot += len(right)
            correct += sum(right)

    print("Epoch {} Validation Accuracy: {}%".format(trainer.state.epoch, round(100 * correct / tot, 3)))
    net.train()
trainer.run(train_loader1, max_epochs=10)


Epoch[1] Loss: 0.18
Epoch[1] Loss: 0.15
Epoch[1] Loss: 0.26
Epoch[1] Loss: 0.29
Epoch[1] Loss: 0.11
Epoch[1] Loss: 0.16
Epoch 1 Validation Accuracy: 92.73%
Epoch[2] Loss: 0.04


Engine run is terminating due to exception: 


KeyboardInterrupt: 

In [10]:
net2 = Net()
net2.load_state_dict(net.state_dict())

<All keys matched successfully>

In [150]:
def get_fisher(net, tset):
    net.eval()
    lf = nn.CrossEntropyLoss()
    sums = [np.zeros(tuple(param.shape)) for param in net.parameters()]
    for pic, lab in tqdm(tset):
        out = net(pic.cuda())
        loss = lf(out, lab.cuda())
        net.zero_grad()
        loss.backward()
        
        for i, param in enumerate(net.parameters()):
            sums[i] += param.grad.cpu().detach().numpy() ** 2
    n = len(tset)
    net.train()
    return sums

In [151]:
fisher = get_fisher(net, train_loader1)

  0%|          | 0/60000 [00:00<?, ?it/s]

In [154]:
torch.zeros((5, 5))

tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])

In [149]:
fisher

array([2.98416710e-04, 2.25864234e-04, 9.90082760e-05, ...,
       9.62385681e-03, 1.42198027e-02, 1.40491480e-02])

In [13]:
class EWC(nn.Module):
    def __init__(self, fisher, pweights, lam):
        super(EWC, self).__init__()
        self.fisher = torch.tensor(fisher)
        self.pweights = torch.tensor(np.concatenate([param.cpu().detach().numpy().flatten() for param in pweights]))
        self.n = len(self.fisher)
        self.lam = lam
        self.lf = nn.CrossEntropyLoss()
    
    def forward(self, outputs, labels, params):
        weights = torch.tensor(np.concatenate([param.cpu().detach().numpy().flatten() for param in params]))
        return self.lf(outputs, labels) + (self.lam / (2)) * (self.fisher * (weights - self.pweights) ** 2).sum()

In [120]:
(param1 - param2).norm(2)

tensor(0.0011, device='cuda:0', grad_fn=<NormBackward1>)

In [148]:
Variable( torch.FloatTensor(1), requires_grad=True).to('cuda:0')

tensor([1.3001e-18], device='cuda:0', grad_fn=<CopyBackwards>)

In [111]:
reg = Variable( torch.FloatTensor(1), requires_grad=True).to('cuda:0')
for param1, param2 in zip(net.parameters(), net2.parameters()):
    print(param1.shape, param2.shape)
    reg = reg + (param1 - param2).norm(2)
    print(reg)

torch.Size([400, 784]) torch.Size([400, 784])
tensor([-3046216.], device='cuda:0', grad_fn=<AddBackward0>)
torch.Size([400]) torch.Size([400])
tensor([-3046216.], device='cuda:0', grad_fn=<AddBackward0>)
torch.Size([400, 400]) torch.Size([400, 400])
tensor([-3046216.], device='cuda:0', grad_fn=<AddBackward0>)
torch.Size([400]) torch.Size([400])
tensor([-3046216.], device='cuda:0', grad_fn=<AddBackward0>)
torch.Size([10, 400]) torch.Size([10, 400])
tensor([-3046216.], device='cuda:0', grad_fn=<AddBackward0>)
torch.Size([10]) torch.Size([10])
tensor([-3046216.], device='cuda:0', grad_fn=<AddBackward0>)


In [110]:
reg

tensor([0.0231], device='cuda:0', grad_fn=<AddBackward0>)

In [14]:
import pdb
criterion = EWC(fisher, net.parameters(), 0.5)
optimizer = torch.optim.SGD(net2.parameters(), lr=10 ** -3)
net2.cuda()
for epoch in range(10):  # loop over the dataset multiple times

    for i, data in enumerate(train_loader2, 0):
        # get the inputs; data is a list of [inputs, labels]
        rl = 0.0
        inputs, labels = data
        inputs = inputs.cuda()
        labels = labels.cuda()

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net2(inputs)
        pdb.set_trace()
        loss = criterion(outputs, labels, net2.parameters())
        rl += loss.item()
        loss.backward()
        optimizer.step()
        if i % 1000 == 999:
            print("Loss: {}".format(rl / 1000))
            rl = 0.0

        # print statistics

        if i % 20000 == 19999:    # print every 2000 mini-batches
            net2.eval()
            tot = 0
            correct = 0
            with torch.no_grad():
                for item in test_loader1:
                    ims = item[0].cuda()
                    labs = item[1].cuda()
                    preds = net2(ims)
                    preds = torch.sigmoid(preds).cpu().detach().numpy()
                    right = preds.argmax(axis=1) == labs.cpu().detach().numpy()
                    tot += len(right)
                    correct += sum(right)

            print("Epoch {} Test Accuracy Task 1: {}%".format(epoch, round(100 * correct / tot, 3)))
            tot = 0
            correct = 0
            with torch.no_grad():
                for item in test_loader2:
                    ims = item[0].cuda()
                    labs = item[1].cuda()
                    preds = net2(ims)
                    preds = torch.sigmoid(preds).cpu().detach().numpy()
                    right = preds.argmax(axis=1) == labs.cpu().detach().numpy()
                    tot += len(right)
                    correct += sum(right)

            print("Epoch {} Test Accuracy Task 2: {}%".format(epoch, round(100 * correct / tot, 3)))
            net2.train()
print('Finished Training')


> [0;32m<ipython-input-14-95c39fd25402>[0m(20)[0;36m<module>[0;34m()[0m
[0;32m     18 [0;31m        [0moutputs[0m [0;34m=[0m [0mnet2[0m[0;34m([0m[0minputs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     19 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 20 [0;31m        [0mloss[0m [0;34m=[0m [0mcriterion[0m[0;34m([0m[0moutputs[0m[0;34m,[0m [0mlabels[0m[0;34m,[0m [0mnet2[0m[0;34m.[0m[0mparameters[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     21 [0;31m        [0mrl[0m [0;34m+=[0m [0mloss[0m[0;34m.[0m[0mitem[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     22 [0;31m        [0mloss[0m[0;34m.[0m[0mbackward[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> outputs
tensor([[ 0.3418,  0.0343, -0.3148,  0.5553, -1.2397,  1.9939, -0.6209,  0.1416,
         -0.2269, -0.6056]], device='cuda:

--KeyboardInterrupt--

KeyboardInterrupt: Interrupted by user


FileNotFoundError: [Errno 2] No such file or directory

In [66]:
v = np.random.rand(10 ** 7)
start = time.time()
for i in range(10 ** 2):
    v ** 2
print(time.time() - start)

1.483384370803833


In [20]:
net.eval()
tot = 0
correct = 0
with torch.no_grad():
    for item in test_loader2:
        ims = item[0].cuda()
        labs = item[1].cuda()
        preds = net(ims)
        preds = torch.sigmoid(preds).cpu().detach().numpy()
        right = preds.argmax(axis=1) == labs.cpu().detach().numpy()
        tot += len(right)
        correct += sum(right)

print("Epoch {} Validation Accuracy: {}%".format(trainer.state.epoch, round(100 * correct / tot, 3)))
net.train()

Epoch 10 Validation Accuracy: 98.16%


Net(
  (model): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=784, out_features=400, bias=True)
    (2): ReLU()
    (3): Linear(in_features=400, out_features=400, bias=True)
    (4): ReLU()
    (5): Linear(in_features=400, out_features=10, bias=True)
  )
)

In [23]:
for name, param in net.named_parameters():
    print(name)
    print(param)

model.1.weight
Parameter containing:
tensor([[-0.0010, -0.0215, -0.0250,  ...,  0.0065,  0.0283, -0.0226],
        [ 0.0252, -0.0094,  0.0043,  ...,  0.0127,  0.0122, -0.0153],
        [ 0.0194, -0.0278, -0.0274,  ..., -0.0106,  0.0075,  0.0269],
        ...,
        [ 0.0841, -0.0356,  0.0574,  ...,  0.0791, -0.0258, -0.0337],
        [-0.0259, -0.0220, -0.0053,  ..., -0.0801, -0.0152,  0.0464],
        [-0.0285,  0.0353, -0.0504,  ..., -0.0264,  0.0110,  0.0104]],
       device='cuda:0', requires_grad=True)
model.1.bias
Parameter containing:
tensor([ 0.0709, -0.0061, -0.0501,  0.0325,  0.0203, -0.0372,  0.0115, -0.0025,
        -0.0258,  0.0369,  0.0506, -0.0083,  0.0359,  0.0308,  0.0575,  0.0993,
        -0.0386, -0.0272,  0.0063,  0.0296, -0.0769, -0.0245,  0.0014, -0.0311,
        -0.0185,  0.0223,  0.0750,  0.0502, -0.0095,  0.0383,  0.0711,  0.0243,
         0.0408,  0.1451, -0.0033,  0.0863,  0.0722,  0.0396, -0.0185,  0.0249,
        -0.0129, -0.0681,  0.0310,  0.0209,  0.073