## Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=1

In [None]:
import torch
from torch import nn

## Load model

In [None]:
%run ../models/checkpoint/__init__.py

In [None]:
# run_name = '1215_174443_cxr14_resnet-50-v2_lr0.0001_os_Cardiomegaly_normS_size256_sch-roc_auc-p5-f0.1'
run_name = '1215_222128_cxr14_mobilenet-v2_lr0.0001_os_Cardiomegaly_normS_size256_sch-roc_auc-p5-f0.1'
debug_run = False

In [None]:
compiled_model = load_compiled_model_classification(run_name, debug=debug_run)
compiled_model.metadata['model_kwargs']

## Load data

In [None]:
%run ../datasets/__init__.py

In [None]:
dataset_kwargs = {
    'dataset_name': 'cxr14',
    'dataset_type': 'train',
    'max_samples': None,
    'labels': ['Cardiomegaly'],
    'masks': True,
    'image_size': (128, 128),
    'batch_size': 2,
}
dataloader = prepare_data_classification(**dataset_kwargs)
dataset = dataloader.dataset
len(dataset)

## Train with HINT

### Prepare model

In [None]:
from torchviz import make_dot

In [None]:
%run ../training/classification/grad_cam.py
%run ../datasets/common/utils.py
%run ../losses/out_of_target.py

In [None]:
# %run ../tensorboard/__init__.py
# run_name = 'hint'
# tb_writer = TBWriter(run_name, 'cls', debug=True)
# tb_writer.write_graph(model, data_batch.image.to(device))
# tb_writer.close()

In [None]:
class SimpleModel(nn.Module):
    def __init__(self, n_labels=14):
        super().__init__()

        self.features = nn.Conv2d(3, 10, 2)
        
        self.global_pool = nn.Sequential(
            nn.AdaptiveMaxPool2d((1, 1)),
            nn.Flatten(),
        )
        
        self.prediction = nn.Linear(10, n_labels)
        
    def forward(self, x):
        x = self.features(x)
        x = self.global_pool(x)
        x = self.prediction(x)
        
        return x

In [None]:
device = 'cuda'
training = True
multilabel = dataset.multilabel
multilabel

In [None]:
loss_fn = nn.BCEWithLogitsLoss()
hint_loss_fn = OutOfTargetSumLoss()
hint_loss_fn = lambda x, y: (x - y).mean() # Simple loss to debug

In [None]:
model, optimizer, _ = compiled_model.get_elements()
grad_cam = create_grad_cam(compiled_model)
grad_cam

In [None]:
model = SimpleModel(len(dataset.labels)).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
grad_cam = LayerGradCam(model, model.features)
model

### Run step_fn

To debug, copy here the `step_fn` from `training/classification/__init__.py`

In [None]:
def step_fn(data_batch):
    images = data_batch.image.to(device)
    # shape: batch_size, channels=3, height, width

    labels = data_batch.labels.to(device)
    # shape(multilabel=True): batch_size, n_labels
    # shape(multilabel=False): batch_size

    masks = reduce_masks_for_diseases(dataset.labels, data_batch.masks.to(device))
    # shape: batch_size, n_labels, height, width

    # Enable training
    model.train(training)
    torch.set_grad_enabled(training)

    # zero the parameter gradients
    if training:
        optimizer.zero_grad()

    # Forward
    outputs = model(images)
    if isinstance(outputs, tuple):
        outputs = output_tuple[0]
    
    # outputs shape: batch_size, n_labels

    if multilabel:
        labels = labels.float()
    else:
        labels = labels.long()

    images.requires_grad = True
        
    grad_cam_attrs = calculate_attributions_for_labels(
        grad_cam, images, dataset.labels,
        relu=True, create_graph=True,
    )
    # shape: (batch_size, n_labels, height, width)

    images.requires_grad = False

    # Compute HINT loss
    hint_loss = hint_loss_fn(grad_cam_attrs, masks)

    # Compute classification loss
    cl_loss = loss_fn(outputs, labels)
    
    total_loss = cl_loss + hint_loss
    
    if training:
        total_loss.backward()
        optimizer.step()

    return cl_loss, hint_loss, total_loss, images
        
    # print(hint_loss.item(), loss.item())

    # return total_loss

In [None]:
data_batch = next(iter(dataloader))
data_batch.image.size(), data_batch.masks.size()

In [None]:
cl_loss, hint_loss, total_loss, i = step_fn(data_batch)
cl_loss.item(), hint_loss.item(), total_loss.item()

In [None]:
_print_grads(model)

In [None]:
_print_grads(model)

In [None]:
def _print_grads(model):
    print(model.features.weight.grad.mean())
    print(model.features.bias.grad.mean())
    print(model.prediction.weight.grad.mean())
    print(model.prediction.bias.grad.mean())

In [None]:
params = dict(model.named_parameters())

In [None]:
params['images'] = i

In [None]:
make_dot(cl_loss, params=params)

In [None]:
make_dot(hint_loss, params=params)

In [None]:
make_dot(total_loss, params=params)