In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
import numpy as np
from PIL import Image

In [3]:
from resnet_cifar import resnet32

In [4]:
#hyper
LR = 1e-2
WEIGHT_DECAY = 0.00001
BATCH_SIZE = 128
NUM_EPOCHS = 70
DEVICE = 'cuda'
N=100
NB_CL=10
K=2000

In [17]:
class iCarl(nn.Module):
    def __init__(self, num_classes):
        super(iCarl,self).__init__()
        self.feature_extractor = resnet32()
        self.feature_extractor.fc = nn.Linear(self.feature_extractor.fc.in_features, num_classes)

        self.loss = nn.CrossEntropyLoss()
        self.dist_loss = nn.BCELoss()
        self.optimizer = optim.SGD(self.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
        
        self.num_classes = num_classes
        self.num_known = 0
        self.exemplars_sets = []
    def __classifyNME__(self,net,data,exemplars):
        print('NME')
        mean_exemplars = dict.fromkeys(np.arange(n_classes))
        net.eval() #turn off gradient computations
        #compute mean
        for k in exemplars:
            exemplar_loader=DataLoader(exemplars[k],shuffle=True, batch_size=BATCH_SIZE,drop_last=True, num_workers=4)
            mean=torch.zeros((1,64),device=DEVICE)
            for image,_,_ in exemplar_loader:
                with torch.no_grad():
                    images=images.to(self.device)
                    outputs=net(images,features=True)
                    for out in outputs:
                        mean += outputs
            mean_exemplars=mean/mean.norm()/len(exemplars[key]) 
        dataloader=DataLoader(data,shuffle=True,batch_size=BATCH_SIZE, drop_last=True, num_workers=4)
        correct = 0
       
        #prediction
        for image,labels,_ in dataloader:
            images=images.to(DEVICE)
            with torch.no_grad():
                outputs = net(images,features=True)
                outputs = outputs.to(DEVICE)
                pred=[]
                for out in outputs:
                    for k in mean_exemplars:
                        dist=torch.dist
                        
    
    def forward(self,x):
        x=self.feature_extractor(x)
        return x
    
    #TODO
    def __classify__(self,x):
        #compute exemplars mean
        mean_exemplars=[]
        for exemplars in self.exemplars_set:
            features=[self.feature_extractor(ex) for ex in exemplars ]
            media=np.mean(features)
            mean_exemplars.append(media)
        
        phi_x=self.feature_extractor(x)
        distances=np.sqrt([(phi_x-mean)**2 for mean in exemplars_mean])
        nearest=np.argmin(distances)
        
        return nearest   
    
    def combine_dataset_with_exemplars(self, dataset):
        for y, P_y in enumerate(self.exemplar_sets):
            exemplar_images = P_y
            exemplar_labels = [y] * len(P_y)
            dataset.append(exemplar_images, exemplar_labels)
    
    def update_representation(self,data):
        #X : training images of classes s...t
        #exemplar_sets P_1,P_(s-1)
        # theta : current model parameters
        
              
        #incremen classes
        classes = list(set(data.train_labels))
        n = [cls for cls in classes if cls > self.num_classes - 1] #new classes
        
        input_f=self.feature_extractor.fc.in_features
        out_f=self.feature_extractor.fc.out_features
        weight = self.feature_extractor.fc.weight.data
        bias = self.feature_extractor.fc.bias.data
        
        
        self.feature_extractor.fc = nn.Linear(in_features, out_features+n, bias=False)
        self.feature_extractor.fc.weight.data[:out_features] = weight
        self.feature_extractor.fc.bias.data[:out_features] = bias
        self.num_classes += n
        
       #self.to(DEVICE)
        
        print('{} new classes'.format(len(list(set(self.targets)))))
        #TODO
        #form combined training sets   
        D=data
        
        dataloader = DataLoader(D, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
        print('Store network outputs with pre-update parameters')
        q = torch.zeros(len(dataset), self.n_classes).to(DEVICE)
        for images,labels,indexes in dataloader:
            images = images.to(DEVICE)
            indexes = indexes.to(DEVICE)
            
            g=F.sigmoid(self.forward(images))
            q[indexes]=g.data
        q.to(DEVICE)
        
        #network training
        for epoch in range(1,NUM_EPOCHS+1):
            for i,(images,labels, indexes) in enumerate(dataloader):
                
                images,labels=images.to(DEVICE),labels.to(DEVICE)
                indexes=indexes.to(DEVICE)
                #zero_grad: we need to set the gradients to zero before starting to do backpropragation
                #because PyTorch accumulates the gradients on subsequent backward passes.
                optimizer.zero_grad()
                output = self.forward(images)
                #classification loss new classes
                loss = self.loss(output, labels)
                # Distilation loss for old classes
                if self.num_known > 0:
                    g = F.sigmoid(g)
                    q_i = q[indices]
                    dist_loss = sum(self.dist_loss(g[:,y], q_i[:,y]) for y in range(self.num_known))
                    #dist_loss = dist_loss / self.n_known
                    loss += dist_loss

                loss.backward()
                optimizer.step()
                if (i+1) % 10 == 0:
                    print ('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'
                           %(epoch+1, NUM_EPOCHS, i+1, len(dataset)//BATCH_SIZE, loss.data[0]))
                    
    def __construct_exemplar_set(self,images,m):
        #m target number of exemplars
        #images of class y
        #feature function=self.feature_extractor
        features=[]
        for img in images:
            img.to(DEVICE)
            feature=self.feature_extractor(img)
            features.append(feature)
        features=np.array(features)
        
        current_class_mean=np.mean(features)
        
        exemplar_set=[]
        exemplar_feature=[]
        for k in range(1,m+1):
            exemplar_sum = np.sum(exemplar_features)
            somma=float(exemplar_sum+features)/(k)
            
            p_k=np.argmin(np.sqrt([(current_class_mean-c)**2 for c in somma]))
            
            exemplar_set.append(images[p_k])
            exemplar_feature.append(features[p_k])
            
            features = np.delete(features, i) #####????
            
        self.exemplar_set=exemplar_set
        
    def __reduce_exemplar_set(self,m):
        #m taget number of exemplars
        #input : current exemplar set
        
        #keep only first m for each class
        for y, P_y in enumerate(self.exemplar_sets):
            self.exemplar_sets[y] = P_y[:m]
        
   
            

In [2]:
#main
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from dataset import CIFAR10, CIFAR100


In [None]:
def show_images(images):
    N = images.shape[0]
    fig = plt.figure(figsize=(1, N))
    gs = gridspec.GridSpec(1, N)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img)
    plt.show()

In [3]:
#hyperparameters
TOT_CLASSES= 100
NUM_CL = 10
K = 2000 #num of exemplars
DEVICE='cuda'
BATCH_SIZE=128

In [None]:
transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


In [None]:
icarl=iCarl
icarl.to(DEVICE)

In [None]:
for s in range(0, TOT_CLASSES, NUM_CL):
    train_dataset = CIFAR100(root='data/', classes=range(s,s+num_classes), train=True, download=True, transform=train_transform)
    test_dataset = CIFAR100(root='data/', classes=range(s,s+num_classes),  train=False, download=True, transform=test_transform)
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE,shuffle=True, num_workers=0)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE,shuffle=True, num_workers=0)
    
    #update via backprop
    icarl.update_representation(train_set)
    m=int(K/icarl.num_classes)
    
    @TODO
    # Reduce exemplar sets for known classes
    
    # Construct exemplar sets for new classes

In [5]:
for s in range(0, TOT_CLASSES, NUM_CL):
    print(range(s,s+NUM_CL))

range(0, 10)
range(10, 20)
range(20, 30)
range(30, 40)
range(40, 50)
range(50, 60)
range(60, 70)
range(70, 80)
range(80, 90)
range(90, 100)


In [9]:
import torch
torch.zeros(3, 4)

tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

In [11]:
a=40
b=10
c=2
a=a/b
d=a/c

#del d
d

2.0

In [16]:
a=40
b=10
c=2
a/b/c

2.0

# implementazione paper


In [1]:
from cifar100 import CIFAR100

In [None]:
#class splits
range_classes = np.arange(100)
classes= np.array_split(range_classes, 10)

for iteration in range(N/NB_CL):
    #save results each increment
    
    #prepare data for current batch
    train_dataset = CIFAR100(root='data/', classes=classes[i], train=True, download=True, transform=train_transform)
    test_dataset = CIFAR100(root='data/', classes=classes[i],  train=False, download=True, transform=test_transform)
    
    #add stored exemplars to training
    
    #training loop
    for epoch in range(NUM_EPOCHS):
        #shuffle training
        
        #for batch
        
        #distillation
        if iteration>0:
            