In [4]:
from torch import optim
from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights
import torch
import random
import numpy as np
import torch.nn as nn
import albumentations as Albu
import pandas as pd
from torch.utils.data.sampler import RandomSampler
from warmup_scheduler import GradualWarmupScheduler
import os
from utils.dataset import PandasDataset
from utils.metrics import model_checkpoint
from utils.train import train_model
from utils.models import EfficientNetApi

In [5]:
seed = 42
shuffle = True
batch_size = 6
num_workers = 4
output_classes = 5
init_lr = 3e-4
warmup_factor = 2
warmup_epochs = 1
n_epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
loss_function = nn.BCEWithLogitsLoss()

torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

ROOT_DIR = '../..'

data_dir = '../../../dataset'
images_dir = os.path.join(data_dir, 'tiles')

Using device: cuda


In [6]:
load_model = efficientnet_b1(
     weights=EfficientNet_B1_Weights.DEFAULT
)
model = EfficientNetApi(model=load_model, output_dimensions=output_classes, dropout_rate=0.6)
model = model.to(device)

In [7]:
print("Using device:", device)
loss_function = nn.BCEWithLogitsLoss()

torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

Using device: cuda


In [8]:
df_train_ = pd.read_csv(f"{ROOT_DIR}/data/train_5fold.csv")
df_train_.columns = df_train_.columns.str.strip()
train_indexes = np.where((df_train_['fold'] != 3))[0]
valid_indexes = np.where((df_train_['fold'] == 3))[0]
#
df_train = df_train_.loc[train_indexes]
df_val = df_train_.loc[valid_indexes]
df_test = pd.read_csv(f"{ROOT_DIR}/data/test.csv")

#### view data

In [9]:
(df_train.shape, df_val.shape, df_test.shape)

((7219, 5), (1805, 5), (1592, 4))

In [10]:
transforms = Albu.Compose([
    Albu.Transpose(p=0.5),
    Albu.VerticalFlip(p=0.5),
    Albu.HorizontalFlip(p=0.5),
])

In [11]:
df_train.columns = df_train.columns.str.strip()

train_dataset = PandasDataset(images_dir, df_train, transforms=transforms)
valid_dataset = PandasDataset(images_dir, df_val, transforms=None)
test_dataset = PandasDataset(images_dir, df_test, transforms=None)

In [12]:
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, num_workers=num_workers, sampler=RandomSampler(train_dataset)
)
valid_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=batch_size, num_workers=num_workers, sampler = RandomSampler(valid_dataset)
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, num_workers=num_workers, sampler = RandomSampler(test_dataset)
)

In [13]:
optimizer = optim.Adam(model.parameters(), lr = init_lr / warmup_factor)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs - warmup_epochs)
scheduler = GradualWarmupScheduler(optimizer, multiplier = warmup_factor, total_epoch = warmup_epochs, after_scheduler=scheduler_cosine)

In [14]:
train_model(
    model=model,
    epochs=n_epochs,
    optimizer=optimizer,
    scheduler=scheduler,
    train_dataloader=train_loader,
    valid_dataloader=valid_loader,
    checkpoint=model_checkpoint,
    device=device,
    loss_function=loss_function,
    path_to_save_metrics="logs/b1.txt",
    path_to_save_model="models/b1.pth",
    patience=5,
)

Epoch 1/50



loss: 0.26871, smooth loss: 0.34580: 100%|██████████| 1204/1204 [13:12<00:00,  1.52it/s]
100%|██████████| 301/301 [01:55<00:00,  2.61it/s]


VAL_LOSS     0.292
VAL_ACC      Mean: 50.277 | Std: 1.151 | 95% CI: [48.421, 52.078]
VAL_KAPPA    Mean: 0.768 | Std: 0.012 | 95% CI: [0.749, 0.786]
VAL_F1       Mean: 0.434 | Std: 0.012 | 95% CI: [0.415, 0.454]
VAL_RECALL   Mean: 0.434 | Std: 0.011 | 95% CI: [0.416, 0.452]
VAL_PRECISION Mean: 0.528 | Std: 0.013 | 95% CI: [0.506, 0.550]
Salvando o melhor modelo... 0.0 -> 0.7679253809124893
Epoch 2/50



