In [1]:
from torch import optim
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
import torch
import random
import numpy as np
import torch.nn as nn
import albumentations as Albu
import pandas as pd
from torch.utils.data.sampler import RandomSampler
from warmup_scheduler import GradualWarmupScheduler
import os
from utils.dataset import PandasDataset, RGB2XYZTransform
from utils.metrics import model_checkpoint
from utils.train import train_model
from utils.models import EfficientNetApi, EfficientNetApiGem

In [2]:
seed = 42
shuffle = True
batch_size = 8
num_workers = 4
output_classes = 5
init_lr = 3e-4
warmup_factor = 10
warmup_epochs = 1
n_epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
loss_function = nn.BCEWithLogitsLoss()

torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

ROOT_DIR = '../..'

data_dir = '../../../dataset'
images_dir = os.path.join(data_dir, 'tiles')

Using device: cuda


In [3]:
load_model = efficientnet_b0(
     weights=EfficientNet_B0_Weights.DEFAULT
)
model = EfficientNetApiGem(model=load_model, output_dimensions=output_classes)
model = model.to(device)

In [4]:
print("Using device:", device)
loss_function = nn.BCEWithLogitsLoss()

torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

Using device: cuda


In [5]:
df_train_ = pd.read_csv(f"{ROOT_DIR}/data/train_5fold.csv")
df_train_.columns = df_train_.columns.str.strip()
train_indexes = np.where((df_train_['fold'] != 3))[0]
valid_indexes = np.where((df_train_['fold'] == 3))[0]
#
df_train = df_train_.loc[train_indexes]
df_val = df_train_.loc[valid_indexes]
df_test = pd.read_csv(f"{ROOT_DIR}/data/test.csv")

#### view data

In [6]:
(df_train.shape, df_val.shape, df_test.shape)

((7219, 5), (1805, 5), (1592, 4))

In [7]:
transforms = Albu.Compose([
    RGB2XYZTransform(),
    Albu.Transpose(p=0.5),
    Albu.VerticalFlip(p=0.5),
    Albu.HorizontalFlip(p=0.5),
])

valid_transforms = Albu.Compose([
    RGB2XYZTransform()
])

In [8]:
df_train.columns = df_train.columns.str.strip()

train_dataset = PandasDataset(images_dir, df_train, transforms=transforms)
valid_dataset = PandasDataset(images_dir, df_val, transforms=valid_transforms)
test_dataset = PandasDataset(images_dir, df_test, transforms=valid_transforms)

In [9]:
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, num_workers=num_workers, sampler=RandomSampler(train_dataset)
)
valid_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=batch_size, num_workers=num_workers, sampler = RandomSampler(valid_dataset)
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, num_workers=num_workers, sampler = RandomSampler(test_dataset)
)

In [10]:
optimizer = optim.Adam(model.parameters(), lr = init_lr / warmup_factor)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs - warmup_epochs)
scheduler = GradualWarmupScheduler(optimizer, multiplier = warmup_factor, total_epoch = warmup_epochs, after_scheduler=scheduler_cosine)

In [11]:
train_model(
    model=model,
    epochs=n_epochs,
    optimizer=optimizer,
    scheduler=scheduler,
    train_dataloader=train_loader,
    valid_dataloader=valid_loader,
    checkpoint=model_checkpoint,
    device=device,
    loss_function=loss_function,
    path_to_save_metrics="logs/with-noise-xyz.txt",
    path_to_save_model="models/efficientnet-xyz.pth",
    patience=5,
)

Epoch 1/50



loss: 0.41192, smooth loss: 0.38587: 100%|██████████| 903/903 [10:05<00:00,  1.49it/s]
100%|██████████| 226/226 [02:36<00:00,  1.45it/s]


