## Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=1

In [None]:
import torch
from torch import optim
from ignite.handlers import Checkpoint

In [None]:
%run -n ../train_classification.py

In [None]:
%run ../models/classification/__init__.py
%run ../models/checkpoint/__init__.py
%run ../datasets/__init__.py
%run ../utils/__init__.py

In [None]:
# DEVICE = torch.device('cpu')
DEVICE = torch.device('cuda')
DEVICE

## Load stuff

### Load datasets

In [None]:
dataset_name = 'covid-uc'
dataset_kwargs = {
    'max_samples': None,
    'batch_size': 26,
    'image_size': (512, 512),
}
train_kwargs = {
    'augment': True,
    # 'augment_label': 'covid',
    'oversample': True,
    'oversample_label': 'covid',
    'oversample_max_ratio': 10,
}

train_dataloader = prepare_data_classification(dataset_name, 'train',
                                               **dataset_kwargs, **train_kwargs)
val_dataloader = prepare_data_classification(dataset_name, 'val', **dataset_kwargs)
len(train_dataloader.dataset)

### Load pretrained model

In [None]:
multiple_gpu = True

In [None]:
# run_name = '0704_005511_covid-kaggle_tfs-small_lr1e-06'
# run_name = '0714_232500_cxr14_densenet-121_lr1e-06'
# run_name = '0714_232518_cxr14_densenet-121_lr1e-06'
# run_name = '0716_133211_cxr14_densenet-121_lr1e-06_us_aug-0_Pneumonia'
run_name = '0717_120222_covid-x_densenet-121_lr1e-06_os_aug-covid'
compiled_model = load_compiled_model_classification(run_name,
                                                    debug=False,
                                                    multiple_gpu=False,
                                                    device=DEVICE)

In [None]:
compiled_model.metadata

### Use pre-trained weights new model

In [None]:
old_compiled_model = compiled_model

In [None]:
lr = 0.0001
cnn_name = 'densenet-121'
run_name = f'{get_timestamp()}_{dataset_name}_{cnn_name}_lr{lr}_os-max10_aug_pre-covid-x'
run_name

In [None]:
model = init_empty_model(cnn_name,
                         train_dataloader.dataset.labels,
                         multilabel=train_dataloader.dataset.multilabel,
                         pretrained_cnn=old_compiled_model.model.base_cnn,
                         imagenet=False,
                         freeze=False,
                        ).to(DEVICE)

if multiple_gpu:
    model = nn.DataParallel(model)

optimizer = optim.Adam(model.parameters(), lr=lr)

# TODO: metadata!!
compiled_model = CompiledModel(model, optimizer)

del old_compiled_model

In [None]:
run_name

### ...or create model

In [None]:
lr = 0.000001
cnn_name = 'densenet-121'
run_name = f'{get_timestamp()}_{dataset_name}_{cnn_name}_lr{lr}'

model = init_empty_model(cnn_name,
                         train_dataloader.dataset.labels,
                         multilabel=train_dataloader.dataset.multilabel,
                         imagenet=True,
                         freeze=False,
                        ).to(DEVICE)

if multiple_gpu:
    model = nn.DataParallel(model)

optimizer = optim.Adam(model.parameters(), lr=lr)

compiled_model = CompiledModel(model, optimizer)

## Train

In [None]:
loss_name = 'cross-entropy'

In [None]:
%%time

# print_metrics = ['loss', 'acc', 'hamming']
print_metrics = ['loss', 'acc', 'prec_covid', 'recall_covid']

train_metrics, val_metrics = train_model(run_name,
                                         compiled_model,
                                         train_dataloader,
                                         val_dataloader,
                                         n_epochs=5,
                                         loss_name=loss_name,
                                         print_metrics=print_metrics,
                                         debug=False,
                                         device=DEVICE,
                                        )

In [None]:
print(output)

In [None]:
test_dataloader = prepare_data_classification(dataset_name, 'test', **dataset_kwargs)

In [None]:
dataloaders = [
    train_dataloader,
    val_dataloader,
]

In [None]:
evaluate_and_save(run_name,
                  compiled_model.model,
                  dataloaders,
                  loss_name,
                  debug=False,
                  device=DEVICE)

## Debug stuff

