In [1]:
import os

from sklearn.preprocessing import LabelEncoder

from tqdm import tqdm
import timm

import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn

from src.model import NW_CNN
from src.dataset import CropsPytorchDataset
from src.utils import Averager

In [2]:
le = LabelEncoder()
le.fit(['human-hold-bicycle', 'human-ride-bicycle', 'human-ride-motorcycle', 'human-walk-bicycle', 'human-walk-motorcycle', 'human-hold-motorcycle'])
train_dataset = CropsPytorchDataset(img_dir='data/annotated_data/clips/', anno_file='data_anno/training_data.csv', label_encoder=le)
val_dataset = CropsPytorchDataset(img_dir='data/annotated_data/clips/', anno_file='data_anno/testing_data.csv', label_encoder=le)

train_dataloader = DataLoader(train_dataset, batch_size=40, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=40, shuffle=True)

In [3]:
out_dir = 'products/models'
if not os.path.isdir(out_dir):
    os.mkdir(out_dir)

In [4]:
model = timm.create_model('resnet18', num_classes = 6, in_chans=1)
# model = NW_CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)

device = torch.device('cuda')

In [5]:
#In case you want to continue training from checkpoint
model = torch.load('products/models/NW_CNN_checkpoint.pth')
model.train()
model.to(device)

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, m

In [6]:
train_losses = []
val_losses = []
train_loss_hist = Averager()
val_loss_hist = Averager()
best_valid_loss=float('inf')
max_epochs = 40

for epoch in range(20, max_epochs):  # loop over the dataset multiple times
    print(f"\nEPOCH {epoch} of {max_epochs}")
    train_loss_hist.reset()
    prog_bar = tqdm(train_dataloader, total=len(train_dataloader), leave=True)
    for inputs, labels in prog_bar:
        labels = torch.squeeze(labels, dim=1).long().to(device)
        inputs = inputs.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        train_loss_hist.send(loss)
        loss.backward()
        optimizer.step()

   
    print(f"Epoch #{epoch} train loss: {train_loss_hist.value:.3f}")
    train_losses.append(train_loss_hist.value)

    val_loss_hist.reset()  
    
    prog_bar = tqdm(val_dataloader, total=len(val_dataloader), leave=True)
    for inputs, labels in prog_bar:
        labels = torch.squeeze(labels, dim=1).long().to(device)
        inputs = inputs.to(device)
        outputs = model(inputs)
        with torch.no_grad():
            loss = criterion(outputs, labels)
        val_loss_hist.send(loss)      
     
    print(f"Epoch #{epoch} validation loss: {val_loss_hist.value:.3f}")
    val_losses.append(val_loss_hist.value)

    #Saving best model
    if val_loss_hist.value < best_valid_loss:
        best_valid_loss = val_loss_hist.value
        print(f"\nBest validation loss: {best_valid_loss}")
        print(f"\nSaving best model for epoch: {epoch}\n")
        torch.save(model, os.path.join('products/models', f'NW_CNN_checkpoint.pth'))

    
print('Finished Training')


EPOCH 20 of 40


100%|██████████| 192/192 [00:18<00:00, 10.67it/s]


Epoch #20 train loss: 0.612


100%|██████████| 88/88 [00:07<00:00, 11.65it/s]


Epoch #20 validation loss: 0.555

Best validation loss: 0.5548872351646423

Saving best model for epoch: 20


EPOCH 21 of 40


100%|██████████| 192/192 [00:18<00:00, 10.55it/s]


Epoch #21 train loss: 0.611


100%|██████████| 88/88 [00:07<00:00, 12.07it/s]


Epoch #21 validation loss: 0.554

Best validation loss: 0.5539177656173706

Saving best model for epoch: 21


EPOCH 22 of 40


100%|██████████| 192/192 [00:17<00:00, 10.75it/s]


Epoch #22 train loss: 0.611


100%|██████████| 88/88 [00:07<00:00, 12.11it/s]


Epoch #22 validation loss: 0.553

Best validation loss: 0.5528751015663147

Saving best model for epoch: 22


EPOCH 23 of 40


100%|██████████| 192/192 [00:17<00:00, 10.88it/s]


Epoch #23 train loss: 0.613


100%|██████████| 88/88 [00:07<00:00, 11.98it/s]


Epoch #23 validation loss: 0.550

Best validation loss: 0.5498838424682617

Saving best model for epoch: 23


EPOCH 24 of 40


100%|██████████| 192/192 [00:17<00:00, 10.85it/s]


Epoch #24 train loss: 0.614


100%|██████████| 88/88 [00:07<00:00, 12.14it/s]


Epoch #24 validation loss: 0.552

EPOCH 25 of 40


100%|██████████| 192/192 [00:17<00:00, 10.98it/s]


Epoch #25 train loss: 0.610


100%|██████████| 88/88 [00:07<00:00, 11.72it/s]


Epoch #25 validation loss: 0.553

EPOCH 26 of 40


100%|██████████| 192/192 [00:18<00:00, 10.59it/s]


Epoch #26 train loss: 0.613


