In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.sampler import RandomSampler
from warmup_scheduler import GradualWarmupScheduler
import albumentations
from work.utils.dataset import RGB2YHUTransform, PandasDataset
from work.utils.models import EfficientNet
from work.utils.train import train_model
from work.utils.metrics import model_checkpoint
import random

In [2]:
backbone_model = 'efficientnet-b0'
pretrained_model = {
    backbone_model: '/home/woshington/Projects/Doutorado/work/efficientnet-b0-08094119.pth'
}

data_dir = '../../dataset'
images_dir = os.path.join(data_dir, 'tiles')

df_train = pd.read_csv(f"../data/train_5fold.csv")

In [3]:
seed = 42
shuffle = True
batch_size = 2
num_workers = 4
output_classes = 5
init_lr = 3e-4
warmup_factor = 10
warmup_epochs = 1
n_epochs = 15
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
loss_function = nn.BCEWithLogitsLoss()

torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)


Using device: cuda


In [4]:
from work.utils.dataset import RGB2YHVTransform

transforms = albumentations.Compose([
    RGB2YHVTransform(),
    albumentations.Transpose(p=0.5),
    albumentations.VerticalFlip(p=0.5),
    albumentations.HorizontalFlip(p=0.5),
])

valid_transforms =albumentations.Compose([
    RGB2YHVTransform()
])

In [5]:
df_train.columns = df_train.columns.str.strip()

train_indexes = np.where((df_train['fold'] != 3))[0]
valid_indexes = np.where((df_train['fold'] == 3))[0]

train = df_train.loc[train_indexes]
valid = df_train.loc[valid_indexes]

train_dataset = PandasDataset(images_dir, train, transforms=transforms)
valid_dataset = PandasDataset(images_dir, valid, transforms=valid_transforms)

In [6]:
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=2, num_workers=num_workers, sampler = RandomSampler(train_dataset)
)
valid_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=2, num_workers=num_workers, sampler = RandomSampler(valid_dataset)
)

In [7]:
model = EfficientNet(
    backbone=backbone_model,
    output_dimensions=output_classes,
    pre_trained_model=pretrained_model
)
model = model.to(device)

Loaded pretrained weights for efficientnet-b0


In [8]:
optimizer = optim.Adam(model.parameters(), lr = init_lr / warmup_factor)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs - warmup_epochs)
scheduler = GradualWarmupScheduler(optimizer, multiplier = warmup_factor, total_epoch = warmup_epochs, after_scheduler=scheduler_cosine)
save_path = f'models/with-noise-yhv.pth'

In [9]:
train_model(
    model=model,
    epochs=n_epochs,
    optimizer=optimizer,
    scheduler=scheduler,
    train_dataloader=train_loader,
    valid_dataloader=valid_loader,
    checkpoint=model_checkpoint,
    device=device,
    loss_function=loss_function,
    path_to_save_metrics="logs/with-noise-yhv.txt",
    path_to_save_model=save_path
)

Epoch 1/15



loss: 0.66532, smooth loss: 0.40873: 100%|██████████| 3610/3610 [32:22<00:00,  1.86it/s]
100%|██████████| 903/903 [07:57<00:00,  1.89it/s]


metrics {'val_loss': 0.3711381, 'val_acc': {'mean': 37.24282548725605, 'std': 1.15247016491602, 'ci_5': 35.34626066684723, 'ci_95': 39.224377274513245}, 'val_kappa': {'mean': 0.6604293589486305, 'std': 0.01389070968864319, 'ci_5': 0.6375778317778442, 'ci_95': 0.6823712123870127}, 'val_f1': {'mean': 0.326611205637455, 'std': 0.011639576973464323, 'ci_5': 0.3072046637535095, 'ci_95': 0.34626040905714034}, 'val_recall': {'mean': 0.33108225437998773, 'std': 0.011058732013172359, 'ci_5': 0.31285790503025057, 'ci_95': 0.34895030409097666}, 'val_precision': {'mean': 0.39871926182508466, 'std': 0.01630618915746472, 'ci_5': 0.3723119616508484, 'ci_95': 0.42466580271720883}}
Salvando o melhor modelo... 0.0 -> 0.6604293589486305
Epoch 2/15



