In [9]:
import numpy as np
from warmup_scheduler import GradualWarmupScheduler
import pandas as pd
import os
import torch
from torch import nn, optim
from sklearn.model_selection import train_test_split

from work.utils.dataset import PandasDataset
from work.utils.dataset import RemovePenMarkAlbumentations
import albumentations as A

from torch.utils.data import DataLoader
from work.utils.models import EfficientNet
from work.utils.train import train_model
from work.utils.metrics import model_checkpoint

In [2]:
backbone_model = 'efficientnet-b0'
pretrained_model = {
    backbone_model: 'pre-trained-models/efficientnet-b0-08094119.pth'
}
data_dir = 'data'
images_dir = os.path.join(data_dir, 'tiles')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [10]:
# Leitura dos dados
df_remove = pd.read_csv(f"{data_dir}/remove-images.csv").sort_values(by=["entropy"], ascending=False)
df = pd.read_csv(f"{data_dir}/train_val.csv")
df_remove_filtered = df_remove[df_remove["entropy"] < 1.3]
#
df_filtered = df[df["image_id"].isin(df_remove_filtered["image_id"])]
print(f"Original shape: {df.shape}")
print(f"Filtered shape: {df_filtered.shape}")

array([1.3160478, 1.3184186, 1.3194076, 1.3210733, 1.3219315])

In [4]:
batch_size = 2
num_workers = 4
output_classes = 5
init_lr = 3e-4
loss_function = nn.BCEWithLogitsLoss()
epochs = 50
n_folds = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
transforms = A.Compose([
    # RemovePenMarkAlbumentations(),
    A.Transpose(p=0.5),
    A.VerticalFlip(p=0.5),
    A.HorizontalFlip(p=0.5),
])

In [6]:
df_train, df_val = train_test_split(df_filtered, test_size=0.20)

dataset_train = PandasDataset("../dataset/tiles", df_train, transforms=transforms)
dataset_valid = PandasDataset("../dataset/tiles", df_val)

print(f"train: {len(dataset_train)} images | validation: {len(dataset_valid)} images ")

train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)

valid_loader = DataLoader(dataset_valid, batch_size=batch_size, shuffle=True, num_workers=num_workers)

model = EfficientNet(backbone_model, output_classes, weights_path=pretrained_model.get(backbone_model))
optimizer = optim.Adam(model.parameters(), lr=init_lr)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs - 1)

scheduler = GradualWarmupScheduler(optimizer, multiplier=10, total_epoch = 1, after_scheduler=scheduler_cosine)


save_path = f'pre-trained-models/removed-images.pth'

train_model(
    model,
    epochs,
    optimizer,
    scheduler,
    train_loader,
    valid_loader,
    df_val,
    checkpoint=model_checkpoint,
    device=device,
    loss_function=loss_function,
    path_to_save_metrics="logs/history/removed-images.txt",
    path_to_save_model=save_path,
)



train: 6433 images | validation: 1609 images 
Loaded pretrained weights for efficientnet-b0
Epoch 1/50



loss: 0.19289, smooth loss: 0.34276: 100%|██████████| 3217/3217 [22:40<00:00,  2.36it/s]
100%|██████████| 805/805 [01:48<00:00,  7.44it/s]


Epoch 2/50



loss: 0.66366, smooth loss: 0.54756: 100%|██████████| 3217/3217 [22:30<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.42it/s]


Epoch 3/50



loss: 0.65118, smooth loss: 0.47414: 100%|██████████| 3217/3217 [22:31<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.42it/s]


Epoch 4/50



loss: 0.82752, smooth loss: 0.42192: 100%|██████████| 3217/3217 [22:31<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.41it/s]


Epoch 5/50



loss: 0.34460, smooth loss: 0.36397: 100%|██████████| 3217/3217 [22:31<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.41it/s]


Epoch 6/50



loss: 0.29393, smooth loss: 0.39305: 100%|██████████| 3217/3217 [22:30<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.41it/s]


Epoch 7/50



loss: 0.19315, smooth loss: 0.38934: 100%|██████████| 3217/3217 [22:30<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.41it/s]


Epoch 8/50



loss: 0.26637, smooth loss: 0.37560: 100%|██████████| 3217/3217 [22:29<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.42it/s]


Epoch 9/50



loss: 0.44706, smooth loss: 0.36533: 100%|██████████| 3217/3217 [22:31<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.41it/s]


