In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
import numpy as np
from PIL import Image

from resnet_cifar import resnet32

import math


In [2]:
#hyperparameters
LR = 0.1
WEIGHT_DECAY = 1e-5
BATCH_SIZE = 128
NUM_EPOCHS = 5
DEVICE = 'cuda'

In [3]:
def kaiming_normal_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid')



In [5]:
class LWF(nn.Module):
    def __init__(self, num_classes):
        super(LWF,self).__init__()
        self.model = resnet32()
        self.model.apply(kaiming_normal_init)
        self.model.fc = nn.Linear(64, num_classes) # Modify output layers

# Save FC layer in attributes
        self.fc = self.feature_extractor.fc
        # Save other layers in attributes
        self.feature_extractor = nn.Sequential(*list(self.model.children())[:-1])
        self.feature_extractor = nn.DataParallel(self.feature_extractor) 


        self.loss = nn.CrossEntropyLoss()
        self.dist_loss = nn.BCEWithLogitsLoss()

        self.optimizer = optim.SGD(self.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

        # n_classes is incremented before processing new data in an iteration
        # n_known is set to n_classes after all data for an iteration has been processed
        self.n_classes = 0
        self.n_known = 0
    def forward(self, x):
        x = self.feature_extractor(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
    def increment_classes(self, new_classes):
        """Add n classes in the final fc layer"""
        n = len(new_classes)
        print('new classes: ', n)
        in_features = self.fc.in_features
        out_features = self.fc.out_features
        weight = self.fc.weight.data

        if self.n_known == 0:
            new_out_features = n
        else:
            new_out_features = out_features + n
        print('new out features: ', new_out_features)
        self.model.fc = nn.Linear(in_features, new_out_features, bias=False)
        self.fc = self.model.fc

        kaiming_normal_init(self.fc.weight)
        self.fc.weight.data[:out_features] = weight
        self.n_classes += n
