In [1]:
from torch import optim
from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights
import torch
import random
import numpy as np
import torch.nn as nn
import albumentations as Albu
import pandas as pd
from torch.utils.data.sampler import RandomSampler
from warmup_scheduler import GradualWarmupScheduler
import os
from utils.dataset import PandasDataset
from utils.metrics import model_checkpoint
from utils.train import train_model
from utils.models import EfficientNetApi

NameError: name 'Literal' is not defined

In [2]:
seed = 42
shuffle = True
batch_size = 6
num_workers = 4
output_classes = 5
init_lr = 3e-4
warmup_factor = 2
warmup_epochs = 1
n_epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
loss_function = nn.BCEWithLogitsLoss()

torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

ROOT_DIR = '../..'

data_dir = '../../../dataset'
images_dir = os.path.join(data_dir, 'tiles')

Using device: cuda


In [3]:
load_model = efficientnet_b2(
     weights=EfficientNet_B2_Weights.DEFAULT
)
model = EfficientNetApi(model=load_model, output_dimensions=output_classes, dropout_rate=0.6)
model = model.to(device)

In [4]:
print("Using device:", device)
loss_function = nn.BCEWithLogitsLoss()

torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

Using device: cuda


In [5]:
df_train_ = pd.read_csv(f"{ROOT_DIR}/data/train_5fold.csv")
df_train_.columns = df_train_.columns.str.strip()
train_indexes = np.where((df_train_['fold'] != 3))[0]
valid_indexes = np.where((df_train_['fold'] == 3))[0]
#
df_train = df_train_.loc[train_indexes]
df_val = df_train_.loc[valid_indexes]
df_test = pd.read_csv(f"{ROOT_DIR}/data/test.csv")

#### view data

In [6]:
(df_train.shape, df_val.shape, df_test.shape)

((7219, 5), (1805, 5), (1592, 4))

In [7]:
transforms = Albu.Compose([
    Albu.Transpose(p=0.5),
    Albu.VerticalFlip(p=0.5),
    Albu.HorizontalFlip(p=0.5),
])

In [8]:
df_train.columns = df_train.columns.str.strip()

train_dataset = PandasDataset(images_dir, df_train, transforms=transforms)
valid_dataset = PandasDataset(images_dir, df_val, transforms=None)
test_dataset = PandasDataset(images_dir, df_test, transforms=None)

In [9]:
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, num_workers=num_workers, sampler=RandomSampler(train_dataset)
)
valid_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=batch_size, num_workers=num_workers, sampler = RandomSampler(valid_dataset)
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, num_workers=num_workers, sampler = RandomSampler(test_dataset)
)

In [10]:
optimizer = optim.Adam(model.parameters(), lr = init_lr / warmup_factor)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs - warmup_epochs)
scheduler = GradualWarmupScheduler(optimizer, multiplier = warmup_factor, total_epoch = warmup_epochs, after_scheduler=scheduler_cosine)

In [11]:
train_model(
    model=model,
    epochs=n_epochs,
    optimizer=optimizer,
    scheduler=scheduler,
    train_dataloader=train_loader,
    valid_dataloader=valid_loader,
    checkpoint=model_checkpoint,
    device=device,
    loss_function=loss_function,
    path_to_save_metrics="logs/b2.txt",
    path_to_save_model="models/b2.pth",
    patience=5,
)

Epoch 1/50



loss: 0.25129, smooth loss: 0.33740: 100%|██████████| 1204/1204 [11:00<00:00,  1.82it/s]
100%|██████████| 301/301 [01:32<00:00,  3.26it/s]


VAL_LOSS     0.281
VAL_ACC      Mean: 52.278 | Std: 1.183 | 95% CI: [50.360, 54.241]
VAL_KAPPA    Mean: 0.787 | Std: 0.011 | 95% CI: [0.768, 0.805]
VAL_F1       Mean: 0.457 | Std: 0.012 | 95% CI: [0.437, 0.476]
VAL_RECALL   Mean: 0.465 | Std: 0.011 | 95% CI: [0.446, 0.485]
VAL_PRECISION Mean: 0.550 | Std: 0.014 | 95% CI: [0.527, 0.572]
Salvando o melhor modelo... 0.0 -> 0.786536666439305
Epoch 2/50



