In [1]:
from torch import optim
from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights
import torch
import random
import numpy as np
import torch.nn as nn
import albumentations as Albu
import pandas as pd
from torch.utils.data.sampler import RandomSampler
from warmup_scheduler import GradualWarmupScheduler
import os
from utils.dataset import PandasDataset
from utils.metrics import model_checkpoint
from utils.train import train_model
from utils.models import EfficientNetApi

In [2]:
seed = 42
shuffle = True
batch_size = 6
num_workers = 4
output_classes = 5
init_lr = 3e-4
warmup_factor = 2
warmup_epochs = 1
n_epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
loss_function = nn.BCEWithLogitsLoss()

torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

ROOT_DIR = '../..'

data_dir = '../../../dataset'
images_dir = os.path.join(data_dir, 'tiles')

Using device: cuda


In [3]:
load_model = efficientnet_b2(
     weights=EfficientNet_B2_Weights.DEFAULT
)
model = EfficientNetApi(model=load_model, output_dimensions=output_classes, dropout_rate=0.6)
model = model.to(device)

In [4]:
print("Using device:", device)
loss_function = nn.BCEWithLogitsLoss()

torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

Using device: cuda


In [5]:
df_train_ = pd.read_csv(f"{ROOT_DIR}/data/train_5fold.csv")
df_train_.columns = df_train_.columns.str.strip()
train_indexes = np.where((df_train_['fold'] != 3))[0]
valid_indexes = np.where((df_train_['fold'] == 3))[0]
#
df_train = df_train_.loc[train_indexes]
df_val = df_train_.loc[valid_indexes]
df_test = pd.read_csv(f"{ROOT_DIR}/data/test.csv")

#### view data

In [6]:
(df_train.shape, df_val.shape, df_test.shape)

((7219, 5), (1805, 5), (1592, 4))

In [7]:
transforms = Albu.Compose([
    Albu.Transpose(p=0.5),
    Albu.VerticalFlip(p=0.5),
    Albu.HorizontalFlip(p=0.5),
])

In [8]:
df_train.columns = df_train.columns.str.strip()

train_dataset = PandasDataset(images_dir, df_train, transforms=transforms)
valid_dataset = PandasDataset(images_dir, df_val, transforms=None)
test_dataset = PandasDataset(images_dir, df_test, transforms=None)

In [9]:
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, num_workers=num_workers, sampler=RandomSampler(train_dataset)
)
valid_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=batch_size, num_workers=num_workers, sampler = RandomSampler(valid_dataset)
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, num_workers=num_workers, sampler = RandomSampler(test_dataset)
)

In [10]:
optimizer = optim.Adam(model.parameters(), lr = init_lr / warmup_factor)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs - warmup_epochs)
scheduler = GradualWarmupScheduler(optimizer, multiplier = warmup_factor, total_epoch = warmup_epochs, after_scheduler=scheduler_cosine)

In [11]:
train_model(
    model=model,
    epochs=n_epochs,
    optimizer=optimizer,
    scheduler=scheduler,
    train_dataloader=train_loader,
    valid_dataloader=valid_loader,
    checkpoint=model_checkpoint,
    device=device,
    loss_function=loss_function,
    path_to_save_metrics="logs/b2.txt",
    path_to_save_model="models/b2.pth",
    patience=5,
)

Epoch 1/50



loss: 0.21219, smooth loss: 0.34523: 100%|██████████| 1204/1204 [13:47<00:00,  1.46it/s]
100%|██████████| 301/301 [02:05<00:00,  2.40it/s]


VAL_LOSS     0.283
VAL_ACC      Mean: 54.179 | Std: 1.170 | 95% CI: [52.355, 56.180]
VAL_KAPPA    Mean: 0.779 | Std: 0.012 | 95% CI: [0.760, 0.798]
VAL_F1       Mean: 0.455 | Std: 0.012 | 95% CI: [0.436, 0.474]
VAL_RECALL   Mean: 0.464 | Std: 0.011 | 95% CI: [0.447, 0.482]
VAL_PRECISION Mean: 0.558 | Std: 0.015 | 95% CI: [0.532, 0.582]
Salvando o melhor modelo... 0.0 -> 0.7787572475555404
Epoch 2/50



loss: 0.09602, smooth loss: 0.27118: 100%|██████████| 1204/1204 [13:48<00:00,  1.45it/s]
100%|██████████| 301/301 [02:05<00:00,  2.40it/s]
  _warn_get_lr_called_within_step(self)