Epoch 10/50



loss: 0.71648, smooth loss: 0.34013: 100%|██████████| 3217/3217 [22:30<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.41it/s]


Epoch 11/50



loss: 0.30395, smooth loss: 0.36622: 100%|██████████| 3217/3217 [22:29<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.40it/s]


Epoch 12/50



loss: 0.33344, smooth loss: 0.33689: 100%|██████████| 3217/3217 [22:30<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.41it/s]


Epoch 13/50



loss: 0.28462, smooth loss: 0.30389: 100%|██████████| 3217/3217 [22:30<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.41it/s]


Epoch 14/50



loss: 0.27154, smooth loss: 0.26109: 100%|██████████| 3217/3217 [22:30<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.41it/s]


Epoch 15/50



loss: 0.15762, smooth loss: 0.28679: 100%|██████████| 3217/3217 [22:30<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.41it/s]


Epoch 16/50



loss: 0.13094, smooth loss: 0.27758: 100%|██████████| 3217/3217 [22:30<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.41it/s]


Epoch 17/50



loss: 0.22893, smooth loss: 0.27006: 100%|██████████| 3217/3217 [22:31<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.41it/s]


Epoch 18/50



loss: 0.22181, smooth loss: 0.28298: 100%|██████████| 3217/3217 [22:30<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.41it/s]


Epoch 19/50



loss: 0.56986, smooth loss: 0.29898: 100%|██████████| 3217/3217 [22:30<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.42it/s]


Epoch 20/50



loss: 0.24890, smooth loss: 0.27832: 100%|██████████| 3217/3217 [22:30<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.40it/s]


Epoch 21/50



loss: 0.03613, smooth loss: 0.31170: 100%|██████████| 3217/3217 [22:30<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.42it/s]


Epoch 22/50



loss: 0.06961, smooth loss: 0.23693: 100%|██████████| 3217/3217 [22:30<00:00,  2.38it/s]
100%|██████████| 805/805 [01:48<00:00,  7.41it/s]


Epoch 23/50



loss: 0.12256, smooth loss: 0.23848: 100%|██████████| 3217/3217 [22:30<00:00,  2.38it/s]
100%|██████████| 805/805 [01:49<00:00,  7.35it/s]


Epoch 24/50



loss: 0.09114, smooth loss: 0.24301: 100%|██████████| 3217/3217 [22:16<00:00,  2.41it/s]
100%|██████████| 805/805 [01:47<00:00,  7.48it/s]


Epoch 25/50



loss: 0.13483, smooth loss: 0.23758: 100%|██████████| 3217/3217 [22:17<00:00,  2.40it/s]
100%|██████████| 805/805 [01:47<00:00,  7.47it/s]


Epoch 26/50



loss: 0.19264, smooth loss: 0.26626: 100%|██████████| 3217/3217 [22:18<00:00,  2.40it/s]
100%|██████████| 805/805 [01:47<00:00,  7.47it/s]


Epoch 27/50



loss: 0.02700, smooth loss: 0.25329: 100%|██████████| 3217/3217 [22:18<00:00,  2.40it/s]
100%|██████████| 805/805 [01:47<00:00,  7.49it/s]


Epoch 28/50



loss: 0.61208, smooth loss: 0.24532: 100%|██████████| 3217/3217 [22:18<00:00,  2.40it/s]
100%|██████████| 805/805 [01:47<00:00,  7.46it/s]


Epoch 29/50



loss: 0.10573, smooth loss: 0.22180: 100%|██████████| 3217/3217 [22:18<00:00,  2.40it/s]
100%|██████████| 805/805 [01:47<00:00,  7.47it/s]


Epoch 30/50



loss: 0.17553, smooth loss: 0.27396: 100%|██████████| 3217/3217 [22:17<00:00,  2.40it/s]
100%|██████████| 805/805 [01:47<00:00,  7.48it/s]


Epoch 31/50



loss: 0.08644, smooth loss: 0.24895: 100%|██████████| 3217/3217 [22:17<00:00,  2.41it/s]
100%|██████████| 805/805 [01:47<00:00,  7.47it/s]


Epoch 32/50



loss: 0.15631, smooth loss: 0.21872: 100%|██████████| 3217/3217 [22:17<00:00,  2.40it/s]
100%|██████████| 805/805 [01:47<00:00,  7.47it/s]


