Importaciones

In [1]:
import os
import copy
import time
import pandas as pd
from torchvision.io import read_image
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_url
from torch.utils.data import Dataset
import torchvision.transforms as T
import matplotlib.pyplot as plt
from torchvision.models import vgg16
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
import numpy as np
import torch.nn.utils.prune as prune
from heapq import nsmallest

Activar acceleración por hardware -> GPU

In [2]:
torch.cuda.is_available()

True

In [3]:
!nvidia-smi

Thu Nov  3 16:09:21 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 522.06       Driver Version: 522.06       CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:04:00.0 Off |                  N/A |
| 45%   32C    P8    13W /  95W |   1842MiB /  2048MiB |     12%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla M40 24GB     TCC   | 00000000:2B:00.0 Off |           1488973049 |
| N/A   58C    P8    17W / 250W |     11MiB / 23040MiB |      0%      Default |
|       

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Función de carga del dataset CUB (El dataset se descarga automáticamente al ejecutar las siguientes celdas)

In [5]:
class Cub2011(Dataset):
    base_folder = 'CUB_200_2011/images'
    url = 'https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz?download=1'
    filename = 'CUB_200_2011.tgz'
    tgz_md5 = '97eceeb196236b17998738112f37df78'

    def __init__(self, root, train=True, transform=None, loader=default_loader, download=True):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.loader = default_loader
        self.train = train

        if download:
            self._download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

    def _load_metadata(self):
        images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
                             names=['img_id', 'filepath'])
        image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
                                         sep=' ', names=['img_id', 'target'])
        train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
                                       sep=' ', names=['img_id', 'is_training_img'])

        data = images.merge(image_class_labels, on='img_id')
        self.data = data.merge(train_test_split, on='img_id')

        if self.train:
            self.data = self.data[self.data.is_training_img == 1]
        else:
            self.data = self.data[self.data.is_training_img == 0]

    def _check_integrity(self):
        try:
            self._load_metadata()
        except Exception:
            return False

        for index, row in self.data.iterrows():
            filepath = os.path.join(self.root, self.base_folder, row.filepath)
            if not os.path.isfile(filepath):
                print(filepath)
                return False
        return True

    def _download(self):
        import tarfile

        if self._check_integrity():
            print('Files already downloaded and verified')
            return

        download_url(self.url, self.root, self.filename, self.tgz_md5)

        with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
            tar.extractall(path=self.root)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data.iloc[idx]
        path = os.path.join(self.root, self.base_folder, sample.filepath)
        target = sample.target - 1 
        img = self.loader(path)

        if self.transform is not None:
            img = self.transform(img)

        return img, target

Preprocesado de las imágenes

In [6]:
transform = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

Carga del dataset (Colab no soporta un batch_size muy alto así que lo dejo en 64 para la prueba)



In [7]:
train_ds = Cub2011('.', train=True, transform = transform)
val_ds = Cub2011('.s', train=False, transform = transform)

ds = {'train': DataLoader(train_ds, batch_size = 64, shuffle=True),
      'val': DataLoader(val_ds, batch_size = 64, shuffle=False)}


ds_sizes = {'train': len(train_ds),
      'val': len(val_ds)}

Files already downloaded and verified
Files already downloaded and verified


Función de entrenamiento del modelo

In [8]:
def train_model(model, criterion, optimizer, num_epochs=40, nclas=200, patience=8):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    #best_acc = 0.0
    best_bal_acc = 0.0

    #early stopping
    best_epoch = 0
    
    epochs_bal_acc = []


    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            CF = np.zeros((nclas,nclas)) # Confusion matrix

            # Iterate over data.
            for inputs, labels in ds[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                for i in range(len(labels.data)):
                    CF[labels.data[i]][preds[i]] +=1
                
            #if phase == 'train':
            #    scheduler.step()

            epoch_loss = running_loss / ds_sizes[phase]
            epoch_acc = running_corrects.double() / ds_sizes[phase]
            recalli = 0
            for i in range(nclas):
                TP = CF[i][i]
                FN = 0
                for j in range(nclas):
                    if i!=j:
                        FN+=CF[i][j]
                if (TP+FN) !=0:
                    recalli+= TP/(TP+FN)
            epoch_bal_acc = recalli/nclas
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} Balanced Acc: {epoch_bal_acc:.4f}')
            if phase == 'val':
                epochs_bal_acc.append(epoch_bal_acc)
            # deep copy the model
            if phase == 'val' and epoch_bal_acc > best_bal_acc:
                best_bal_acc = epoch_bal_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                best_epoch = epoch
            
            if phase == 'val' and epoch - best_epoch > patience:
                print('Early stopping')
                break

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Balanced Acc: {best_bal_acc:4f}')
    print(epochs_bal_acc)

    # load best model weights
    #model.load_state_dict(best_model_wts)
    return model, best_bal_acc

