In [None]:
class Model(nn.Module):
  def __init__(self):
    super().__init__()
    self.fc1=nn.Linear(3*32*32,1024)
    self.fc2=nn.Linear(1024,512)
    self.fc3=nn.Linear(512,256)
    self.fc4=nn.Linear(256,128)
    self.fc5=nn.Linear(128,128)
    self.fc6=nn.Linear(128,10)
    self.relu=nn.ReLU()

  def forward(self,x):
    x=x.view(x.shape[0],-1)
    x=self.fc1(x)
    x=self.relu(x)
    x=self.fc2(x)
    x=self.relu(x)
    x=self.fc3(x)
    x=self.relu(x)
    x=self.fc4(x)
    x=self.relu(x)
    x=self.fc5(x)
    x=self.relu(x)
    x=self.fc6(x)
    return x

In [None]:
class EWC(object):
  """
    @article{kirkpatrick2017overcoming,
        title={Overcoming catastrophic forgetting in neural networks},
        author={Kirkpatrick, James and Pascanu, Razvan and Rabinowitz, Neil and Veness, Joel and Desjardins, Guillaume and Rusu, Andrei A and Milan, Kieran and Quan, John and Ramalho, Tiago and Grabska-Barwinska, Agnieszka and others},
        journal={Proceedings of the national academy of sciences},
        year={2017},
        url={https://arxiv.org/abs/1612.00796}
    }
  """
  def __init__(self, model: nn.Module, dataloaders: list, device):
    
    self.model = model
    self.dataloaders = dataloaders
    self.device = device
    
    self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad} #抓出模型的所有參數
    self._means = {} # 初始化 平均參數
    self._precision_matrices = self._calculate_importance() # 產生 EWC 的 Fisher (F) 矩陣 
    
    for n, p in self.params.items():
      self._means[n] = p.clone().detach() # 算出每個參數的平均 （用之前任務的資料去算平均）
    
  def _calculate_importance(self):
    print('Computing EWC')
    
    precision_matrices={}
    for n,p in self.params.items():
      precision_matrices[n]=p.clone().detach().fill_(0)

    self.model.eval()
    dataloader_num=len(self.dataloaders)
    number_data=sum([len(loader) for loader in self.dataloaders])
    for dataloader in self.dataloaders:
      for data in dataloader:
        self.model.zero_grad()
        input=data[0].to(self.device)
        output=self.model(input).view(1,-1)
        label=output.max(1)[1].view(-1)

        loss=F.nll_loss(F.log_softmax(output,dim=1),label)
        loss.backward()

        for n,p in self.model.named_parameters():
          precision_matrices[n].data+=p.grad.data**2/number_data

    precision_matrices={n:p for n,p in precision_matrices.items()}
    return precision_matrices

  def penalty(self,model:nn.Module):
    loss=0
    for n,p in model.named_parameters():
      _loss=self._precision_matrices[n]*(p-self._means[n])**2
      loss+=_loss.sum()
    return loss

In [None]:
class MAS(object):
    """
    @article{aljundi2017memory,
      title={Memory Aware Synapses: Learning what (not) to forget},
      author={Aljundi, Rahaf and Babiloni, Francesca and Elhoseiny, Mohamed and Rohrbach, Marcus and Tuytelaars, Tinne},
      booktitle={ECCV},
      year={2018},
      url={https://eccv2018.org/openaccess/content_ECCV_2018/papers/Rahaf_Aljundi_Memory_Aware_Synapses_ECCV_2018_paper.pdf}
    }
    """
    def __init__(self,model:nn.Module,dataloaders:list,device):
      self.model=model
      self.dataloaders=dataloaders
      self.params={n:p for n,p in self.model.named_parameters() if p.requires_grad}
      self._means={}
      self.device=device
      self._precision_matrices=self._calculate_importance()

      for n,p in self.params.items():
        self._means[n]=p.clone().detach()

    def _calculate_importance(self):
      print('Computing MAS')

      precision_matrices={}
      for n,p in self.params.items():
        precision_matrices[n]=p.clone().detach().fill_(0)
      
      self.model.eval()
      dataloader_num=len(self.dataloaders)
      num_data=sum([len(loader) for loader in self.dataloaders])
      for dataloader in self.dataloaders:
        for data in dataloader:
          self.model.zero_grad()
          output=self.model(data[0].to(self.device))
          output.pow_(2)
          loss=torch.sum(output,dim=1)
          loss=loss.mean()
          loss.backward()

          for n,p in self.model.named_parameters():
            precision_matrices[n].data+=p.grad.abs()/num_data
      precision_matrices={n:p for n,p in precision_matrices.items()}
      return precision_matrices
    
    def penalty(self,model:nn.Module):
      loss=0
      for n,p in model.named_parameters():
        _loss=self._precision_matrices[n]*(p-self._means[n])**2
        loss+=_loss.sum()
      return loss

In [None]:
class SCP(object):
  """
    OPEN REVIEW VERSION:
    https://openreview.net/forum?id=BJge3TNKwH
  """
  def __init__(self,model:nn.Module,dataloaders:list,L:int,device):
    self.model=model
    self.dataloaders=dataloaders
    self.params={n:p for n,p in self.model.named_parameters() if p.requires_grad}
    self._means={}
    self.L=L
    self.device=device
    self._precision_matrices=self.calculate_importance()

    for n,p in self.params.items():
      self._means[n]=p.clone().detach()

  def calculate_importance(self):
    print('Computing SCP')

    precision_matrices={}
    for n,p in self.params.items():
      precision_matrices[n]=p.clone().detach().fill_(0)

    self.model.eval()
    dataloader_num=len(self.dataloaders)
    num_data=sum([len(loader) for loader in self.dataloaders])
    vector=0.
    for dataloader in self.dataloaders:
      for data in dataloader:
        self.model.zero_grad()
        output=self.model(data[0].to(self.device)) #(256,10)
        vector+=output.sum(dim=0)/num_data #(10)   

    sigmas=torch.tensor(sample_spherical(self.L,ndim=10),dtype=torch.float32).to(device)   #(10,self.L)
    for l in range(self.L):
      self.model.zero_grad()
      ro=torch.dot(sigmas[:,l],vector)   #(1)
      ro.backward(retain_graph=True)

      for n,p in self.model.named_parameters():
        precision_matrices[n].data+=p.grad.data**2/self.L

    precision_matrices={n:p for n,p in precision_matrices.items()}
    return precision_matrices

  def penalty(self,model:nn.Module):
    loss=0
    for n,p in model.named_parameters():
      _loss=self._precision_matrices[n]*(p-self._means[n])**2
      loss+=_loss.sum()
    return loss