# Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import torch
from torch import nn
from torchvision import models

In [None]:
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.facecolor'] = 'white'
matplotlib.rcParams['figure.figsize'] = (15, 5)

In [None]:
import pandas as pd
pd.options.display.max_columns = None

In [None]:
%run ../utils/__init__.py
config_logging(logging.INFO)

# Debug model

## Feature extractor

In [None]:
extractor = VisualFeatureExtractor('resnet152', pretrained=True)
extractor

In [None]:
BATCH_SIZE = 7

In [None]:
images = torch.rand(BATCH_SIZE, 3, 224, 224)
images.size()

In [None]:
local_features, global_features = extractor(images)
local_features.size(), global_features.size()

## MLC

In [None]:
mlc = MLC(fc_in_features=extractor.out_features)
mlc

In [None]:
tags, semantic_features = mlc(global_features)
tags.size(), semantic_features.size()

## Co-attention

In [None]:
embed_size = 512
hidden_size = 512
k = 10
bn_momentum = 0.1

In [None]:
coatt = CoAttention(version='v4',
                    embed_size=embed_size,
                    hidden_size=hidden_size,
                    visual_size=extractor.out_features,
                    k=k,
                    momentum=bn_momentum,
                   )

In [None]:
prev_hidden_states = torch.zeros(BATCH_SIZE, 1, hidden_size)
prev_hidden_states.size()

In [None]:
ctx, other_a, other_b = coatt(global_features, semantic_features, prev_hidden_states)
ctx.size(), other_a.size(), other_b.size()

## Sentence model

In [None]:
sentence_num_layers = 2

In [None]:
sentence_model = SentenceLSTM(
    version='v1', embed_size=embed_size, hidden_size=hidden_size,
    num_layers=sentence_num_layers, dropout=0,
    momentum=bn_momentum)

In [None]:
sentence_states = None

In [None]:
topic, p_stop, hidden_states, sentence_states = sentence_model(ctx,
                                                               prev_hidden_states,
                                                               sentence_states)

In [None]:
topic.size(), p_stop.size(), hidden_states.size(), sentence_states[0].size()

## Word model

In [None]:
%run ../models/report_generation/coatt/lstm.py

In [None]:
vocab_size = 1000

In [None]:
word_model = WordLSTM(vocab_size=vocab_size,
                      embed_size=512,
                      hidden_size=512,
                      num_layers=1,
                      n_max=30)

In [None]:
BATCH_SIZE = 2
topic = torch.rand(BATCH_SIZE, 1, 512)
captions = (torch.rand(BATCH_SIZE, 3, 19) * (vocab_size - 1)).long()

In [None]:
word_model.train()
words = word_model(topic, captions[:, 0, :1])
words.size()

In [None]:
words.argmax(-1)

In [None]:
word_model.eval()
words = word_model(topic, None)
words.size()

In [None]:
words

## Losses

In [None]:
label = (torch.rand(BATCH_SIZE, 156) > 0.5)

In [None]:
mse_criterion = nn.MSELoss()

In [None]:
batch_tag_loss = mse_criterion(tags, label).sum()
batch_tag_loss

In [None]:
prob_real = (torch.rand(BATCH_SIZE, ) > 0.5).long()

In [None]:
batch_stop_loss = words_criterion(p_stop.squeeze(), prob_real).sum()
batch_stop_loss

## Full model

In [None]:
%run ../models/report_generation/coatt/__init__.py

In [None]:
VOCAB_SIZE = 1234

In [None]:
model = CoAttModel(range(VOCAB_SIZE))
# model

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
BATCH_SIZE = 7
N_SENTENCES = 3
N_WORDS = 5
images = torch.rand(BATCH_SIZE, 3, 224, 224)
labels = (torch.rand(BATCH_SIZE, 14) > 0.5).long()
captions = (torch.rand(BATCH_SIZE, N_SENTENCES, N_WORDS) * (VOCAB_SIZE - 1)).long()
prob_real = torch.rand(BATCH_SIZE, N_SENTENCES).long()

In [None]:
out = model(images, captions, prob_real)
len(out)

In [None]:
words, tags, l1, l2 = out
words.size()