In [9]:
class AlphaConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(AlphaConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        #One alpha per filter
        #self.alpha = nn.Parameter(torch.ones(out_channels, 1, 1, 1))
        self.alpha = nn.Parameter(torch.ones(out_channels))
        
    def forward(self, x):
        #return super(AlphaConv2d, self).forward(x) * self.alpha
        return super(AlphaConv2d, self).forward(x) * self.alpha.unsqueeze(1).unsqueeze(2)

Función para mostrar los parámetros de la capa

In [10]:
def check_weights(m):
    if type(m) == AlphaConv2d:
        for module in m.named_parameters():
            print(module)

Carga el modelo preentrenado (He utilizado una VGG16 para simplificar)

In [11]:
model = vgg16(weights='IMAGENET1K_V1')
model.classifier[6] = nn.Linear(4096, 200)

In [12]:
for name, module in model.named_modules():
  if type(module) == nn.Conv2d:
    new_module = AlphaConv2d(module.in_channels, module.out_channels, module.kernel_size, module.stride, module.padding, module.dilation, module.groups, True)
    new_module.weight = module.weight
    new_module.bias = module.bias
    model.features[int(name.split('.')[1])] = new_module

In [13]:
conv = [] #Conv layers
fc = [] #FC layers

for name, module in model.named_modules():
    if type(module) == AlphaConv2d:
        conv.append(module.alpha)
    elif type(module) == nn.Linear:
        fc.append(module.weight)
        fc.append(module.bias)

optimizer = torch.optim.SGD([
                {'params': conv},
                {'params': fc, 'lr': 0.005}
          ], weight_decay  = 0.005, momentum = 0.9, lr = 0.0005)

adam_optimizer = torch.optim.Adam([
                {'params': conv},
                {'params': fc, 'lr': 0.01}
            ], weight_decay  = 0.005, lr = 0.001)

criterion = nn.CrossEntropyLoss()


In [14]:

optimizer_last = torch.optim.SGD(model.classifier[6].parameters(), lr=0.005, momentum=0.9, weight_decay=0.005)
adam_optimizer_last = torch.optim.Adam(model.classifier[6].parameters(), lr=0.01, weight_decay=0.005)
model = model.to(device)
model, _ = train_model(model, criterion, optimizer_last, num_epochs=20, nclas=200)

Epoch 0/19
----------
train Loss: 4.0674 Acc: 0.1461 Balanced Acc: 0.1461
val Loss: 2.9069 Acc: 0.3407 Balanced Acc: 0.3452

Epoch 1/19
----------
train Loss: 2.7207 Acc: 0.3458 Balanced Acc: 0.3458
val Loss: 2.4438 Acc: 0.4068 Balanced Acc: 0.4120

Epoch 2/19
----------
train Loss: 2.4293 Acc: 0.4002 Balanced Acc: 0.4002
val Loss: 2.2485 Acc: 0.4474 Balanced Acc: 0.4524

Epoch 3/19
----------
train Loss: 2.2676 Acc: 0.4369 Balanced Acc: 0.4370
val Loss: 2.1705 Acc: 0.4512 Balanced Acc: 0.4553

Epoch 4/19
----------
train Loss: 2.1809 Acc: 0.4543 Balanced Acc: 0.4542
val Loss: 2.1448 Acc: 0.4569 Balanced Acc: 0.4599

Epoch 5/19
----------
train Loss: 2.0794 Acc: 0.4623 Balanced Acc: 0.4623
val Loss: 2.0311 Acc: 0.4824 Balanced Acc: 0.4873

Epoch 6/19
----------
train Loss: 2.0231 Acc: 0.4795 Balanced Acc: 0.4795
val Loss: 2.0778 Acc: 0.4700 Balanced Acc: 0.4735

Epoch 7/19
----------
train Loss: 1.9775 Acc: 0.4887 Balanced Acc: 0.4886
val Loss: 2.0324 Acc: 0.4860 Balanced Acc: 0.4880



In [15]:

betterAcc = True
previousAcc = 0.0

#Lock non alphaconv2d modules grad

for name, module in model.named_modules():
  if type(module) != AlphaConv2d:
      for param in module.parameters():
            if(param.requires_grad):
                param.requires_grad = False
        

while betterAcc:
    # Train the factors alpha by Eq.[3]
   
    for name, module in model.named_modules():
      if type(module) == AlphaConv2d:
          for param in module.parameters():
              param.requires_grad = False
          module.alpha.requires_grad = True


    model_ft,_ = train_model(model, criterion, optimizer, num_epochs=20, nclas =200) # Set to 40 in the paper
    
    #Calculo del gradiente de los parámetros alpha
    
    alpha_grad = {}
    for inputs, labels in ds['train']:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        #Get the gradient of all alpha parameters in the vgg16 model
        for name, param in model.named_parameters():
            if 'alpha' in name:
                if name.split('.')[0]+'.'+name.split('.')[1] not in alpha_grad:
                    alpha_grad[name.split('.')[0]+'.'+name.split('.')[1]] = (param.grad/len(ds['train']))
                else:
                    alpha_grad[name.split('.')[0]+'.'+name.split('.')[1]] += (param.grad/len(ds['train']))
    
    betas = []
    
    for name, module in model_ft.named_modules():
        if type(module) == AlphaConv2d:
            module.alpha.data = torch.abs(alpha_grad[name] * module.alpha.data) #Transform to beta
            betas.extend(module.alpha)
    
    PERC = 0.10
    pruneVal = max(nsmallest(int(len(betas)*PERC),betas))
    print("Betas length:")
    print(len(betas))
    
    for name, module in model_ft.named_modules():
        if type(module) == AlphaConv2d:
            mask = module.alpha > pruneVal
            print(f'Pruned {torch.sum((mask) == 0)} filters' ) 
            #ToDo: poner las betas podadas a 0
            mask=mask.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(module.weight.data)
            prune.custom_from_mask(module, 'weight', mask)
            
            

    #Fine tune the model
    for name, module in model_ft.named_modules():
        if type(module) == AlphaConv2d:
            for param in module.parameters():
                param.requires_grad = True
            module.alpha.requires_grad = False

            
    model_ft,current_acc = train_model(model_ft, criterion, optimizer, num_epochs=40, nclas =200) # Set to 40 in the paper
    if current_acc-previousAcc > 0.3: #tau = 0.3
        previousAcc = current_acc
        model = model_ft
    else:
        betterAcc = False
    break



Epoch 0/19
----------
train Loss: 1.7326 Acc: 0.5464 Balanced Acc: 0.5465
val Loss: 1.9868 Acc: 0.4953 Balanced Acc: 0.4989

Epoch 1/19
----------
train Loss: 1.6641 Acc: 0.5689 Balanced Acc: 0.5690
val Loss: 1.9479 Acc: 0.5071 Balanced Acc: 0.5109

Epoch 2/19
----------
train Loss: 1.6862 Acc: 0.5701 Balanced Acc: 0.5701
val Loss: 1.9867 Acc: 0.5038 Balanced Acc: 0.5076

Epoch 3/19
----------
train Loss: 1.7058 Acc: 0.5697 Balanced Acc: 0.5699
val Loss: 2.0289 Acc: 0.4950 Balanced Acc: 0.4998

Epoch 4/19
----------
train Loss: 1.7204 Acc: 0.5697 Balanced Acc: 0.5699
val Loss: 2.0282 Acc: 0.4986 Balanced Acc: 0.5037

Epoch 5/19
----------
train Loss: 1.7580 Acc: 0.5714 Balanced Acc: 0.5715
val Loss: 2.0414 Acc: 0.5002 Balanced Acc: 0.5048

Epoch 6/19
----------
train Loss: 1.7827 Acc: 0.5696 Balanced Acc: 0.5697
val Loss: 2.0689 Acc: 0.5050 Balanced Acc: 0.5077

Epoch 7/19
----------
train Loss: 1.8077 Acc: 0.5657 Balanced Acc: 0.5659
val Loss: 2.0865 Acc: 0.5036 Balanced Acc: 0.5090



KeyboardInterrupt: 