<a href="https://colab.research.google.com/github/quang-vo-ds/banana_leaf_disease_detection/blob/main/banana_leaf_disease_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Initial Setup

In [1]:
!pip -q install pydicom
!pip -q install timm
!pip -q install catalyst

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m15.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m446.7/446.7 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m244.2/244.2 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.6/101.6 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from glob import glob
from sklearn.model_selection import GroupKFold, StratifiedKFold
import cv2
from skimage import io
import torch
from torch import nn
import os
from datetime import datetime
import time
import random
import cv2
import torchvision
from torchvision import transforms
import pandas as pd
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.cuda.amp import autocast, GradScaler
from torch.nn.modules.loss import _WeightedLoss
import torch.nn.functional as F

import sklearn
import warnings
import joblib
from sklearn.metrics import roc_auc_score, log_loss
from sklearn import metrics
import warnings
import cv2
import pydicom
import timm
#from efficientnet_pytorch import EfficientNet
from scipy.ndimage import zoom

In [3]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/Vin_ML_Course/Final_Project
root_dir = os.getcwd()
save_data_dir = os.path.join(root_dir, "output/processed_data")
save_model_dir = os.path.join(root_dir, "output/checkpoints")

Mounted at /content/drive
/content/drive/MyDrive/Vin_ML_Course/Final_Project


In [4]:
CFG = {
    'fold_num': 5,
    'seed': 719,
    'model_arch': 'tf_efficientnet_b4_ns',
    'img_size': 512,
    'train_all': False,
    'epochs': 10,
    'train_bs': 16,
    'valid_bs': 32,
    'T_0': 10,
    'lr': 1e-4,
    'min_lr': 1e-6,
    'weight_decay':1e-6,
    'num_workers': 4,
    'accum_iter': 2, # suppoprt to do batch accumulation for backprop with effectively larger batch size
    'verbose_step': 1,
    'device': 'cuda:0'
}

In [5]:
path = os.path.join(save_data_dir, "train.csv")
train = pd.read_csv(path)
train.head()

Unnamed: 0,id,label,label_name,path
0,Banana___Healthy_29.jpg,0,healthy,/content/drive/MyDrive/Vin_ML_Course/Final_Pro...
1,Banana___Healthy_295.jpg,0,healthy,/content/drive/MyDrive/Vin_ML_Course/Final_Pro...
2,Banana___Healthy_189.jpg,0,healthy,/content/drive/MyDrive/Vin_ML_Course/Final_Pro...
3,Banana___Healthy_141.jpg,0,healthy,/content/drive/MyDrive/Vin_ML_Course/Final_Pro...
4,Banana___Healthy_226.jpg,0,healthy,/content/drive/MyDrive/Vin_ML_Course/Final_Pro...


## Utils

In [6]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    return im_rgb

## Dataset

In [7]:
class BananaDataset(Dataset):
    def __init__(self, df,
                 transforms=None,
                 output_label=True,
                 one_hot_label=False,
                ):

        super().__init__()
        self.df = df.copy()
        self.transforms = transforms
        self.output_label = output_label
        self.one_hot_label = one_hot_label

        if output_label == True:
            self.labels = self.df['label'].values
            if one_hot_label is True:
                self.labels = np.eye(self.df['label'].max()+1)[self.labels]

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index: int):

        # get labels
        if self.output_label:
            target = self.labels[index]

        img_dir = self.df.iloc[index].path
        img  = get_img(img_dir)

        if self.transforms:
            img = self.transforms(image=img)['image']

        if self.output_label == True:
            return img, target
        else:
            return img

## Image Augmentation

In [8]:
from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)

from albumentations.pytorch import ToTensorV2

def get_train_transforms():
    return Compose([
        RandomResizedCrop(CFG['img_size'], CFG['img_size']),
        Transpose(p=0.5),
        HorizontalFlip(p=0.5),
        VerticalFlip(p=0.5),
        ShiftScaleRotate(p=0.5),
        HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
        RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
        CoarseDropout(p=0.5),
        Cutout(p=0.5),
        ToTensorV2(p=1.0),
        ], p=1.)


def get_valid_transforms():
    return Compose([
        Resize(CFG['img_size'], CFG['img_size']),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
        ToTensorV2(p=1.0),
        ], p=1.)

## Model

In [9]:
class MyImgClassifier(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, n_class)
        '''
        self.model.classifier = nn.Sequential(
            nn.Dropout(0.3),
            #nn.Linear(n_features, hidden_size,bias=True), nn.ELU(),
            nn.Linear(n_features, n_class, bias=True)
        )
        '''
    def forward(self, x):
        x = self.model(x)
        return x

## Training API

In [10]:
def prepare_dataloader(df, trn_idx, val_idx, train_all=CFG['train_all']):

    from catalyst.data.sampler import BalanceClassSampler

    train_ = df.loc[trn_idx,:].reset_index(drop=True)
    valid_ = df.loc[val_idx,:].reset_index(drop=True)

    train_ds = BananaDataset(train_, transforms=get_train_transforms(), output_label=True, one_hot_label=False)
    valid_ds = BananaDataset(valid_, transforms=get_valid_transforms(), output_label=True)

    train_loader = torch.utils.data.DataLoader(
        train_ds,
        batch_size=CFG['train_bs'],
        pin_memory=False,
        drop_last=False,
        shuffle=True,
        num_workers=CFG['num_workers'],
        #sampler=BalanceClassSampler(labels=train_['label'].values, mode="downsampling")
    )
    val_loader = torch.utils.data.DataLoader(
        valid_ds,
        batch_size=CFG['valid_bs'],
        num_workers=CFG['num_workers'],
        shuffle=False,
        pin_memory=False,
    )
    return train_loader, val_loader

def train_one_epoch(epoch, model, loss_fn, optimizer, train_loader, device, scaler, scheduler=None, schd_batch_update=False):
    model.train()

    t = time.time()
    running_loss = None

    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()

        with autocast():
            image_preds = model(imgs)

            loss = loss_fn(image_preds, image_labels)

            scaler.scale(loss).backward()

            if running_loss is None:
                running_loss = loss.item()
            else:
                running_loss = running_loss * .99 + loss.item() * .01

            if ((step + 1) %  CFG['accum_iter'] == 0) or ((step + 1) == len(train_loader)):
                # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)

                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

                if scheduler is not None and schd_batch_update:
                    scheduler.step()

            if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(train_loader)):
                description = f'epoch {epoch} loss: {running_loss:.4f}'

                pbar.set_description(description)

    if scheduler is not None and not schd_batch_update:
        scheduler.step()

def valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None, schd_loss_update=False):
    model.eval()

    t = time.time()
    loss_sum = 0
    sample_num = 0
    image_preds_all = []
    image_targets_all = []

    pbar = tqdm(enumerate(val_loader), total=len(val_loader))
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()

        image_preds = model(imgs)
        image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()]
        image_targets_all += [image_labels.detach().cpu().numpy()]

        loss = loss_fn(image_preds, image_labels)

        loss_sum += loss.item()*image_labels.shape[0]
        sample_num += image_labels.shape[0]

        if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(val_loader)):
            description = f'epoch {epoch} loss: {loss_sum/sample_num:.4f}'
            pbar.set_description(description)

    image_preds_all = np.concatenate(image_preds_all)
    image_targets_all = np.concatenate(image_targets_all)
    print('validation multi-class accuracy = {:.4f}'.format((image_preds_all==image_targets_all).mean()))

    if scheduler is not None:
        if schd_loss_update:
            scheduler.step(loss_sum/sample_num)
        else:
            scheduler.step()

## Main loop

In [11]:
if __name__ == '__main__':
     # for training only, need nightly build pytorch

    seed_everything(CFG['seed'])

    folds = StratifiedKFold(n_splits=CFG['fold_num'], shuffle=True, random_state=CFG['seed']).split(np.arange(train.shape[0]), train.label.values)

    for fold, (trn_idx, val_idx) in enumerate(folds):
        print('Training with {} started'.format(fold))
        print(len(trn_idx), len(val_idx))

        train_loader, val_loader = prepare_dataloader(train, trn_idx, val_idx)

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        model = MyImgClassifier(CFG['model_arch'], train.label.nunique(), pretrained=True).to(device)
        scaler = GradScaler()
        optimizer = torch.optim.Adam(model.parameters(), lr=CFG['lr'], weight_decay=CFG['weight_decay'])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=CFG['T_0'], T_mult=1, eta_min=CFG['min_lr'], last_epoch=-1)

        loss_tr = nn.CrossEntropyLoss().to(device)
        loss_fn = nn.CrossEntropyLoss().to(device)

        for epoch in range(CFG['epochs']):
            train_one_epoch(epoch, model, loss_tr, optimizer, train_loader, device, scaler, scheduler=scheduler, schd_batch_update=False)

            with torch.no_grad():
                valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None, schd_loss_update=False)

            torch.save(model.state_dict(), os.path.join(save_model_dir,'{}_fold_{}_{}'.format(CFG['model_arch'], fold, epoch)))

        del model, optimizer, train_loader, val_loader, scaler, scheduler
        torch.cuda.empty_cache()