metrics {'val_loss': np.float32(0.33706138), 'val_acc': {'mean': np.float64(44.27096953392029), 'std': np.float64(1.171892221023603), 'ci_5': np.float64(42.43767261505127), 'ci_95': np.float64(46.20498716831207)}, 'val_kappa': {'mean': np.float64(0.7298986000985215), 'std': np.float64(0.01232557676247516), 'ci_5': np.float64(0.7082781031891044), 'ci_95': np.float64(0.7495382091581599)}, 'val_f1': {'mean': np.float64(0.3736350327432156), 'std': np.float64(0.011632533395105305), 'ci_5': np.float64(0.35374375283718107), 'ci_95': np.float64(0.39240901470184325)}, 'val_recall': {'mean': np.float64(0.3796866256594658), 'std': np.float64(0.01103874274868711), 'ci_5': np.float64(0.3609559327363968), 'ci_95': np.float64(0.3976366207003593)}, 'val_precision': {'mean': np.float64(0.4489036958217621), 'std': np.float64(0.017669090968495428), 'ci_5': np.float64(0.4194111436605453), 'ci_95': np.float64(0.4772007346153259)}}
Salvando o melhor modelo... 0.0 -> 0.7298986000985215
Epoch 2/50



loss: 0.22062, smooth loss: 0.33542: 100%|██████████| 903/903 [09:51<00:00,  1.53it/s]
100%|██████████| 226/226 [02:02<00:00,  1.85it/s]
  _warn_get_lr_called_within_step(self)


metrics {'val_loss': np.float32(0.28499892), 'val_acc': {'mean': np.float64(50.765650898218155), 'std': np.float64(1.1588111486812671), 'ci_5': np.float64(48.86426627635956), 'ci_95': np.float64(52.576178312301636)}, 'val_kappa': {'mean': np.float64(0.7782135102777165), 'std': np.float64(0.011035137480311962), 'ci_5': np.float64(0.7594752747647072), 'ci_95': np.float64(0.796247533091356)}, 'val_f1': {'mean': np.float64(0.4461217902004719), 'std': np.float64(0.012124918204641007), 'ci_5': np.float64(0.4263900279998779), 'ci_95': np.float64(0.46623875200748444)}, 'val_recall': {'mean': np.float64(0.44571535238623616), 'std': np.float64(0.011316592383648108), 'ci_5': np.float64(0.42788256853818896), 'ci_95': np.float64(0.4647971898317337)}, 'val_precision': {'mean': np.float64(0.5219291223883629), 'std': np.float64(0.013402630531479507), 'ci_5': np.float64(0.4997529312968254), 'ci_95': np.float64(0.5441346585750579)}}
Salvando o melhor modelo... 0.7298986000985215 -> 0.7782135102777165
Ep

loss: 0.15815, smooth loss: 0.26150: 100%|██████████| 903/903 [09:31<00:00,  1.58it/s]
100%|██████████| 226/226 [02:02<00:00,  1.85it/s]


metrics {'val_loss': np.float32(0.29573908), 'val_acc': {'mean': np.float64(54.839224588871005), 'std': np.float64(1.1643879790971172), 'ci_5': np.float64(52.90858745574951), 'ci_95': np.float64(56.7313015460968)}, 'val_kappa': {'mean': np.float64(0.7781033863756001), 'std': np.float64(0.011840865910553286), 'ci_5': np.float64(0.7584689594689568), 'ci_95': np.float64(0.7969110896092344)}, 'val_f1': {'mean': np.float64(0.48329090091586113), 'std': np.float64(0.012502390873978779), 'ci_5': np.float64(0.4629761502146721), 'ci_95': np.float64(0.5027846544981003)}, 'val_recall': {'mean': np.float64(0.48537938845157624), 'std': np.float64(0.011607687460213371), 'ci_5': np.float64(0.46569063514471054), 'ci_95': np.float64(0.503139990568161)}, 'val_precision': {'mean': np.float64(0.5489448027014733), 'std': np.float64(0.01363614584195948), 'ci_5': np.float64(0.5258326232433319), 'ci_95': np.float64(0.5706386476755142)}}
Epoch 4/50



loss: 0.15507, smooth loss: 0.18654: 100%|██████████| 903/903 [09:33<00:00,  1.57it/s]
100%|██████████| 226/226 [02:00<00:00,  1.87it/s]


