In [0]:
from PIL import Image
import os
import os.path
import numpy as np
import pickle

from torchvision.datasets.vision import VisionDataset
from torchvision.datasets.utils import check_integrity, download_and_extract_archive


class CIFAR10(VisionDataset):
   
    base_folder = 'cifar-10-batches-py'
    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    filename = "cifar-10-python.tar.gz"
    tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
    train_list = [
        ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
        ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
        ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
        ['data_batch_4', '634d18415352ddfa80567beed471001a'],
        ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
    ]

    test_list = [
        ['test_batch', '40351d587109b95175f43aff81a1287e'],
    ]
    meta = {
        'filename': 'batches.meta',
        'key': 'label_names',
        'md5': '5ff9c542aee3614f3951f8cda6e48888',
    }

    def __init__(self, root, classes=np.arange(10), train=True, transform=None, target_transform=None,
                 download=False):

        super(CIFAR10, self).__init__(root, transform=transform,
                                      target_transform=target_transform)

        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        if self.train:
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list

        self.data = []
        self.targets = []

        # now load the picked numpy arrays
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(self.root, self.base_folder, file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.data.append(entry['data'])
                if 'labels' in entry:
                    self.targets.extend(entry['labels'])
                else:
                    self.targets.extend(entry['fine_labels'])

        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC

        self._load_meta()

        data = []
        targets = []

        for i in range(len(self)):
            if self.targets[i] in classes:
                data.append(self.data[i])
                targets.append(self.targets[i])

        self.data=np.array(data)
        self.targets=targets


    def _load_meta(self):
        path = os.path.join(self.root, self.base_folder, self.meta['filename'])
        if not check_integrity(path, self.meta['md5']):
            raise RuntimeError('Dataset metadata file not found or corrupted.' +
                               ' You can use download=True to download it')
        with open(path, 'rb') as infile:
            data = pickle.load(infile, encoding='latin1')
            self.classes = data[self.meta['key']]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index


    def __len__(self):
        return len(self.data)

    def _check_integrity(self):
        root = self.root
        for fentry in (self.train_list + self.test_list):
            filename, md5 = fentry[0], fentry[1]
            fpath = os.path.join(root, self.base_folder, filename)
            if not check_integrity(fpath, md5):
                return False
        return True

    def download(self):
        if self._check_integrity():
            print('Files already downloaded and verified')
            return
        download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)

    def extra_repr(self):
        return "Split: {}".format("Train" if self.train is True else "Test")

    def append(self, data, targets):
        self.data = np.concatenate((self.data, data))
        self.targets = np.concatenate((self.targets, targets))

    def get_class_imgs(self, target):
        images = []
        for i, img in enumerate(self.data):
            if self.targets[i] == target:
                images.append(img)

        return images


class CIFAR100(CIFAR10):
    base_folder = 'cifar-100-python'
    url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
    train_list = [
        ['train', '16019d7e3df5f24257cddd939b257f8d'],
    ]

    test_list = [
        ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
    ]
    meta = {
        'filename': 'meta',
        'key': 'fine_label_names',
        'md5': '7973b15100ade9c7d40fb424638fde48',
    }


In [0]:
import torch
a = torch.randn([2, 3, 4])
b = torch.randn([2, 3])
b

tensor([[ 0.2462,  0.7337, -0.4384],
        [ 0.4247, -0.2700,  0.2603]])

In [0]:
b = b.unsqueeze(0)
b

tensor([[[ 0.2462,  0.7337, -0.4384],
         [ 0.4247, -0.2700,  0.2603]]])

In [0]:
b.squeeze()

tensor([[ 0.2462,  0.7337, -0.4384],
        [ 0.4247, -0.2700,  0.2603]])

In [0]:
b.mean(0).squeeze()

tensor([[ 0.2462,  0.7337, -0.4384],
        [ 0.4247, -0.2700,  0.2603]])

In [0]:
#from resnet_cifar import resnet32

