# 1. Import modules (or libraries)

In [1]:
# from notebook.services.config import ConfigManager
# cm = ConfigManager()
# cm.update('livereveal', {
#         'width': "90%",
#         'height': "90%",
#         'scroll': True,
# })

In [2]:
import sys
import os
import numpy as np
import gc
from datetime import datetime

In [None]:
import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
from datetime import datetime
from tqdm import tqdm

# 2. Define the device for training

In [None]:
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
DEVICE

# 3. Define model
![alt text](../figures/vgg16-class.png)

In [None]:
def getVGGModel():
#   vgg16 = models.vgg16(weights=torchvision.models.vgg.VGG16_Weights.IMAGENET1K_V5)
  vgg16 = models.vgg16_bn(weights=models.vgg.VGG16_BN_Weights.IMAGENET1K_V1)

  # Fix the conv layers parameters
  for conv_param in vgg16.features.parameters():
    conv_param.require_grad = False

  # Replace w/ new classification layers
  classifications = nn.Sequential(
    nn.Linear(25088,1024),
    nn.ReLU(inplace=True),
    nn.Dropout(p=0.5),
    nn.Linear(1024,3)
  )

  vgg16.classifier = classifications

  return vgg16

In [None]:
model = getVGGModel()
    
model.to(DEVICE)

# 4. Define hyperparameters   

In [None]:
hp = {"lr":1e-5, "beta1":0.9, "beta2":0.999, "batch_size":16, "epochs":5}

# 5. Load dataset

In [None]:
def load_datasets(train_path, val_path, test_path):
  img_transform = transforms.Compose([transforms.Resize((244,244)),transforms.ToTensor()])
  train_dataset = datasets.ImageFolder(train_path, transform=img_transform)
  val_dataset = datasets.ImageFolder(val_path, transform=img_transform) 
  test_dataset = datasets.ImageFolder(test_path, transform=img_transform) if test_path is not None else None
  print(f"Train set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
  return train_dataset, val_dataset, test_dataset

def construct_dataloaders(train_set, val_set, test_set, batch_size, shuffle=True):
  train_dataloader = torch.utils.data.DataLoader(train_set, batch_size, shuffle)
  val_dataloader = torch.utils.data.DataLoader(val_set, batch_size) 
  test_dataloader = torch.utils.data.DataLoader(test_set, batch_size) if test_path is not None else None
  return train_dataloader, val_dataloader, test_dataloader

In [None]:
# Please specify the path to train, cross_validation, and test images below:
train_path, val_path, test_path = "/tmp/Dataset_2/Train/", "/tmp/Dataset_2/Validation/", None
train_set, val_set, test_set = load_datasets(train_path, val_path, test_path)
train_dataloader, val_dataloader, test_dataloader = construct_dataloaders(train_set, val_set, test_set, hp["batch_size"], True)

# 6. Define optimizer

In [None]:
opt = torch.optim.Adam(model.parameters(),lr=hp["lr"], betas=(hp["beta1"], hp["beta2"]))

# 7. Define loss function

In [None]:
loss_fn = nn.CrossEntropyLoss()

# 8. Train model

## 8.1 Define evaluation function

In [None]:
@torch.no_grad()
def eval_model(data_loader, model, loss_fn, DEVICE):
  model.eval()
  loss, accuracy = 0.0, 0.0
  n = len(data_loader)

  for i, data in enumerate(data_loader):
    x,y = data
    x,y = x.to(DEVICE), y.to(DEVICE)
    pred = model(x)
    loss += loss_fn(pred, y)/len(x)
    pred_label = torch.argmax(pred, axis = 1)
    accuracy += torch.sum(pred_label == y)/len(x)

  return loss/n, accuracy/n 

## 8.2 Define train function

In [None]:
def train(train_loader, val_loader, model, opt, loss_fn, epochs, DEVICE):
  
  for epoch in range(epochs):
    model.train(True)
    count = 0
    
    avg_loss, avg_acc = 0.0, 0.0
    count = 0
    print(f"Epoch {epoch+1}/{epochs}:")
    
    start_time = datetime.now()
    
    total_loss = 0
    total_accuracy = 0
    total_count = 0
    
    with tqdm(
        total=len(train_loader),
        bar_format='{l_bar}{bar:10}{r_bar}',
        desc=f'Epoch {epoch:3d}/{epochs:3d}',
        disable=False
    ) as t:
        for x, y in train_loader:
          x, y = x.to(DEVICE), y.to(DEVICE)
          pred = model(x)
          loss = loss_fn(pred,y)

          opt.zero_grad()
          loss.backward()
          opt.step()

          pred_label = torch.argmax(pred, axis=1)
        
          total_loss += loss
          total_accuracy += torch.sum(pred_label == y)
          total_count += len(x)
          t.set_postfix_str(
                    'loss: {:.4f}, acc: {:.2f}%'.format(
                        total_loss/total_count,
                        100*total_accuracy/total_count,
                    ),
                )
          t.update(1)
      
    
    end_time = datetime.now()
    print(f"Time: {(end_time-start_time).seconds}s")

    val_loss, val_acc = eval_model(val_loader, model, loss_fn, DEVICE)
    print(f"Val loss: {val_loss}, Val accuracy: {val_acc}%\n")

## 8.3 Start training

In [None]:
train(train_dataloader, val_dataloader, model, opt, loss_fn, hp["epochs"], DEVICE)