In [0]:
!pip install -q torch
!pip install -q torchvision

In [0]:
import torch
import torchvision
import torch.distributions as ds
import torch.utils

transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    lambda x: x.cuda()
])

# get the MNIST datasets
train_mnist = torchvision.datasets.MNIST("/", transform=transform, download=True)
test_mnist = torchvision.datasets.MNIST("/", transform=transform, download=True, train=False)

In [0]:
import torch.utils

# this class returns a PyTorch dataset with just the MNIST digits given in the list digits
class DigitDataset(torch.utils.data.Dataset):
  
  def __init__(self, mnist, digits):
    torch.utils.data.Dataset.__init__(self)
    loader = torch.utils.data.DataLoader(mnist, batch_size=len(mnist))
    for inputs, targets in loader:
      mask = targets == digits[0]
      for i in range(1, len(digits)):
        mask |= targets == digits[i]
      self.inputs = inputs[mask]
      self.targets = targets[mask]
      
  def __getitem__(self, i):
    return self.inputs[i], self.targets[i]
  
  def __len__(self):
    return len(self.targets)

zero_four = DigitDataset(train_mnist, [0,1,2,3,4]) # dataset with just every 0, 1, 2, 3, and 4 from MNIST
five = DigitDataset(train_mnist, [5]) # dataset with just every 5 from MNIST
six = DigitDataset(train_mnist, [6]) # etc.
seven = DigitDataset(train_mnist, [7])
eight = DigitDataset(train_mnist, [8])
nine = DigitDataset(train_mnist, [9])

In [0]:
# our CNN feature extractor, which we first train to classify MNIST
class FeatureExtractor(torch.nn.Module):
  
  # DO NOT CHANGE THE VALUE OF conv_output_dim
  def __init__(self, *layers, conv_output_dim=3136, feature_dim=128, class_dim=10):
    torch.nn.Module.__init__(self)
    self.features = []
    i_conv = 0
    i_relu = 0
    i_pool = 0
    i_bn = 0
    i_dropout = 0
    for i in range(len(layers)):
      if layers[i].__class__ == torch.nn.Conv2d:
        self.add_module(f"conv_{i_conv}", layers[i])
        i_conv += 1
      elif layers[i].__class__ == torch.nn.ReLU:
        self.add_module(f"relu_{i_relu}", layers[i])
        i_relu += 1
      elif layers[i].__class__ == torch.nn.MaxPool2d:
        self.add_module(f"maxpool_{i_pool}", layers[i])
        i_pool += 1
      elif layers[i].__class__ == torch.nn.BatchNorm2d:
        self.add_module(f"bn_{i_bn}", layers[i])
        i_bn += 1
      elif layers[i].__class__ == torch.nn.Dropout2d:
        self.add_module(f"dropout_{i_dropout}", layers[i])
        i_dropout += 1
      self.features.append(layers[i])
    self.conv_output_dim = conv_output_dim
    self.feature_dim = feature_dim
    self.class_dim = class_dim
    self.feature_space = torch.nn.Linear(self.conv_output_dim, self.feature_dim)
    self.bn = torch.nn.BatchNorm1d(self.feature_dim)
    self.relu = torch.nn.ReLU()
    self.class_space = torch.nn.Linear(self.feature_dim, self.class_dim)
    self.lsm = torch.nn.LogSoftmax(dim=1)

  def forward(self, batch):
    for layer in self.features:
      batch = layer(batch)
    batch = batch.view(batch.shape[0], -1)
    batch = self.feature_space(batch)
    features = self.bn(batch)
    batch = self.relu(features)
    batch = self.class_space(batch)
    return self.lsm(batch), features

  def extract_features(self, batch):
    return self(batch)[1]

  def predict(self, batch, argmax=False):
    if not argmax:
      return self(batch)[0]
    else:
      return self(batch)[0].argmax(dim=1)

  def train_classifier(self, data, batch_size, n_epochs):
    sampler = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)

    optim = torch.optim.Adam(self.parameters(), lr=0.01)

    nllloss = torch.nn.NLLLoss()

    for i in range(n_epochs):
      for inputs, targets in sampler:
        optim.zero_grad()
        targets = targets.cuda()
        out = self.predict(inputs)
        loss = nllloss(out, targets)
        loss.backward()
        optim.step()

In [0]:
# encoder and decoder networks
# feature_dim is the size of the output from the feature extractor
# hidden_dim is the size of the hidden layer
# latent_dim is the memory dimension size (M)
# class_dim is the number of classes