In [0]:
#resnet implementation like prof one!
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=10):
        self.inplanes = 16
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
        self.avgpool = nn.AvgPool2d(8, stride=1)
        self.fc = nn.Linear(64 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

def resnet20(pretrained=False, **kwargs):
    n = 3
    model = ResNet(BasicBlock, [n, n, n], **kwargs)
    return model

def resnet32(pretrained=False, **kwargs):
    n = 5
    model = ResNet(BasicBlock, [n, n, n], **kwargs)
    return model

def resnet56(pretrained=False, **kwargs):
    n = 9
    model = ResNet(Bottleneck, [n, n, n], **kwargs)
    return model


# finetuning

In [22]:
!pip3 install 'livelossplot'
import numpy as np

import torch

import torch.nn as nn
import torch.optim as optim

from torchvision import transforms
from torch.utils.data import DataLoader,Subset
from livelossplot import PlotLosses
from torch.backends import cudnn 
#from resnet import resnet32
from sklearn.model_selection import train_test_split
from torchvision.models import resnet18
import copy



In [0]:
#Hyper-parameters
DEVICE = 'cuda'
NUM_CLASSES = 10
BATCH_SIZE = 128
ClASSES_BATCH =10
LR = 0.01 #default 
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-5
NUM_EPOCHS = 70
STEP_SIZE= 49 # How many epochs before decreasing learning rate (if using a step-down policy)
GAMMA=0.1
DRAW=False

In [0]:
#definition of train and test functions
liveloss=PlotLosses()
BEST_ACC=0
best_net_acc=None
logs={}

In [0]:
from sklearn.model_selection import train_test_split
#train function + validation
def train(net, train_dataloader,val_dataloader):
  best_acc=BEST_ACC
  cudnn.benchmark #optimizes benchmark
  criterion = nn.CrossEntropyLoss() # for classification, Cross Entropy
  #criterion = nn.BCELoss()#binary CrossEntropyLoss
  parameters_to_optimize = net.parameters() # In this case we optimize over all the parameters 
  optimizer = optim.SGD(parameters_to_optimize, lr=LR, weight_decay=WEIGHT_DECAY)
  scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)
  net.to(DEVICE)


  for epoch in range(NUM_EPOCHS):
      logs={}
      if(epoch%5 == 0 ):
        print('-' * 30)
        print('Epoch {}/{}'.format(epoch+1, NUM_EPOCHS))
        for param_group in optimizer.param_groups:
          print('Learning rate:{}'.format(param_group['lr']))
      
      running_loss = 0.0
      val_loss = 0.0
      running_corrects_train = 0
      running_corrects_val = 0
      #train
      for inputs, labels, index in train_dataloader:
          inputs = inputs.to(DEVICE)
          labels = labels.to(DEVICE)

          net.train(True)
          # zero the parameter gradients
          optimizer.zero_grad()
          # forward
          outputs = net(inputs)
          _, preds = torch.max(outputs, 1)
          loss = criterion(outputs, labels)
          loss.backward() #backward pass: compute gradients
          optimizer.step() #update weights based on accumulated gradients

          # statistics
          running_loss += loss.item() * inputs.size(0)
          running_corrects_train += torch.sum(preds == labels.data)


      #validation
      
      for inputs,labels,index in val_dataloader:
        inputs,labels=inputs.to(DEVICE),labels.to(DEVICE)
        net.train(False)
        outputs = net(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

        val_loss+=loss.item() * inputs.size(0)
        running_corrects_val += torch.sum(preds == labels.data)
      
   
     # Calculate average losses
      epoch_loss = running_loss / len(train_dataloader.dataset)
      val_loss = val_loss / len(val_dataloader.dataset)
      # Calculate accuracy
      epoch_acc = running_corrects_train.double() / len(train_dataloader.dataset)
      valid_acc = running_corrects_val / float(len(val_dataloader.dataset))
      '''
      #calcolo media per liveloss
      running_loss += loss.detach() * inputs.size(0)
      running_corrects += torch.sum(preds == labels.data)
      epoch_acc=running_corrects.float()/float(len(train_dataloader) * BATCH_SIZE )
      epoch_loss=running_loss/float(len(train_dataloader) *BATCH_SIZE)
      logs['log loss'] = epoch_loss.item()
      logs['accuracy'] = epoch_acc.item()
      liveloss.update(logs)
      if(epoch%5 == 0 ):
        print('Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))
        if DRAW:
          liveloss.draw()
      '''
      if (valid_acc > best_acc):
        best_acc = valid_acc
        best_net = copy.deepcopy(net.state_dict())
      if(epoch%5 == 0 ):
        print('Train Loss: {:.4f} Train Acc: {:.4f}'.format(epoch_loss, epoch_acc))
        print('Val Loss: {:.4f} Val Acc: {:.4f}'.format(val_loss, valid_acc))
      scheduler.step()
  
  net.load_state_dict(best_net)
  return net

#test function
def test(net, test_dataloader):
  net.to(DEVICE)
  net.train(False)

  running_corrects = 0
  for images, labels, _ in test_dataloader:
    images = images.to(DEVICE)
    labels = labels.to(DEVICE)

    # Forward Pass
    outputs = net(images)
    # Get predictions
    _, preds = torch.max(outputs.data, 1)
    # Update Corrects
    running_corrects += torch.sum(preds == labels.data).data.item()

  # Calculate Accuracy
  accuracy = running_corrects / float(len(test_dataloader.dataset))
  print('Test Accuracy: {}'.format(accuracy))




  

In [0]:
#define images transformation
train_transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

test_transform = transforms.Compose([transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                                    #transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
                                    ])

test_transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                                    #transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
                                    ])

