## Imports

In [1]:
%run -n train_classification.py

## Load stuff

In [2]:
%run datasets/__init__.py

In [3]:
dataset_name = 'cxr14'
max_samples = 100
BS = 15

train_dataloader = prepare_data_classification(dataset_name,
                                               dataset_type='train',
                                               max_samples=max_samples,
                                               batch_size=BS)
val_dataloader = prepare_data_classification(dataset_name,
                                             dataset_type='val',
                                             max_samples=max_samples,
                                             batch_size=BS)
train_dataloader.dataset.size()

Loading train dataset...
Loading val dataset...


(100, 14)

### Load CXR pretrained

In [4]:
%run models/classification/__init__.py

In [16]:
cxr_dataloader = prepare_data_classification('cxr14', dataset_type='train',
                                             max_samples=10)
run_name = 'cxr_pretrain'
debug_run = True

pretrained_cnn = init_empty_model('resnet',
                                  cxr_dataloader.dataset.labels,
                                  multilabel=cxr_dataloader.dataset.multilabel,
                                 ).to(DEVICE)

dummy_optimizer = optim.Adam(model.parameters(), lr=0.0001)

compiled_model = CompiledModel(pretrained_cnn, dummy_optimizer)
filepath = get_latest_filepath(run_name, classification=True, debug=debug_run)
checkpoint = torch.load(filepath)
Checkpoint.load_objects(compiled_model.to_save_checkpoint(), checkpoint)

Loading train dataset...


### Create COVID model

See `pretrained_cnn` param to transfer from CXR-14

In [13]:
lr = 0.000001
run_name = f'cxr14_lr{lr}'

model = init_empty_model('resnet',
                         train_dataloader.dataset.labels,
                         multilabel=train_dataloader.dataset.multilabel,
                         # pretrained_cnn=pretrained_cnn.base_cnn,
                         imagenet=True,
                         freeze=False,
                        ).to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=lr)

compiled_model = CompiledModel(model, optimizer)

## Train

In [None]:
%%time

train_model(run_name, compiled_model, train_dataloader, val_dataloader, n_epochs=40,
            loss_name='wbce',
            # print_metrics=['loss', 'acc', 'spec_covid', 'recall_covid'],
            print_metrics=['loss', 'prec_Cardiomegaly', 'recall_Cardiomegaly'],
           )

--------------------------------------------------
Training...


In [17]:
train_dataloader.dataset.__class__

mrg.datasets.cxr14.CXR14Dataset

In [59]:
compiled_model.model.base_cnn

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [20]:
compiled_model.model.base_cnn

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

## Debug stuff

### Test metrics

In [95]:
from ignite.metrics import Accuracy, Recall, Precision

In [143]:
%run ./metrics/classification/__init__.py
%run ./metrics/classification/specificity.py
%run ./metrics/classification/accuracy.py
%run ./metrics/classification/hamming.py

In [144]:
acc = MultilabelAccuracy(output_transform=_transform_remove_loss_and_round)
ham = Hamming(output_transform=_transform_remove_loss_and_round)

In [153]:
outputs = torch.tensor([[0, 1, 1],
                        [0.3, 0.7, 0.8],
                       ])
target = torch.tensor([[0, 0, 1],
                       [0, 1, 1],
                      ])

In [154]:
acc.reset()
acc.update(_transform_remove_loss_and_round((0, outputs, target)))
acc.compute()

0.6666666666666666

In [155]:
ham.reset()
ham.update(_transform_remove_loss_and_round((0, outputs, target)))
ham.compute()

0.16666666666666666

In [None]:
sp = Specificity()
rec = Recall()
prec = Precision()

In [None]:
fn = _get_transform_one_class(0)

In [None]:
# outputs = torch.tensor([[1, 2, 1, 0, 0]])
# target = torch.tensor([[1, 0, 1, 1, 2]])
outputs = torch.tensor([[0, 20, -1],
                        [-40, 2, 3],
                        [17, 5, 6],
                       ])
target = torch.tensor([0, 0, 2])
outputs, target = fn((0, outputs, target))
outputs, target

In [None]:
sp.reset()
sp.update((outputs, target))
sp.compute()

In [None]:
rec.reset()
rec.update((outputs, target))
rec.compute().item()

In [None]:
prec.reset()
prec.update((outputs, target))
prec.compute().item()