In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torchvision
from PIL import Image
import numpy as np
from tqdm.notebook import tqdm
from model import InceptionResnetV2

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
model = InceptionResnetV2()

In [4]:
model.to(device)

InceptionResnetV2(
  (model): Sequential(
    (0): Stem(
      (branch1): Sequential(
        (0): BasicConv2d(
          (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2))
          (batch_norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU()
        )
        (1): BasicConv2d(
          (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
          (batch_norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU()
        )
        (2): BasicConv2d(
          (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (batch_norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU()
        )
      )
      (branch2_a): MaxPool2d(kernel_size=(3, 3), stride=2, padding=0, dilation=1, ceil_mode=False)
      (branch2_b): BasicConv2d(
        (conv): Conv2d(64, 96, kernel_size=(3, 3), stride=(2,

In [5]:
train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((299, 299)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomPerspective(distortion_scale=0.2),
    torchvision.transforms.RandomRotation(15),
    torchvision.transforms.ToTensor()
])

val_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((299, 299)),
    torchvision.transforms.ToTensor()
])

train_dataset = torchvision.datasets.ImageFolder(root='./data/train', transform=train_transforms)
val_dataset = torchvision.datasets.ImageFolder(root='./data/val', transform=val_transforms)

In [6]:
def accuracy(outputs, labels):
    _, preds = torch.max(F.log_softmax(outputs, dim=1), dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

In [7]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

In [8]:
def evaluate(model, batch, criterion, optimizer):
    inputs, labels = batch
    inputs = inputs.to(device)
    labels = labels.to(device)

    prediction = model(inputs)
    loss = criterion(prediction, labels)
    acc = accuracy(prediction, labels)

    return loss, acc

In [9]:
def train(model, batch, criterion, optimizer):
    optimizer.zero_grad()
    
    loss, acc = evaluate(model, batch, criterion, optimizer)
    
    loss.backward()
    optimizer.step()
    return loss, acc

In [10]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

In [11]:
max_epochs = 100
min_loss = np.inf
cur_patience = 0
patience = 10

for epoch in range(1, max_epochs + 1):
    train_loss = 0.0
    train_acc = 0.0
    val_loss = 0.0
    val_acc = 0.0

    model.train()
    for i, batch in tqdm(enumerate(train_loader), total=len(train_loader), leave=False):
        batch_loss, batch_acc = train(model, batch, criterion, optimizer)
        train_loss += batch_loss.item()
        train_acc += batch_acc
    
    model.eval()
    with torch.no_grad():
        for i, batch in tqdm(enumerate(val_loader), total=len(val_loader), leave=False):
            batch_loss, batch_acc = evaluate(model, batch, criterion, optimizer)
            val_loss += batch_loss.item()
            val_acc += batch_acc

    train_loss /= len(train_loader)
    train_acc /= len(train_loader)
    val_loss /= len(val_loader)
    val_acc /= len(val_loader)
    
    print(f'Epoch: {epoch}')
    print(f'Train Loss: {train_loss:.4}, Train Acc: {train_acc:.3}')
    print(f'Validation Loss: {val_loss:.4}, Validation Acc: {val_acc:.3}')
    print(f'-' * 100)

    
    if val_loss < min_loss:
        cur_patience = 0
        min_loss = val_loss
        best_model = model.state_dict()
    else:
        cur_patience += 1
        if cur_patience == patience:
            cur_patience = 0
            break

HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 1
Train Loss: 3.874, Train Acc: 0.397
Validation Loss: 1.671, Validation Acc: 0.567
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 2
Train Loss: 2.842, Train Acc: 0.497
Validation Loss: 2.243, Validation Acc: 0.471
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 3
Train Loss: 2.282, Train Acc: 0.508
Validation Loss: 9.025, Validation Acc: 0.479
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 4
Train Loss: 2.237, Train Acc: 0.504
Validation Loss: 1.055, Validation Acc: 0.663
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 5
Train Loss: 1.542, Train Acc: 0.617
Validation Loss: 12.87, Validation Acc: 0.561
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 6
Train Loss: 1.913, Train Acc: 0.637
Validation Loss: 0.6187, Validation Acc: 0.798
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 7
Train Loss: 1.576, Train Acc: 0.667
Validation Loss: 1.441, Validation Acc: 0.668
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 8
Train Loss: 1.101, Train Acc: 0.709
Validation Loss: 6.714, Validation Acc: 0.683
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 9
Train Loss: 1.243, Train Acc: 0.694
Validation Loss: 0.6886, Validation Acc: 0.795
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 10
Train Loss: 0.9344, Train Acc: 0.743
Validation Loss: 0.5578, Validation Acc: 0.817
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 11
Train Loss: 0.742, Train Acc: 0.774
Validation Loss: 0.3691, Validation Acc: 0.861
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 12
Train Loss: 0.7118, Train Acc: 0.78
Validation Loss: 0.3869, Validation Acc: 0.858
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 13
Train Loss: 0.614, Train Acc: 0.801
Validation Loss: 0.3942, Validation Acc: 0.844
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 14
Train Loss: 0.4945, Train Acc: 0.829
Validation Loss: 0.4489, Validation Acc: 0.844
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 15
Train Loss: 0.5101, Train Acc: 0.833
Validation Loss: 0.2949, Validation Acc: 0.874
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 16
Train Loss: 0.4564, Train Acc: 0.847
Validation Loss: 0.2745, Validation Acc: 0.898
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 17
Train Loss: 0.4705, Train Acc: 0.845
Validation Loss: 0.9645, Validation Acc: 0.807
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 18
Train Loss: 1.403, Train Acc: 0.687
Validation Loss: 0.4667, Validation Acc: 0.828
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 19
Train Loss: 0.8252, Train Acc: 0.782
Validation Loss: 0.506, Validation Acc: 0.836
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 20
Train Loss: 0.666, Train Acc: 0.814
Validation Loss: 0.2609, Validation Acc: 0.892
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 21
Train Loss: 0.5084, Train Acc: 0.84
Validation Loss: 0.2659, Validation Acc: 0.902
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 22
Train Loss: 0.4646, Train Acc: 0.847
Validation Loss: 0.3877, Validation Acc: 0.881
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 23
Train Loss: 0.4917, Train Acc: 0.856
Validation Loss: 0.3051, Validation Acc: 0.892
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 24
Train Loss: 0.3624, Train Acc: 0.875
Validation Loss: 0.2204, Validation Acc: 0.916
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 25
Train Loss: 0.3555, Train Acc: 0.881
Validation Loss: 0.2416, Validation Acc: 0.912
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 26
Train Loss: 0.3609, Train Acc: 0.891
Validation Loss: 9.0, Validation Acc: 0.736
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 27
Train Loss: 0.437, Train Acc: 0.869
Validation Loss: 0.2814, Validation Acc: 0.885
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 28
Train Loss: 0.3893, Train Acc: 0.888
Validation Loss: 0.2832, Validation Acc: 0.917
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 29
Train Loss: 0.293, Train Acc: 0.895
Validation Loss: 0.2473, Validation Acc: 0.924
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 30
Train Loss: 0.3263, Train Acc: 0.892
Validation Loss: 0.3006, Validation Acc: 0.923
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 31
Train Loss: 0.529, Train Acc: 0.862
Validation Loss: 0.4219, Validation Acc: 0.873
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 32
Train Loss: 0.4424, Train Acc: 0.872
Validation Loss: 0.2243, Validation Acc: 0.929
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 33
Train Loss: 0.3186, Train Acc: 0.895
Validation Loss: 0.509, Validation Acc: 0.844
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

Epoch: 34
Train Loss: 0.8505, Train Acc: 0.809
Validation Loss: 0.9736, Validation Acc: 0.798
----------------------------------------------------------------------------------------------------


In [12]:
torch.save(best_model, 'model.pt')