In [0]:
#class splits
range_classes = np.arange(100)
classes= np.array_split(range_classes, 10)


In [0]:
#net = resnet18()
net = resnet32()


for i in range(int(100/ClASSES_BATCH)):
#cambio il numero di classi di output
    net.fc = nn.Linear(64, 10+i*10)

    #creating dataset for current iteration
    train_dataset = CIFAR100(root='data/', classes=classes[i], train=True, download=True, transform=train_transform)
    test_dataset = CIFAR100(root='data/', classes=classes[i],  train=False, download=True, transform=test_transform)

    #subsetting train set in train and validation
    train_indexes,val_indexes=train_test_split(range(len(train_dataset)),test_size=0.2,random_state=42,stratify=train_dataset.targets)
    val_dataset=Subset(train_dataset,val_indexes)
    train_dataset=Subset(train_dataset,train_indexes)
    
    #debug length
    print('Len Train : {}'.format(len(train_dataset)))
    print('Len Valid : {}'.format(len(val_dataset)))
    print('Len Test : {}'.format(len(test_dataset)))

    if i != 0:
      #creating dataset for test on previous classes
      previous_classes = np.array([])
      for j in range(i):
        previous_classes = np.concatenate((previous_classes, classes[j])).astype(int)
      test_prev_dataset = CIFAR100(root='data/', classes=previous_classes,  train=False, download=True, transform=test_transform)

      #creating dataset for all classes
      #all_classes=np.concatenate((current_classes, classes[i]))
      #ho modificato questa riga mettendo al posto di current_classes, previous_classes
      all_classes = np.concatenate((previous_classes, classes[i]))
      test_all_dataset = CIFAR100(root='data/', classes=all_classes,  train=False, download=True, transform=test_transform)

      test_prev_dataloader = DataLoader(test_prev_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=4)
      test_all_dataloader = DataLoader(test_all_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=4)

    #creating dataloaders
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=4)
    val_dataloader=DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=4)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=4)
    


    net = train(net, train_dataloader,val_dataloader)
    print('Test on new classes')
    test(net, test_dataloader)

    if i!=0:
      print('Test on previous classes')
      test(net, test_prev_dataloader)
      print('Test on all classes')
      test(net, test_all_dataloader)


LwF

To implement LwF I follow the slide given by professor, I take the finetuning and I add the distillation loss.
I follow the forum to implement if and I inser it in the train phase as loss

In [0]:
from sklearn.model_selection import train_test_split
#train function + validation
def train(net, train_dataloader,val_dataloader):
  best_acc=BEST_ACC
  cudnn.benchmark #optimizes benchmark
  #criterion = nn.CrossEntropyLoss() # for classification, Cross Entropy
  pos_weight = torch.ones([128])
  criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
  #criterion = nn.BCELoss()#binary CrossEntropyLoss
  parameters_to_optimize = net.parameters() # In this case we optimize over all the parameters 
  optimizer = optim.SGD(parameters_to_optimize, lr=LR, weight_decay=WEIGHT_DECAY)
  scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)
  net.to(DEVICE)


  for epoch in range(NUM_EPOCHS):
      logs={}
      if(epoch%5 == 0 ):
        print('-' * 30)
        print('Epoch {}/{}'.format(epoch+1, NUM_EPOCHS))
        for param_group in optimizer.param_groups:
          print('Learning rate:{}'.format(param_group['lr']))
      
      running_loss = 0.0
      val_loss = 0.0
      running_corrects_train = 0
      running_corrects_val = 0
      #train
      for inputs, labels, index in train_dataloader:
          inputs = inputs.to(DEVICE)
          labels = labels.to(DEVICE)

          net.train(True)
          # zero the parameter gradients
          optimizer.zero_grad()
          # forward
          outputs = net(inputs)
          _, preds = torch.max(outputs, 1)
          loss = criterion(outputs, labels)
          loss.backward() #backward pass: compute gradients
          optimizer.step() #update weights based on accumulated gradients

          # statistics
          running_loss += loss.item() * inputs.size(0)
          running_corrects_train += torch.sum(preds == labels.data)


      #validation
      
      for inputs,labels,index in val_dataloader:
        inputs,labels=inputs.to(DEVICE),labels.to(DEVICE)
        net.train(False)
        outputs = net(inputs)
        _, preds = torch.max(outputs, 1)
        #new part!
        loss = criterion(outputs, labels)
        val_loss+=loss.item() * inputs.size(0)
        running_corrects_val += torch.sum(preds == labels.data)
      
   
     # Calculate average losses
      epoch_loss = running_loss / len(train_dataloader.dataset)
      val_loss = val_loss / len(val_dataloader.dataset)
      # Calculate accuracy
      epoch_acc = running_corrects_train.double() / len(train_dataloader.dataset)
      valid_acc = running_corrects_val / float(len(val_dataloader.dataset))
      '''
      #calcolo media per liveloss
      running_loss += loss.detach() * inputs.size(0)
      running_corrects += torch.sum(preds == labels.data)
      epoch_acc=running_corrects.float()/float(len(train_dataloader) * BATCH_SIZE )
      epoch_loss=running_loss/float(len(train_dataloader) *BATCH_SIZE)
      logs['log loss'] = epoch_loss.item()
      logs['accuracy'] = epoch_acc.item()
      liveloss.update(logs)
      if(epoch%5 == 0 ):
        print('Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))
        if DRAW:
          liveloss.draw()
      '''
      if (valid_acc > best_acc):
        best_acc = valid_acc
        best_net = copy.deepcopy(net.state_dict())
      if(epoch%5 == 0 ):
        print('Train Loss: {:.4f} Train Acc: {:.4f}'.format(epoch_loss, epoch_acc))
        print('Val Loss: {:.4f} Val Acc: {:.4f}'.format(val_loss, valid_acc))
      scheduler.step()
  
  net.load_state_dict(best_net)
  return net