VAL_LOSS     0.283
VAL_ACC      Mean: 54.036 | Std: 1.173 | 95% CI: [52.186, 56.122]
VAL_KAPPA    Mean: 0.775 | Std: 0.012 | 95% CI: [0.756, 0.794]
VAL_F1       Mean: 0.466 | Std: 0.012 | 95% CI: [0.447, 0.485]
VAL_RECALL   Mean: 0.477 | Std: 0.011 | 95% CI: [0.459, 0.496]
VAL_PRECISION Mean: 0.578 | Std: 0.013 | 95% CI: [0.557, 0.599]
Epoch 3/50



loss: 0.04654, smooth loss: 0.17406: 100%|██████████| 1204/1204 [13:48<00:00,  1.45it/s]
100%|██████████| 301/301 [02:05<00:00,  2.40it/s]


VAL_LOSS     0.413
VAL_ACC      Mean: 54.052 | Std: 1.220 | 95% CI: [52.022, 56.014]
VAL_KAPPA    Mean: 0.721 | Std: 0.014 | 95% CI: [0.697, 0.744]
VAL_F1       Mean: 0.414 | Std: 0.011 | 95% CI: [0.395, 0.433]
VAL_RECALL   Mean: 0.430 | Std: 0.010 | 95% CI: [0.414, 0.448]
VAL_PRECISION Mean: 0.513 | Std: 0.021 | 95% CI: [0.477, 0.545]
Epoch 4/50



loss: 0.02404, smooth loss: 0.12970: 100%|██████████| 1204/1204 [13:48<00:00,  1.45it/s]
100%|██████████| 301/301 [02:05<00:00,  2.40it/s]


VAL_LOSS     0.388
VAL_ACC      Mean: 58.462 | Std: 1.173 | 95% CI: [56.562, 60.501]
VAL_KAPPA    Mean: 0.764 | Std: 0.014 | 95% CI: [0.742, 0.787]
VAL_F1       Mean: 0.502 | Std: 0.012 | 95% CI: [0.482, 0.523]
VAL_RECALL   Mean: 0.504 | Std: 0.011 | 95% CI: [0.486, 0.523]
VAL_PRECISION Mean: 0.569 | Std: 0.014 | 95% CI: [0.546, 0.591]
Epoch 5/50



loss: 0.03727, smooth loss: 0.10251: 100%|██████████| 1204/1204 [13:48<00:00,  1.45it/s]
100%|██████████| 301/301 [02:05<00:00,  2.40it/s]


VAL_LOSS     0.523
VAL_ACC      Mean: 57.147 | Std: 1.202 | 95% CI: [55.235, 59.224]
VAL_KAPPA    Mean: 0.725 | Std: 0.015 | 95% CI: [0.699, 0.747]
VAL_F1       Mean: 0.462 | Std: 0.012 | 95% CI: [0.442, 0.482]
VAL_RECALL   Mean: 0.465 | Std: 0.011 | 95% CI: [0.448, 0.482]
VAL_PRECISION Mean: 0.549 | Std: 0.014 | 95% CI: [0.525, 0.572]
Epoch 6/50



loss: 0.01338, smooth loss: 0.08542: 100%|██████████| 1204/1204 [13:46<00:00,  1.46it/s]
100%|██████████| 301/301 [02:05<00:00,  2.41it/s]


VAL_LOSS     0.493
VAL_ACC      Mean: 58.728 | Std: 1.212 | 95% CI: [56.787, 60.668]
VAL_KAPPA    Mean: 0.763 | Std: 0.014 | 95% CI: [0.740, 0.787]
VAL_F1       Mean: 0.502 | Std: 0.013 | 95% CI: [0.481, 0.523]
VAL_RECALL   Mean: 0.497 | Std: 0.012 | 95% CI: [0.478, 0.516]
VAL_PRECISION Mean: 0.557 | Std: 0.014 | 95% CI: [0.534, 0.579]

Early stopping at epoch 6. No improvement for 5 epochs.
Best epoch: 1 with kappa: 0.7788


# tests

In [12]:
from utils.metrics import evaluation, format_metrics
model.load_state_dict(
    torch.load(f"models/b2.pth")
)
response = evaluation(model, test_loader, device)
result = format_metrics(response[0])
print(result)

100%|██████████| 266/266 [01:40<00:00,  2.64it/s]


VAL_ACC      Mean: 52.448 | Std: 1.269 | 95% CI: [50.314, 54.585]
VAL_KAPPA    Mean: 0.779 | Std: 0.013 | 95% CI: [0.757, 0.799]
VAL_F1       Mean: 0.434 | Std: 0.012 | 95% CI: [0.414, 0.455]
VAL_RECALL   Mean: 0.451 | Std: 0.012 | 95% CI: [0.432, 0.471]
VAL_PRECISION Mean: 0.558 | Std: 0.016 | 95% CI: [0.529, 0.583]