loss: 0.14834, smooth loss: 0.27134: 100%|██████████| 1204/1204 [10:30<00:00,  1.91it/s]
100%|██████████| 301/301 [01:29<00:00,  3.35it/s]
  _warn_get_lr_called_within_step(self)


VAL_LOSS     0.284
VAL_ACC      Mean: 56.038 | Std: 1.163 | 95% CI: [54.127, 58.006]
VAL_KAPPA    Mean: 0.789 | Std: 0.012 | 95% CI: [0.770, 0.809]
VAL_F1       Mean: 0.472 | Std: 0.012 | 95% CI: [0.452, 0.493]
VAL_RECALL   Mean: 0.476 | Std: 0.011 | 95% CI: [0.458, 0.495]
VAL_PRECISION Mean: 0.573 | Std: 0.011 | 95% CI: [0.555, 0.592]
Salvando o melhor modelo... 0.786536666439305 -> 0.7892299375422803
Epoch 3/50



loss: 0.06460, smooth loss: 0.17439: 100%|██████████| 1204/1204 [10:26<00:00,  1.92it/s]
100%|██████████| 301/301 [01:29<00:00,  3.36it/s]


VAL_LOSS     0.346
VAL_ACC      Mean: 53.896 | Std: 1.203 | 95% CI: [51.967, 55.903]
VAL_KAPPA    Mean: 0.755 | Std: 0.013 | 95% CI: [0.733, 0.777]
VAL_F1       Mean: 0.446 | Std: 0.012 | 95% CI: [0.427, 0.466]
VAL_RECALL   Mean: 0.450 | Std: 0.011 | 95% CI: [0.433, 0.468]
VAL_PRECISION Mean: 0.540 | Std: 0.013 | 95% CI: [0.518, 0.560]
Epoch 4/50



loss: 0.03560, smooth loss: 0.13808: 100%|██████████| 1204/1204 [10:30<00:00,  1.91it/s]
100%|██████████| 301/301 [01:30<00:00,  3.33it/s]


VAL_LOSS     0.344
VAL_ACC      Mean: 53.725 | Std: 1.149 | 95% CI: [51.856, 55.626]
VAL_KAPPA    Mean: 0.763 | Std: 0.012 | 95% CI: [0.743, 0.783]
VAL_F1       Mean: 0.434 | Std: 0.011 | 95% CI: [0.417, 0.452]
VAL_RECALL   Mean: 0.446 | Std: 0.010 | 95% CI: [0.430, 0.464]
VAL_PRECISION Mean: 0.560 | Std: 0.009 | 95% CI: [0.545, 0.576]
Epoch 5/50



loss: 0.02230, smooth loss: 0.10201: 100%|██████████| 1204/1204 [10:28<00:00,  1.92it/s]
100%|██████████| 301/301 [01:30<00:00,  3.31it/s]


VAL_LOSS     0.318
VAL_ACC      Mean: 59.682 | Std: 1.165 | 95% CI: [57.784, 61.609]
VAL_KAPPA    Mean: 0.818 | Std: 0.011 | 95% CI: [0.799, 0.835]
VAL_F1       Mean: 0.540 | Std: 0.012 | 95% CI: [0.519, 0.560]
VAL_RECALL   Mean: 0.537 | Std: 0.012 | 95% CI: [0.516, 0.557]
VAL_PRECISION Mean: 0.579 | Std: 0.012 | 95% CI: [0.559, 0.599]
Salvando o melhor modelo... 0.7892299375422803 -> 0.8176330290307612
Epoch 6/50



loss: 0.03094, smooth loss: 0.07442: 100%|██████████| 1204/1204 [10:24<00:00,  1.93it/s]
100%|██████████| 301/301 [01:28<00:00,  3.39it/s]


VAL_LOSS     0.358
VAL_ACC      Mean: 58.538 | Std: 1.170 | 95% CI: [56.731, 60.501]
VAL_KAPPA    Mean: 0.818 | Std: 0.011 | 95% CI: [0.801, 0.836]
VAL_F1       Mean: 0.531 | Std: 0.012 | 95% CI: [0.511, 0.551]
VAL_RECALL   Mean: 0.529 | Std: 0.012 | 95% CI: [0.511, 0.549]
VAL_PRECISION Mean: 0.569 | Std: 0.013 | 95% CI: [0.549, 0.589]
Salvando o melhor modelo... 0.8176330290307612 -> 0.8181692650655135
Epoch 7/50