class Encoder(torch.nn.Module):
  
  def __init__(self, feature_dim=128, hidden_dim=512, latent_dim=10, class_dim=10):
    torch.nn.Module.__init__(self)
    self.feature_to_hidden = torch.nn.Linear(feature_dim, hidden_dim)
    self.bn0 = torch.nn.BatchNorm1d(hidden_dim)
    self.relu0 = torch.nn.ReLU()
    self.hidden_to_latent = torch.nn.Linear(hidden_dim, latent_dim)
    self.bn_latent = torch.nn.BatchNorm1d(latent_dim)
    self.hidden_to_class = torch.nn.Linear(hidden_dim, class_dim)
    self.lsm = torch.nn.LogSoftmax(dim=1)
      
  def forward(self, batch):
    batch = self.feature_to_hidden(batch)
    batch = self.bn0(batch)
    batch = self.relu0(batch)
    latent_vecs = self.hidden_to_latent(batch)
    latent_vecs = self.bn_latent(latent_vecs)
    class_vecs = self.hidden_to_class(batch)
    class_vecs = self.lsm(class_vecs)
    return class_vecs, latent_vecs
  
class Decoder(torch.nn.Module):
  
  def __init__(self, feature_dim=128, hidden_dim=512, latent_dim=10):
    torch.nn.Module.__init__(self)
    self.latent_to_hidden = torch.nn.Linear(latent_dim, hidden_dim)
    self.bn0 = torch.nn.BatchNorm1d(hidden_dim)
    self.relu0 = torch.nn.ReLU()
    self.hidden_to_feature = torch.nn.Linear(hidden_dim, feature_dim)
    self.bn1 = torch.nn.BatchNorm1d(feature_dim)
      
  def forward(self, batch):
    batch = self.latent_to_hidden(batch)
    batch = self.bn0(batch)
    batch = self.relu0(batch)
    batch = self.hidden_to_feature(batch)
    batch = self.bn1(batch)
    return batch

In [0]:
class mPFC(torch.nn.Module):
  
  def __init__(self, feature_dim=128, hidden_dim=512, latent_dim=10):
    torch.nn.Module.__init__(self)
    
    self.fe = FeatureExtractor(
      torch.nn.Conv2d(1, 32, 3, padding=1), # convolutional layer with 32 filters and kernel size 3x3
      torch.nn.BatchNorm2d(32),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2), # max pool layer with kernel size 2x2
      torch.nn.Conv2d(32, 64, 3, padding=1), # convolutional layer with 64 filters and kernel size 3x3
      torch.nn.BatchNorm2d(64),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2), # max pool layer with kernel size 2x2
      feature_dim=feature_dim
    ).cuda()
    
    self.e = Encoder(feature_dim=feature_dim, latent_dim=latent_dim, hidden_dim=hidden_dim).cuda()
    
    self.d = Decoder(feature_dim=feature_dim, latent_dim=latent_dim, hidden_dim=hidden_dim).cuda()
    
    self.class_means = None
    self.class_covs = None
    
    self.latent_dim = latent_dim
  
  # method to generate pseudo examples along with their corresponding labels
  def generate(self, n, classes=None):
    if self.class_means is not None and self.class_covs is not None:
      # sample from class distributions
      if classes is None:
        dist = ds.multivariate_normal.MultivariateNormal(self.class_means, covariance_matrix=self.class_covs)
      else:
        dist = ds.multivariate_normal.MultivariateNormal(self.class_means[classes], covariance_matrix=self.class_covs[classes])
      inputs = dist.rsample((n,)).contiguous().view(-1, self.class_means.shape[-1])
      n_classes = self.class_means.shape[0]
      labels = torch.arange(n_classes).repeat(n)
      return self.d(inputs), labels.cuda()
    else:
      return None, None
  
  # train on a specific task
  def consolidate(self, data, batch_size, n_epochs, new_classes):
    sampler = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
    
    optim_e = torch.optim.SGD(self.e.parameters(), lr=0.01)
    optim_d = torch.optim.SGD(self.d.parameters(), lr=0.01)
    
    nllloss = torch.nn.NLLLoss()
    mse_loss = torch.nn.MSELoss()
    
    # this loop corresponds to the mPFC task training algorithm from our paper
    for i in range(n_epochs):
      for inputs, targets in sampler:
        optim_e.zero_grad()
        targets = targets.cuda()
        inputs = self.fe.extract_features(inputs)
        pseudo_inputs, pseudo_targets = self.generate(batch_size)
        optim_d.zero_grad()
        if pseudo_inputs is not None:
          inputs = torch.cat((inputs, pseudo_inputs)) 
          targets = torch.cat((targets, pseudo_targets))
        classes, latent_vecs = self.e(inputs)
        out_features = self.d(latent_vecs)
        loss = mse_loss(out_features, inputs)
        loss += nllloss(classes, targets)
        loss.backward()
        optim_e.step()
        optim_d.step()

    optim_e.zero_grad()
    optim_d.zero_grad()
    
    self.e.eval()
    self.d.eval()
    
    if self.class_means is None:
      class_means = torch.zeros(new_classes, self.latent_dim).cuda()
      class_covs = torch.zeros(new_classes, self.latent_dim, self.latent_dim).cuda()
    else:
      class_means = torch.zeros(self.class_means.shape[0] + new_classes, self.latent_dim).cuda()
      class_covs = torch.zeros(self.class_means.shape[0] + new_classes, self.latent_dim, self.latent_dim).cuda()
    
    examples_per_class = torch.zeros(class_means.shape[0]).cuda()
    
    for p in self.e.parameters():
      p.requires_grad_(False)
      
    for p in self.d.parameters():
      p.requires_grad_(False)
    
    # this loop corresponds to the mPFC task storage algorithm from our paper
    for inputs, targets in sampler:
      targets = targets.cuda()
      inputs = self.fe.extract_features(inputs)
      pseudo_inputs, pseudo_targets = self.generate(batch_size)
      if pseudo_inputs is not None:
        inputs = torch.cat((inputs, pseudo_inputs)) 
        targets = torch.cat((targets, pseudo_targets))
      classes, latent_vecs = self.e(inputs)
      examples_per_class += targets.bincount(minlength=examples_per_class.shape[0]).float()
      class_means.index_add_(0, targets, latent_vecs)
      class_covs.index_add_(0, targets, torch.bmm(latent_vecs.unsqueeze(-1), latent_vecs.unsqueeze(1)))
    
    for p in self.e.parameters():
      p.requires_grad_(True)
      
    for p in self.d.parameters():
      p.requires_grad_(True)
      
    self.e.train()
    self.d.train()
    
    # store the new means and covariance matrices
    self.class_means = class_means / examples_per_class.unsqueeze(1)
    self.class_covs = class_covs / (examples_per_class.unsqueeze(1).unsqueeze(1) - 1)
    scale = examples_per_class.unsqueeze(1).unsqueeze(1) / (examples_per_class.unsqueeze(1).unsqueeze(1) - 1)
    self.class_covs -= scale * torch.bmm(self.class_means.unsqueeze(-1), self.class_means.unsqueeze(1))