In [None]:
_, l1, l2 = out
l1.type(), l2.type()

In [None]:
prob_real.size()

In [None]:
def get_step_fn(model, optimizer=None, training=True,
                lambda_tag=1,
                lambda_stop=1,
                lambda_word=1,
                clip=0,
               ):
    mse_criterion = nn.MSELoss()
    def step_fn(unused_engine, batch):
        images, _, labels, reports, gt_stops = batch
        
        # Forward pass
        words, tags, batch_stop_loss, batch_word_loss = model(images, reports, gt_stops)
        
        # Tags loss
        batch_tag_loss = mse_criterion(tags, labels.float()).sum()

        # Total loss
        batch_loss = lambda_tag * batch_tag_loss \
                     + lambda_stop * batch_stop_loss \
                     + lambda_word * batch_word_loss

        if optimizer:
            optimizer.zero_grad()

        if training:
            batch_loss.backward()

            if clip > 0:
                torch.nn.utils.clip_grad_norm(model.sentence_model.parameters(), clip)
                torch.nn.utils.clip_grad_norm(model.word_model.parameters(), clip)

            optimizer.step()
            
        return {
            'loss': batch_loss.detach(),
        }
    return step_fn

In [None]:
step_fn = get_step_fn(model, optimizer)

In [None]:
def get_step_fn(model, optimizer=None, training=True,
                lambda_tag=1,
                lambda_stop=1,
                lambda_word=1,
                clip=0,
               ):
    mse_criterion = nn.MSELoss()
    def step_fn(unused_engine, batch):
        images, _, labels, reports, gt_stops = batch
        
        # Forward pass
        tags, batch_stop_loss, batch_word_loss = model(images, label, captions, gt_stops)
        
        # Tags loss
        batch_tag_loss = mse_criterion(tags, labels.float()).sum()

        # Total loss
        batch_loss = lambda_tag * batch_tag_loss \
                     + lambda_stop * batch_stop_loss \
                     + lambda_word * batch_word_loss

        if optimizer:
            optimizer.zero_grad()

        if training:
            batch_loss.backward()

            if clip > 0:
                torch.nn.utils.clip_grad_norm(model.sentence_model.parameters(), clip)
                torch.nn.utils.clip_grad_norm(model.word_model.parameters(), clip)

            optimizer.step()
            
        return {
            'loss': batch_loss.detach(),
        }
    return step_fn

In [None]:
step_fn(None, (images, None, labels, captions, prob_real))

## Their dataloader

In [None]:
%run -n ../../../software/Medical-Report-Generation/utils/dataset.py

In [None]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))])

In [None]:
new_data_dir = '/home/pdpino/software/Medical-Report-Generation/data/new_data'

In [None]:
with open(f'{new_data_dir}/vocab.pkl', 'rb') as f:
    vocab = pickle.load(f)

In [None]:
data_loader = get_loader(
    image_dir='/mnt/workspace/iu-x-ray/dataset/images',
    caption_json=f'{new_data_dir}/captions.json',
    file_list=f'{new_data_dir}/train_data.txt',
    vocabulary=vocab,
    transform=transform,
    batch_size=10,
    shuffle=False,
)

In [None]:
dataloader = iter(data_loader)

In [None]:
batch = next(dataloader)

In [None]:
images, ids, labels, captions, stops = batch
captions = torch.from_numpy(captions)
stops = torch.from_numpy(stops)
images.size(), len(ids), labels.size(), captions.size(), stops.size()

In [None]:
stops

In [None]:
captions.long()

# Load model

In [None]:
%run ../models/checkpoint/__init__.py
%run ../utils/files.py

In [None]:
run_id = RunId('1215_174443', debug=False, task='cls')

In [None]:
compiled_model = load_compiled_model(run_id)
compiled_model.metadata['model_kwargs']

# Load data

In [None]:
%run ../datasets/__init__.py

In [None]:
dataset_kwargs = {
    'dataset_name': 'cxr14',
    'dataset_type': 'train',
    'max_samples': None,
}
dataloader = prepare_data_classification(**dataset_kwargs)
dataset = dataloader.dataset
len(dataset)

# Do something ...