Importaciones

In [None]:
import os
import copy
import time
import pandas as pd
from torchvision.io import read_image
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_url
from torch.utils.data import Dataset
from torch import autograd
import torchvision.transforms as T
import matplotlib.pyplot as plt
from torchvision.models import vgg16
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import torch.nn.utils.prune as prune
from heapq import nsmallest
import torch.optim.lr_scheduler as lr_scheduler
import torch.optim as optim

In [None]:
import nni
from nni.compression.pytorch.pruning import MovementPruner
from nni.compression.pytorch import TorchEvaluator

In [None]:
BATCH_SIZE = 32

In [None]:
torch.cuda.is_available()

In [None]:
#Clean vram
torch.cuda.empty_cache() 

In [None]:
!nvidia-smi

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Función de carga del dataset CUB (El dataset se descarga automáticamente al ejecutar las siguientes celdas)

In [None]:
class Cub2011(Dataset):
    base_folder = 'CUB_200_2011/images'
    url = 'https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz?download=1'
    filename = 'CUB_200_2011.tgz'
    tgz_md5 = '97eceeb196236b17998738112f37df78'

    def __init__(self, root, train=True, transform=None, loader=default_loader, download=True):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.loader = default_loader
        self.train = train

        if download:
            self._download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

    def _load_metadata(self):
        images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
                             names=['img_id', 'filepath'])
        image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
                                         sep=' ', names=['img_id', 'target'])
        train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
                                       sep=' ', names=['img_id', 'is_training_img'])

        data = images.merge(image_class_labels, on='img_id')
        self.data = data.merge(train_test_split, on='img_id')

        if self.train:
            self.data = self.data[self.data.is_training_img == 1]
        else:
            self.data = self.data[self.data.is_training_img == 0]

    def _check_integrity(self):
        try:
            self._load_metadata()
        except Exception:
            return False

        for index, row in self.data.iterrows():
            filepath = os.path.join(self.root, self.base_folder, row.filepath)
            if not os.path.isfile(filepath):
                print(filepath)
                return False
        return True

    def _download(self):
        import tarfile

        if self._check_integrity():
            print('Files already downloaded and verified')
            return

        download_url(self.url, self.root, self.filename, self.tgz_md5)

        with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
            tar.extractall(path=self.root)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data.iloc[idx]
        path = os.path.join(self.root, self.base_folder, sample.filepath)
        target = sample.target - 1 
        img = self.loader(path)

        if self.transform is not None:
            img = self.transform(img)

        return img, target

In [None]:
transform = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
train_ds = Cub2011('.', train=True, transform = transform)
val_ds = Cub2011('.s', train=False, transform = transform)

ds = {'train': DataLoader(train_ds, batch_size = BATCH_SIZE, shuffle=True),
      'val': DataLoader(val_ds, batch_size = BATCH_SIZE, shuffle=False)}


ds_sizes = {'train': len(train_ds),
      'val': len(val_ds)}

Función para mostrar los parámetros de la capa

Carga el modelo preentrenado (He utilizado una VGG16 para simplificar)

In [None]:
model = vgg16(weights='IMAGENET1K_V1')
model.classifier[6] = nn.Linear(4096, 200)

In [None]:
model = model.to(device)

In [None]:
val_bal_acc = []
val_acc = []
val_loss = []

train_bal_acc = []
train_acc = []
train_loss = []
NCLAS = 200


def training_model(model, optimizer, criterion, lr_scheduler, max_steps, max_epochs, *args, **kwargs):
  for epoch in range(max_epochs):
    for phase in ['train', 'val']:
      if phase == 'train':
        model.train() 
      else:
        model.eval() 
      running_loss = 0.0
      running_corrects = 0
      CF = np.zeros((NCLAS,NCLAS)) # Confusion matrix
      for inputs,labels in ds[phase]:
          inputs, labels = inputs.to(device), labels.to(device)
          optimizer.zero_grad()
          with torch.set_grad_enabled(phase == 'train'):
            output = model(inputs)
            _, preds = torch.max(output, 1)
            loss = criterion(output, labels)
          if phase == 'train':
            loss.backward()
            optimizer.step()
          running_loss += loss.item() * inputs.size(0)
          running_corrects += torch.sum(preds == labels.data)
          for i in range(len(labels.data)):
            CF[labels.data[i]][preds[i]] +=1
      if phase == 'train':
        lr_scheduler.step()
      epoch_loss = running_loss / ds_sizes[phase]
      epoch_acc = running_corrects.double() / ds_sizes[phase]
      recalli = 0
      for i in range(NCLAS):
          TP = CF[i][i]
          FN = 0
          for j in range(NCLAS):
              if i!=j:
                  FN+=CF[i][j]
          if (TP+FN) !=0:
              recalli+= TP/(TP+FN)
      epoch_bal_acc = recalli/NCLAS
      
      if phase == 'val':
          val_bal_acc.append(epoch_bal_acc)
          val_acc.append(epoch_acc)
          val_loss.append(epoch_loss)
          print(f'Val Acc: {epoch_acc:.4f}')
      else:
          train_bal_acc.append(epoch_bal_acc)
          train_acc.append(epoch_acc)
          train_loss.append(epoch_loss)
          print(f'Train Acc: {epoch_acc:.4f}')


In [None]:

traced_optimizer = nni.trace(optim.SGD)(model.parameters(), lr=0.001, momentum=0.9)
config_list = [{'op_types': ['Conv2d','Linear'], 
'sparsity_per_layer': 0.2}]
criterion = nn.CrossEntropyLoss()
lr_scheduler = nni.trace(torch.optim.lr_scheduler.StepLR)(traced_optimizer, step_size=7, gamma=0.1)
evaluator = TorchEvaluator(training_func=training_model, optimizers=traced_optimizer,criterion=criterion, lr_schedulers=lr_scheduler)
# warm_up_step – The total optimizer.step() number before start pruning for warm up. Make sure warm_up_step is smaller than cool_down_beginning_step.
# cool_down_beginning_step – The number of steps at which sparsity stops growing, note that the sparsity stop growing doesn’t mean masks not changed.
warm_up_step = len(train_ds) // BATCH_SIZE * 6
cool_down_begin_step = len(train_ds) // BATCH_SIZE * 8
pruner = MovementPruner(model, config_list, evaluator, warm_up_step=warm_up_step, cool_down_beginning_step=cool_down_begin_step, training_epochs=70,  movement_mode='hard')
_, masks = pruner.compress()