#test function
def test(net, test_dataloader):
  net.to(DEVICE)
  net.train(False)

  running_corrects = 0
  for images, labels, _ in test_dataloader:
    images = images.to(DEVICE)
    labels = labels.to(DEVICE)

    # Forward Pass
    outputs = net(images)
    # Get predictions
    _, preds = torch.max(outputs.data, 1)
    # Update Corrects
    running_corrects += torch.sum(preds == labels.data).data.item()

  # Calculate Accuracy
  accuracy = running_corrects / float(len(test_dataloader.dataset))
  print('Test Accuracy: {}'.format(accuracy))

In [39]:
#net = resnet18()
net = resnet32()


for i in range(int(100/ClASSES_BATCH)):
#cambio il numero di classi di output
    net.fc = nn.Linear(64, 10+i*10)

    #creating dataset for current iteration
    train_dataset = CIFAR100(root='data/', classes=classes[i], train=True, download=True, transform=train_transform)
    test_dataset = CIFAR100(root='data/', classes=classes[i],  train=False, download=True, transform=test_transform)

    #subsetting train set in train and validation
    train_indexes,val_indexes=train_test_split(range(len(train_dataset)),test_size=0.2,random_state=42,stratify=train_dataset.targets)
    val_dataset=Subset(train_dataset,val_indexes)
    train_dataset=Subset(train_dataset,train_indexes)
    
    #debug length
    print('Len Train : {}'.format(len(train_dataset)))
    print('Len Valid : {}'.format(len(val_dataset)))
    print('Len Test : {}'.format(len(test_dataset)))

    if i != 0:
      #creating dataset for test on previous classes
      previous_classes = np.array([])
      for j in range(i):
        previous_classes = np.concatenate((previous_classes, classes[j])).astype(int)
      test_prev_dataset = CIFAR100(root='data/', classes=previous_classes,  train=False, download=True, transform=test_transform)

      #creating dataset for all classes
      #all_classes=np.concatenate((current_classes, classes[i]))
      #ho modificato questa riga mettendo al posto di current_classes, previous_classes
      all_classes = np.concatenate((previous_classes, classes[i]))
      test_all_dataset = CIFAR100(root='data/', classes=all_classes,  train=False, download=True, transform=test_transform)

      test_prev_dataloader = DataLoader(test_prev_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=4)
      test_all_dataloader = DataLoader(test_all_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=4)

    #creating dataloaders
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=4)
    val_dataloader=DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=4)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=4)
    


    net = train(net, train_dataloader,val_dataloader)
    print('Test on new classes')
    test(net, test_dataloader)

    if i!=0:
      print('Test on previous classes')
      test(net, test_prev_dataloader)
      print('Test on all classes')
      test(net, test_all_dataloader)

Files already downloaded and verified
Files already downloaded and verified
Len Train : 4000
Len Valid : 1000
Len Test : 1000
------------------------------
Epoch 1/70
Learning rate:0.01


ValueError: ignored