In [1]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

In [2]:
import torch
import torch.nn as nn
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Subset

# Observing catastrophic forgetting on MNIST

Code adapted from:
https://gist.github.com/xmfbit/b27cdbff68870418bdb8cefa86a2d558

In [3]:
# set_a is the original training set, set_b is the additional, new classes

set_a = [1, 3, 5, 7, 9]

set_b = [0, 2, 4, 6, 8]

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

root = '/home/boto/yasiyu/data/incremental'

'cuda'

In [5]:
batch_size = 32

## Data preparation

In [6]:
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])  # mean, std dev

# if not exist, download mnist dataset
train_set = dset.MNIST(root=root, train=True, transform=trans, download=True)
test_set = dset.MNIST(root=root, train=False, transform=trans, download=True)

len(train_set)
len(test_set)

60000

10000

In [7]:
example_indices_train_a = []
example_indices_train_b = []

for i, (im, label) in enumerate(train_set):
    if label in set_a:
        example_indices_train_a.append(i)
    else:
        example_indices_train_b.append(i)
        
len(example_indices_train_a)
len(example_indices_train_b)


example_indices_test_a = []
example_indices_test_b = []

for i, (im, label) in enumerate(test_set):
    if label in set_a:
        example_indices_test_a.append(i)
    else:
        example_indices_test_b.append(i)
        
len(example_indices_test_a)
len(example_indices_test_b)

30508

29492

5074

4926

In [8]:
example_indices_train_a[:5]
example_indices_train_b[:5]

example_indices_test_a[:5]
example_indices_test_b[:5]

[0, 3, 4, 6, 7]

[1, 2, 5, 9, 13]

[0, 2, 5, 7, 8]

[1, 3, 4, 6, 10]

In [10]:
train_set_a = Subset(train_set, example_indices_train_a)
train_set_b = Subset(train_set, example_indices_train_b)

test_set_a = Subset(test_set, example_indices_test_a)
test_set_b = Subset(test_set, example_indices_test_b)

In [11]:
train_loader_a = torch.utils.data.DataLoader(
                 dataset=train_set_a,
                 batch_size=batch_size,
                 shuffle=True)
train_loader_b = torch.utils.data.DataLoader(
                 dataset=train_set_b,
                 batch_size=batch_size,
                 shuffle=True)

test_loader_a = torch.utils.data.DataLoader(
                dataset=test_set_a,
                batch_size=batch_size,
                shuffle=False)
test_loader_b = torch.utils.data.DataLoader(
                dataset=test_set_b,
                batch_size=batch_size,
                shuffle=False)

print('Length of training set A (in batches):', len(train_loader_a))
print('Length of training set B:', len(train_loader_b))

print('Length of test set A:', len(test_loader_a))
print('Length of test set B:', len(test_loader_b))

Length of training set A (in batches): 954
Length of training set B: 922
Length of test set A: 159
Length of test set B: 154


## Models

**Outputs 10 classes**

In [12]:
class MLPNet(nn.Module):
    def __init__(self):
        super(MLPNet, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 500)
        self.fc2 = nn.Linear(500, 256)
        self.fc3 = nn.Linear(256, 10)
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def name(self):
        return "MLP"

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
    def name(self):
        return "LeNet"

## Training on examples from class set A

In [27]:
print_every = 100  # evaluate on test set print_every steps

model = LeNet().to(device=device)  # initially weights are random

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

criterion = nn.CrossEntropyLoss()

