## Library imports

In [25]:
#standard libraries

import random
import os
import numpy as np
import time
import copy

from matplotlib.colors import ListedColormap # colormaps
import matplotlib as plt # visualizations
%matplotlib inline


#Pytorch
import torch #operations in tensors
from torch import nn #layer package and activation functions
from torch import optim # optimization package
from torch.optim import lr_scheduler #scheduler package
import torchvision
from torchvision.models import resnet18, ResNet18_Weights #pre-trained architectures


#Data Set
from torch.utils import data #create new datasets or iterate over one already created
from torchvision import datasets #preloaded datasets
from torchvision import transforms #transformations on the data after it is loaded


#pre-trained models
from torchvision import models #load different pre-trained models

from tqdm.auto import tqdm
from prettytable import PrettyTable
#from trainer import Trainer
#from callbacks import ModelCheckpoint, EarlyStopping

try:
  import torchinfo
except:
  !pip install torchinfo
from torchinfo import summary # Information of implemented architectures



Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [8]:
#pip install --user matplotlib
#pip install --user --upgrade matplotlib

Note: you may need to restart the kernel to use updated packages.


In [16]:
#pip install torch

Collecting torch
  Downloading torch-2.0.1-cp38-cp38-win_amd64.whl (172.4 MB)
Installing collected packages: torch
Successfully installed torch-2.0.1
Note: you may need to restart the kernel to use updated packages.


In [20]:
#pip install torchvision

Collecting torchvisionNote: you may need to restart the kernel to use updated packages.

  Downloading torchvision-0.15.2-cp38-cp38-win_amd64.whl (1.2 MB)
Installing collected packages: torchvision
Successfully installed torchvision-0.15.2


## Seeds for Reproductibility

In [26]:
def set_seed(seed=None, seed_torch=True):
  if seed is None:
    seed = np.random.choice(2 ** 32)
  random.seed(seed)
  np.random.seed(seed)
  if seed_torch:
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    #torch.backends.cudnn.benchmark = False
    #torch.backends.cudnn.deterministic = True
  
  print(f'Seed {seed} has been assigned.')

In [27]:
set_seed(42)

Seed 42 has been assigned.


## Device
We select the available device to perform the training and tests.

In [28]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print('Using {}'.format(device))

Using cpu


## Useful functions

### Method to visualize learning curves

In [29]:
def visualize_learning_curves(results):
    fig = plt.figure(figsize=(15,5))
    ax = plt.subplot(121)
    epochs = range(1, len(results["train_loss"])+1)
    plt.plot(epochs, results["train_loss"])
    plt.plot(epochs,results["val_loss"])
    plt.title("Loss vs epoch")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend(["train", "val"])

    ax = plt.subplot(122)
    plt.plot(epochs, results["train_acc"])
    plt.plot(epochs, results["val_acc"])
    plt.title("ACC vs epoch")
    plt.xlabel("Epochs")
    plt.ylabel("ACC")
    plt.legend(["train", "val"])
    plt.show()

### Summary of Results

In [30]:
def print_table(headers, values):
    table = PrettyTable(headers)
    for i in values:
        table.add_row(i)
    table.float_format = '.3'    
    print(table)

def add_results(name, ckp_results):
    return [name, ckp_results['train_loss'], ckp_results['val_loss'], 
                  ckp_results['train_acc'], ckp_results['val_acc'] ]

def total_num_parameters(model):
    return sum(p.numel() for p in model.parameters())

## Dataset

Download  from : https://download.pytorch.org/tutorial/hymenoptera_data.zip , and extract it to the current directory. In colab we can use !wget to download it and !unzip to extract it:



In [35]:
#!wget https://download.pytorch.org/tutorial/hymenoptera_data.zip
#!unzip hymenoptera_data.zip

'wget' is not recognized as an internal or external command,
operable program or batch file.
'unzip' is not recognized as an internal or external command,
operable program or batch file.


In [36]:
import urllib.request

url = 'https://download.pytorch.org/tutorial/hymenoptera_data.zip'
file_name = 'hymenoptera_data.zip'

urllib.request.urlretrieve(url, file_name)

('hymenoptera_data.zip', <http.client.HTTPMessage at 0x1dbec157be0>)

In [37]:
import zipfile

zip_file = 'hymenoptera_data.zip'
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
    zip_ref.extractall()