Epoch 33/50



loss: 0.16503, smooth loss: 0.20644: 100%|██████████| 3217/3217 [22:17<00:00,  2.40it/s]
100%|██████████| 805/805 [01:47<00:00,  7.45it/s]


Epoch 34/50



loss: 0.19203, smooth loss: 0.23863: 100%|██████████| 3217/3217 [22:17<00:00,  2.41it/s]
100%|██████████| 805/805 [01:47<00:00,  7.47it/s]


Epoch 35/50



loss: 0.46432, smooth loss: 0.20803: 100%|██████████| 3217/3217 [22:17<00:00,  2.41it/s]
100%|██████████| 805/805 [01:48<00:00,  7.45it/s]


Epoch 36/50



loss: 0.39530, smooth loss: 0.22771: 100%|██████████| 3217/3217 [22:17<00:00,  2.41it/s]
100%|██████████| 805/805 [01:47<00:00,  7.48it/s]


Epoch 37/50



loss: 0.17161, smooth loss: 0.19809: 100%|██████████| 3217/3217 [22:17<00:00,  2.40it/s]
100%|██████████| 805/805 [01:47<00:00,  7.46it/s]


Epoch 38/50



loss: 0.14992, smooth loss: 0.23161: 100%|██████████| 3217/3217 [22:17<00:00,  2.40it/s]
100%|██████████| 805/805 [01:47<00:00,  7.48it/s]


Epoch 39/50



loss: 0.28897, smooth loss: 0.21953: 100%|██████████| 3217/3217 [22:18<00:00,  2.40it/s]
100%|██████████| 805/805 [01:47<00:00,  7.48it/s]


Epoch 40/50



loss: 0.04079, smooth loss: 0.23166: 100%|██████████| 3217/3217 [22:18<00:00,  2.40it/s]
100%|██████████| 805/805 [01:47<00:00,  7.48it/s]


Epoch 41/50



loss: 0.30202, smooth loss: 0.22291: 100%|██████████| 3217/3217 [22:16<00:00,  2.41it/s]
100%|██████████| 805/805 [01:47<00:00,  7.48it/s]


Epoch 42/50



loss: 0.36101, smooth loss: 0.24058: 100%|██████████| 3217/3217 [22:16<00:00,  2.41it/s]
100%|██████████| 805/805 [01:48<00:00,  7.45it/s]


Epoch 43/50



loss: 0.06505, smooth loss: 0.23013: 100%|██████████| 3217/3217 [22:17<00:00,  2.41it/s]
100%|██████████| 805/805 [01:47<00:00,  7.46it/s]


Epoch 44/50



loss: 0.31235, smooth loss: 0.17784: 100%|██████████| 3217/3217 [22:16<00:00,  2.41it/s]
100%|██████████| 805/805 [01:47<00:00,  7.47it/s]


Epoch 45/50



loss: 0.01971, smooth loss: 0.19824: 100%|██████████| 3217/3217 [22:17<00:00,  2.41it/s]
100%|██████████| 805/805 [01:47<00:00,  7.46it/s]


Epoch 46/50



loss: 0.03468, smooth loss: 0.23096: 100%|██████████| 3217/3217 [22:17<00:00,  2.41it/s]
100%|██████████| 805/805 [01:47<00:00,  7.47it/s]


Epoch 47/50



loss: 0.38374, smooth loss: 0.24168: 100%|██████████| 3217/3217 [22:17<00:00,  2.41it/s]
100%|██████████| 805/805 [01:47<00:00,  7.46it/s]


Epoch 48/50



loss: 0.18908, smooth loss: 0.17868: 100%|██████████| 3217/3217 [22:17<00:00,  2.41it/s]
100%|██████████| 805/805 [01:47<00:00,  7.46it/s]


Epoch 49/50



loss: 0.23410, smooth loss: 0.18333: 100%|██████████| 3217/3217 [22:17<00:00,  2.41it/s]
100%|██████████| 805/805 [01:47<00:00,  7.46it/s]


Epoch 50/50



loss: 0.24000, smooth loss: 0.22716: 100%|██████████| 3217/3217 [22:17<00:00,  2.40it/s]
100%|██████████| 805/805 [01:48<00:00,  7.45it/s]


In [9]:
model_checkpoint(model, 9, 10, save_path)

Salvando o melhor modelo... 9 -> 10


10