## Try 2

In [6]:
import os
import sys
sys.path.append(os.getcwd())

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data.dataloader import DataLoader


In [22]:
PATH = "../finetune_data/output/"

train_dir = PATH + "train"
val_dir = PATH + "val"

num_classes = 5
batch_size = 16
num_epochs = 10

feature_extract = True

In [None]:
from tqdm import tqdm
import os

def train(model, iterator, optimizer, criterion, accuracy, device):
    
  epoch_loss = 0
  epoch_cor = 0.0
  total_samples = 0

  model.train()

  pbar = tqdm(total=len(iterator), dynamic_ncols=True)
  
  for (x, y) in iterator:
    x = x.to(device)
    y = y.to(device)
    
    optimizer.zero_grad()
            
    y_pred = model(x)
    
    loss = criterion(y_pred, y)
    
    cor, n_samples = accuracy(y_pred, y)
    
    loss.backward()
    
    optimizer.step()
    
    epoch_loss += loss.item()
    epoch_cor += cor
    total_samples += n_samples

    pbar.update(1) 
  epoch_loss /= len(iterator)
  epoch_acc = epoch_cor / total_samples
  pbar.close()
      
  return epoch_loss, epoch_acc

In [None]:
def evaluate(model, iterator, criterion, accuracy, device):
    
  epoch_loss = 0
  epoch_cor = 0.0
  total_samples = 0
  
  model.eval()

  pbar = tqdm(total=len(iterator), dynamic_ncols=True)
  
  with torch.no_grad():
    for (x, y) in iterator:

      x = x.to(device)
      y = y.to(device)

      y_pred = model(x)

      loss = criterion(y_pred, y)

      correct, n_samples = accuracy(y_pred, y)

      epoch_loss += loss.item()
      epoch_cor += correct
      total_samples += n_samples

      pbar.update(1)
      
  epoch_loss /= len(iterator)
  epoch_acc = epoch_cor/total_samples

  pbar.close()
      
  return epoch_loss, epoch_acc

In [25]:
softmax = nn.Softmax(dim=-1)
def calc_accuracy(pred, target):
    
    pred = softmax(pred)
    pred = torch.argmax(pred, dim=-1)

    correct = torch.sum(pred==target)

    return correct, len(target)

data_transforms = {
    'train': transforms.Compose(
        [transforms.Grayscale(num_output_channels=3), transforms.Resize(224), \
         transforms.CenterCrop((224,224)), transforms.ToTensor(), \
         transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
    ),
    'val': transforms.Compose(
        [transforms.Grayscale(num_output_channels=3), transforms.Resize(224), \
         transforms.CenterCrop((224,224)), transforms.ToTensor(), \
         transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
    ),
}

In [None]:
from torch.optim import Adam
from torchvision.models import vgg16, vgg16_bn

model_name = "shottypes_vgg16_bn"
device = "mps"
EPOCHS = 12
n_class = 5
model = vgg16_bn(pretrained=True)

model.classifier[0] = torch.nn.Linear(7 * 7 * 512, 4096)
model.classifier[3] = torch.nn.Linear(4096, 4096)
model.classifier[6] = torch.nn.Linear(4096, n_class)

model = model.to(device)

train_set = datasets.ImageFolder(train_dir, data_transforms['train'])
val_set = datasets.ImageFolder(val_dir, data_transforms['val'])

train_loader = DataLoader(train_set, shuffle=True, batch_size=32, num_workers=8)
val_loader = DataLoader(val_set, shuffle=False, batch_size=32, num_workers=8)

optimizer = Adam([
    {'params': model.classifier.parameters(), 'lr': 1e-4} 
])

criterion = nn.CrossEntropyLoss()

best_valid_loss = float('+inf')
best_epoch = 0
best_valid_acc = 0

for epoch in range(EPOCHS):

    print("EPOCH:", epoch+1)

    train_loss, train_acc_1 = train(model, train_loader, optimizer, criterion, calc_accuracy, device)
    
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\tTrain Acc @1: {train_acc_1*100:6.2f}%')

    valid_loss, valid_acc_1 = evaluate(model, val_loader, criterion, calc_accuracy, device)
    
    print(f'\tValid Loss: {valid_loss:.3f}')
    print(f'\tValid Acc @1: {valid_acc_1*100:6.2f}%')
    
    if best_valid_loss > valid_loss:        
        filename = model_name + '.pt'
        
        if os.path.isfile(filename):
            os.remove(filename)
        best_valid_loss = valid_loss
        best_valid_acc = valid_acc_1
        best_epoch = epoch + 1
        torch.save(model.state_dict(), filename)
        print(f"\tEpoch {best_epoch} saved")

    print(f'\tBest Valid Loss @1: {best_valid_loss:.3f} in Epoch {best_epoch}')