Training with 0 started
2200 550


  model = create_fn(


Downloading model.safetensors:   0%|          | 0.00/77.9M [00:00<?, ?B/s]

epoch 0 loss: 1.2105: 100%|██████████| 138/138 [03:13<00:00,  1.41s/it]
epoch 0 loss: 0.7718: 100%|██████████| 18/18 [00:50<00:00,  2.79s/it]


validation multi-class accuracy = 0.7255


epoch 1 loss: 0.6981: 100%|██████████| 138/138 [01:19<00:00,  1.74it/s]
epoch 1 loss: 0.4210: 100%|██████████| 18/18 [00:21<00:00,  1.21s/it]


validation multi-class accuracy = 0.8600


epoch 2 loss: 0.5270: 100%|██████████| 138/138 [01:19<00:00,  1.74it/s]
epoch 2 loss: 0.4374: 100%|██████████| 18/18 [00:15<00:00,  1.16it/s]


validation multi-class accuracy = 0.8418


epoch 3 loss: 0.4103: 100%|██████████| 138/138 [01:22<00:00,  1.68it/s]
epoch 3 loss: 0.3933: 100%|██████████| 18/18 [00:13<00:00,  1.31it/s]


validation multi-class accuracy = 0.8655


epoch 4 loss: 0.4332: 100%|██████████| 138/138 [01:21<00:00,  1.70it/s]
epoch 4 loss: 0.3726: 100%|██████████| 18/18 [00:14<00:00,  1.22it/s]


validation multi-class accuracy = 0.8655


epoch 5 loss: 0.3743: 100%|██████████| 138/138 [01:18<00:00,  1.76it/s]
epoch 5 loss: 0.3667: 100%|██████████| 18/18 [00:15<00:00,  1.18it/s]


validation multi-class accuracy = 0.8618


epoch 6 loss: 0.3854: 100%|██████████| 138/138 [01:19<00:00,  1.73it/s]
epoch 6 loss: 0.3465: 100%|██████████| 18/18 [00:15<00:00,  1.16it/s]


validation multi-class accuracy = 0.8727


epoch 7 loss: 0.3949: 100%|██████████| 138/138 [01:19<00:00,  1.74it/s]
epoch 7 loss: 0.3505: 100%|██████████| 18/18 [00:15<00:00,  1.17it/s]


validation multi-class accuracy = 0.8727


epoch 8 loss: 0.3141: 100%|██████████| 138/138 [01:19<00:00,  1.73it/s]
epoch 8 loss: 0.3353: 100%|██████████| 18/18 [00:16<00:00,  1.09it/s]


validation multi-class accuracy = 0.8745


epoch 9 loss: 0.3718: 100%|██████████| 138/138 [01:19<00:00,  1.73it/s]
epoch 9 loss: 0.3416: 100%|██████████| 18/18 [00:15<00:00,  1.19it/s]


validation multi-class accuracy = 0.8727
Training with 1 started
2200 550


  model = create_fn(
epoch 0 loss: 1.1510: 100%|██████████| 138/138 [01:17<00:00,  1.78it/s]
epoch 0 loss: 0.6889: 100%|██████████| 18/18 [00:13<00:00,  1.29it/s]


validation multi-class accuracy = 0.7636


epoch 1 loss: 0.7311: 100%|██████████| 138/138 [01:19<00:00,  1.74it/s]
epoch 1 loss: 0.4584: 100%|██████████| 18/18 [00:14<00:00,  1.26it/s]


validation multi-class accuracy = 0.8545


epoch 2 loss: 0.4856: 100%|██████████| 138/138 [01:18<00:00,  1.76it/s]
epoch 2 loss: 0.3896: 100%|██████████| 18/18 [00:14<00:00,  1.25it/s]


validation multi-class accuracy = 0.8673


epoch 3 loss: 0.4488: 100%|██████████| 138/138 [01:20<00:00,  1.72it/s]
epoch 3 loss: 0.3509: 100%|██████████| 18/18 [00:13<00:00,  1.33it/s]


validation multi-class accuracy = 0.8764


epoch 4 loss: 0.4372: 100%|██████████| 138/138 [01:18<00:00,  1.76it/s]
epoch 4 loss: 0.3736: 100%|██████████| 18/18 [00:13<00:00,  1.30it/s]


validation multi-class accuracy = 0.8745


epoch 5 loss: 0.4086: 100%|██████████| 138/138 [01:19<00:00,  1.73it/s]
epoch 5 loss: 0.3564: 100%|██████████| 18/18 [00:14<00:00,  1.28it/s]


validation multi-class accuracy = 0.8818


epoch 6 loss: 0.4133: 100%|██████████| 138/138 [01:19<00:00,  1.73it/s]
epoch 6 loss: 0.3180: 100%|██████████| 18/18 [00:13<00:00,  1.32it/s]


validation multi-class accuracy = 0.8891


epoch 7 loss: 0.4026: 100%|██████████| 138/138 [01:18<00:00,  1.75it/s]
epoch 7 loss: 0.3154: 100%|██████████| 18/18 [00:13<00:00,  1.30it/s]


validation multi-class accuracy = 0.8909


epoch 8 loss: 0.3285: 100%|██████████| 138/138 [01:19<00:00,  1.73it/s]
epoch 8 loss: 0.3020: 100%|██████████| 18/18 [00:14<00:00,  1.26it/s]


validation multi-class accuracy = 0.8982


epoch 9 loss: 0.3391: 100%|██████████| 138/138 [01:18<00:00,  1.76it/s]
epoch 9 loss: 0.2948: 100%|██████████| 18/18 [00:13<00:00,  1.30it/s]


validation multi-class accuracy = 0.8927
Training with 2 started
2200 550


  model = create_fn(
epoch 0 loss: 1.1892: 100%|██████████| 138/138 [01:17<00:00,  1.78it/s]
epoch 0 loss: 0.7520: 100%|██████████| 18/18 [00:14<00:00,  1.26it/s]


validation multi-class accuracy = 0.7382


epoch 1 loss: 0.7019: 100%|██████████| 138/138 [01:20<00:00,  1.71it/s]
epoch 1 loss: 0.5716: 100%|██████████| 18/18 [00:14<00:00,  1.21it/s]


validation multi-class accuracy = 0.7873


epoch 2 loss: 0.5447: 100%|██████████| 138/138 [01:18<00:00,  1.76it/s]
epoch 2 loss: 0.5010: 100%|██████████| 18/18 [00:14<00:00,  1.21it/s]


validation multi-class accuracy = 0.8182


epoch 3 loss: 0.5067: 100%|██████████| 138/138 [01:18<00:00,  1.76it/s]
epoch 3 loss: 0.4677: 100%|██████████| 18/18 [00:14<00:00,  1.20it/s]


validation multi-class accuracy = 0.8418


epoch 4 loss: 0.5292: 100%|██████████| 138/138 [01:18<00:00,  1.75it/s]
epoch 4 loss: 0.4407: 100%|██████████| 18/18 [00:14<00:00,  1.22it/s]


validation multi-class accuracy = 0.8273


epoch 5 loss: 0.3903: 100%|██████████| 138/138 [01:18<00:00,  1.75it/s]
epoch 5 loss: 0.3958: 100%|██████████| 18/18 [00:15<00:00,  1.17it/s]


validation multi-class accuracy = 0.8491


epoch 6 loss: 0.4747: 100%|██████████| 138/138 [01:19<00:00,  1.73it/s]
epoch 6 loss: 0.3796: 100%|██████████| 18/18 [00:14<00:00,  1.24it/s]


validation multi-class accuracy = 0.8727


epoch 7 loss: 0.3408: 100%|██████████| 138/138 [01:18<00:00,  1.76it/s]
epoch 7 loss: 0.3644: 100%|██████████| 18/18 [00:15<00:00,  1.17it/s]


validation multi-class accuracy = 0.8691


epoch 8 loss: 0.3011: 100%|██████████| 138/138 [01:19<00:00,  1.73it/s]
epoch 8 loss: 0.3786: 100%|██████████| 18/18 [00:15<00:00,  1.16it/s]


validation multi-class accuracy = 0.8655


epoch 9 loss: 0.3546: 100%|██████████| 138/138 [01:22<00:00,  1.68it/s]
epoch 9 loss: 0.3555: 100%|██████████| 18/18 [00:15<00:00,  1.18it/s]


validation multi-class accuracy = 0.8655
Training with 3 started
2200 550


  model = create_fn(
epoch 0 loss: 1.2215: 100%|██████████| 138/138 [01:20<00:00,  1.72it/s]
epoch 0 loss: 0.8167: 100%|██████████| 18/18 [00:16<00:00,  1.11it/s]


validation multi-class accuracy = 0.7382


epoch 1 loss: 0.7530: 100%|██████████| 138/138 [01:19<00:00,  1.74it/s]
epoch 1 loss: 0.4964: 100%|██████████| 18/18 [00:14<00:00,  1.22it/s]


validation multi-class accuracy = 0.8236


epoch 2 loss: 0.5006: 100%|██████████| 138/138 [01:18<00:00,  1.76it/s]
epoch 2 loss: 0.4811: 100%|██████████| 18/18 [00:14<00:00,  1.21it/s]


validation multi-class accuracy = 0.8200


epoch 3 loss: 0.5077: 100%|██████████| 138/138 [01:17<00:00,  1.78it/s]
epoch 3 loss: 0.3678: 100%|██████████| 18/18 [00:15<00:00,  1.16it/s]


validation multi-class accuracy = 0.8600


epoch 4 loss: 0.4264: 100%|██████████| 138/138 [01:19<00:00,  1.73it/s]
epoch 4 loss: 0.3384: 100%|██████████| 18/18 [00:16<00:00,  1.09it/s]


validation multi-class accuracy = 0.8800


epoch 5 loss: 0.4008: 100%|██████████| 138/138 [01:18<00:00,  1.76it/s]
epoch 5 loss: 0.3428: 100%|██████████| 18/18 [00:16<00:00,  1.11it/s]


validation multi-class accuracy = 0.8727


epoch 6 loss: 0.3994: 100%|██████████| 138/138 [01:19<00:00,  1.73it/s]
epoch 6 loss: 0.3417: 100%|██████████| 18/18 [00:15<00:00,  1.13it/s]


validation multi-class accuracy = 0.8782


epoch 7 loss: 0.3940: 100%|██████████| 138/138 [01:19<00:00,  1.74it/s]
epoch 7 loss: 0.3309: 100%|██████████| 18/18 [00:14<00:00,  1.20it/s]


validation multi-class accuracy = 0.8764


epoch 8 loss: 0.3311: 100%|██████████| 138/138 [01:17<00:00,  1.79it/s]
epoch 8 loss: 0.3225: 100%|██████████| 18/18 [00:14<00:00,  1.23it/s]


validation multi-class accuracy = 0.8800


epoch 9 loss: 0.3214: 100%|██████████| 138/138 [01:18<00:00,  1.75it/s]
epoch 9 loss: 0.3306: 100%|██████████| 18/18 [00:15<00:00,  1.13it/s]


validation multi-class accuracy = 0.8764
Training with 4 started
2200 550


  model = create_fn(
epoch 0 loss: 1.2184: 100%|██████████| 138/138 [01:18<00:00,  1.76it/s]
epoch 0 loss: 0.7643: 100%|██████████| 18/18 [00:14<00:00,  1.21it/s]


validation multi-class accuracy = 0.7127


epoch 1 loss: 0.7313: 100%|██████████| 138/138 [01:16<00:00,  1.80it/s]
epoch 1 loss: 0.5519: 100%|██████████| 18/18 [00:15<00:00,  1.16it/s]


validation multi-class accuracy = 0.7909


epoch 2 loss: 0.6212: 100%|██████████| 138/138 [01:18<00:00,  1.75it/s]
epoch 2 loss: 0.4313: 100%|██████████| 18/18 [00:15<00:00,  1.17it/s]


validation multi-class accuracy = 0.8364


epoch 3 loss: 0.4669: 100%|██████████| 138/138 [01:17<00:00,  1.77it/s]
epoch 3 loss: 0.4030: 100%|██████████| 18/18 [00:14<00:00,  1.25it/s]


validation multi-class accuracy = 0.8545


epoch 4 loss: 0.5178: 100%|██████████| 138/138 [01:17<00:00,  1.78it/s]
epoch 4 loss: 0.4119: 100%|██████████| 18/18 [00:15<00:00,  1.14it/s]


validation multi-class accuracy = 0.8473


epoch 5 loss: 0.3900: 100%|██████████| 138/138 [01:18<00:00,  1.75it/s]
epoch 5 loss: 0.3893: 100%|██████████| 18/18 [00:15<00:00,  1.15it/s]


validation multi-class accuracy = 0.8582


epoch 6 loss: 0.3945: 100%|██████████| 138/138 [01:18<00:00,  1.75it/s]
epoch 6 loss: 0.3685: 100%|██████████| 18/18 [00:15<00:00,  1.16it/s]


validation multi-class accuracy = 0.8564


epoch 7 loss: 0.3758: 100%|██████████| 138/138 [01:16<00:00,  1.79it/s]
epoch 7 loss: 0.3642: 100%|██████████| 18/18 [00:15<00:00,  1.13it/s]


validation multi-class accuracy = 0.8618


epoch 8 loss: 0.3678: 100%|██████████| 138/138 [01:18<00:00,  1.75it/s]
epoch 8 loss: 0.3804: 100%|██████████| 18/18 [00:14<00:00,  1.20it/s]


validation multi-class accuracy = 0.8491


epoch 9 loss: 0.3442: 100%|██████████| 138/138 [01:17<00:00,  1.78it/s]
epoch 9 loss: 0.3521: 100%|██████████| 18/18 [00:15<00:00,  1.16it/s]


validation multi-class accuracy = 0.8691
