In [1]:
%cd ..

/home/antonbabenko/Projects/newsclass01


In [2]:
%load_ext autoreload
%autoreload 2
%aimport src.model
%aimport src.data
%aimport src.train
%aimport src.utils
%aimport src.metrics

In [3]:
import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import ttach as tta
import torch.nn as nn
import albumentations as A
from torch.utils.data import DataLoader
from albumentations.pytorch import ToTensorV2
from efficientnet_pytorch import EfficientNet
from pytorch_toolbelt import losses as L

from src.utils import set_seed, load_splits, get_tensorboard_writer
from src.data import ImageClassificationDataset
from src.train import fit
from src.metrics import f1_macro

In [4]:
SEED = 42
BATCH_SIZE = 32
EPOCHS = 40
LEARNING_RATE = 5e-3

In [5]:
set_seed(SEED)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
logs_dir = Path("logs") 
states_dir = Path("states") / "cv"
data_dir = Path("data")
folds = Path("folds") / "cv"
image_path = data_dir / "images" / "images"

In [8]:
splits = load_splits(folds, val_folds=[2], train_folds=[0, 1, 3, 4]) 

In [9]:
num_classes = splits["train"].source.nunique()

In [10]:
transform = A.Compose([
    A.Resize(256, 256),
    A.RandomCrop(224, 224),
    A.HorizontalFlip(),
    A.RandomRotate90(0.5),
    A.ColorJitter(),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
    ToTensorV2()
])
    
test_transform = A.Compose([
    A.Resize(256, 256),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
    ToTensorV2()
])

In [11]:
train_dataset = ImageClassificationDataset(
    df=splits["train"],
    folder=image_path, 
    mode="train",
    transform=transform
)

valid_dataset = ImageClassificationDataset(
    df=splits["val"],
    folder=image_path, 
    mode="val",
    transform=transform
)

# test_dataset = ImageClassificationDataset(
#     df=splits["test"],
#     folder=image_path, 
#     mode="test",
#     transform=test_transform
# )

In [12]:
# train_dataset = Subset(train_dataset, 20)
# valid_dataset = Subset(valid_dataset, 20)
# test_dataset = Subset(test_dataset, 20)

In [13]:
train_data_loader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

val_data_loader = DataLoader(
    dataset=valid_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
)

# test_data_loader = DataLoader(
#     dataset=test_dataset,
#     batch_size=BATCH_SIZE,
#     shuffle=False,
# )

In [14]:
dataloaders = {
    "train": train_data_loader,
    "val": val_data_loader,
#     "test": test_data_loader
}

In [15]:
mini_batch_data = next(iter(val_data_loader))

### Model Setup

In [16]:
model_name = 'efficientnet-b4'

In [17]:
model = EfficientNet.from_pretrained(model_name)

Loaded pretrained weights for efficientnet-b4


In [18]:
num_ftrs = model._fc.in_features
model._fc = torch.nn.Linear(num_ftrs, num_classes)
model = model.to(device)

### Metrics

In [19]:
metrics = {
    "f1_macro": f1_macro
}

### Tensorboard setup

In [20]:
model_name = f"{model_name}-{int(datetime.datetime.now().timestamp())}"
writer = get_tensorboard_writer(logs_dir, model_name) 

In [21]:
# criterion = nn.CrossEntropyLoss().to(device)
criterion = L.FocalLoss().to(device)
params = list(model.parameters())
optimizer = torch.optim.Adam(params, lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3)

In [22]:
fit(
    model,
    criterion,
    optimizer,
    dataloaders,
    device,
    scheduler=scheduler,
    metrics=metrics,
    epochs=EPOCHS,
    model_name=model_name,
    model_folder=states_dir,
    writer=writer,
    fit_type="cv"
)

Epoch: 1/40, train phase: 100%|██████████| 616/616 [07:04<00:00,  1.45it/s]
Epoch: 1/40, validation phase: 100%|██████████| 154/154 [00:54<00:00,  2.83it/s]


{'f1_macro': 0.07454929617086377}
Epoch: 1/40, time: 478.8328278064728 train loss: 0.7494287510996356, val loss: 0.8331470197832523


Epoch: 2/40, train phase:   0%|          | 0/616 [00:00<?, ?it/s]

Checkpoint was saved


Epoch: 2/40, train phase: 100%|██████████| 616/616 [06:57<00:00,  1.47it/s]
Epoch: 2/40, validation phase: 100%|██████████| 154/154 [00:53<00:00,  2.88it/s]


{'f1_macro': 0.14458319787283802}
Epoch: 2/40, time: 471.2422857284546 train loss: 0.6953326252319043, val loss: 0.8758661457003676


Epoch: 3/40, train phase:   0%|          | 0/616 [00:00<?, ?it/s]

Checkpoint was saved


Epoch: 3/40, train phase: 100%|██████████| 616/616 [06:59<00:00,  1.47it/s]
Epoch: 3/40, validation phase: 100%|██████████| 154/154 [00:53<00:00,  2.88it/s]
Epoch: 4/40, train phase:   0%|          | 0/616 [00:00<?, ?it/s]

{'f1_macro': 0.12993452312740378}
Epoch: 3/40, time: 473.1540901660919 train loss: 0.6297334525977193, val loss: 2.342773691022457


Epoch: 4/40, train phase: 100%|██████████| 616/616 [07:05<00:00,  1.45it/s]
Epoch: 4/40, validation phase: 100%|██████████| 154/154 [00:53<00:00,  2.87it/s]


{'f1_macro': 0.2524325768828478}
Epoch: 4/40, time: 479.15545201301575 train loss: 0.559888451444273, val loss: 0.5941699919603802


Epoch: 5/40, train phase:   0%|          | 0/616 [00:00<?, ?it/s]