metrics {'val_loss': np.float32(0.3234469), 'val_acc': {'mean': np.float64(57.853573250770566), 'std': np.float64(1.2136896709120721), 'ci_5': np.float64(55.84487318992615), 'ci_95': np.float64(59.781162440776825)}, 'val_kappa': {'mean': np.float64(0.8065899935455927), 'std': np.float64(0.011575307847020202), 'ci_5': np.float64(0.7878776006980814), 'ci_95': np.float64(0.8251677299914945)}, 'val_f1': {'mean': np.float64(0.5170908218622208), 'std': np.float64(0.012389896117271081), 'ci_5': np.float64(0.4954427793622017), 'ci_95': np.float64(0.5369038075208664)}, 'val_recall': {'mean': np.float64(0.5188881051838398), 'std': np.float64(0.012270993203181797), 'ci_5': np.float64(0.4975251227617264), 'ci_95': np.float64(0.5383736073970795)}, 'val_precision': {'mean': np.float64(0.5217737020552158), 'std': np.float64(0.01267389588695666), 'ci_5': np.float64(0.4986429825425148), 'ci_95': np.float64(0.5425685524940491)}}
Salvando o melhor modelo... 0.7782135102777165 -> 0.8065899935455927
Epoch 

loss: 0.25727, smooth loss: 0.14764: 100%|██████████| 903/903 [09:35<00:00,  1.57it/s]
100%|██████████| 226/226 [02:01<00:00,  1.86it/s]


metrics {'val_loss': np.float32(0.33797818), 'val_acc': {'mean': np.float64(58.078060793876645), 'std': np.float64(1.15270273871618), 'ci_5': np.float64(56.23268485069275), 'ci_95': np.float64(60.00000238418579)}, 'val_kappa': {'mean': np.float64(0.7961981817825472), 'std': np.float64(0.012805338436477885), 'ci_5': np.float64(0.7745375343218047), 'ci_95': np.float64(0.8168108946123726)}, 'val_f1': {'mean': np.float64(0.5281326229274272), 'std': np.float64(0.01223964296836117), 'ci_5': np.float64(0.5070422768592835), 'ci_95': np.float64(0.5476847052574157)}, 'val_recall': {'mean': np.float64(0.5277835808694362), 'std': np.float64(0.012106160872816334), 'ci_5': np.float64(0.5065186083316803), 'ci_95': np.float64(0.5473936975002289)}, 'val_precision': {'mean': np.float64(0.545230679512024), 'std': np.float64(0.012445087352284692), 'ci_5': np.float64(0.5237867414951325), 'ci_95': np.float64(0.5653471320867538)}}
Epoch 6/50



loss: 0.24980, smooth loss: 0.10596: 100%|██████████| 903/903 [09:36<00:00,  1.57it/s]
100%|██████████| 226/226 [02:01<00:00,  1.86it/s]


metrics {'val_loss': np.float32(0.37152332), 'val_acc': {'mean': np.float64(61.98598338365555), 'std': np.float64(1.1378801515962007), 'ci_5': np.float64(60.10803133249283), 'ci_95': np.float64(63.87811899185181)}, 'val_kappa': {'mean': np.float64(0.816991794208443), 'std': np.float64(0.012071989739902499), 'ci_5': np.float64(0.7970512824453291), 'ci_95': np.float64(0.8363190616681155)}, 'val_f1': {'mean': np.float64(0.5610440965294838), 'std': np.float64(0.012300529787370456), 'ci_5': np.float64(0.540840157866478), 'ci_95': np.float64(0.5816447019577027)}, 'val_recall': {'mean': np.float64(0.5592311269044876), 'std': np.float64(0.012072922047134074), 'ci_5': np.float64(0.5387717604637146), 'ci_95': np.float64(0.5799090147018432)}, 'val_precision': {'mean': np.float64(0.5767368113994599), 'std': np.float64(0.012526126676701863), 'ci_5': np.float64(0.5566595077514649), 'ci_95': np.float64(0.5974617660045624)}}
Salvando o melhor modelo... 0.8065899935455927 -> 0.816991794208443
Epoch 7/5