loss: 0.20095, smooth loss: 0.37288: 100%|██████████| 3610/3610 [34:04<00:00,  1.77it/s]
100%|██████████| 903/903 [07:34<00:00,  1.99it/s]


metrics {'val_loss': 0.3686064, 'val_acc': {'mean': 41.203213238716124, 'std': 1.1704267375126294, 'ci_5': 39.27977979183197, 'ci_95': 43.10249388217926}, 'val_kappa': {'mean': 0.7027116273573382, 'std': 0.014239227807903253, 'ci_5': 0.6793818828007175, 'ci_95': 0.7264697790666725}, 'val_f1': {'mean': 0.37948748674988747, 'std': 0.012718722960974137, 'ci_5': 0.35894930064678193, 'ci_95': 0.40065894573926925}, 'val_recall': {'mean': 0.3886094211935997, 'std': 0.011763332948187315, 'ci_5': 0.36998034715652467, 'ci_95': 0.40892868191003795}, 'val_precision': {'mean': 0.4709523431956768, 'std': 0.014160680777828777, 'ci_5': 0.4479375913739204, 'ci_95': 0.4939526543021202}}
Salvando o melhor modelo... 0.6604293589486305 -> 0.7027116273573382
Epoch 3/15



loss: 0.08286, smooth loss: 0.20693: 100%|██████████| 3610/3610 [29:14<00:00,  2.06it/s]
100%|██████████| 903/903 [07:02<00:00,  2.14it/s]


metrics {'val_loss': 0.32694173, 'val_acc': {'mean': 54.41551274061203, 'std': 1.212073833038065, 'ci_5': 52.518005669116974, 'ci_95': 56.45429491996765}, 'val_kappa': {'mean': 0.7710719785796073, 'std': 0.012521517031397303, 'ci_5': 0.7488856960256862, 'ci_95': 0.7917376252477882}, 'val_f1': {'mean': 0.4641020483672619, 'std': 0.012339086092282042, 'ci_5': 0.4438016802072525, 'ci_95': 0.4847140058875084}, 'val_recall': {'mean': 0.45638982266187667, 'std': 0.011358661465250223, 'ci_5': 0.4376010850071907, 'ci_95': 0.4754104524850845}, 'val_precision': {'mean': 0.4953705088496208, 'std': 0.01363993801641386, 'ci_5': 0.4736652851104736, 'ci_95': 0.5183945387601853}}
Salvando o melhor modelo... 0.7027116273573382 -> 0.7710719785796073
Epoch 4/15



loss: 0.05766, smooth loss: 0.16770: 100%|██████████| 3610/3610 [29:06<00:00,  2.07it/s]
100%|██████████| 903/903 [07:03<00:00,  2.13it/s]


metrics {'val_loss': 0.33437333, 'val_acc': {'mean': 55.02637141942978, 'std': 1.2029495902123977, 'ci_5': 53.07479500770569, 'ci_95': 57.06371068954468}, 'val_kappa': {'mean': 0.7806070020375689, 'std': 0.01248182854491143, 'ci_5': 0.7590224880680382, 'ci_95': 0.8008043070435261}, 'val_f1': {'mean': 0.46952192717790603, 'std': 0.012512751814642706, 'ci_5': 0.45054401755332946, 'ci_95': 0.49021091759204866}, 'val_recall': {'mean': 0.4629181105196476, 'std': 0.01169710382510572, 'ci_5': 0.44459050744771955, 'ci_95': 0.4821698278188705}, 'val_precision': {'mean': 0.4974699110686779, 'std': 0.013578784875110462, 'ci_5': 0.4757212147116661, 'ci_95': 0.5193062037229538}}
Salvando o melhor modelo... 0.7710719785796073 -> 0.7806070020375689
Epoch 5/15



loss: 0.04668, smooth loss: 0.13232: 100%|██████████| 3610/3610 [29:16<00:00,  2.06it/s]
100%|██████████| 903/903 [07:03<00:00,  2.13it/s]


