In [1]:
import os
import random

from copy import deepcopy
from datetime import datetime
from collections import defaultdict

import torch
import numpy as np
import pandas as pd
import albumentations as A
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import classification_report
from torch.utils.data import Dataset, DataLoader

from densenet import DenseNet

seed = datetime.today().year
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
class ImageDataset(Dataset):
    def __init__(self, metadata: pd.DataFrame, transform=None):
        super().__init__()
        self.metadata = metadata.reset_index(drop=True)
        self.classes_to_idx = {
            cls: idx for idx, cls in enumerate(self.metadata.target.unique())
        }
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        metadata = self.metadata.iloc[idx]
        image = Image.open(metadata.image)
        image = np.asarray(image.convert("RGB"))
        label = self.classes_to_idx[metadata.target]
        if not self.transform is None:
            image = self.transform(image=image)["image"]
        image = image.transpose(2, 0, 1)
        return image, label


def create_metadata(root: str) -> pd.DataFrame:
    classes = sorted([folder for folder in os.listdir(root)])
    metadata = []
    for cls in classes:
        subfolder = f"{root}/{cls}"
        for image in os.listdir(subfolder):
            metadata += [{"image": f"{subfolder}/{image}", "target": cls}]
    metadata = pd.DataFrame(metadata)
    return metadata


In [3]:
batch_size = 32
epochs = 20

metadata_train = create_metadata("train")
metadata_valid = create_metadata("valid")
metadata_test = create_metadata("test")

transform_train = A.Compose(
    [
        A.VerticalFlip(p=1 / 16),
        A.HorizontalFlip(p=1 / 16),
        A.ColorJitter(p=1 / 16),
        A.Affine(1.25, p=1 / 16),
        A.Resize(224, 224),
        A.Normalize(),
    ]
)

transform_eval = A.Compose([A.Resize(224, 224), A.Normalize()])

dataset_train = ImageDataset(metadata_train, transform_train)
dataset_valid = ImageDataset(metadata_valid, transform_eval)
dataset_test = ImageDataset(metadata_test, transform_eval)


trainloader = DataLoader(dataset_train, batch_size, True)
validloader = DataLoader(dataset_valid, batch_size)
testloader = DataLoader(dataset_test, batch_size)

model = DenseNet(len(dataset_train.classes_to_idx), 3).to(device)
optimizer = torch.optim.AdamW(model.parameters(), 1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(trainloader) * 10)
scaler = torch.cuda.amp.GradScaler()
loss_function_train = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
loss_function_eval = torch.nn.CrossEntropyLoss(reduction="none")


In [4]:
loss_best = torch.inf
history = []
for i in range(1, epochs + 1):
    print(f"Epoch {i}/{epochs}")
    progbar = tqdm(total=len(trainloader))
    model.train()
    metrics_train = defaultdict(float)
    results_train = defaultdict(list)
    for i, (inputs, targets) in enumerate(trainloader):
        inputs = inputs.to(device)
        targets = targets.to(device)

        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = loss_function_train(outputs, targets)

        if loss.isnan():
            break

        scale = scaler.get_scale()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        optimizer.step()
        scaler.update()
        if scale > scaler.get_scale():
            scheduler.step()
        optimizer.zero_grad()

        loss = loss.detach().cpu().numpy()
        outputs = outputs.detach().cpu().argmax(-1).numpy()
        targets = targets.detach().cpu().numpy()

        score = classification_report(
            targets, outputs, output_dict=True, zero_division=0
        )

        results_train["loss"] += [loss]
        results_train["accuracy"] += [score["accuracy"]]
        results_train["precision"] += [score["macro avg"]["precision"]]
        results_train["recall"] += [score["macro avg"]["recall"]]
        results_train["f1"] += [score["macro avg"]["f1-score"]]

        metrics_train["loss"] = np.mean(loss)
        metrics_train["accuracy"] = np.mean(results_train["accuracy"])
        metrics_train["precision"] = np.mean(results_train["precision"])
        metrics_train["recall"] = np.mean(results_train["recall"])
        metrics_train["f1"] = np.mean(results_train["f1"])

        progbar.set_postfix(metrics_train)
        progbar.update(1)

    model.eval()
    metrics_valid = defaultdict(float)
    results_valid = defaultdict(list)
    for i, (inputs, targets) in enumerate(validloader):
        inputs = inputs.to(device)
        targets = targets.to(device)

        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = loss_function_eval(outputs, targets)

        loss = loss.detach().cpu().numpy().tolist()
        outputs = outputs.detach().cpu().argmax(-1).numpy().tolist()
        targets = targets.detach().cpu().numpy().tolist()

        results_valid["loss"] += loss
        results_valid["output"] += outputs
        results_valid["target"] += targets

    score = classification_report(
        results_valid["target"],
        results_valid["output"],
        output_dict=True,
        zero_division=0,
    )
    metrics_valid["val_loss"] = np.mean(loss)
    metrics_valid["val_accuracy"] = score["accuracy"]
    metrics_valid["val_precision"] = score["macro avg"]["precision"]
    metrics_valid["val_recall"] = score["macro avg"]["recall"]
    metrics_valid["val_f1"] = score["macro avg"]["f1-score"]

    metrics = {**metrics_train, **metrics_valid}
    progbar.set_postfix(metrics)
    progbar.close()
    history += [metrics]
    if loss_best > history[-1]["val_loss"]:
        loss_best = history[-1]["val_loss"]
        model_best = deepcopy(model.state_dict())
        optimizer_best = deepcopy(optimizer.state_dict())
        scheduler_best = deepcopy(scheduler.state_dict())
        scaler_best = deepcopy(scaler.state_dict())

    model.load_state_dict(model_best)
    optimizer.load_state_dict(optimizer_best)
    scheduler.load_state_dict(scheduler_best)
    scaler.load_state_dict(scaler_best)