loss: 0.01661, smooth loss: 0.07178: 100%|██████████| 1204/1204 [10:24<00:00,  1.93it/s]
100%|██████████| 301/301 [01:29<00:00,  3.37it/s]


VAL_LOSS     0.382
VAL_ACC      Mean: 58.333 | Std: 1.140 | 95% CI: [56.454, 60.277]
VAL_KAPPA    Mean: 0.801 | Std: 0.012 | 95% CI: [0.782, 0.821]
VAL_F1       Mean: 0.533 | Std: 0.012 | 95% CI: [0.514, 0.552]
VAL_RECALL   Mean: 0.527 | Std: 0.012 | 95% CI: [0.508, 0.547]
VAL_PRECISION Mean: 0.554 | Std: 0.012 | 95% CI: [0.536, 0.574]
Epoch 8/50



loss: 0.00830, smooth loss: 0.06647: 100%|██████████| 1204/1204 [10:25<00:00,  1.92it/s]
100%|██████████| 301/301 [01:29<00:00,  3.37it/s]


VAL_LOSS     0.398
VAL_ACC      Mean: 59.526 | Std: 1.196 | 95% CI: [57.618, 61.496]
VAL_KAPPA    Mean: 0.807 | Std: 0.012 | 95% CI: [0.786, 0.827]
VAL_F1       Mean: 0.523 | Std: 0.013 | 95% CI: [0.502, 0.544]
VAL_RECALL   Mean: 0.527 | Std: 0.012 | 95% CI: [0.507, 0.548]
VAL_PRECISION Mean: 0.579 | Std: 0.014 | 95% CI: [0.557, 0.603]
Epoch 9/50



loss: 0.01160, smooth loss: 0.04091: 100%|██████████| 1204/1204 [10:27<00:00,  1.92it/s]
100%|██████████| 301/301 [01:29<00:00,  3.37it/s]


VAL_LOSS     0.440
VAL_ACC      Mean: 59.853 | Std: 1.151 | 95% CI: [58.006, 61.717]
VAL_KAPPA    Mean: 0.826 | Std: 0.011 | 95% CI: [0.808, 0.843]
VAL_F1       Mean: 0.540 | Std: 0.012 | 95% CI: [0.521, 0.560]
VAL_RECALL   Mean: 0.545 | Std: 0.012 | 95% CI: [0.526, 0.565]
VAL_PRECISION Mean: 0.551 | Std: 0.012 | 95% CI: [0.533, 0.572]
Salvando o melhor modelo... 0.8181692650655135 -> 0.8255615484803469
Epoch 10/50



loss: 0.00898, smooth loss: 0.03455: 100%|██████████| 1204/1204 [10:26<00:00,  1.92it/s]
100%|██████████| 301/301 [01:29<00:00,  3.37it/s]


VAL_LOSS     0.428
VAL_ACC      Mean: 60.530 | Std: 1.162 | 95% CI: [58.670, 62.330]
VAL_KAPPA    Mean: 0.832 | Std: 0.011 | 95% CI: [0.814, 0.850]
VAL_F1       Mean: 0.552 | Std: 0.012 | 95% CI: [0.533, 0.572]
VAL_RECALL   Mean: 0.555 | Std: 0.012 | 95% CI: [0.536, 0.575]
VAL_PRECISION Mean: 0.564 | Std: 0.012 | 95% CI: [0.544, 0.583]
Salvando o melhor modelo... 0.8255615484803469 -> 0.8321029280745594
Epoch 11/50



loss: 0.00480, smooth loss: 0.02785: 100%|██████████| 1204/1204 [10:25<00:00,  1.92it/s]
100%|██████████| 301/301 [01:29<00:00,  3.36it/s]


VAL_LOSS     0.501
VAL_ACC      Mean: 60.888 | Std: 1.176 | 95% CI: [58.947, 62.717]
VAL_KAPPA    Mean: 0.828 | Std: 0.011 | 95% CI: [0.809, 0.847]
VAL_F1       Mean: 0.548 | Std: 0.012 | 95% CI: [0.527, 0.567]
VAL_RECALL   Mean: 0.554 | Std: 0.012 | 95% CI: [0.533, 0.573]
VAL_PRECISION Mean: 0.550 | Std: 0.012 | 95% CI: [0.529, 0.570]
Epoch 12/50