loss: 0.17449, smooth loss: 0.27641: 100%|██████████| 1204/1204 [12:39<00:00,  1.59it/s]
100%|██████████| 301/301 [01:55<00:00,  2.60it/s]
  _warn_get_lr_called_within_step(self)


VAL_LOSS     0.290
VAL_ACC      Mean: 54.809 | Std: 1.150 | 95% CI: [52.909, 56.676]
VAL_KAPPA    Mean: 0.768 | Std: 0.012 | 95% CI: [0.748, 0.787]
VAL_F1       Mean: 0.456 | Std: 0.012 | 95% CI: [0.436, 0.475]
VAL_RECALL   Mean: 0.461 | Std: 0.011 | 95% CI: [0.443, 0.478]
VAL_PRECISION Mean: 0.558 | Std: 0.013 | 95% CI: [0.535, 0.579]
Salvando o melhor modelo... 0.7679253809124893 -> 0.7681001925213348
Epoch 3/50



loss: 0.08349, smooth loss: 0.18228: 100%|██████████| 1204/1204 [12:38<00:00,  1.59it/s]
100%|██████████| 301/301 [01:55<00:00,  2.60it/s]


VAL_LOSS     0.284
VAL_ACC      Mean: 58.060 | Std: 1.151 | 95% CI: [56.230, 60.000]
VAL_KAPPA    Mean: 0.803 | Std: 0.012 | 95% CI: [0.783, 0.822]
VAL_F1       Mean: 0.503 | Std: 0.012 | 95% CI: [0.482, 0.523]
VAL_RECALL   Mean: 0.498 | Std: 0.011 | 95% CI: [0.479, 0.517]
VAL_PRECISION Mean: 0.547 | Std: 0.013 | 95% CI: [0.525, 0.569]
Salvando o melhor modelo... 0.7681001925213348 -> 0.8029652612767363
Epoch 4/50



loss: 0.03991, smooth loss: 0.12235: 100%|██████████| 1204/1204 [12:41<00:00,  1.58it/s]
100%|██████████| 301/301 [01:55<00:00,  2.60it/s]


VAL_LOSS     0.358
VAL_ACC      Mean: 55.461 | Std: 1.183 | 95% CI: [53.573, 57.507]
VAL_KAPPA    Mean: 0.766 | Std: 0.013 | 95% CI: [0.744, 0.788]
VAL_F1       Mean: 0.476 | Std: 0.012 | 95% CI: [0.456, 0.497]
VAL_RECALL   Mean: 0.470 | Std: 0.012 | 95% CI: [0.451, 0.490]
VAL_PRECISION Mean: 0.537 | Std: 0.014 | 95% CI: [0.516, 0.561]
Epoch 5/50



loss: 0.01760, smooth loss: 0.08798: 100%|██████████| 1204/1204 [13:44<00:00,  1.46it/s]
100%|██████████| 301/301 [02:14<00:00,  2.24it/s]


VAL_LOSS     0.458
VAL_ACC      Mean: 57.010 | Std: 1.213 | 95% CI: [55.014, 59.003]
VAL_KAPPA    Mean: 0.749 | Std: 0.014 | 95% CI: [0.727, 0.772]
VAL_F1       Mean: 0.452 | Std: 0.012 | 95% CI: [0.433, 0.472]
VAL_RECALL   Mean: 0.450 | Std: 0.010 | 95% CI: [0.434, 0.468]
VAL_PRECISION Mean: 0.510 | Std: 0.014 | 95% CI: [0.489, 0.533]
Epoch 6/50



loss: 0.01118, smooth loss: 0.06584: 100%|██████████| 1204/1204 [14:57<00:00,  1.34it/s]
100%|██████████| 301/301 [02:14<00:00,  2.24it/s]