metrics {'val_loss': 0.34659922, 'val_acc': {'mean': 55.761496025323865, 'std': 1.229709527200851, 'ci_5': 53.684210777282715, 'ci_95': 57.783931493759155}, 'val_kappa': {'mean': 0.7792397737379754, 'std': 0.013028301736880838, 'ci_5': 0.7572933982273296, 'ci_95': 0.8006583811435499}, 'val_f1': {'mean': 0.4763098585307598, 'std': 0.012607047030484216, 'ci_5': 0.4555456817150116, 'ci_95': 0.4986086219549179}, 'val_recall': {'mean': 0.4701937953233719, 'std': 0.011901046526281179, 'ci_5': 0.4505683422088623, 'ci_95': 0.4903592631220817}, 'val_precision': {'mean': 0.49943355792760846, 'std': 0.01335505952555002, 'ci_5': 0.47756308764219285, 'ci_95': 0.521995696425438}}
Epoch 6/15



loss: 0.03575, smooth loss: 0.10400: 100%|██████████| 3610/3610 [29:14<00:00,  2.06it/s]
100%|██████████| 903/903 [07:04<00:00,  2.13it/s]


metrics {'val_loss': 0.38410437, 'val_acc': {'mean': 56.712354600429535, 'std': 1.2390155589795941, 'ci_5': 54.67866837978363, 'ci_95': 58.725762367248535}, 'val_kappa': {'mean': 0.7744861017603517, 'std': 0.013505563669357087, 'ci_5': 0.7520293402308849, 'ci_95': 0.7953119768238889}, 'val_f1': {'mean': 0.4898073253631592, 'std': 0.012874326233795908, 'ci_5': 0.46833808720111847, 'ci_95': 0.5106616616249084}, 'val_recall': {'mean': 0.48266786324977873, 'std': 0.012206627576031883, 'ci_5': 0.46268461644649506, 'ci_95': 0.5025647640228271}, 'val_precision': {'mean': 0.5143263120055198, 'std': 0.013551361532388975, 'ci_5': 0.4912600785493851, 'ci_95': 0.5364024579524994}}
Epoch 7/15



loss: 0.02265, smooth loss: 0.07828: 100%|██████████| 3610/3610 [29:13<00:00,  2.06it/s]
100%|██████████| 903/903 [07:04<00:00,  2.13it/s]


metrics {'val_loss': 0.44368303, 'val_acc': {'mean': 56.08016628026962, 'std': 1.2302036573254929, 'ci_5': 53.961217403411865, 'ci_95': 58.06094408035278}, 'val_kappa': {'mean': 0.7652437198298887, 'std': 0.013797615879271077, 'ci_5': 0.7428966773652254, 'ci_95': 0.7869717868883879}, 'val_f1': {'mean': 0.48349935057759286, 'std': 0.01239691498058713, 'ci_5': 0.4631100133061409, 'ci_95': 0.504277378320694}, 'val_recall': {'mean': 0.47662996184825895, 'std': 0.0118197787757199, 'ci_5': 0.4573521688580513, 'ci_95': 0.4965847671031952}, 'val_precision': {'mean': 0.505854597657919, 'std': 0.01316351697567171, 'ci_5': 0.4842665895819664, 'ci_95': 0.5272669464349746}}
Epoch 8/15



loss: 0.01750, smooth loss: 0.05926: 100%|██████████| 3610/3610 [29:13<00:00,  2.06it/s]
100%|██████████| 903/903 [07:04<00:00,  2.13it/s]


metrics {'val_loss': 0.5012648, 'val_acc': {'mean': 56.39689745903015, 'std': 1.232833777039414, 'ci_5': 54.45706397294998, 'ci_95': 58.39335322380066}, 'val_kappa': {'mean': 0.7683456769006534, 'std': 0.013714323311044175, 'ci_5': 0.7456100155915246, 'ci_95': 0.7896040139710523}, 'val_f1': {'mean': 0.4905656627416611, 'std': 0.012717279982412635, 'ci_5': 0.4696858897805214, 'ci_95': 0.5115055531263352}, 'val_recall': {'mean': 0.4828641346693039, 'std': 0.012160058577485916, 'ci_5': 0.46251715570688245, 'ci_95': 0.5031343907117843}, 'val_precision': {'mean': 0.5157651947140693, 'std': 0.013427067311677181, 'ci_5': 0.4935809552669525, 'ci_95': 0.5371984034776688}}
Epoch 9/15