Epoch 1/20


100%|██████████| 1825/1825 [09:23<00:00,  3.24it/s, loss=3.88, accuracy=0.0828, precision=0.0456, recall=0.0487, f1=0.0461, val_loss=2.35, val_accuracy=0.243, val_precision=0.24, val_recall=0.244, val_f1=0.199]


Epoch 2/20


100%|██████████| 1825/1825 [09:21<00:00,  3.25it/s, loss=2.83, accuracy=0.387, precision=0.252, recall=0.252, f1=0.25, val_loss=1.56, val_accuracy=0.539, val_precision=0.636, val_recall=0.539, val_f1=0.523]


Epoch 3/20


100%|██████████| 1825/1825 [09:30<00:00,  3.20it/s, loss=2.57, accuracy=0.576, precision=0.417, recall=0.416, f1=0.414, val_loss=0.95, val_accuracy=0.725, val_precision=0.774, val_recall=0.725, val_f1=0.715]


Epoch 4/20


100%|██████████| 1825/1825 [04:47<00:00,  6.34it/s, loss=2.76, accuracy=0.666, precision=0.512, recall=0.51, f1=0.509, val_loss=nan, val_accuracy=0.0025, val_precision=6.25e-6, val_recall=0.0025, val_f1=1.25e-5]


Epoch 5/20


100%|██████████| 1825/1825 [04:45<00:00,  6.40it/s, loss=2, accuracy=0.663, precision=0.507, recall=0.505, f1=0.504, val_loss=nan, val_accuracy=0.0025, val_precision=6.25e-6, val_recall=0.0025, val_f1=1.25e-5]


Epoch 6/20


100%|██████████| 1825/1825 [04:46<00:00,  6.37it/s, loss=2.67, accuracy=0.668, precision=0.514, recall=0.512, f1=0.511, val_loss=nan, val_accuracy=0.0025, val_precision=6.25e-6, val_recall=0.0025, val_f1=1.25e-5]


Epoch 7/20


100%|██████████| 1825/1825 [04:42<00:00,  6.46it/s, loss=2.93, accuracy=0.666, precision=0.513, recall=0.511, f1=0.51, val_loss=nan, val_accuracy=0.0025, val_precision=6.25e-6, val_recall=0.0025, val_f1=1.25e-5]


Epoch 8/20


100%|██████████| 1825/1825 [04:42<00:00,  6.46it/s, loss=2.64, accuracy=0.665, precision=0.512, recall=0.509, f1=0.508, val_loss=nan, val_accuracy=0.0025, val_precision=6.25e-6, val_recall=0.0025, val_f1=1.25e-5]


Epoch 9/20


100%|██████████| 1825/1825 [04:44<00:00,  6.42it/s, loss=2.68, accuracy=0.659, precision=0.506, recall=0.504, f1=0.503, val_loss=nan, val_accuracy=0.0025, val_precision=6.25e-6, val_recall=0.0025, val_f1=1.25e-5]


Epoch 10/20


100%|██████████| 1825/1825 [04:49<00:00,  6.30it/s, loss=2.31, accuracy=0.662, precision=0.507, recall=0.506, f1=0.505, val_loss=nan, val_accuracy=0.0025, val_precision=6.25e-6, val_recall=0.0025, val_f1=1.25e-5]


Epoch 11/20