loss: 0.00337, smooth loss: 0.02284: 100%|██████████| 1204/1204 [10:26<00:00,  1.92it/s]
100%|██████████| 301/301 [01:29<00:00,  3.37it/s]


VAL_LOSS     0.542
VAL_ACC      Mean: 62.438 | Std: 1.153 | 95% CI: [60.609, 64.432]
VAL_KAPPA    Mean: 0.822 | Std: 0.012 | 95% CI: [0.803, 0.842]
VAL_F1       Mean: 0.559 | Std: 0.012 | 95% CI: [0.540, 0.580]
VAL_RECALL   Mean: 0.557 | Std: 0.012 | 95% CI: [0.538, 0.578]
VAL_PRECISION Mean: 0.565 | Std: 0.013 | 95% CI: [0.545, 0.587]
Epoch 13/50



loss: 0.02501, smooth loss: 0.02334: 100%|██████████| 1204/1204 [10:25<00:00,  1.93it/s]
100%|██████████| 301/301 [01:29<00:00,  3.37it/s]


VAL_LOSS     0.426
VAL_ACC      Mean: 62.208 | Std: 1.152 | 95% CI: [60.332, 64.100]
VAL_KAPPA    Mean: 0.822 | Std: 0.012 | 95% CI: [0.803, 0.842]
VAL_F1       Mean: 0.563 | Std: 0.012 | 95% CI: [0.544, 0.582]
VAL_RECALL   Mean: 0.560 | Std: 0.012 | 95% CI: [0.540, 0.579]
VAL_PRECISION Mean: 0.574 | Std: 0.012 | 95% CI: [0.555, 0.594]
Epoch 14/50



loss: 0.00143, smooth loss: 0.02521: 100%|██████████| 1204/1204 [10:25<00:00,  1.92it/s]
100%|██████████| 301/301 [01:29<00:00,  3.37it/s]


VAL_LOSS     0.479
VAL_ACC      Mean: 61.992 | Std: 1.170 | 95% CI: [60.108, 63.936]
VAL_KAPPA    Mean: 0.826 | Std: 0.012 | 95% CI: [0.807, 0.845]
VAL_F1       Mean: 0.561 | Std: 0.012 | 95% CI: [0.542, 0.583]
VAL_RECALL   Mean: 0.557 | Std: 0.012 | 95% CI: [0.538, 0.578]
VAL_PRECISION Mean: 0.579 | Std: 0.012 | 95% CI: [0.558, 0.600]
Epoch 15/50



loss: 0.00147, smooth loss: 0.02232: 100%|██████████| 1204/1204 [10:27<00:00,  1.92it/s]
100%|██████████| 301/301 [01:29<00:00,  3.36it/s]


VAL_LOSS     0.557
VAL_ACC      Mean: 63.502 | Std: 1.137 | 95% CI: [61.715, 65.429]
VAL_KAPPA    Mean: 0.815 | Std: 0.013 | 95% CI: [0.795, 0.836]
VAL_F1       Mean: 0.557 | Std: 0.012 | 95% CI: [0.537, 0.579]
VAL_RECALL   Mean: 0.557 | Std: 0.011 | 95% CI: [0.538, 0.577]
VAL_PRECISION Mean: 0.618 | Std: 0.013 | 95% CI: [0.597, 0.640]

Early stopping at epoch 15. No improvement for 5 epochs.
Best epoch: 10 with kappa: 0.8321


# tests

In [12]:
from utils.metrics import evaluation, format_metrics
model.load_state_dict(
    torch.load(f"models/b2.pth")
)
response = evaluation(model, test_loader, device)
result = format_metrics(response[0])
print(result)

100%|██████████| 266/266 [01:10<00:00,  3.75it/s]


VAL_ACC      Mean: 59.282 | Std: 1.208 | 95% CI: [57.286, 61.244]
VAL_KAPPA    Mean: 0.826 | Std: 0.012 | 95% CI: [0.806, 0.846]
VAL_F1       Mean: 0.537 | Std: 0.013 | 95% CI: [0.516, 0.558]
VAL_RECALL   Mean: 0.541 | Std: 0.013 | 95% CI: [0.520, 0.562]
VAL_PRECISION Mean: 0.551 | Std: 0.013 | 95% CI: [0.530, 0.572]