loss: 0.01429, smooth loss: 0.04694: 100%|██████████| 3610/3610 [29:06<00:00,  2.07it/s]
100%|██████████| 903/903 [07:03<00:00,  2.13it/s]


metrics {'val_loss': 0.5393909, 'val_acc': {'mean': 56.216565078496934, 'std': 1.234226890046813, 'ci_5': 54.12742495536804, 'ci_95': 58.17174315452576}, 'val_kappa': {'mean': 0.7724576375823518, 'std': 0.01360111479229477, 'ci_5': 0.7507823467557602, 'ci_95': 0.7940907358285073}, 'val_f1': {'mean': 0.4943665365278721, 'std': 0.012659066891605628, 'ci_5': 0.47434520721435547, 'ci_95': 0.514823517203331}, 'val_recall': {'mean': 0.48759201073646546, 'std': 0.012392311515518395, 'ci_5': 0.4673697426915169, 'ci_95': 0.5076794147491455}, 'val_precision': {'mean': 0.5147203017771244, 'std': 0.013155515561681784, 'ci_5': 0.49262667298316953, 'ci_95': 0.5365533947944641}}
Epoch 10/15



loss: 0.01092, smooth loss: 0.03581: 100%|██████████| 3610/3610 [29:07<00:00,  2.07it/s]
100%|██████████| 903/903 [07:01<00:00,  2.14it/s]


metrics {'val_loss': 0.5824832, 'val_acc': {'mean': 55.43196700811386, 'std': 1.227950877026541, 'ci_5': 53.35180163383484, 'ci_95': 57.34071731567383}, 'val_kappa': {'mean': 0.7717014318907488, 'std': 0.013381767352031816, 'ci_5': 0.749559106873139, 'ci_95': 0.7931673237217114}, 'val_f1': {'mean': 0.48734716391563415, 'std': 0.012587505482038875, 'ci_5': 0.4668102741241455, 'ci_95': 0.5069258034229278}, 'val_recall': {'mean': 0.4804419705569744, 'std': 0.012348141428037744, 'ci_5': 0.45950971394777296, 'ci_95': 0.500304388999939}, 'val_precision': {'mean': 0.5077060605287552, 'std': 0.013092065322531068, 'ci_5': 0.48501157760620117, 'ci_95': 0.5279776841402054}}
Epoch 11/15



loss: 0.01090, smooth loss: 0.02669: 100%|██████████| 3610/3610 [29:08<00:00,  2.06it/s]
100%|██████████| 903/903 [07:02<00:00,  2.14it/s]


metrics {'val_loss': 0.6329251, 'val_acc': {'mean': 55.9409973859787, 'std': 1.2329908250797064, 'ci_5': 53.85041832923889, 'ci_95': 57.894736528396606}, 'val_kappa': {'mean': 0.771768374158152, 'std': 0.0135555397183713, 'ci_5': 0.7498494685583973, 'ci_95': 0.7928146623514971}, 'val_f1': {'mean': 0.49283202648162844, 'std': 0.012542763743206195, 'ci_5': 0.4727387949824333, 'ci_95': 0.5125750750303268}, 'val_recall': {'mean': 0.4869143050014973, 'std': 0.012308778438618514, 'ci_5': 0.46721028834581374, 'ci_95': 0.5063706487417221}, 'val_precision': {'mean': 0.5090836697816848, 'std': 0.013018474553178663, 'ci_5': 0.48676607608795164, 'ci_95': 0.5289625257253647}}
Epoch 12/15



loss: 0.01176, smooth loss: 0.02053: 100%|██████████| 3610/3610 [29:06<00:00,  2.07it/s]
100%|██████████| 903/903 [07:03<00:00,  2.13it/s]