mpfc = mPFC(feature_dim=128, hidden_dim=512, latent_dim=20)

# train the feature extractor
mpfc.fe.train_classifier(train_mnist, batch_size=256, n_epochs=2)
mpfc.fe.eval()
for p in mpfc.fe.parameters():
  p.requires_grad_(False)
  
print("Finished training feature extractor")

# IMPORTANT NOTE: THE MPFC MUST BE TRAINED ON DIGITS IN ASCENDING ORDER, OTHERWISE
# THE METHOD FOR CALCULATING MEANS AND COVARIANCES WILL NOT WORK. ITS INDEXING RELIES
# ON THE SEQUENTIAL ASCENDING PROPERTY OF NEW CLASSES.
mpfc.consolidate(zero_four, batch_size=64, n_epochs=6, new_classes=5)
print("Finished digits 0 through 4")
mpfc.consolidate(five, batch_size=64, n_epochs=6, new_classes=1)
print("Finished digit 5")
mpfc.consolidate(six, batch_size=64, n_epochs=6, new_classes=1)
print("Finished digit 6")
mpfc.consolidate(seven, batch_size=64, n_epochs=6, new_classes=1)
print("Finished digit 7")
mpfc.consolidate(eight, batch_size=64, n_epochs=6, new_classes=1)
print("Finished digit 8")
mpfc.consolidate(nine, batch_size=64, n_epochs=6, new_classes=1)
print("Finished digit 9")

Finished training feature extractor
Finished digits 0 through 4
Finished digit 5
Finished digit 6
Finished digit 7
Finished digit 8
Finished digit 9


In [0]:
# report all-class accuracy on the entire MNIST test set
mpfc.e.eval()
for p in mpfc.e.parameters():
  p.requires_grad_(False)
test_sampler = torch.utils.data.DataLoader(test_mnist, batch_size=len(test_mnist))
for inputs, targets in test_sampler:
  out, _ = mpfc.e(mpfc.fe.extract_features(inputs))
  preds = out.argmax(dim=1)
  print(len(preds[preds == targets.cuda()]) / float(len(targets)))
for p in mpfc.e.parameters():
  p.requires_grad_(True)
_ = mpfc.e.train()

0.976


In [0]:
# report all-class accuracy of classification on generated pseudo-examples
mpfc.d.eval()
for p in mpfc.d.parameters():
  p.requires_grad_(False)
features, labels = mpfc.generate(500)
preds = mpfc.e(features)[0].argmax(dim=1)
print(len(preds[preds == labels]) / float(len(labels)))
for p in mpfc.d.parameters():
  p.requires_grad_(True)
_ = mpfc.d.train()

0.9992
