In [90]:
import timm
from timm.loss import BinaryCrossEntropy
from timm.optim import create_optimizer_v2
import torch
from pytorch_accelerated.callbacks import SaveBestModelCallback
from pytorch_accelerated.trainer import DEFAULT_CALLBACKS

from src.data.datasets.coin_data import CoinData, CoinDataFolder
from src.training.trainer import TimmMixupTrainer

In [91]:
%matplotlib inline

# Enable autoreloading of imported modules.
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [92]:
# Set training arguments, hardcoded here for clarity
image_size = (224, 224)
lr = 5e-3
smoothing = 0.1
mixup = 0.2
cutmix = 1.0
batch_size = 16
bce_target_thresh = 0.2
num_epochs = 10

In [93]:
# load data
coin_data = CoinData()
num_classes = len(coin_data.images_and_targets)

../data/raw/CN_dataset_04_23/data_types_example


[('../data/raw/CN_dataset_04_23/data_types_example/1/CN_type_1_cn_coin_8022_p.jpg', 0), ('../data/raw/CN_dataset_04_23/data_types_example/1/CN_type_1_MK_18203122_cn_coin_6383_o.jpg', 0), ('../data/raw/CN_dataset_04_23/data_types_example/2/CN_type_2_cn_coin_8024_p.jpg', 1), ('../data/raw/CN_dataset_04_23/data_types_example/3/CN_type_3_BNF_Platzhalter_cn_coin_11904_o.jpg', 2), ('../data/raw/CN_dataset_04_23/data_types_example/3/CN_type_3_MK_18247614_cn_coin_6696_o.jpg', 2), ('../data/raw/CN_dataset_04_23/data_types_example/5/CN_type_5_cn_coin_7685_p.jpg', 3), ('../data/raw/CN_dataset_04_23/data_types_example/6/CN_type_6_cn_coin_7686_p.jpg', 4), ('../data/raw/CN_dataset_04_23/data_types_example/8/CN_type_8_cn_coin_7689_p.jpg', 5), ('../data/raw/CN_dataset_04_23/data_types_example/8/CN_type_8_cn_coin_15352_p.jpg', 5), ('../data/raw/CN_dataset_04_23/data_types_example/11/CN_type_11_cn_coin_8036_p.jpg', 6), ('../data/raw/CN_dataset_04_23/data_types_example/11/CN_type_11_MK_18247626_cn_coin_8

In [94]:
mixup_args = dict(
    mixup_alpha=mixup,
    cutmix_alpha=cutmix,
    label_smoothing=smoothing,
    num_classes=num_classes,
)

In [95]:
# Create model using timm
model = timm.create_model(
    "resnet50d", pretrained=False, num_classes=num_classes, drop_path_rate=0.05
)

In [96]:
# Load data config associated with the model to use in data augmentation pipeline
data_config = timm.data.resolve_data_config({}, model=model, verbose=True)
data_mean = data_config["mean"]
data_std = data_config["std"]

In [97]:
train_dataset, eval_dataset = coin_data.generate_train_val_datasets(val_pct=0.3, image_size=image_size, data_mean=data_mean, data_std=data_std)

In [98]:
# Create optimizer
optimizer = create_optimizer_v2(
    model, opt="RMSprop", lr=lr, weight_decay=0.01
)

In [99]:
optimizer

RMSprop (
Parameter Group 0
    alpha: 0.9
    centered: False
    differentiable: False
    eps: 1e-08
    foreach: None
    lr: 0.005
    maximize: False
    momentum: 0.9
    weight_decay: 0.0

Parameter Group 1
    alpha: 0.9
    centered: False
    differentiable: False
    eps: 1e-08
    foreach: None
    lr: 0.005
    maximize: False
    momentum: 0.9
    weight_decay: 0.01
)

In [100]:
# As we are using Mixup, we can use BCE during training and CE for evaluation
train_loss_fn = BinaryCrossEntropy(
    target_threshold=bce_target_thresh, smoothing=smoothing
)
validate_loss_fn = torch.nn.CrossEntropyLoss()

In [101]:
# Create trainer and start training
trainer = TimmMixupTrainer(
    model=model,
    optimizer=optimizer,
    loss_func=train_loss_fn,
    eval_loss_fn=validate_loss_fn,
    mixup_args=mixup_args,
    num_classes=num_classes,
    callbacks=[
        *DEFAULT_CALLBACKS,
        SaveBestModelCallback(watch_metric="accuracy", greater_is_better=True),
    ],
)

In [102]:
trainer.train(
        per_device_batch_size=batch_size,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        num_epochs=num_epochs,
        create_scheduler_fn=trainer.create_scheduler,
    )


Starting training run

Starting epoch 1


100%|██████████| 7/7 [00:03<00:00,  2.19it/s]



train_loss_epoch: 0.3571896553039551


100%|██████████| 3/3 [00:01<00:00,  2.75it/s]



ema_model_accuracy: 0.0

accuracy: 0.17391304671764374

eval_loss_epoch: 4.99569034576416

Starting epoch 2


100%|██████████| 7/7 [00:02<00:00,  2.60it/s]



train_loss_epoch: 0.35857322812080383


100%|██████████| 3/3 [00:01<00:00,  2.81it/s]



ema_model_accuracy: 0.043478261679410934

accuracy: 0.043478261679410934

eval_loss_epoch: 1551.0035400390625

Starting epoch 3


100%|██████████| 7/7 [00:02<00:00,  2.60it/s]



train_loss_epoch: 0.6006650328636169


100%|██████████| 3/3 [00:01<00:00,  2.42it/s]



ema_model_accuracy: 0.08695652335882187

accuracy: 0.0

eval_loss_epoch: 5615.0390625

Starting epoch 4


100%|██████████| 7/7 [00:02<00:00,  2.58it/s]



train_loss_epoch: 0.07354910671710968


100%|██████████| 3/3 [00:01<00:00,  2.71it/s]



ema_model_accuracy: 0.17391304671764374

accuracy: 0.043478261679410934

eval_loss_epoch: 4.716425895690918

Starting epoch 5


100%|██████████| 7/7 [00:02<00:00,  2.52it/s]



train_loss_epoch: 0.07665299624204636


100%|██████████| 3/3 [00:01<00:00,  2.62it/s]



ema_model_accuracy: 0.17391304671764374

accuracy: 0.021739130839705467

eval_loss_epoch: 5.092090129852295

Starting epoch 6


100%|██████████| 7/7 [00:02<00:00,  2.42it/s]



train_loss_epoch: 0.08072541654109955


100%|██████████| 3/3 [00:01<00:00,  2.67it/s]



ema_model_accuracy: 0.17391304671764374

accuracy: 0.043478261679410934

eval_loss_epoch: 4.5323381423950195

Starting epoch 7


100%|██████████| 7/7 [00:02<00:00,  2.42it/s]



train_loss_epoch: 0.052062150090932846


100%|██████████| 3/3 [00:01<00:00,  2.59it/s]



ema_model_accuracy: 0.17391304671764374

accuracy: 0.17391304671764374

eval_loss_epoch: 4.613162517547607

Starting epoch 8


100%|██████████| 7/7 [00:02<00:00,  2.45it/s]



train_loss_epoch: 0.05658640339970589


100%|██████████| 3/3 [00:01<00:00,  2.60it/s]



ema_model_accuracy: 0.17391304671764374

accuracy: 0.17391304671764374

eval_loss_epoch: 4.565769195556641

Starting epoch 9


100%|██████████| 7/7 [00:02<00:00,  2.45it/s]



train_loss_epoch: 0.04543003439903259


100%|██████████| 3/3 [00:01<00:00,  2.59it/s]



ema_model_accuracy: 0.17391304671764374

accuracy: 0.17391304671764374

eval_loss_epoch: 4.5083489418029785

Starting epoch 10


100%|██████████| 7/7 [00:02<00:00,  2.41it/s]



train_loss_epoch: 0.049835123121738434


100%|██████████| 3/3 [00:01<00:00,  2.60it/s]



ema_model_accuracy: 0.17391304671764374

accuracy: 0.17391304671764374

eval_loss_epoch: 4.499942779541016
Finishing training run
Loading checkpoint with accuracy: 0.17391304671764374 from epoch 1


In [104]:
trainer.evaluate(dataset=eval_dataset) 


Starting evaluation run


100%|██████████| 6/6 [00:01<00:00,  3.59it/s]


ema_model_accuracy: 0.17391304671764374

accuracy: 0.17391304671764374

evaluation_loss: 4.995690822601318
Finishing evaluation run





In [109]:
checkpoint = torch.load(f="../training/best_model.pt", map_location=torch.device('cpu'))

In [111]:
checkpoint.keys()

dict_keys(['model_state_dict', 'optimizer_state_dict', 'accuracy'])