Checkpoint was saved


Epoch: 5/40, train phase: 100%|██████████| 616/616 [07:01<00:00,  1.46it/s]
Epoch: 5/40, validation phase: 100%|██████████| 154/154 [00:53<00:00,  2.88it/s]


{'f1_macro': 0.2686941217462447}
Epoch: 5/40, time: 475.3205852508545 train loss: 0.5308320921619227, val loss: 0.7098926579710191


Epoch: 6/40, train phase:   0%|          | 0/616 [00:00<?, ?it/s]

Checkpoint was saved


Epoch: 6/40, train phase: 100%|██████████| 616/616 [07:03<00:00,  1.45it/s]
Epoch: 6/40, validation phase: 100%|██████████| 154/154 [00:55<00:00,  2.78it/s]
Epoch: 7/40, train phase:   0%|          | 0/616 [00:00<?, ?it/s]

{'f1_macro': 0.23841736508454925}
Epoch: 6/40, time: 478.73261308670044 train loss: 0.5120428024654406, val loss: 0.6204316828093553


Epoch: 7/40, train phase: 100%|██████████| 616/616 [07:08<00:00,  1.44it/s]
Epoch: 7/40, validation phase: 100%|██████████| 154/154 [00:52<00:00,  2.91it/s]


{'f1_macro': 0.29592375918773406}
Epoch: 7/40, time: 481.93724513053894 train loss: 0.49977869204479014, val loss: 0.550253440171934


Epoch: 8/40, train phase:   0%|          | 0/616 [00:00<?, ?it/s]

Checkpoint was saved


Epoch: 8/40, train phase: 100%|██████████| 616/616 [06:55<00:00,  1.48it/s]
Epoch: 8/40, validation phase: 100%|██████████| 154/154 [00:52<00:00,  2.92it/s]
Epoch: 9/40, train phase:   0%|          | 0/616 [00:00<?, ?it/s]

{'f1_macro': 0.24770145928169268}
Epoch: 8/40, time: 467.707154750824 train loss: 0.49132841470259825, val loss: 0.572717948377435


Epoch: 9/40, train phase: 100%|██████████| 616/616 [06:50<00:00,  1.50it/s]
Epoch: 9/40, validation phase: 100%|██████████| 154/154 [00:52<00:00,  2.95it/s]


{'f1_macro': 0.33336602841887497}
Epoch: 9/40, time: 462.2965052127838 train loss: 0.48266636281850056, val loss: 0.48768098063275295


Epoch: 10/40, train phase:   0%|          | 0/616 [00:00<?, ?it/s]

Checkpoint was saved


Epoch: 10/40, train phase: 100%|██████████| 616/616 [06:49<00:00,  1.50it/s]
Epoch: 10/40, validation phase: 100%|██████████| 154/154 [00:52<00:00,  2.96it/s]
Epoch: 11/40, train phase:   0%|          | 0/616 [00:00<?, ?it/s]

{'f1_macro': 0.2959091677885439}
Epoch: 10/40, time: 461.8872344493866 train loss: 0.4738398236764603, val loss: 0.5776516892825286


Epoch: 11/40, train phase: 100%|██████████| 616/616 [06:46<00:00,  1.51it/s]
Epoch: 11/40, validation phase: 100%|██████████| 154/154 [00:51<00:00,  2.98it/s]
Epoch: 12/40, train phase:   0%|          | 0/616 [00:00<?, ?it/s]

{'f1_macro': 0.329613793856617}
Epoch: 11/40, time: 458.38076877593994 train loss: 0.46976008564513666, val loss: 0.5143769050249594


Epoch: 12/40, train phase: 100%|██████████| 616/616 [06:55<00:00,  1.48it/s]
Epoch: 12/40, validation phase: 100%|██████████| 154/154 [00:57<00:00,  2.67it/s]
Epoch: 13/40, train phase:   0%|          | 0/616 [00:00<?, ?it/s]

{'f1_macro': 0.29381688548368723}
Epoch: 12/40, time: 473.0814280509949 train loss: 0.4619527160259024, val loss: 0.7128719202758091


Epoch: 13/40, train phase: 100%|██████████| 616/616 [07:23<00:00,  1.39it/s]
Epoch: 13/40, validation phase: 100%|██████████| 154/154 [00:54<00:00,  2.84it/s]


{'f1_macro': 0.3393575367615059}
Epoch: 13/40, time: 497.5826597213745 train loss: 0.4532629638067617, val loss: 0.507262180916549


Epoch: 14/40, train phase:   0%|          | 0/616 [00:00<?, ?it/s]

Checkpoint was saved


Epoch: 14/40, train phase: 100%|██████████| 616/616 [07:12<00:00,  1.43it/s]
Epoch: 14/40, validation phase: 100%|██████████| 154/154 [00:55<00:00,  2.79it/s]


{'f1_macro': 0.3774437381120091}
Epoch: 14/40, time: 487.3890085220337 train loss: 0.42016663285788985, val loss: 0.4347921196216254


Epoch: 15/40, train phase:   0%|          | 0/616 [00:00<?, ?it/s]

Checkpoint was saved


Epoch: 15/40, train phase: 100%|██████████| 616/616 [07:03<00:00,  1.45it/s]
Epoch: 15/40, validation phase: 100%|██████████| 154/154 [00:53<00:00,  2.90it/s]
Epoch: 16/40, train phase:   0%|          | 0/616 [00:00<?, ?it/s]

{'f1_macro': 0.3770447262939835}
Epoch: 15/40, time: 476.39202404022217 train loss: 0.4131841394730501, val loss: 0.4243992874283476


Epoch: 16/40, train phase:  42%|████▏     | 261/616 [02:57<04:01,  1.47it/s]


KeyboardInterrupt: 

In [None]:
tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform())