In [None]:
import torch
from torch import optim
from ignite.engine import Engine
from ignite.handlers import Checkpoint

In [None]:
!echo $CUDA_VISIBLE_DEVICES

In [None]:
DEVICE = torch.device('cuda', 1)
DEVICE

### Load data

In [None]:
%run ../datasets/__init__.py

In [None]:
dataset_name = 'covid-kaggle'
kwargs = {
    'max_samples': None,
    'batch_size': 10,
}

train_dataloader = prepare_data_classification(dataset_name, 'train', **kwargs)
train_dataloader_os = prepare_data_classification(dataset_name, 'train',
                                                  oversample=True,
                                                  oversample_label='covid',
                                                  **kwargs)
val_dataloader = prepare_data_classification(dataset_name, 'val', **kwargs)
test_dataloader = prepare_data_classification(dataset_name, 'test', **kwargs)
train_dataloader.dataset.size()

### Load model

In [None]:
DEVICE = torch.device('cuda', 1)
DEVICE

In [None]:
%run ../models/classification/__init__.py
%run ../models/checkpoint/__init__.py

In [None]:
run_name = '0704_005511_covid-kaggle_tfs-small_lr1e-06'
debug_run = True

compiled_model = load_compiled_model_classification(run_name, debug=debug_run, device=DEVICE)

In [None]:
compiled_model.model

### Run in test

In [None]:
import matplotlib.pyplot as plt

In [None]:
%run ../losses/__init__.py

In [None]:
%run ../metrics/classification/__init__.py

In [None]:
%run -n ../train_classification.py

In [None]:
%run ../utils/cm.py

In [None]:
train_metrics = evaluate_model(compiled_model.model, train_dataloader, loss_name=loss_name)
val_metrics = evaluate_model(compiled_model.model, val_dataloader, loss_name=loss_name)
test_metrics = evaluate_model(compiled_model.model, test_dataloader, loss_name=loss_name)

In [None]:
test_metrics

In [None]:
plt.figure(figsize=(18, 6))

plt.subplot(1, 3, 1)
plot_cm(train_metrics['cm'], labels, 'train')

plt.subplot(1, 3, 2)
plot_cm(val_metrics['cm'], labels, 'val')

plt.subplot(1, 3, 3)
plot_cm(test_metrics['cm'], labels, 'test')

plt.tight_layout(1)