Importaciones

In [1]:
import os
import copy
import time
import pandas as pd
from torchvision.io import read_image
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_url
from torch.utils.data import Dataset
from torch import autograd
import torchvision.transforms as T
import matplotlib.pyplot as plt
from torchvision.models import vgg16
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
import numpy as np
import torch.nn.utils.prune as prune
from heapq import nsmallest

Activar acceleración por hardware -> GPU

In [2]:
torch.cuda.is_available()

True

In [3]:
!nvidia-smi

Wed Nov  2 15:20:50 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 522.06       Driver Version: 522.06       CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:04:00.0 Off |                  N/A |
| 44%   34C    P8    13W /  95W |   1582MiB /  2048MiB |      7%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla M40 24GB     TCC   | 00000000:2B:00.0 Off |           1488973049 |
| N/A   77C    P0    62W / 250W |     11MiB / 23040MiB |      0%      Default |
|       

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Función de carga del dataset CUB (El dataset se descarga automáticamente al ejecutar las siguientes celdas)

In [5]:
class Cub2011(Dataset):
    base_folder = 'CUB_200_2011/images'
    url = 'https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz?download=1'
    filename = 'CUB_200_2011.tgz'
    tgz_md5 = '97eceeb196236b17998738112f37df78'

    def __init__(self, root, train=True, transform=None, loader=default_loader, download=True):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.loader = default_loader
        self.train = train

        if download:
            self._download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

    def _load_metadata(self):
        images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
                             names=['img_id', 'filepath'])
        image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
                                         sep=' ', names=['img_id', 'target'])
        train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
                                       sep=' ', names=['img_id', 'is_training_img'])

        data = images.merge(image_class_labels, on='img_id')
        self.data = data.merge(train_test_split, on='img_id')

        if self.train:
            self.data = self.data[self.data.is_training_img == 1]
        else:
            self.data = self.data[self.data.is_training_img == 0]

    def _check_integrity(self):
        try:
            self._load_metadata()
        except Exception:
            return False

        for index, row in self.data.iterrows():
            filepath = os.path.join(self.root, self.base_folder, row.filepath)
            if not os.path.isfile(filepath):
                print(filepath)
                return False
        return True

    def _download(self):
        import tarfile

        if self._check_integrity():
            print('Files already downloaded and verified')
            return

        download_url(self.url, self.root, self.filename, self.tgz_md5)

        with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
            tar.extractall(path=self.root)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data.iloc[idx]
        path = os.path.join(self.root, self.base_folder, sample.filepath)
        target = sample.target - 1 
        img = self.loader(path)

        if self.transform is not None:
            img = self.transform(img)

        return img, target

In [6]:
transform = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [7]:
train_ds = Cub2011('.', train=True, transform = transform)
val_ds = Cub2011('.s', train=False, transform = transform)

ds = {'train': DataLoader(train_ds, batch_size = 64, shuffle=True),
      'val': DataLoader(val_ds, batch_size = 64, shuffle=False)}


ds_sizes = {'train': len(train_ds),
      'val': len(val_ds)}

Files already downloaded and verified
Files already downloaded and verified


In [8]:
def train_model(model, criterion, optimizer, scheduler, dataloaders, nclas, dataset_sizes, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_bal_acc = 0.0

    val_bal_acc = []
    val_acc = []
    val_loss = []

    train_bal_acc = []
    train_acc = []
    train_loss = []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            CF = np.zeros((nclas,nclas)) # Confusion matrix

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                for i in range(len(labels.data)):
                    CF[labels.data[i]][preds[i]] +=1

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            recalli = 0
            for i in range(nclas):
                TP = CF[i][i]
                FN = 0
                for j in range(nclas):
                    if i!=j:
                        FN+=CF[i][j]
                if (TP+FN) !=0:
                    recalli+= TP/(TP+FN)
            epoch_bal_acc = recalli/nclas

            if phase == 'val':
                val_bal_acc.append(epoch_bal_acc)
                val_acc.append(epoch_acc)
                val_loss.append(epoch_loss)
            else:
                train_bal_acc.append(epoch_bal_acc)
                train_acc.append(epoch_acc)
                train_loss.append(epoch_loss)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} Balanced Acc: {epoch_bal_acc:.4f}')

            # deep copy the model
            if phase == 'val' and epoch_bal_acc > best_bal_acc:
                best_acc = epoch_acc
                best_bal_acc = epoch_bal_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    #print(f'Best val Acc: {best_acc:4f}')
    print(f'Best val Balanced Acc: {best_bal_acc:4f}')

    print('Validation:')
    print('Val_bal_acc:', val_bal_acc)
    print('Val_acc:', val_acc)
    print('Val_loss:', val_loss)

    print('Training:')
    print('Train_bal_acc:', train_bal_acc)
    print('Train_acc:', train_acc)
    print('Train_loss:', train_loss)

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [9]:
class TopKBinarizer(autograd.Function):
    @staticmethod
    def forward(ctx, inputs: torch.tensor, threshold: float):
        mask = inputs.clone()
        _, idx = inputs.flatten().sort(descending=True)
        j = int(threshold * inputs.numel())

        flat_out = mask.flatten()
        flat_out[idx[j:]] = 0
        flat_out[idx[:j]] = 1
        return mask
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None



class MaskedConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(MaskedConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        size = self.weight.size()
        #Create a mask for each filter
        mask_size = (size[0], 1, size[2], size[3])
        self.mask_scores = nn.Parameter(torch.Tensor(*mask_size))
        

    @staticmethod
    def mask_(mask_scores, threshold):
        mask = TopKBinarizer.apply(mask_scores, threshold)
        return mask

        
    def forward(self, x):
        # 10% of pruned filters
        mask = self.mask_(self.mask_scores, 0.1)
        #print(mask)
        #print(self.weight)
        self.weight.data = self.weight.data * mask
        return torch.nn.functional.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
    
    def backward(self, grad_output):
        return grad_output, None
        




Función para mostrar los parámetros de la capa

In [10]:
def check_weights(m):
    if type(m) == AlphaConv2d:
        for module in m.named_parameters():
            print(module)

Carga el modelo preentrenado (He utilizado una VGG16 para simplificar)

In [11]:
model = vgg16(weights='IMAGENET1K_V1')
model.classifier[6] = nn.Linear(4096, 200)

In [12]:
for name, module in model.named_modules():
  if type(module) == nn.Conv2d:
    new_module = AlphaConv2d(module.in_channels, module.out_channels, module.kernel_size, module.stride, module.padding, module.dilation, module.groups, True)
    new_module.weight = module.weight
    new_module.bias = module.bias
    model.features[int(name.split('.')[1])] = new_module

In [13]:
conv = [] #Conv layers
fc = [] #FC layers

for name, module in model.named_modules():
    if type(module) == AlphaConv2d:
        conv.append(module.alpha)
    elif type(module) == nn.Linear:
        fc.append(module.weight)
        fc.append(module.bias)

optimizer = torch.optim.SGD([
                {'params': conv},
                {'params': fc, 'lr': 0.005}
          ], weight_decay  = 0.005, momentum = 0.9, lr = 0.0005)

adam_optimizer = torch.optim.Adam([
                {'params': conv},
                {'params': fc, 'lr': 0.01}
            ], weight_decay  = 0.005, lr = 0.001)

criterion = nn.CrossEntropyLoss()


In [14]:

optimizer_last = torch.optim.SGD(model.classifier[6].parameters(), lr=0.005, momentum=0.9, weight_decay=0.005)
adam_optimizer_last = torch.optim.Adam(model.classifier[6].parameters(), lr=0.01, weight_decay=0.005)
model = model.to(device)
model, _ = train_model(model, criterion, optimizer_last, num_epochs=60, nclas=200)
torch.save(model, '.')

Epoch 0/59
----------
train Loss: 5.3123 Acc: 0.0028 Balanced Acc: 0.0028
val Loss: 5.2962 Acc: 0.0064 Balanced Acc: 0.0068

Epoch 1/59
----------
train Loss: 5.3097 Acc: 0.0032 Balanced Acc: 0.0032
val Loss: 5.2936 Acc: 0.0055 Balanced Acc: 0.0054

Epoch 2/59
----------
train Loss: 5.3062 Acc: 0.0038 Balanced Acc: 0.0038
val Loss: 5.2911 Acc: 0.0072 Balanced Acc: 0.0070

Epoch 3/59
----------
train Loss: 5.3053 Acc: 0.0035 Balanced Acc: 0.0035
val Loss: 5.2887 Acc: 0.0081 Balanced Acc: 0.0079

Epoch 4/59
----------
train Loss: 5.3029 Acc: 0.0030 Balanced Acc: 0.0030
val Loss: 5.2872 Acc: 0.0085 Balanced Acc: 0.0091

Epoch 5/59
----------
train Loss: 5.2992 Acc: 0.0048 Balanced Acc: 0.0048
val Loss: 5.2850 Acc: 0.0116 Balanced Acc: 0.0112

Epoch 6/59
----------
train Loss: 5.2978 Acc: 0.0043 Balanced Acc: 0.0043
val Loss: 5.2830 Acc: 0.0129 Balanced Acc: 0.0133

Epoch 7/59
----------
train Loss: 5.2956 Acc: 0.0058 Balanced Acc: 0.0058
val Loss: 5.2811 Acc: 0.0143 Balanced Acc: 0.0141



PermissionError: [Errno 13] Permission denied: '.'

In [15]:

betterAcc = True
previousAcc = 0.0


        

while betterAcc:
    # Train the factors alpha by Eq.[3]
   
    for name, module in model.named_modules():
      if type(module) == AlphaConv2d:
          for param in module.parameters():
              param.requires_grad = False
          module.alpha.requires_grad = True


    model_ft,_ = train_model(model, criterion, optimizer, num_epochs=60, nclas =200) # Set to 40 in the paper
    
    #Calculo del gradiente de los parámetros alpha
    
    alpha_grad = {}
    for inputs, labels in ds['train']:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        #Get the gradient of all alpha parameters in the vgg16 model
        for name, param in model.named_parameters():
            if 'alpha' in name:
                if name.split('.')[0]+'.'+name.split('.')[1] not in alpha_grad:
                    alpha_grad[name.split('.')[0]+'.'+name.split('.')[1]] = (param.grad/len(ds['train']))
                else:
                    alpha_grad[name.split('.')[0]+'.'+name.split('.')[1]] += (param.grad/len(ds['train']))
    
    betas = []
    for name, module in model_ft.named_modules():
        if type(module) == AlphaConv2d:
            module.alpha.data = torch.abs(alpha_grad[name] * module.alpha.data) #Transform to beta
            betas.extend(module.alpha)
    
    PERC = 0.10
    pruneVal = max(nsmallest(int(len(betas)*PERC),betas))
    
    for name, module in model_ft.named_modules():
        if type(module) == AlphaConv2d:
            mask = module.alpha > pruneVal
            print(f'Pruned {torch.sum((mask) == 0)} filters' ) 
            mask=mask.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(module.weight.data)
            prune.custom_from_mask(module, 'weight', mask)
            
            

    #Fine tune the model
    for name, module in model_ft.named_modules():
        if type(module) == AlphaConv2d:
            for param in module.parameters():
                param.requires_grad = True
            module.alpha.requires_grad = False

            
    model_ft = model_ft.to(device)
    model_ft,current_acc = train_model(model_ft, criterion, optimizer, num_epochs=60, nclas =200) # Set to 40 in the paper
    if current_acc-previousAcc > 0.003: #tau = 0.3
        previousAcc = current_acc
        model = model_ft
    else:
        betterAcc = False



Epoch 0/59
----------
train Loss: 5.2342 Acc: 0.0120 Balanced Acc: 0.0121
val Loss: 5.2070 Acc: 0.0128 Balanced Acc: 0.0127

Epoch 1/59
----------
train Loss: 5.2066 Acc: 0.0123 Balanced Acc: 0.0123
val Loss: 5.1887 Acc: 0.0178 Balanced Acc: 0.0174

Epoch 2/59
----------
train Loss: 5.1886 Acc: 0.0135 Balanced Acc: 0.0135
val Loss: 5.1779 Acc: 0.0197 Balanced Acc: 0.0194

Epoch 3/59
----------
train Loss: 5.1694 Acc: 0.0150 Balanced Acc: 0.0150
val Loss: 5.1668 Acc: 0.0171 Balanced Acc: 0.0172

Epoch 4/59
----------
train Loss: 5.1464 Acc: 0.0190 Balanced Acc: 0.0190
val Loss: 5.1525 Acc: 0.0185 Balanced Acc: 0.0188

Epoch 5/59
----------
train Loss: 5.1371 Acc: 0.0184 Balanced Acc: 0.0184
val Loss: 5.1435 Acc: 0.0178 Balanced Acc: 0.0180

Epoch 6/59
----------
train Loss: 5.1140 Acc: 0.0189 Balanced Acc: 0.0189
val Loss: 5.1155 Acc: 0.0214 Balanced Acc: 0.0214

Epoch 7/59
----------
train Loss: 5.0949 Acc: 0.0177 Balanced Acc: 0.0177
val Loss: 5.1035 Acc: 0.0216 Balanced Acc: 0.0217



KeyboardInterrupt: 