In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import autograd

In [2]:
class BaseModel(nn.Module):
    def __init__(self, num_ip, num_op):
        super(BaseModel, self).__init__()
        self.l1 = nn.Linear(num_ip, 20)
        self.l2 = nn.Linear(20, num_op)
    def forward(self, x):
        o = self.l2(self.l1(x))
        return o

In [3]:
import numpy as np

In [4]:
model = BaseModel(32, 4)

In [5]:
from torch.utils.data import Dataset, DataLoader

In [6]:
class ElasticWeightConsolidation:
    def __init__(self, model, crit, lr=0.001, weight=0.1):
        self.model = model
        self.weight = weight
        self.crit = crit
        self.optimizer = optim.Adam(self.model.parameters(), lr)
    def _update_mean_params(self):
        for param_name, param in self.model.named_parameters():
            _buff_param_name = param_name.replace('.', '__')
            self.model.register_buffer(_buff_param_name+'_estimated_mean', param.data.clone())
    def _update_fisher_params(self, current_ds, batch_size, num_batch):
        dl = DataLoader(ds, batch_size, shuffle=True)
        log_liklihoods = []
        for i, (input, target) in enumerate(dl):
            if i > num_batch:
                break
            output = F.log_softmax(self.model(input), dim=1)
            log_liklihoods.append(output[:, target])
        log_likelihood = torch.cat(log_liklihoods).mean()
        grad_log_liklihood = autograd.grad(log_likelihood, self.model.parameters())
        _buff_param_names = [param[0].replace('.', '__') for param in self.model.named_parameters()]
        for _buff_param_name, param in zip(_buff_param_names, grad_log_liklihood):
            self.model.register_buffer(_buff_param_name+'_estimated_fisher', param.data.clone() ** 2)
    def register_ewc_params(self, dataset, batch_size, num_batches):
        self._update_fisher_params(dataset, batch_size, num_batches)
        self._update_mean_params()
    def _compute_consolidation_loss(self, weight):
        try:
            losses = []
            for param_name, param in self.model.named_parameters():
                _buff_param_name = param_name.replace('.', '__')
                estimated_mean = getattr(self.model, '{}_estimated_mean'.format(_buff_param_name))
                estimated_fisher = getattr(self.model, '{}_estimated_fisher'.format(_buff_param_name))
                losses.append((estimated_fisher * (param - estimated_mean) ** 2).sum())
            return (weight / 2) * sum(losses)
        except AttributeError:
            return 0     
    def forward_backward_update(self, input, target):
        output = self.model(input)
        loss = self._compute_consolidation_loss(self.weight) + self.crit(output, target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
    def save(self):
        torch.save(self.model)
    def load(self, filename):
        self.model = torch.load(filename)

In [7]:
class RandomDataset(Dataset):
    def __getitem__(self, index):
        input =  np.random.rand(32)
        output = np.random.randint(0, 3, 1)
        return torch.Tensor(input).float(), torch.Tensor(output).long()
    def __len__(self):
        return 10000

In [13]:
ds = RandomDataset()

In [14]:
crit = nn.CrossEntropyLoss()

In [10]:
ewc = ElasticWeightConsolidation(model, crit)

In [15]:
dl = DataLoader(ds, batch_size=5, shuffle=True)

In [16]:
for input, target in dl:
    ewc.forward_backward_update(input, torch.squeeze(target))
ewc.register_ewc_params(ds, 5, 10)

tensor(0., grad_fn=<MulBackward0>)
tensor(3.6783e-09, grad_fn=<MulBackward0>)
tensor(1.2652e-08, grad_fn=<MulBackward0>)
tensor(2.7241e-08, grad_fn=<MulBackward0>)
tensor(3.6557e-08, grad_fn=<MulBackward0>)
tensor(4.1845e-08, grad_fn=<MulBackward0>)
tensor(5.2557e-08, grad_fn=<MulBackward0>)
tensor(6.1444e-08, grad_fn=<MulBackward0>)
tensor(5.9770e-08, grad_fn=<MulBackward0>)
tensor(5.0217e-08, grad_fn=<MulBackward0>)
tensor(4.1888e-08, grad_fn=<MulBackward0>)
tensor(4.0935e-08, grad_fn=<MulBackward0>)
tensor(3.8803e-08, grad_fn=<MulBackward0>)
tensor(3.5675e-08, grad_fn=<MulBackward0>)
tensor(3.4343e-08, grad_fn=<MulBackward0>)
tensor(3.4015e-08, grad_fn=<MulBackward0>)
tensor(3.2781e-08, grad_fn=<MulBackward0>)
tensor(3.0843e-08, grad_fn=<MulBackward0>)
tensor(2.7784e-08, grad_fn=<MulBackward0>)
tensor(3.0785e-08, grad_fn=<MulBackward0>)
tensor(4.2405e-08, grad_fn=<MulBackward0>)
tensor(5.9448e-08, grad_fn=<MulBackward0>)
tensor(8.3283e-08, grad_fn=<MulBackward0>)
tensor(1.1421e-07, 

tensor(1.8644e-06, grad_fn=<MulBackward0>)
tensor(1.8617e-06, grad_fn=<MulBackward0>)
tensor(1.8233e-06, grad_fn=<MulBackward0>)
tensor(1.7955e-06, grad_fn=<MulBackward0>)
tensor(1.7697e-06, grad_fn=<MulBackward0>)
tensor(1.7180e-06, grad_fn=<MulBackward0>)
tensor(1.7153e-06, grad_fn=<MulBackward0>)
tensor(1.7535e-06, grad_fn=<MulBackward0>)
tensor(1.7527e-06, grad_fn=<MulBackward0>)
tensor(1.7593e-06, grad_fn=<MulBackward0>)
tensor(1.6935e-06, grad_fn=<MulBackward0>)
tensor(1.6491e-06, grad_fn=<MulBackward0>)
tensor(1.5349e-06, grad_fn=<MulBackward0>)
tensor(1.4450e-06, grad_fn=<MulBackward0>)
tensor(1.3775e-06, grad_fn=<MulBackward0>)
tensor(1.3281e-06, grad_fn=<MulBackward0>)
tensor(1.2631e-06, grad_fn=<MulBackward0>)
tensor(1.2145e-06, grad_fn=<MulBackward0>)
tensor(1.1498e-06, grad_fn=<MulBackward0>)
tensor(1.0742e-06, grad_fn=<MulBackward0>)
tensor(9.6428e-07, grad_fn=<MulBackward0>)
tensor(8.6422e-07, grad_fn=<MulBackward0>)
tensor(7.9119e-07, grad_fn=<MulBackward0>)
tensor(7.41

tensor(8.6219e-07, grad_fn=<MulBackward0>)
tensor(8.5108e-07, grad_fn=<MulBackward0>)
tensor(8.3261e-07, grad_fn=<MulBackward0>)
tensor(8.0416e-07, grad_fn=<MulBackward0>)
tensor(7.8756e-07, grad_fn=<MulBackward0>)
tensor(7.8134e-07, grad_fn=<MulBackward0>)
tensor(7.8231e-07, grad_fn=<MulBackward0>)
tensor(7.8244e-07, grad_fn=<MulBackward0>)
tensor(7.8685e-07, grad_fn=<MulBackward0>)
tensor(8.0065e-07, grad_fn=<MulBackward0>)
tensor(8.1878e-07, grad_fn=<MulBackward0>)
tensor(8.2876e-07, grad_fn=<MulBackward0>)
tensor(8.4396e-07, grad_fn=<MulBackward0>)
tensor(8.6882e-07, grad_fn=<MulBackward0>)
tensor(9.1401e-07, grad_fn=<MulBackward0>)
tensor(9.5554e-07, grad_fn=<MulBackward0>)
tensor(1.0083e-06, grad_fn=<MulBackward0>)
tensor(1.0463e-06, grad_fn=<MulBackward0>)
tensor(1.0844e-06, grad_fn=<MulBackward0>)
tensor(1.1284e-06, grad_fn=<MulBackward0>)
tensor(1.1746e-06, grad_fn=<MulBackward0>)
tensor(1.1817e-06, grad_fn=<MulBackward0>)
tensor(1.2035e-06, grad_fn=<MulBackward0>)
tensor(1.23

tensor(2.7432e-06, grad_fn=<MulBackward0>)
tensor(2.7830e-06, grad_fn=<MulBackward0>)
tensor(2.7706e-06, grad_fn=<MulBackward0>)
tensor(2.7397e-06, grad_fn=<MulBackward0>)
tensor(2.7922e-06, grad_fn=<MulBackward0>)
tensor(2.7752e-06, grad_fn=<MulBackward0>)
tensor(2.7307e-06, grad_fn=<MulBackward0>)
tensor(2.7702e-06, grad_fn=<MulBackward0>)
tensor(2.7452e-06, grad_fn=<MulBackward0>)
tensor(2.6881e-06, grad_fn=<MulBackward0>)
tensor(2.6055e-06, grad_fn=<MulBackward0>)
tensor(2.5434e-06, grad_fn=<MulBackward0>)
tensor(2.4905e-06, grad_fn=<MulBackward0>)
tensor(2.3775e-06, grad_fn=<MulBackward0>)
tensor(2.2544e-06, grad_fn=<MulBackward0>)
tensor(2.1844e-06, grad_fn=<MulBackward0>)
tensor(2.1042e-06, grad_fn=<MulBackward0>)
tensor(2.0148e-06, grad_fn=<MulBackward0>)
tensor(1.8950e-06, grad_fn=<MulBackward0>)
tensor(1.7648e-06, grad_fn=<MulBackward0>)
tensor(1.6367e-06, grad_fn=<MulBackward0>)
tensor(1.5256e-06, grad_fn=<MulBackward0>)
tensor(1.4516e-06, grad_fn=<MulBackward0>)
tensor(1.42

tensor(2.6965e-06, grad_fn=<MulBackward0>)
tensor(2.6755e-06, grad_fn=<MulBackward0>)
tensor(2.6517e-06, grad_fn=<MulBackward0>)
tensor(2.6259e-06, grad_fn=<MulBackward0>)
tensor(2.6204e-06, grad_fn=<MulBackward0>)
tensor(2.6124e-06, grad_fn=<MulBackward0>)
tensor(2.5975e-06, grad_fn=<MulBackward0>)
tensor(2.5869e-06, grad_fn=<MulBackward0>)
tensor(2.5785e-06, grad_fn=<MulBackward0>)
tensor(2.5700e-06, grad_fn=<MulBackward0>)
tensor(2.5643e-06, grad_fn=<MulBackward0>)
tensor(2.5603e-06, grad_fn=<MulBackward0>)
tensor(2.5652e-06, grad_fn=<MulBackward0>)
tensor(2.5683e-06, grad_fn=<MulBackward0>)
tensor(2.5722e-06, grad_fn=<MulBackward0>)
tensor(2.5832e-06, grad_fn=<MulBackward0>)
tensor(2.6052e-06, grad_fn=<MulBackward0>)
tensor(2.6339e-06, grad_fn=<MulBackward0>)
tensor(2.6669e-06, grad_fn=<MulBackward0>)
tensor(2.6974e-06, grad_fn=<MulBackward0>)
tensor(2.7285e-06, grad_fn=<MulBackward0>)
tensor(2.7820e-06, grad_fn=<MulBackward0>)
tensor(2.8365e-06, grad_fn=<MulBackward0>)
tensor(2.85

tensor(4.0546e-06, grad_fn=<MulBackward0>)
tensor(4.0707e-06, grad_fn=<MulBackward0>)
tensor(4.0558e-06, grad_fn=<MulBackward0>)
tensor(4.0686e-06, grad_fn=<MulBackward0>)
tensor(4.0763e-06, grad_fn=<MulBackward0>)
tensor(4.0762e-06, grad_fn=<MulBackward0>)
tensor(4.0496e-06, grad_fn=<MulBackward0>)
tensor(4.0326e-06, grad_fn=<MulBackward0>)
tensor(4.0002e-06, grad_fn=<MulBackward0>)
tensor(3.9669e-06, grad_fn=<MulBackward0>)
tensor(3.9327e-06, grad_fn=<MulBackward0>)
tensor(3.9116e-06, grad_fn=<MulBackward0>)
tensor(3.8973e-06, grad_fn=<MulBackward0>)
tensor(3.8830e-06, grad_fn=<MulBackward0>)
tensor(3.8662e-06, grad_fn=<MulBackward0>)
tensor(3.8582e-06, grad_fn=<MulBackward0>)
tensor(3.8480e-06, grad_fn=<MulBackward0>)
tensor(3.8483e-06, grad_fn=<MulBackward0>)
tensor(3.8516e-06, grad_fn=<MulBackward0>)
tensor(3.8787e-06, grad_fn=<MulBackward0>)
tensor(3.9259e-06, grad_fn=<MulBackward0>)
tensor(3.9548e-06, grad_fn=<MulBackward0>)
tensor(3.9898e-06, grad_fn=<MulBackward0>)
tensor(4.00

tensor(4.5799e-06, grad_fn=<MulBackward0>)
tensor(4.5778e-06, grad_fn=<MulBackward0>)
tensor(4.5727e-06, grad_fn=<MulBackward0>)
tensor(4.5689e-06, grad_fn=<MulBackward0>)
tensor(4.5703e-06, grad_fn=<MulBackward0>)
tensor(4.5768e-06, grad_fn=<MulBackward0>)
tensor(4.5790e-06, grad_fn=<MulBackward0>)
tensor(4.5851e-06, grad_fn=<MulBackward0>)
tensor(4.5914e-06, grad_fn=<MulBackward0>)
tensor(4.5981e-06, grad_fn=<MulBackward0>)
tensor(4.6039e-06, grad_fn=<MulBackward0>)
tensor(4.6140e-06, grad_fn=<MulBackward0>)
tensor(4.6332e-06, grad_fn=<MulBackward0>)
tensor(4.6650e-06, grad_fn=<MulBackward0>)
tensor(4.7196e-06, grad_fn=<MulBackward0>)
tensor(4.7748e-06, grad_fn=<MulBackward0>)
tensor(4.8418e-06, grad_fn=<MulBackward0>)
tensor(4.8909e-06, grad_fn=<MulBackward0>)
tensor(4.9281e-06, grad_fn=<MulBackward0>)
tensor(4.9896e-06, grad_fn=<MulBackward0>)
tensor(5.0396e-06, grad_fn=<MulBackward0>)
tensor(5.0694e-06, grad_fn=<MulBackward0>)
tensor(5.0908e-06, grad_fn=<MulBackward0>)
tensor(5.11

tensor(5.9051e-06, grad_fn=<MulBackward0>)
tensor(5.9217e-06, grad_fn=<MulBackward0>)
tensor(5.9506e-06, grad_fn=<MulBackward0>)
tensor(5.9683e-06, grad_fn=<MulBackward0>)
tensor(5.9761e-06, grad_fn=<MulBackward0>)
tensor(5.9694e-06, grad_fn=<MulBackward0>)
tensor(5.9796e-06, grad_fn=<MulBackward0>)
tensor(5.9869e-06, grad_fn=<MulBackward0>)
tensor(5.9789e-06, grad_fn=<MulBackward0>)
tensor(5.9846e-06, grad_fn=<MulBackward0>)
tensor(5.9921e-06, grad_fn=<MulBackward0>)
tensor(5.9885e-06, grad_fn=<MulBackward0>)
tensor(5.9923e-06, grad_fn=<MulBackward0>)
tensor(6.0002e-06, grad_fn=<MulBackward0>)
tensor(6.0212e-06, grad_fn=<MulBackward0>)
tensor(6.0357e-06, grad_fn=<MulBackward0>)
tensor(6.0374e-06, grad_fn=<MulBackward0>)
tensor(6.0280e-06, grad_fn=<MulBackward0>)
tensor(6.0118e-06, grad_fn=<MulBackward0>)
tensor(5.9828e-06, grad_fn=<MulBackward0>)
tensor(5.9616e-06, grad_fn=<MulBackward0>)
tensor(5.9527e-06, grad_fn=<MulBackward0>)
tensor(5.9549e-06, grad_fn=<MulBackward0>)
tensor(5.97