In [1]:
!pip install torchvision



In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import autograd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

In [17]:
def find_accuracy(model, dataloader):
    model = model.eval()
    accuracy = 0
    sum = 0 
    for input, target in dataloader:
        output = model(input.to(device))
        sum += (output.argmax(dim=1).long() == target.to(device)).float().mean()
    accuracy = float(sum) / len(dataloader)
    return accuracy

class LinearLayer(nn.Module):
    def __init__(self, input_dim, output_dim, batch_norm=True):
        super(LinearLayer, self).__init__()
        self.batch_norm = batch_norm
        self.linear = nn.Linear(input_dim, output_dim)
        self.activation = nn.ReLU()
        self.batchNorm = None
        
        if self.batch_norm:
            self.batchNorm = nn.BatchNorm1d(output_dim)
            
    def forward(self, x):
        if self.batch_norm:
            return self.batchNorm(self.activation(self.linear(x)))
            
        return self.activation(self.linear(x))

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.shape[0], -1)
    
class Model(nn.Module):
    def __init__(self, num_inputs, num_hidden, num_outputs):
        super(Model, self).__init__()
        self.f1 = Flatten()
        self.linearLayer1 = LinearLayer(num_inputs, num_hidden)
        self.linearLayer2 = LinearLayer(num_hidden, num_hidden)
        self.linearLayer3 = nn.Linear(num_hidden, num_outputs)
        
    def forward(self, x):
        return self.linearLayer3(self.linearLayer2(self.linearLayer1(self.f1(x))))

In [19]:

def ewc_loss(model, weight, estimated_fishers, estimated_means):
    losses = []
    for param_name, param in model.named_parameters():
        estimated_mean = estimated_means[param_name]
        estimated_fisher = estimated_fishers[param_name]
        losses.append((estimated_fisher * (param - estimated_mean) ** 2).sum())
    return (weight / 2) * sum(losses)

def estimate_ewc_params(model, train_ds, batch_size=100, num_batch=300, estimate_type='true'):
    estimated_mean = {}

    for param_name, param in model.named_parameters():
        estimated_mean[param_name] = param.data.clone()
        
    estimated_fisher = {}
    dl = DataLoader(train_ds, batch_size, shuffle=True)
    
    for n, p in model.named_parameters():
        estimated_fisher[n] = torch.zeros_like(p)
        
    model.eval()
    for i, (input, target) in enumerate(dl):
        if i > num_batch:
            break
        model.zero_grad()

        output = model(input.to(device))
        label = target.to(device)
        loss = F.nll_loss(F.log_softmax(output, dim=1), label)
        loss.backward()
        
        for n, p in model.named_parameters():
            estimated_fisher[n].data += p.grad.data ** 2 / len(dl)

    estimated_fisher = {n: p for n, p in estimated_fisher.items()}
    return estimated_mean, estimated_fisher

In [21]:
# Load MNIST dataset, representint task A
mnist_train = datasets.MNIST("../data", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST("../data", train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)

In [25]:
EPOCHS = 10
lr=0.001
weight=100000 
accuracies = {}

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

criterion = nn.CrossEntropyLoss()

model = Model(num_inputs = 28 * 28, num_hidden = 100, num_outputs = 10).to(device)
optimizer = optim.Adam(model.parameters(), lr)

for _ in range(EPOCHS):
    for input, target in tqdm(train_loader):
        output = model(input.to(device))
        loss = criterion(output, target.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
accuracies['mnist_initial'] = find_accuracy(model, test_loader)

100%|████████████████████████████████████████| 600/600 [00:01<00:00, 398.08it/s]
100%|████████████████████████████████████████| 600/600 [00:01<00:00, 414.11it/s]
100%|████████████████████████████████████████| 600/600 [00:01<00:00, 409.79it/s]
100%|████████████████████████████████████████| 600/600 [00:01<00:00, 412.06it/s]
100%|████████████████████████████████████████| 600/600 [00:01<00:00, 408.63it/s]
100%|████████████████████████████████████████| 600/600 [00:01<00:00, 412.86it/s]
100%|████████████████████████████████████████| 600/600 [00:01<00:00, 411.17it/s]
100%|████████████████████████████████████████| 600/600 [00:01<00:00, 409.71it/s]
100%|████████████████████████████████████████| 600/600 [00:01<00:00, 411.01it/s]
100%|████████████████████████████████████████| 600/600 [00:01<00:00, 408.58it/s]


In [27]:
f_mnist_train = datasets.FashionMNIST("../data", train=True, download=True, transform=transforms.ToTensor())
f_mnist_test = datasets.FashionMNIST("../data", train=False, download=True, transform=transforms.ToTensor())
f_train_loader = DataLoader(f_mnist_train, batch_size = 100, shuffle=True)
f_test_loader = DataLoader(f_mnist_test, batch_size = 100, shuffle=False)

In [32]:

estimated_mean, estimated_fisher = estimate_ewc_params(model, mnist_train)

for _ in range(EPOCHS):
    for input, target in tqdm(f_train_loader):
        output = model(input.to(device))
        loss = ewc_loss(model, weight, estimated_fisher, estimated_mean) + criterion(output, target.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
accuracies['mnist_EWC'] = find_accuracy(model, test_loader)
accuracies['f_mnist_EWC'] = find_accuracy(model, f_test_loader)

100%|████████████████████████████████████████| 600/600 [00:01<00:00, 372.26it/s]
100%|████████████████████████████████████████| 600/600 [00:01<00:00, 371.95it/s]
100%|████████████████████████████████████████| 600/600 [00:01<00:00, 372.20it/s]
100%|████████████████████████████████████████| 600/600 [00:01<00:00, 373.79it/s]
100%|████████████████████████████████████████| 600/600 [00:01<00:00, 373.83it/s]
100%|████████████████████████████████████████| 600/600 [00:01<00:00, 365.60it/s]
100%|████████████████████████████████████████| 600/600 [00:01<00:00, 373.42it/s]
100%|████████████████████████████████████████| 600/600 [00:01<00:00, 374.73it/s]
100%|████████████████████████████████████████| 600/600 [00:01<00:00, 374.61it/s]
100%|████████████████████████████████████████| 600/600 [00:01<00:00, 371.95it/s]


In [34]:
accuracies 

{'mnist_initial': 0.9805001831054687,
 'mnist_EWC': 0.9651000213623047,
 'f_mnist_EWC': 0.8382001495361329}