VAL_LOSS     0.365
VAL_ACC      Mean: 61.442 | Std: 1.141 | 95% CI: [59.665, 63.324]
VAL_KAPPA    Mean: 0.818 | Std: 0.012 | 95% CI: [0.798, 0.837]
VAL_F1       Mean: 0.551 | Std: 0.012 | 95% CI: [0.532, 0.572]
VAL_RECALL   Mean: 0.545 | Std: 0.012 | 95% CI: [0.526, 0.565]
VAL_PRECISION Mean: 0.575 | Std: 0.012 | 95% CI: [0.556, 0.595]
Salvando o melhor modelo... 0.8029652612767363 -> 0.8179133476043071
Epoch 7/50



loss: 0.01852, smooth loss: 0.05566: 100%|██████████| 1204/1204 [14:45<00:00,  1.36it/s]
100%|██████████| 301/301 [01:55<00:00,  2.61it/s]


VAL_LOSS     0.368
VAL_ACC      Mean: 62.631 | Std: 1.164 | 95% CI: [60.831, 64.654]
VAL_KAPPA    Mean: 0.825 | Std: 0.012 | 95% CI: [0.806, 0.845]
VAL_F1       Mean: 0.564 | Std: 0.012 | 95% CI: [0.545, 0.585]
VAL_RECALL   Mean: 0.560 | Std: 0.012 | 95% CI: [0.540, 0.580]
VAL_PRECISION Mean: 0.573 | Std: 0.013 | 95% CI: [0.553, 0.595]
Salvando o melhor modelo... 0.8179133476043071 -> 0.8252063291363636
Epoch 8/50



loss: 0.00698, smooth loss: 0.04239: 100%|██████████| 1204/1204 [13:21<00:00,  1.50it/s]
100%|██████████| 301/301 [01:58<00:00,  2.53it/s]


VAL_LOSS     0.384
VAL_ACC      Mean: 63.034 | Std: 1.168 | 95% CI: [61.108, 64.986]
VAL_KAPPA    Mean: 0.827 | Std: 0.012 | 95% CI: [0.807, 0.845]
VAL_F1       Mean: 0.567 | Std: 0.012 | 95% CI: [0.546, 0.587]
VAL_RECALL   Mean: 0.562 | Std: 0.012 | 95% CI: [0.543, 0.582]
VAL_PRECISION Mean: 0.580 | Std: 0.013 | 95% CI: [0.559, 0.601]
Salvando o melhor modelo... 0.8252063291363636 -> 0.8266492107073063
Epoch 9/50



loss: 0.00690, smooth loss: 0.04040: 100%|██████████| 1204/1204 [13:07<00:00,  1.53it/s]
100%|██████████| 301/301 [01:58<00:00,  2.53it/s]


VAL_LOSS     0.403
VAL_ACC      Mean: 63.379 | Std: 1.135 | 95% CI: [61.551, 65.263]
VAL_KAPPA    Mean: 0.839 | Std: 0.011 | 95% CI: [0.821, 0.856]
VAL_F1       Mean: 0.575 | Std: 0.012 | 95% CI: [0.554, 0.594]
VAL_RECALL   Mean: 0.572 | Std: 0.012 | 95% CI: [0.552, 0.591]
VAL_PRECISION Mean: 0.586 | Std: 0.013 | 95% CI: [0.565, 0.606]
Salvando o melhor modelo... 0.8266492107073063 -> 0.8389711731209907
Epoch 10/50



loss: 0.01284, smooth loss: 0.04473: 100%|██████████| 1204/1204 [13:40<00:00,  1.47it/s]
100%|██████████| 301/301 [02:01<00:00,  2.47it/s]


VAL_LOSS     0.438
VAL_ACC      Mean: 59.919 | Std: 1.137 | 95% CI: [58.116, 61.776]
VAL_KAPPA    Mean: 0.821 | Std: 0.012 | 95% CI: [0.801, 0.840]
VAL_F1       Mean: 0.547 | Std: 0.012 | 95% CI: [0.528, 0.567]
VAL_RECALL   Mean: 0.544 | Std: 0.012 | 95% CI: [0.524, 0.564]
VAL_PRECISION Mean: 0.563 | Std: 0.012 | 95% CI: [0.543, 0.583]
Epoch 11/50



loss: 0.00517, smooth loss: 0.04511: 100%|██████████| 1204/1204 [13:55<00:00,  1.44it/s]
100%|██████████| 301/301 [02:12<00:00,  2.28it/s]