100%|██████████| 1825/1825 [04:42<00:00,  6.45it/s, loss=2.59, accuracy=0.66, precision=0.504, recall=0.502, f1=0.501, val_loss=nan, val_accuracy=0.0025, val_precision=6.25e-6, val_recall=0.0025, val_f1=1.25e-5]


Epoch 12/20


100%|██████████| 1825/1825 [04:44<00:00,  6.41it/s, loss=2.82, accuracy=0.667, precision=0.513, recall=0.511, f1=0.51, val_loss=nan, val_accuracy=0.0025, val_precision=6.25e-6, val_recall=0.0025, val_f1=1.25e-5]


Epoch 13/20


100%|██████████| 1825/1825 [04:45<00:00,  6.39it/s, loss=2.93, accuracy=0.667, precision=0.513, recall=0.512, f1=0.511, val_loss=nan, val_accuracy=0.0025, val_precision=6.25e-6, val_recall=0.0025, val_f1=1.25e-5]


Epoch 14/20


100%|██████████| 1825/1825 [04:47<00:00,  6.35it/s, loss=2.5, accuracy=0.659, precision=0.504, recall=0.502, f1=0.501, val_loss=nan, val_accuracy=0.0025, val_precision=6.25e-6, val_recall=0.0025, val_f1=1.25e-5]


Epoch 15/20


100%|██████████| 1825/1825 [04:43<00:00,  6.44it/s, loss=2.66, accuracy=0.664, precision=0.51, recall=0.508, f1=0.507, val_loss=nan, val_accuracy=0.0025, val_precision=6.25e-6, val_recall=0.0025, val_f1=1.25e-5]


Epoch 16/20


100%|██████████| 1825/1825 [04:42<00:00,  6.45it/s, loss=2.2, accuracy=0.664, precision=0.509, recall=0.507, f1=0.506, val_loss=nan, val_accuracy=0.0025, val_precision=6.25e-6, val_recall=0.0025, val_f1=1.25e-5]


Epoch 17/20


100%|██████████| 1825/1825 [04:44<00:00,  6.41it/s, loss=2.73, accuracy=0.665, precision=0.51, recall=0.509, f1=0.508, val_loss=nan, val_accuracy=0.0025, val_precision=6.25e-6, val_recall=0.0025, val_f1=1.25e-5]


Epoch 18/20


100%|██████████| 1825/1825 [04:46<00:00,  6.38it/s, loss=2.19, accuracy=0.663, precision=0.508, recall=0.506, f1=0.505, val_loss=nan, val_accuracy=0.0025, val_precision=6.25e-6, val_recall=0.0025, val_f1=1.25e-5]


Epoch 19/20


100%|██████████| 1825/1825 [04:43<00:00,  6.45it/s, loss=2.63, accuracy=0.662, precision=0.508, recall=0.506, f1=0.505, val_loss=nan, val_accuracy=0.0025, val_precision=6.25e-6, val_recall=0.0025, val_f1=1.25e-5]


Epoch 20/20


100%|██████████| 1825/1825 [04:41<00:00,  6.48it/s, loss=3.26, accuracy=0.664, precision=0.51, recall=0.508, f1=0.507, val_loss=nan, val_accuracy=0.0025, val_precision=6.25e-6, val_recall=0.0025, val_f1=1.25e-5]


In [5]:
model.eval()
metrics_test = defaultdict(float)
results_test = defaultdict(list)
progbar = tqdm(total=len(testloader))
for i, (inputs, targets) in enumerate(testloader):
    inputs = inputs.to(device)
    targets = targets.to(device)

    with torch.cuda.amp.autocast():
        outputs = model(inputs)
        loss = loss_function_eval(outputs, targets)

    loss = loss.detach().cpu().numpy().tolist()
    outputs = outputs.detach().cpu().argmax(-1).numpy().tolist()
    targets = targets.detach().cpu().numpy().tolist()

    results_test["loss"] += loss
    results_test["output"] += outputs
    results_test["target"] += targets

    score = classification_report(
        results_test["target"],
        results_test["output"],
        output_dict=True,
        zero_division=0,
    )
    metrics_test["loss"] = np.mean(loss)
    metrics_test["accuracy"] = score["accuracy"]
    metrics_test["precision"] = score["macro avg"]["precision"]
    metrics_test["recall"] = score["macro avg"]["recall"]
    metrics_test["f1"] = score["macro avg"]["f1-score"]
    progbar.set_postfix(metrics_test)
    progbar.update(1)
metrics_test = {k: v / (i + 1) for k, v in metrics_test.items()}
progbar.close()


100%|██████████| 63/63 [00:05<00:00, 10.58it/s, loss=0.671, accuracy=0.762, precision=0.817, recall=0.762, f1=0.757]


: 