loss: 0.04261, smooth loss: 0.09158: 100%|██████████| 903/903 [09:33<00:00,  1.57it/s]
100%|██████████| 226/226 [02:01<00:00,  1.86it/s]


metrics {'val_loss': np.float32(0.4566208), 'val_acc': {'mean': np.float64(61.346703457832334), 'std': np.float64(1.169979135253806), 'ci_5': np.float64(59.38781052827835), 'ci_95': np.float64(63.2686972618103)}, 'val_kappa': {'mean': np.float64(0.7966342302867576), 'std': np.float64(0.013082980757605342), 'ci_5': np.float64(0.7750121260115357), 'ci_95': np.float64(0.8180876929617336)}, 'val_f1': {'mean': np.float64(0.5341060686707496), 'std': np.float64(0.012882731949575735), 'ci_5': np.float64(0.5126706451177597), 'ci_95': np.float64(0.555432739853859)}, 'val_recall': {'mean': np.float64(0.5287434825599193), 'std': np.float64(0.011994830994508755), 'ci_5': np.float64(0.5082443356513977), 'ci_95': np.float64(0.5482612490653992)}, 'val_precision': {'mean': np.float64(0.5760494379997253), 'std': np.float64(0.013617326875402484), 'ci_5': np.float64(0.5535870045423508), 'ci_95': np.float64(0.598091122508049)}}
Epoch 8/50



loss: 0.02907, smooth loss: 0.06687: 100%|██████████| 903/903 [10:09<00:00,  1.48it/s]
100%|██████████| 226/226 [02:06<00:00,  1.79it/s]


metrics {'val_loss': np.float32(0.44231334), 'val_acc': {'mean': np.float64(61.59977833628655), 'std': np.float64(1.1311313657038218), 'ci_5': np.float64(59.72298979759216), 'ci_95': np.float64(63.38227242231369)}, 'val_kappa': {'mean': np.float64(0.8066223484315447), 'std': np.float64(0.01264932635916433), 'ci_5': np.float64(0.7854078944807874), 'ci_95': np.float64(0.8271060872435106)}, 'val_f1': {'mean': np.float64(0.5572920401096344), 'std': np.float64(0.011870177828810988), 'ci_5': np.float64(0.5378538966178894), 'ci_95': np.float64(0.5764776527881622)}, 'val_recall': {'mean': np.float64(0.5532120250463486), 'std': np.float64(0.011623606046823937), 'ci_5': np.float64(0.534056282043457), 'ci_95': np.float64(0.5713724762201309)}, 'val_precision': {'mean': np.float64(0.5712075859308243), 'std': np.float64(0.012211359353462067), 'ci_5': np.float64(0.5511557787656785), 'ci_95': np.float64(0.5901307612657547)}}
Epoch 9/50



loss: 0.04767, smooth loss: 0.04439: 100%|██████████| 903/903 [10:04<00:00,  1.49it/s]
100%|██████████| 226/226 [02:08<00:00,  1.76it/s]


metrics {'val_loss': np.float32(0.4787623), 'val_acc': {'mean': np.float64(60.50387805700302), 'std': np.float64(1.1167304858735254), 'ci_5': np.float64(58.55955481529236), 'ci_95': np.float64(62.3268723487854)}, 'val_kappa': {'mean': np.float64(0.824149975257004), 'std': np.float64(0.011441856799801746), 'ci_5': np.float64(0.8048797466909264), 'ci_95': np.float64(0.842378119235352)}, 'val_f1': {'mean': np.float64(0.5463025863766671), 'std': np.float64(0.011778864074314872), 'ci_5': np.float64(0.527370747923851), 'ci_95': np.float64(0.5657207131385803)}, 'val_recall': {'mean': np.float64(0.5497767493724823), 'std': np.float64(0.011640439500496756), 'ci_5': np.float64(0.5304398536682129), 'ci_95': np.float64(0.5688558369874954)}, 'val_precision': {'mean': np.float64(0.5589139988422394), 'std': np.float64(0.01194514476306654), 'ci_5': np.float64(0.5399866372346878), 'ci_95': np.float64(0.5793697834014893)}}
Salvando o melhor modelo... 0.816991794208443 -> 0.824149975257004
Epoch 10/50



