Importaciones

In [None]:
import os
import copy
import time
import pandas as pd
from torchvision.io import read_image
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_url
from torch.utils.data import Dataset
import torchvision.transforms as T
import matplotlib.pyplot as plt
from torchvision.models import vgg16
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
import numpy as np
import torch.nn.utils.prune as prune
from heapq import nsmallest

Activar acceleración por hardware -> GPU

In [None]:
torch.cuda.is_available()

In [None]:
!nvidia-smi

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Función de carga del dataset CUB (El dataset se descarga automáticamente al ejecutar las siguientes celdas)

In [None]:
class Cub2011(Dataset):
    base_folder = 'CUB_200_2011/images'
    url = 'https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz?download=1'
    filename = 'CUB_200_2011.tgz'
    tgz_md5 = '97eceeb196236b17998738112f37df78'

    def __init__(self, root, train=True, transform=None, loader=default_loader, download=True):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.loader = default_loader
        self.train = train

        if download:
            self._download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

    def _load_metadata(self):
        images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
                             names=['img_id', 'filepath'])
        image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
                                         sep=' ', names=['img_id', 'target'])
        train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
                                       sep=' ', names=['img_id', 'is_training_img'])

        data = images.merge(image_class_labels, on='img_id')
        self.data = data.merge(train_test_split, on='img_id')

        if self.train:
            self.data = self.data[self.data.is_training_img == 1]
        else:
            self.data = self.data[self.data.is_training_img == 0]

    def _check_integrity(self):
        try:
            self._load_metadata()
        except Exception:
            return False

        for index, row in self.data.iterrows():
            filepath = os.path.join(self.root, self.base_folder, row.filepath)
            if not os.path.isfile(filepath):
                print(filepath)
                return False
        return True

    def _download(self):
        import tarfile

        if self._check_integrity():
            print('Files already downloaded and verified')
            return

        download_url(self.url, self.root, self.filename, self.tgz_md5)

        with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
            tar.extractall(path=self.root)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data.iloc[idx]
        path = os.path.join(self.root, self.base_folder, sample.filepath)
        target = sample.target - 1 
        img = self.loader(path)

        if self.transform is not None:
            img = self.transform(img)

        return img, target

Preprocesado de las imágenes

In [None]:
transform = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

Carga del dataset (Colab no soporta un batch_size muy alto así que lo dejo en 64 para la prueba)



In [None]:
train_ds = Cub2011('.', train=True, transform = transform)
val_ds = Cub2011('.s', train=False, transform = transform)

ds = {'train': DataLoader(train_ds, batch_size = 64, shuffle=True),
      'val': DataLoader(val_ds, batch_size = 64, shuffle=False)}


ds_sizes = {'train': len(train_ds),
      'val': len(val_ds)}

Función de entrenamiento del modelo

In [None]:
def train_model(model, criterion, optimizer, num_epochs=40, nclas=200, patience=8):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    #best_acc = 0.0
    best_bal_acc = 0.0

    #early stopping
    best_epoch = 0
    
    epochs_bal_acc = []


    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            CF = np.zeros((nclas,nclas)) # Confusion matrix

            # Iterate over data.
            for inputs, labels in ds[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                for i in range(len(labels.data)):
                    CF[labels.data[i]][preds[i]] +=1
                
            #if phase == 'train':
            #    scheduler.step()

            epoch_loss = running_loss / ds_sizes[phase]
            epoch_acc = running_corrects.double() / ds_sizes[phase]
            recalli = 0
            for i in range(nclas):
                TP = CF[i][i]
                FN = 0
                for j in range(nclas):
                    if i!=j:
                        FN+=CF[i][j]
                if (TP+FN) !=0:
                    recalli+= TP/(TP+FN)
            epoch_bal_acc = recalli/nclas
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} Balanced Acc: {epoch_bal_acc:.4f}')
            if phase == 'val':
                epochs_bal_acc.append(epoch_bal_acc)
            # deep copy the model
            if phase == 'val' and epoch_bal_acc > best_bal_acc:
                best_bal_acc = epoch_bal_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                best_epoch = epoch
            
            if phase == 'val' and epoch - best_epoch > patience:
                print('Early stopping')
                break

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Balanced Acc: {best_bal_acc:4f}')
    print(epochs_bal_acc)

    # load best model weights
    #model.load_state_dict(best_model_wts)
    return model, best_bal_acc

