## Try 2

In [1]:
import os
import sys
sys.path.append(os.getcwd())

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data.dataloader import DataLoader


In [2]:
PATH = "../finetune_data/output/"

train_dir = PATH + "train"
val_dir = PATH + "val"

num_classes = 5
batch_size = 16
num_epochs = 10

feature_extract = True

In [3]:
from tqdm import tqdm
import os

def train(model, iterator, optimizer, criterion, accuracy, device):
    
  epoch_loss = 0
  epoch_cor = 0.0
  total_samples = 0

  model.train()

  pbar = tqdm(total=len(iterator), dynamic_ncols=True)
  
  for (x, y) in iterator:
    x = x.to(device)
    y = y.to(device)
    
    optimizer.zero_grad()
            
    y_pred = model(x)
    
    loss = criterion(y_pred, y)
    
    cor, n_samples = accuracy(y_pred, y)
    
    loss.backward()
    
    optimizer.step()
    
    epoch_loss += loss.item()
    epoch_cor += cor
    total_samples += n_samples

    pbar.update(1) 
  epoch_loss /= len(iterator)
  epoch_acc = epoch_cor / total_samples
  pbar.close()
      
  return epoch_loss, epoch_acc

In [4]:
def evaluate(model, iterator, criterion, accuracy, device):
    
  epoch_loss = 0
  epoch_cor = 0.0
  total_samples = 0
  
  model.eval()

  pbar = tqdm(total=len(iterator), dynamic_ncols=True)
  
  with torch.no_grad():
    for (x, y) in iterator:

      x = x.to(device)
      y = y.to(device)

      y_pred = model(x)

      loss = criterion(y_pred, y)

      correct, n_samples = accuracy(y_pred, y)

      epoch_loss += loss.item()
      epoch_cor += correct
      total_samples += n_samples

      pbar.update(1)
      
  epoch_loss /= len(iterator)
  epoch_acc = epoch_cor/total_samples

  pbar.close()
      
  return epoch_loss, epoch_acc

In [5]:
softmax = nn.Softmax(dim=-1)
def calc_accuracy(pred, target):
    
    pred = softmax(pred)
    pred = torch.argmax(pred, dim=-1)

    correct = torch.sum(pred==target)

    return correct, len(target)

data_transforms = {
    'train': transforms.Compose(
        [transforms.Grayscale(num_output_channels=3), transforms.Resize(224), \
         transforms.CenterCrop((224,224)), transforms.ToTensor(), \
         transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
    ),
    'val': transforms.Compose(
        [transforms.Grayscale(num_output_channels=3), transforms.Resize(224), \
         transforms.CenterCrop((224,224)), transforms.ToTensor(), \
         transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
    ),
}

In [6]:
from torch.optim import Adam
from torchvision.models import vgg16, vgg16_bn

model_name = "shottypes_vgg16_bn"
device = "mps"
EPOCHS = 12
n_class = 5
model = vgg16_bn(pretrained=True)

model.classifier[0] = torch.nn.Linear(7 * 7 * 512, 4096)
model.classifier[3] = torch.nn.Linear(4096, 4096)
model.classifier[6] = torch.nn.Linear(4096, n_class)

model = model.to(device)

train_set = datasets.ImageFolder(train_dir, data_transforms['train'])
val_set = datasets.ImageFolder(val_dir, data_transforms['val'])

train_loader = DataLoader(train_set, shuffle=True, batch_size=32, num_workers=8)
val_loader = DataLoader(val_set, shuffle=False, batch_size=32, num_workers=8)

optimizer = Adam([
    {'params': model.classifier.parameters(), 'lr': 1e-4} 
])

criterion = nn.CrossEntropyLoss()

best_valid_loss = float('+inf')
best_epoch = 0
best_valid_acc = 0

for epoch in range(EPOCHS):

    print("EPOCH:", epoch+1)

    train_loss, train_acc_1 = train(model, train_loader, optimizer, criterion, calc_accuracy, device)
    
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\tTrain Acc @1: {train_acc_1*100:6.2f}%')

    valid_loss, valid_acc_1 = evaluate(model, val_loader, criterion, calc_accuracy, device)
    
    print(f'\tValid Loss: {valid_loss:.3f}')
    print(f'\tValid Acc @1: {valid_acc_1*100:6.2f}%')
    
    if best_valid_loss > valid_loss:        
        filename = model_name + '.pt'
        
        if os.path.isfile(filename):
            os.remove(filename)
        best_valid_loss = valid_loss
        best_valid_acc = valid_acc_1
        best_epoch = epoch + 1
        torch.save(model.state_dict(), filename)
        print(f"\tEpoch {best_epoch} saved")

    print(f'\tBest Valid Loss @1: {best_valid_loss:.3f} in Epoch {best_epoch}')



EPOCH: 1


100%|██████████| 12/12 [01:00<00:00,  5.03s/it]


	Train Loss: 1.387
	Train Acc @1:  35.99%


100%|██████████| 2/2 [00:13<00:00,  6.66s/it]


	Valid Loss: 1.223
	Valid Acc @1:  61.36%
	Epoch 1 saved
	Best Valid Loss @1: 1.223 in Epoch 1
EPOCH: 2


100%|██████████| 12/12 [00:55<00:00,  4.64s/it]


	Train Loss: 0.897
	Train Acc @1:  73.35%


100%|██████████| 2/2 [00:13<00:00,  6.98s/it]


	Valid Loss: 1.009
	Valid Acc @1:  65.91%
	Epoch 2 saved
	Best Valid Loss @1: 1.009 in Epoch 2
EPOCH: 3


100%|██████████| 12/12 [00:55<00:00,  4.66s/it]


	Train Loss: 0.478
	Train Acc @1:  85.99%


100%|██████████| 2/2 [00:13<00:00,  6.87s/it]


	Valid Loss: 1.014
	Valid Acc @1:  56.82%
	Best Valid Loss @1: 1.009 in Epoch 2
EPOCH: 4


100%|██████████| 12/12 [00:55<00:00,  4.64s/it]


	Train Loss: 0.206
	Train Acc @1:  95.05%


100%|██████████| 2/2 [00:13<00:00,  6.98s/it]


	Valid Loss: 0.947
	Valid Acc @1:  63.64%
	Epoch 4 saved
	Best Valid Loss @1: 0.947 in Epoch 4
EPOCH: 5


100%|██████████| 12/12 [00:56<00:00,  4.67s/it]


	Train Loss: 0.062
	Train Acc @1:  98.90%


100%|██████████| 2/2 [00:13<00:00,  6.66s/it]


	Valid Loss: 1.015
	Valid Acc @1:  63.64%
	Best Valid Loss @1: 0.947 in Epoch 4
EPOCH: 6


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


	Train Loss: 0.031
	Train Acc @1:  99.18%


100%|██████████| 2/2 [00:12<00:00,  6.37s/it]


	Valid Loss: 1.041
	Valid Acc @1:  68.18%
	Best Valid Loss @1: 0.947 in Epoch 4
EPOCH: 7


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


	Train Loss: 0.029
	Train Acc @1:  98.90%


100%|██████████| 2/2 [00:13<00:00,  6.60s/it]


	Valid Loss: 1.056
	Valid Acc @1:  63.64%
	Best Valid Loss @1: 0.947 in Epoch 4
EPOCH: 8


100%|██████████| 12/12 [00:55<00:00,  4.61s/it]


	Train Loss: 0.024
	Train Acc @1:  99.45%


100%|██████████| 2/2 [00:12<00:00,  6.47s/it]


	Valid Loss: 1.203
	Valid Acc @1:  61.36%
	Best Valid Loss @1: 0.947 in Epoch 4
EPOCH: 9


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


	Train Loss: 0.046
	Train Acc @1:  99.18%


100%|██████████| 2/2 [00:12<00:00,  6.37s/it]


	Valid Loss: 1.228
	Valid Acc @1:  65.91%
	Best Valid Loss @1: 0.947 in Epoch 4
EPOCH: 10


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


	Train Loss: 0.035
	Train Acc @1:  99.45%


100%|██████████| 2/2 [00:12<00:00,  6.40s/it]


	Valid Loss: 1.206
	Valid Acc @1:  59.09%
	Best Valid Loss @1: 0.947 in Epoch 4
EPOCH: 11


100%|██████████| 12/12 [00:55<00:00,  4.58s/it]


	Train Loss: 0.019
	Train Acc @1:  99.45%


100%|██████████| 2/2 [00:12<00:00,  6.37s/it]


	Valid Loss: 1.136
	Valid Acc @1:  65.91%
	Best Valid Loss @1: 0.947 in Epoch 4
EPOCH: 12


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


	Train Loss: 0.019
	Train Acc @1:  99.45%


100%|██████████| 2/2 [00:13<00:00,  6.93s/it]

	Valid Loss: 1.345
	Valid Acc @1:  70.45%
	Best Valid Loss @1: 0.947 in Epoch 4