for epoch in range(1):
    print(f'Epoch {epoch}')
    
    ave_loss, ave_acc = 0.0, 0.0
    
    for batch_idx, (x, target) in enumerate(train_loader_a):  # trainning on set A classes
        _ = model.train()
        optimizer.zero_grad()
        x, target = x.to(device=device), target.to(device=device)
        scores = model(x)
        
        loss = criterion(scores, target)  # for the mini-batch
        ave_loss = ave_loss * 0.9 + loss.data.item() * 0.1
        
        _, preds = scores.max(1)
        accuracy = (target == preds).float().mean()  # for the mini-batch
        
        loss.backward()
        optimizer.step()

        if (batch_idx + 1) % print_every == 0 or (batch_idx + 1) == len(train_loader_a):
            print(f'==>>> training - epoch: {epoch}, batch index: {batch_idx + 1}, '
                  f'train ave loss: {ave_loss:.6f}, mini-batch accuracy: {accuracy:.4f}')
    
            # evaluate on entire test set to see how performance changes as we go through training
            test_sets = [
                ('test set A', test_loader_a),  # class set A, the ones that were trained on
                ('test set B', test_loader_b)   # class set B, the future class set
            ]
            
            _ = model.eval()  # put model to evaluation mode
            with torch.no_grad():
                
                for set_name, test_loader in test_sets:
                
                    test_loss, test_acc = 0.0, 0.0
                    
                    for test_batch_idx, (x, target) in enumerate(test_loader):
                        x, target = x.to(device=device), target.to(device=device)
                
                        scores = model(x)
                        loss = criterion(scores, target)

                        _, preds = scores.max(1)
                        test_acc += (target == preds).float().mean()
                        test_loss += loss.item()
                    
                    num_batches = len(test_loader)
                    test_loss = test_loss / num_batches
                    test_acc = test_acc / num_batches
                    print(f'       {set_name} - epoch: {epoch}, test loss: {test_loss:.6f}, acc: {test_acc:.4f}')

Epoch 0
==>>> training - epoch: 0, batch index: 100, train ave loss: 0.318821, mini-batch accuracy: 0.8750
       test set A - epoch: 0, test loss: 0.201957, acc: 0.9375
       test set B - epoch: 0, test loss: 9.662710, acc: 0.0000
==>>> training - epoch: 0, batch index: 200, train ave loss: 0.115018, mini-batch accuracy: 1.0000
       test set A - epoch: 0, test loss: 0.120792, acc: 0.9604
       test set B - epoch: 0, test loss: 10.003709, acc: 0.0000
==>>> training - epoch: 0, batch index: 300, train ave loss: 0.105032, mini-batch accuracy: 1.0000
       test set A - epoch: 0, test loss: 0.082159, acc: 0.9743
       test set B - epoch: 0, test loss: 9.356331, acc: 0.0000
==>>> training - epoch: 0, batch index: 400, train ave loss: 0.098413, mini-batch accuracy: 0.9375
       test set A - epoch: 0, test loss: 0.065251, acc: 0.9815
       test set B - epoch: 0, test loss: 8.885456, acc: 0.0000
==>>> training - epoch: 0, batch index: 500, train ave loss: 0.058098, mini-batch accuracy:

In [28]:
torch.save(model.state_dict(), f'{model.name()}_epoch{epoch+1}_set_a')

## Train the model trained on class set A on examples from class set B

In [29]:
checkpoint = torch.load(f'{model.name()}_epoch{epoch+1}_set_a', map_location=device)
model.load_state_dict(checkpoint)

<All keys matched successfully>

We evaluate and print out more often to see the changes in accuracy on class set A and B

In [30]:
print_every = 1  # evaluate on test set print_every steps