loss: 0.05215, smooth loss: 0.03890: 100%|██████████| 903/903 [09:59<00:00,  1.51it/s]
100%|██████████| 226/226 [02:10<00:00,  1.74it/s]


metrics {'val_loss': np.float32(0.48536095), 'val_acc': {'mean': np.float64(61.94747923612594), 'std': np.float64(1.1308250183610353), 'ci_5': np.float64(60.110801458358765), 'ci_95': np.float64(63.767313957214355)}, 'val_kappa': {'mean': np.float64(0.8142262069222415), 'std': np.float64(0.012500046375299645), 'ci_5': np.float64(0.7938538509636034), 'ci_95': np.float64(0.8342224900930637)}, 'val_f1': {'mean': np.float64(0.5706789320111275), 'std': np.float64(0.011950951314651211), 'ci_5': np.float64(0.551391926407814), 'ci_95': np.float64(0.5909215539693833)}, 'val_recall': {'mean': np.float64(0.570021347284317), 'std': np.float64(0.0120999289617883), 'ci_5': np.float64(0.5505019277334213), 'ci_95': np.float64(0.5904801666736603)}, 'val_precision': {'mean': np.float64(0.5753425026535988), 'std': np.float64(0.011869411732136586), 'ci_5': np.float64(0.555832251906395), 'ci_95': np.float64(0.5950737535953522)}}
Epoch 11/50



loss: 0.02348, smooth loss: 0.04274: 100%|██████████| 903/903 [10:39<00:00,  1.41it/s]
100%|██████████| 226/226 [02:14<00:00,  1.68it/s]


metrics {'val_loss': np.float32(0.5243827), 'val_acc': {'mean': np.float64(61.256897354125975), 'std': np.float64(1.1420365485783184), 'ci_5': np.float64(59.39058065414429), 'ci_95': np.float64(63.1024956703186)}, 'val_kappa': {'mean': np.float64(0.8081470523824309), 'std': np.float64(0.012685555759076374), 'ci_5': np.float64(0.7875033436382556), 'ci_95': np.float64(0.8286190484891112)}, 'val_f1': {'mean': np.float64(0.5620880718827248), 'std': np.float64(0.012074702306959105), 'ci_5': np.float64(0.541930741071701), 'ci_95': np.float64(0.5815528094768524)}, 'val_recall': {'mean': np.float64(0.5611928955316544), 'std': np.float64(0.012024456528858854), 'ci_5': np.float64(0.5411147147417068), 'ci_95': np.float64(0.5811546504497528)}, 'val_precision': {'mean': np.float64(0.570023129940033), 'std': np.float64(0.012305337800118853), 'ci_5': np.float64(0.5492014646530151), 'ci_95': np.float64(0.5900336742401123)}}
Epoch 12/50



loss: 0.04401, smooth loss: 0.02894: 100%|██████████| 903/903 [10:51<00:00,  1.39it/s]
100%|██████████| 226/226 [02:34<00:00,  1.46it/s]


metrics {'val_loss': np.float32(0.52590674), 'val_acc': {'mean': np.float64(61.49567866921425), 'std': np.float64(1.100800609622338), 'ci_5': np.float64(59.72298979759216), 'ci_95': np.float64(63.379502296447754)}, 'val_kappa': {'mean': np.float64(0.8168642755082525), 'std': np.float64(0.012116835793390264), 'ci_5': np.float64(0.7969778372547669), 'ci_95': np.float64(0.8360696993660603)}, 'val_f1': {'mean': np.float64(0.5610902594327927), 'std': np.float64(0.011854474650925638), 'ci_5': np.float64(0.5421401739120484), 'ci_95': np.float64(0.5805017203092575)}, 'val_recall': {'mean': np.float64(0.5626221778988838), 'std': np.float64(0.011555428195718037), 'ci_5': np.float64(0.5437200039625167), 'ci_95': np.float64(0.5821237325668335)}, 'val_precision': {'mean': np.float64(0.5804717517495155), 'std': np.float64(0.011819385579777594), 'ci_5': np.float64(0.5610912561416626), 'ci_95': np.float64(0.6004781275987625)}}
Epoch 13/50



