In [11]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score

from tqdm import tqdm
import sys
sys.path.insert(1, '../')
from early_stopping_cls import EarlyStopping

device = torch.device("mps")

In [2]:
PATH = '/Users/rishinigam/t81_588_course/datasets/iris-image'

In [7]:
# transformation and data augmentation for training and validation
train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(360),
    transforms.RandomResizedCrop(256, scale=(0.5, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

train_dataset = ImageFolder(root=PATH, transform=train_transforms)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

val_dataset = ImageFolder(root=PATH, transform=val_transforms)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [8]:
# modeling
class_cnt = len(train_dataset.classes)
model = nn.Sequential(
    # Features
    nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),

    # classifier
    nn.Flatten(),
    nn.Dropout(0.5),
    nn.Linear(in_features=64*6*6, out_features=512),
    nn.ReLU(inplace=True),
    nn.Linear(in_features=512, out_features=class_cnt)).to(device)

optimizer = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()

In [9]:
def validation(model, loader):
    loss_num = loss_denom = 0.0
    model.eval()
    with torch.no_grad():
        for inputs, labels in loader:
            outputs = model(inputs.to(device))
            loss_num += loss_fn(outputs, labels.to(device))
            loss_denom += 1

    return loss_num/loss_denom

In [12]:
BATCH_SIZE = 16
es = EarlyStopping()

epoch = 0
done = False
while epoch < 1000 and not done:
    epoch += 1
    steps = list(enumerate(train_dataloader))
    pbar = tqdm(steps)
    model.train()
    for i, (x_batch, y_batch) in pbar:
        y_batch_pred = model(x_batch.to(device))
        loss = loss_fn(y_batch_pred, y_batch.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss, current = loss.item(), (i+1)* len(x_batch)
        if i == len(steps) - 1:
            model.eval()
            vloss = validation(model, val_dataloader)
            if es(model, vloss):
                done = True
            pbar.set_description(
                f"Epoch: {epoch}, tloss: {loss}, vloss: {vloss:>7f}, {es.status}"
            )
        else:
            pbar.set_description(
                f"Epoch: {epoch}, tloss: {loss:}"
            )

Epoch: 1, tloss: 0.5449950098991394, vloss: 0.780232, : 100%|██████████| 14/14 [00:02<00:00,  5.01it/s]
Epoch: 2, tloss: 0.9308778643608093, vloss: 0.759767, Improvement found, counter reset to 0: 100%|██████████| 14/14 [00:02<00:00,  5.91it/s]
Epoch: 3, tloss: 1.3093557357788086, vloss: 0.820800, No improvement in the last 1 epochs: 100%|██████████| 14/14 [00:02<00:00,  5.57it/s]
Epoch: 4, tloss: 0.30158838629722595, vloss: 0.774938, No improvement in the last 2 epochs: 100%|██████████| 14/14 [00:02<00:00,  5.26it/s]
Epoch: 5, tloss: 0.35183247923851013, vloss: 0.735944, Improvement found, counter reset to 0: 100%|██████████| 14/14 [00:02<00:00,  5.64it/s]
Epoch: 6, tloss: 0.8285929560661316, vloss: 0.726671, Improvement found, counter reset to 0: 100%|██████████| 14/14 [00:02<00:00,  5.48it/s]
Epoch: 7, tloss: 1.3088037967681885, vloss: 0.760502, No improvement in the last 1 epochs: 100%|██████████| 14/14 [00:02<00:00,  5.60it/s]
Epoch: 8, tloss: 0.6994555592536926, vloss: 0.751191, 

In [14]:
# validation
model.eval()
preds = []
targets = []
with torch.no_grad():
    for inputs, labels in val_dataloader:
        outputs = model(inputs.to(device))
        _, predictions = torch.max(outputs, 1)
        preds.extend(predictions.cpu().numpy())
        targets.extend(labels.cpu().numpy())

correct = accuracy_score(targets, preds)
print(f"Accuracy: {correct}")

Accuracy: 0.6935866983372921