metrics {'val_loss': 0.6637844, 'val_acc': {'mean': 55.09966780543327, 'std': 1.228413040979246, 'ci_5': 53.07202488183975, 'ci_95': 57.06371068954468}, 'val_kappa': {'mean': 0.7712769839394986, 'std': 0.013561990281308613, 'ci_5': 0.7485651900766844, 'ci_95': 0.7928735066775299}, 'val_f1': {'mean': 0.48742575642466546, 'std': 0.012366494981551686, 'ci_5': 0.4666631892323494, 'ci_95': 0.5062244683504105}, 'val_recall': {'mean': 0.4824883170127869, 'std': 0.012276254903123597, 'ci_5': 0.4617574542760849, 'ci_95': 0.5010228365659714}, 'val_precision': {'mean': 0.4996923592984676, 'std': 0.012733098904970975, 'ci_5': 0.47815521359443663, 'ci_95': 0.5186867564916611}}
Epoch 13/15



loss: 0.01107, smooth loss: 0.01598: 100%|██████████| 3610/3610 [29:12<00:00,  2.06it/s]
100%|██████████| 903/903 [07:01<00:00,  2.14it/s]


metrics {'val_loss': 0.6875027, 'val_acc': {'mean': 55.84753467440605, 'std': 1.2207332708048908, 'ci_5': 53.73961329460144, 'ci_95': 57.731299102306366}, 'val_kappa': {'mean': 0.7732416422836013, 'std': 0.013533738798728262, 'ci_5': 0.7502321313269255, 'ci_95': 0.7942126028484182}, 'val_f1': {'mean': 0.4965928558409214, 'std': 0.012320205775306434, 'ci_5': 0.4762835368514061, 'ci_95': 0.5158090978860855}, 'val_recall': {'mean': 0.4923291586637497, 'std': 0.012290326558030655, 'ci_5': 0.47165232598781587, 'ci_95': 0.512037169933319}, 'val_precision': {'mean': 0.5074684229791164, 'std': 0.012595033723088265, 'ci_5': 0.48634358644485476, 'ci_95': 0.5274286657571793}}
Epoch 14/15



loss: 0.01042, smooth loss: 0.01218: 100%|██████████| 3610/3610 [29:02<00:00,  2.07it/s]
100%|██████████| 903/903 [07:03<00:00,  2.13it/s]


metrics {'val_loss': 0.70701694, 'val_acc': {'mean': 55.72703610658645, 'std': 1.2216847955764818, 'ci_5': 53.684210777282715, 'ci_95': 57.72852897644043}, 'val_kappa': {'mean': 0.7698533355120811, 'std': 0.013661827399770135, 'ci_5': 0.7474491749687122, 'ci_95': 0.7916056310995245}, 'val_f1': {'mean': 0.494777637809515, 'std': 0.012265939325412548, 'ci_5': 0.4744821161031723, 'ci_95': 0.5142522543668747}, 'val_recall': {'mean': 0.49064872911572455, 'std': 0.012228644609554027, 'ci_5': 0.4701273009181023, 'ci_95': 0.5106257885694504}, 'val_precision': {'mean': 0.5052973806560039, 'std': 0.01252835718937515, 'ci_5': 0.4840704932808876, 'ci_95': 0.5249433249235153}}
Epoch 15/15



loss: 0.01110, smooth loss: 0.01084: 100%|██████████| 3610/3610 [29:17<00:00,  2.05it/s]
100%|██████████| 903/903 [07:03<00:00,  2.13it/s]


metrics {'val_loss': 0.7159609, 'val_acc': {'mean': 55.29096977114678, 'std': 1.2309620549232445, 'ci_5': 53.24099659919739, 'ci_95': 57.34071731567383}, 'val_kappa': {'mean': 0.7692383992615909, 'std': 0.01370252368190862, 'ci_5': 0.7463088087311143, 'ci_95': 0.790861715714273}, 'val_f1': {'mean': 0.49345261174440386, 'std': 0.012429264237152743, 'ci_5': 0.47333581298589705, 'ci_95': 0.5139328181743622}, 'val_recall': {'mean': 0.4896802387535572, 'std': 0.012476483548452492, 'ci_5': 0.46974306404590604, 'ci_95': 0.5099111109972}, 'val_precision': {'mean': 0.5033396010994912, 'std': 0.012625210477940421, 'ci_5': 0.48311374336481094, 'ci_95': 0.5240724086761475}}