loss: 0.01155, smooth loss: 0.02245: 100%|██████████| 903/903 [10:48<00:00,  1.39it/s]
100%|██████████| 226/226 [02:03<00:00,  1.83it/s]


metrics {'val_loss': np.float32(0.53392047), 'val_acc': {'mean': np.float64(63.56758998632431), 'std': np.float64(1.1127311551535035), 'ci_5': np.float64(61.7174506187439), 'ci_95': np.float64(65.42936563491821)}, 'val_kappa': {'mean': np.float64(0.8228741592215256), 'std': np.float64(0.012148501834921456), 'ci_5': np.float64(0.8028979042672931), 'ci_95': np.float64(0.8426812848781877)}, 'val_f1': {'mean': np.float64(0.5843468905091286), 'std': np.float64(0.011989093809077384), 'ci_5': np.float64(0.5653336018323898), 'ci_95': np.float64(0.6046482712030411)}, 'val_recall': {'mean': np.float64(0.5818142438530922), 'std': np.float64(0.01202504457419712), 'ci_5': np.float64(0.5626953780651093), 'ci_95': np.float64(0.601875239610672)}, 'val_precision': {'mean': np.float64(0.5918981665968895), 'std': np.float64(0.0120385566833749), 'ci_5': np.float64(0.5726122468709945), 'ci_95': np.float64(0.6120307743549347)}}
Epoch 14/50



loss: 0.01118, smooth loss: 0.02099: 100%|██████████| 903/903 [09:42<00:00,  1.55it/s]
100%|██████████| 226/226 [02:04<00:00,  1.82it/s]


metrics {'val_loss': np.float32(0.6020033), 'val_acc': {'mean': np.float64(63.535069370269774), 'std': np.float64(1.1231041116806173), 'ci_5': np.float64(61.66204810142517), 'ci_95': np.float64(65.37396311759949)}, 'val_kappa': {'mean': np.float64(0.8197788450224225), 'std': np.float64(0.0121766851231142), 'ci_5': np.float64(0.7999253816548043), 'ci_95': np.float64(0.8394051149141389)}, 'val_f1': {'mean': np.float64(0.5683405811190605), 'std': np.float64(0.01214805509823548), 'ci_5': np.float64(0.5485566645860672), 'ci_95': np.float64(0.5882485061883926)}, 'val_recall': {'mean': np.float64(0.565779797077179), 'std': np.float64(0.011643795231645845), 'ci_5': np.float64(0.547292348742485), 'ci_95': np.float64(0.5854935586452484)}, 'val_precision': {'mean': np.float64(0.5784571298956871), 'std': np.float64(0.012797417878024266), 'ci_5': np.float64(0.5565723180770874), 'ci_95': np.float64(0.599713659286499)}}

Early stopping at epoch 14. No improvement for 5 epochs.
Best epoch: 9 with kapp

# tests

In [13]:
from utils.metrics import evaluation, format_metrics
model.load_state_dict(
    torch.load(f"models/efficientnet-xyz.pth")
)
response = evaluation(model, test_loader, device)
result = format_metrics(response[0])
print(result)

100%|██████████| 199/199 [01:53<00:00,  1.76it/s]


VAL_ACC      Mean: 59.57 | Std: 1.18 | 95% CI: [57.60, 61.56]
VAL_KAPPA    Mean: 0.81 | Std: 0.01 | 95% CI: [0.79, 0.84]
VAL_F1       Mean: 0.54 | Std: 0.01 | 95% CI: [0.52, 0.56]
VAL_RECALL   Mean: 0.54 | Std: 0.01 | 95% CI: [0.52, 0.56]
VAL_PRECISION Mean: 0.55 | Std: 0.01 | 95% CI: [0.53, 0.57]