VAL_LOSS     0.455
VAL_ACC      Mean: 62.314 | Std: 1.168 | 95% CI: [60.443, 64.266]
VAL_KAPPA    Mean: 0.825 | Std: 0.012 | 95% CI: [0.806, 0.846]
VAL_F1       Mean: 0.558 | Std: 0.012 | 95% CI: [0.538, 0.579]
VAL_RECALL   Mean: 0.551 | Std: 0.012 | 95% CI: [0.533, 0.572]
VAL_PRECISION Mean: 0.575 | Std: 0.013 | 95% CI: [0.554, 0.597]
Epoch 12/50



loss: 0.00137, smooth loss: 0.02283: 100%|██████████| 1204/1204 [14:15<00:00,  1.41it/s]
100%|██████████| 301/301 [02:01<00:00,  2.49it/s]


VAL_LOSS     0.475
VAL_ACC      Mean: 63.332 | Std: 1.139 | 95% CI: [61.496, 65.266]
VAL_KAPPA    Mean: 0.814 | Std: 0.012 | 95% CI: [0.794, 0.834]
VAL_F1       Mean: 0.568 | Std: 0.012 | 95% CI: [0.549, 0.588]
VAL_RECALL   Mean: 0.559 | Std: 0.012 | 95% CI: [0.540, 0.579]
VAL_PRECISION Mean: 0.592 | Std: 0.013 | 95% CI: [0.572, 0.613]
Epoch 13/50



loss: 0.00471, smooth loss: 0.00935: 100%|██████████| 1204/1204 [13:24<00:00,  1.50it/s]
100%|██████████| 301/301 [02:00<00:00,  2.51it/s]


VAL_LOSS     0.539
VAL_ACC      Mean: 64.724 | Std: 1.125 | 95% CI: [62.881, 66.537]
VAL_KAPPA    Mean: 0.814 | Std: 0.013 | 95% CI: [0.792, 0.835]
VAL_F1       Mean: 0.581 | Std: 0.012 | 95% CI: [0.560, 0.601]
VAL_RECALL   Mean: 0.572 | Std: 0.012 | 95% CI: [0.552, 0.591]
VAL_PRECISION Mean: 0.606 | Std: 0.013 | 95% CI: [0.585, 0.626]
Epoch 14/50



loss: 0.01329, smooth loss: 0.01138: 100%|██████████| 1204/1204 [12:57<00:00,  1.55it/s]
100%|██████████| 301/301 [01:58<00:00,  2.54it/s]


VAL_LOSS     0.505
VAL_ACC      Mean: 64.774 | Std: 1.139 | 95% CI: [62.825, 66.593]
VAL_KAPPA    Mean: 0.828 | Std: 0.012 | 95% CI: [0.809, 0.848]
VAL_F1       Mean: 0.591 | Std: 0.012 | 95% CI: [0.571, 0.612]
VAL_RECALL   Mean: 0.586 | Std: 0.012 | 95% CI: [0.567, 0.607]
VAL_PRECISION Mean: 0.600 | Std: 0.013 | 95% CI: [0.580, 0.622]

Early stopping at epoch 14. No improvement for 5 epochs.
Best epoch: 9 with kappa: 0.8390


# tests

In [15]:
from utils.metrics import evaluation, format_metrics
model.load_state_dict(
    torch.load(f"models/b1.pth")
)
response = evaluation(model, test_loader, device)
result = format_metrics(response[0])
print(result)

100%|██████████| 266/266 [01:36<00:00,  2.75it/s]


VAL_ACC      Mean: 63.799 | Std: 1.159 | 95% CI: [61.935, 65.704]
VAL_KAPPA    Mean: 0.834 | Std: 0.013 | 95% CI: [0.812, 0.854]
VAL_F1       Mean: 0.572 | Std: 0.013 | 95% CI: [0.553, 0.594]
VAL_RECALL   Mean: 0.569 | Std: 0.013 | 95% CI: [0.550, 0.591]
VAL_PRECISION Mean: 0.584 | Std: 0.013 | 95% CI: [0.563, 0.606]