for epoch in range(1):
    print(f'Epoch {epoch}')
    
    ave_loss, ave_acc = 0.0, 0.0
    
    for batch_idx, (x, target) in enumerate(train_loader_b):  # trainning on set B classes
        _ = model.train()
        optimizer.zero_grad()
        x, target = x.to(device=device), target.to(device=device)
        scores = model(x)
        
        loss = criterion(scores, target)  # for the mini-batch
        ave_loss = ave_loss * 0.9 + loss.data.item() * 0.1
        
        _, preds = scores.max(1)
        accuracy = (target == preds).float().mean()  # for the mini-batch
        ave_acc = ave_acc * 0.9 + accuracy * 0.1
        
        loss.backward()
        optimizer.step()

        if (batch_idx + 1) % print_every == 0 or (batch_idx + 1) == len(train_loader_a):
            print(f'==>>> training - epoch: {epoch}, batch index: {batch_idx + 1}, '
                  f'train ave loss: {ave_loss:.6f}, train ave accuracy: {ave_acc:.4f}')
    
            # evaluate on entire test set to see how performance changes as we go through training
            test_sets = [
                ('test set A', test_loader_a),  # class set A, the ones that were trained on
                ('test set B', test_loader_b)   # class set B, the future class set
            ]
            
            _ = model.eval()  # put model to evaluation mode
            with torch.no_grad():
                
                for set_name, test_loader in test_sets:
                
                    test_loss, test_acc = 0.0, 0.0
                    example_count = 0
                    for test_batch_idx, (x, target) in enumerate(test_loader):
                        example_count += len(target)
                        x, target = x.to(device=device), target.to(device=device)
                
                        scores = model(x)
                        loss = criterion(scores, target)

                        _, preds = scores.max(1)
                        test_acc += (target == preds).float().mean()
                        test_loss += loss.item()
                    
                    num_batches = len(test_loader)
                    test_loss = test_loss / num_batches
                    test_acc = test_acc / num_batches
                    print(f'       {set_name} - epoch: {epoch}, test loss: {test_loss:.6f}, acc: {test_acc:.4f}')
        

        if batch_idx >= 40:
            break

Epoch 0
==>>> training - epoch: 0, batch index: 1, train ave loss: 0.895778, train ave accuracy: 0.0000
       test set A - epoch: 0, test loss: 0.040592, acc: 0.9900
       test set B - epoch: 0, test loss: 6.232926, acc: 0.0000
==>>> training - epoch: 0, batch index: 2, train ave loss: 1.450474, train ave accuracy: 0.0000
       test set A - epoch: 0, test loss: 0.131978, acc: 0.9821
       test set B - epoch: 0, test loss: 3.939050, acc: 0.0000
==>>> training - epoch: 0, batch index: 3, train ave loss: 1.709852, train ave accuracy: 0.0000
       test set A - epoch: 0, test loss: 0.502876, acc: 0.9229
       test set B - epoch: 0, test loss: 2.923585, acc: 0.0000
==>>> training - epoch: 0, batch index: 4, train ave loss: 1.835282, train ave accuracy: 0.0000
       test set A - epoch: 0, test loss: 1.029391, acc: 0.8491
       test set B - epoch: 0, test loss: 2.515279, acc: 0.0000
==>>> training - epoch: 0, batch index: 5, train ave loss: 1.907622, train ave accuracy: 0.0000
       t

       test set B - epoch: 0, test loss: 0.370065, acc: 0.8875
==>>> training - epoch: 0, batch index: 38, train ave loss: 0.967660, train ave accuracy: 0.6639
       test set A - epoch: 0, test loss: 10.964502, acc: 0.0000
       test set B - epoch: 0, test loss: 0.361872, acc: 0.8975
==>>> training - epoch: 0, batch index: 39, train ave loss: 0.895481, train ave accuracy: 0.6944
       test set A - epoch: 0, test loss: 10.815017, acc: 0.0000
       test set B - epoch: 0, test loss: 0.421051, acc: 0.8646
==>>> training - epoch: 0, batch index: 40, train ave loss: 0.860998, train ave accuracy: 0.7062
       test set A - epoch: 0, test loss: 10.632121, acc: 0.0000
       test set B - epoch: 0, test loss: 0.407827, acc: 0.8670
==>>> training - epoch: 0, batch index: 41, train ave loss: 0.814458, train ave accuracy: 0.7262
       test set A - epoch: 0, test loss: 10.515493, acc: 0.0000
       test set B - epoch: 0, test loss: 0.343331, acc: 0.8970