### Test metrics

In [None]:
from ignite.metrics import Accuracy, Recall, Precision

In [None]:
%run ../metrics/classification/__init__.py
%run ../metrics/classification/specificity.py
%run ../metrics/classification/accuracy.py
%run ../metrics/classification/hamming.py

In [None]:
acc = MultilabelAccuracy(output_transform=_transform_remove_loss_and_round)
ham = Hamming(output_transform=_transform_remove_loss_and_round)

In [None]:
outputs = torch.tensor([[0, 1, 1],
                        [0.3, 0.7, 0.8],
                       ])
target = torch.tensor([[0, 0, 1],
                       [0, 1, 1],
                      ])

In [None]:
acc.reset()
acc.update(_transform_remove_loss_and_round((0, outputs, target)))
acc.compute()

In [None]:
ham.reset()
ham.update(_transform_remove_loss_and_round((0, outputs, target)))
ham.compute()

In [None]:
sp = Specificity()
rec = Recall()
prec = Precision()

In [None]:
fn = _get_transform_one_class(0)

In [None]:
# outputs = torch.tensor([[1, 2, 1, 0, 0]])
# target = torch.tensor([[1, 0, 1, 1, 2]])
outputs = torch.tensor([[0, 20, -1],
                        [-40, 2, 3],
                        [17, 5, 6],
                       ])
target = torch.tensor([0, 0, 2])
outputs, target = fn((0, outputs, target))
outputs, target

In [None]:
sp.reset()
sp.update((outputs, target))
sp.compute()

In [None]:
rec.reset()
rec.update((outputs, target))
rec.compute().item()

In [None]:
prec.reset()
prec.update((outputs, target))
prec.compute().item()

### Test BCE loss

In [None]:
import torch
import numpy as np
from torch.nn.functional import binary_cross_entropy

In [None]:
%run ../losses/wbce.py

In [None]:
EPS = 1e-5

In [None]:
target = torch.tensor([[1, 0, 0, 0, 0, 0],
                       [0, 0, 1, 0, 1, 0],
                       [0, 0, 0, 0, 0, 0],
                      ])
bce = WeigthedBCELoss()

In [None]:
output_o = torch.tensor([[1, 0, 0, 0, 0, 0],
                         [0, 0, 1, 0, 1, 0],
                         [1, 1, 1, 1, 1, 1],
                        ]).float()
bce(output_o, target)

In [None]:
output_o = torch.tensor([[0, 0, 0, 0, 0, 0],
                         [0, 0, 1, 0, 1, 0],
                         [1, 1, 1, 1, 1, 1],
                        ]).float()
bce(output_o, target)

In [None]:
output_o = torch.tensor([[1, 0, 0, 0, 0, 0],
                         [0, 0, 1, 0, 1, 0],
                         [0, 0, 0, 0, 0, 0],
                        ]).float()
bce(output_o, target)

In [None]:
output_o = torch.tensor([[0, 0, 0, 0, 0, 0],
                         [0, 0, 1, 0, 1, 0],
                         [0, 0, 0, 0, 0, 0],
                        ]).float()
bce(output_o, target)

In [None]:
output = output_o.clamp(min=EPS, max=1-EPS)
output

In [None]:
total = np.prod(target.size())
positive = (target == 1).sum().item()
negative = total - positive
total, positive, negative

In [None]:
BP = total / positive
BN = total / negative
BP, BN

In [None]:
target.size(), output.size()

In [None]:
left = (target * torch.log(output))
left

In [None]:
right = (1-target) * torch.log(1-output)
right

In [None]:
-(weights*(left + right)).sum()

In [None]:
weights = torch.zeros(target.size())
weights

In [None]:
weights[target == 0] = BN
weights[target == 1] = BP
weights

In [None]:
binary_cross_entropy(output_o, target.float(), weight=weights, reduction='sum')

In [None]:
bce(output_o, target)

In [None]:
def calc_conv_output_size(input_size, padding, kernel_size, stride, dilation=0):
    value = (input_size + 2*padding - dilation * (kernel_size - 1) - 1)
    value /= (stride)
    value += 1
    return value

In [None]:
conv = nn.Conv2d(3, 32, (8, 8), stride=4)