In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from elastic_weight_consolidation import ElasticWeightConsolidation

In [0]:
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

In [0]:
mnist_train = datasets.MNIST("../data", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST("../data", train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)

In [0]:
class LinearLayer(nn.Module):
    def __init__(self, input_dim, output_dim, act='relu', use_bn=False):
        super(LinearLayer, self).__init__()
        self.use_bn = use_bn
        self.lin = nn.Linear(input_dim, output_dim)
        self.act = nn.ReLU() if act == 'relu' else act
        if use_bn:
            self.bn = nn.BatchNorm1d(output_dim)
    def forward(self, x):
        if self.use_bn:
            return self.bn(self.act(self.lin(x)))
        return self.act(self.lin(x))

class Flatten(nn.Module):

    def forward(self, x):
        return x.view(x.shape[0], -1)


In [0]:
class BaseModel(nn.Module):
    
    def __init__(self, num_inputs, num_hidden, num_outputs):
        super(BaseModel, self).__init__()
        self.f1 = Flatten()
        self.lin1 = LinearLayer(num_inputs, num_hidden, use_bn=True)
        self.lin2 = LinearLayer(num_hidden, num_hidden, use_bn=True)
        self.lin3 = nn.Linear(num_hidden, num_outputs)
        
    def forward(self, x):
        return self.lin3(self.lin2(self.lin1(self.f1(x))))

In [0]:
crit = nn.CrossEntropyLoss()

In [0]:
ewc = ElasticWeightConsolidation(BaseModel(28 * 28, 100, 10), crit=crit, lr=1e-4)

In [0]:
from tqdm import tqdm

In [41]:
for _ in range(4):
    for input, target in tqdm(train_loader):
        ewc.forward_backward_update(input, target)

100%|██████████| 600/600 [00:07<00:00, 75.32it/s]
100%|██████████| 600/600 [00:07<00:00, 76.07it/s]
100%|██████████| 600/600 [00:08<00:00, 74.12it/s]
100%|██████████| 600/600 [00:08<00:00, 73.90it/s]


In [0]:
ewc.register_ewc_params(mnist_train, 100, 300)

In [0]:
f_mnist_train = datasets.FashionMNIST("../data", train=True, download=True, transform=transforms.ToTensor())
f_mnist_test = datasets.FashionMNIST("../data", train=False, download=True, transform=transforms.ToTensor())
f_train_loader = DataLoader(f_mnist_train, batch_size = 100, shuffle=True)
f_test_loader = DataLoader(f_mnist_test, batch_size = 100, shuffle=False)

In [44]:
for _ in range(4):
    for input, target in tqdm(f_train_loader):
        ewc.forward_backward_update(input, target)

100%|██████████| 600/600 [00:09<00:00, 62.14it/s]
100%|██████████| 600/600 [00:09<00:00, 66.03it/s]
100%|██████████| 600/600 [00:09<00:00, 66.56it/s]
100%|██████████| 600/600 [00:09<00:00, 65.95it/s]


In [0]:
ewc.register_ewc_params(f_mnist_train, 100, 300)

In [0]:
def accu(model, dataloader):
    model = model.eval()
    acc = 0
    for input, target in dataloader:
        o = model(input)
        acc += (o.argmax(dim=1).long() == target).float().mean()
    return acc / len(dataloader)

In [47]:
accu(ewc.model, f_test_loader)

tensor(0.8188)

In [48]:
accu(ewc.model, test_loader)

tensor(0.7027)