100%|██████████| 88/88 [00:07<00:00, 12.32it/s]


Epoch #26 validation loss: 0.555

EPOCH 27 of 40


100%|██████████| 192/192 [00:17<00:00, 10.91it/s]


Epoch #27 train loss: 0.612


100%|██████████| 88/88 [00:07<00:00, 11.93it/s]


Epoch #27 validation loss: 0.549

Best validation loss: 0.5488559007644653

Saving best model for epoch: 27


EPOCH 28 of 40


100%|██████████| 192/192 [00:17<00:00, 11.06it/s]


Epoch #28 train loss: 0.613


100%|██████████| 88/88 [00:07<00:00, 11.91it/s]


Epoch #28 validation loss: 0.551

EPOCH 29 of 40


100%|██████████| 192/192 [00:17<00:00, 10.94it/s]


Epoch #29 train loss: 0.611


100%|██████████| 88/88 [00:07<00:00, 12.13it/s]


Epoch #29 validation loss: 0.550

EPOCH 30 of 40


100%|██████████| 192/192 [00:17<00:00, 11.19it/s]


Epoch #30 train loss: 0.611


100%|██████████| 88/88 [00:07<00:00, 12.03it/s]


Epoch #30 validation loss: 0.551

EPOCH 31 of 40


100%|██████████| 192/192 [00:17<00:00, 11.09it/s]


Epoch #31 train loss: 0.610


100%|██████████| 88/88 [00:07<00:00, 11.67it/s]


Epoch #31 validation loss: 0.550

EPOCH 32 of 40


100%|██████████| 192/192 [00:17<00:00, 10.70it/s]


Epoch #32 train loss: 0.612


100%|██████████| 88/88 [00:07<00:00, 11.61it/s]


Epoch #32 validation loss: 0.549

EPOCH 33 of 40


100%|██████████| 192/192 [00:18<00:00, 10.40it/s]


Epoch #33 train loss: 0.612


100%|██████████| 88/88 [00:07<00:00, 12.10it/s]


Epoch #33 validation loss: 0.548

Best validation loss: 0.5480824708938599

Saving best model for epoch: 33


EPOCH 34 of 40


100%|██████████| 192/192 [00:17<00:00, 10.76it/s]


Epoch #34 train loss: 0.610


100%|██████████| 88/88 [00:07<00:00, 11.61it/s]


Epoch #34 validation loss: 0.552

EPOCH 35 of 40


100%|██████████| 192/192 [00:17<00:00, 10.90it/s]


Epoch #35 train loss: 0.613


100%|██████████| 88/88 [00:08<00:00, 10.77it/s]


Epoch #35 validation loss: 0.548

EPOCH 36 of 40


100%|██████████| 192/192 [00:25<00:00,  7.59it/s]


Epoch #36 train loss: 0.609


100%|██████████| 88/88 [00:07<00:00, 11.73it/s]


Epoch #36 validation loss: 0.550

EPOCH 37 of 40


100%|██████████| 192/192 [00:19<00:00,  9.80it/s]


Epoch #37 train loss: 0.613


100%|██████████| 88/88 [00:08<00:00, 10.55it/s]


Epoch #37 validation loss: 0.547

Best validation loss: 0.5472320914268494

Saving best model for epoch: 37


EPOCH 38 of 40


100%|██████████| 192/192 [00:18<00:00, 10.18it/s]


Epoch #38 train loss: 0.611


100%|██████████| 88/88 [00:07<00:00, 11.19it/s]


Epoch #38 validation loss: 0.554

EPOCH 39 of 40


100%|██████████| 192/192 [00:17<00:00, 10.84it/s]


Epoch #39 train loss: 0.611


100%|██████████| 88/88 [00:07<00:00, 11.56it/s]

Epoch #39 validation loss: 0.550
Finished Training





In [9]:
model = torch.load('products/models/NW_CNN_checkpoint.pth')
model.to(torch.device('cpu'))
model.eval();

In [10]:
from torchmetrics import Accuracy, Recall, Precision, ConfusionMatrix, MetricCollection
metrics = MetricCollection([Accuracy(task="multiclass", num_classes=6, average='macro'),
                            Precision(task="multiclass", num_classes=6, average='macro'),
                            Recall(task="multiclass", num_classes=6, average='macro'),
                            ConfusionMatrix(task="multiclass", num_classes=6)])

for i, (crop, target) in enumerate(val_dataloader):
    prediction = torch.argmax(model(crop), dim=1)
    target = torch.squeeze(target)
    metrics(prediction, target)

print(metrics.compute())
metrics.reset()


{'MulticlassAccuracy': tensor(0.2615), 'MulticlassPrecision': tensor(0.3536), 'MulticlassRecall': tensor(0.2615), 'MulticlassConfusionMatrix': tensor([[  34,    0,  122,    0,   27,    0],
        [   0,    0,    2,    0,    0,    0],
        [  14,    0, 2930,    0,   15,    0],
        [   1,    0,  195,    0,    0,    0],
        [  10,    0,  122,    0,   20,    0],
        [   0,    0,    0,    0,    0,    0]])}