In [None]:
class AlphaConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(AlphaConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        #One alpha per filter
        #self.alpha = nn.Parameter(torch.ones(out_channels, 1, 1, 1))
        self.alpha = nn.Parameter(torch.rand(out_channels))
        
    def forward(self, x):
        #return super(AlphaConv2d, self).forward(x) * self.alpha
        return super(AlphaConv2d, self).forward(x) * self.alpha.unsqueeze(1).unsqueeze(2)

Función para mostrar los parámetros de la capa

In [None]:
def check_weights(m):
    if type(m) == AlphaConv2d:
        for module in m.named_parameters():
            print(module)

Carga el modelo preentrenado (He utilizado una VGG16 para simplificar)

In [None]:
model = vgg16(weights='IMAGENET1K_V1')
model.classifier[6] = nn.Linear(4096, 200)

In [None]:
for name, module in model.named_modules():
  if type(module) == nn.Conv2d:
    new_module = AlphaConv2d(module.in_channels, module.out_channels, module.kernel_size, module.stride, module.padding, module.dilation, module.groups, True)
    new_module.weight = module.weight
    new_module.bias = module.bias
    model.features[int(name.split('.')[1])] = new_module

In [None]:
conv = [] #Conv layers
fc = [] #FC layers

for name, module in model.named_modules():
    if type(module) == AlphaConv2d:
        conv.append(module.alpha)
    elif type(module) == nn.Linear:
        fc.append(module.weight)
        fc.append(module.bias)

optimizer = torch.optim.SGD([
                {'params': conv},
                {'params': fc, 'lr': 0.005}
          ], weight_decay  = 0.005, momentum = 0.9, lr = 0.0005)

adam_optimizer = torch.optim.Adam([
                {'params': conv},
                {'params': fc, 'lr': 0.01}
            ], weight_decay  = 0.005, lr = 0.001)

criterion = nn.CrossEntropyLoss()


In [None]:

optimizer_last = torch.optim.SGD(model.classifier[6].parameters(), lr=0.005, momentum=0.9, weight_decay=0.005)
adam_optimizer_last = torch.optim.Adam(model.classifier[6].parameters(), lr=0.01, weight_decay=0.005)
model = model.to(device)
model, _ = train_model(model, criterion, optimizer_last, num_epochs=60, nclas=200)
torch.save(model, '.')

In [None]:

betterAcc = True
previousAcc = 0.0


        

while betterAcc:
    # Train the factors alpha by Eq.[3]
   
    for name, module in model.named_modules():
      if type(module) == AlphaConv2d:
          for param in module.parameters():
              param.requires_grad = False
          module.alpha.requires_grad = True


    model_ft,_ = train_model(model, criterion, optimizer, num_epochs=60, nclas =200) # Set to 40 in the paper
    
    #Calculo del gradiente de los parámetros alpha
    
    alpha_grad = {}
    for inputs, labels in ds['train']:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        #Get the gradient of all alpha parameters in the vgg16 model
        for name, param in model.named_parameters():
            if 'alpha' in name:
                if name.split('.')[0]+'.'+name.split('.')[1] not in alpha_grad:
                    alpha_grad[name.split('.')[0]+'.'+name.split('.')[1]] = (param.grad/len(ds['train']))
                else:
                    alpha_grad[name.split('.')[0]+'.'+name.split('.')[1]] += (param.grad/len(ds['train']))
    
    betas = []
    for name, module in model_ft.named_modules():
        if type(module) == AlphaConv2d:
            module.alpha.data = torch.abs(alpha_grad[name] * module.alpha.data) #Transform to beta
            betas.extend(module.alpha)
    
    PERC = 0.10
    pruneVal = max(nsmallest(int(len(betas)*PERC),betas))
    
    for name, module in model_ft.named_modules():
        if type(module) == AlphaConv2d:
            mask = module.alpha > pruneVal
            print(f'Pruned {torch.sum((mask) == 0)} filters' ) 
            mask=mask.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(module.weight.data)
            prune.custom_from_mask(module, 'weight', mask)
            
            

    #Fine tune the model
    for name, module in model_ft.named_modules():
        if type(module) == AlphaConv2d:
            for param in module.parameters():
                param.requires_grad = True
            module.alpha.requires_grad = False

            
    model_ft = model_ft.to(device)
    model_ft,current_acc = train_model(model_ft, criterion, optimizer, num_epochs=60, nclas =200) # Set to 40 in the paper
    if current_acc-previousAcc > 0.003: #tau = 0.3
        previousAcc = current_acc
        model = model_ft
    else:
        betterAcc = False

