# Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import math
from collections import defaultdict, Counter

In [None]:
import torch
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [None]:
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'
matplotlib.rcParams['figure.figsize'] = (15, 5)

In [None]:
import pandas as pd
pd.options.display.max_columns = None

In [None]:
%run ../utils/__init__.py
config_logging(logging.INFO)

In [None]:
%run ../datasets/__init__.py
%run ../utils/nlp.py

# Utils

In [None]:
%run ../models/checkpoint/__init__.py
%run ../utils/files.py

In [None]:
DEVICE = torch.device('cuda')

In [None]:
def get_vocab(metadata):
    if 'vocab' in metadata:
        return metadata['vocab']
    if 'vocab' in metadata['dataset_kwargs']:
        return metadata['dataset_kwargs']['vocab']
    if 'vocab' in metadata['decoder_kwargs']:
        return metadata['decoder_kwargs']['vocab']
    raise Exception('Vocab not found in metadata')

In [None]:
def load_stuff_wrapper(run_name):
    run_id = RunId(run_name, debug=False, task='rg')
    compiled_model = load_compiled_model(run_id)
    compiled_model.model.eval()
    
    vocab = get_vocab(compiled_model.metadata)
    report_reader = ReportReader(vocab)

    # HACK to wrap things
    compiled_model.reader = report_reader
    return compiled_model

In [None]:
def is_hierarchical(metadata):
    return 'h-' in metadata['decoder_kwargs']['decoder_name']

In [None]:
def plot_atts_mean_and_std(att_dict, keys=None):
    if keys is None:
        keys = list(att_dict.keys())
    else:
        keys = [key for key in keys if key in att_dict]

    n_rows = len(att_dict)
    n_cols = 2
    plt.figure(figsize=(n_cols*7, n_rows*5))

    for i_key, key in enumerate(keys):
        att = att_dict[key]

        n_samples = att.size(0)

        plt.subplot(n_rows, n_cols, i_key * 2 + 1)
        plt.title(f'{key} (mean, samples={n_samples:,})')
        plt.imshow(att.mean(dim=0).cpu().numpy())
        plt.axis('off')
        plt.colorbar()

        plt.subplot(n_rows, n_cols, i_key * 2 + 2)
        plt.title(f'{key} (STD, samples={n_samples:,})')
        plt.imshow(att.std(dim=0).cpu().numpy())
        plt.axis('off')
        plt.colorbar()

# Analyze word-attention

## Load stuff

In [None]:
%run ../training/report_generation/flat.py

In [None]:
RUN_ID, COMPILED_MODEL = load_model_wrapper('0513_145846')
METADATA = COMPILED_MODEL.metadata
METADATA.keys()

In [None]:
HIERARCHICAL = is_hierarchical(METADATA)
VOCAB, REPORT_READER = get_vocab_and_reader(METADATA)
len(VOCAB), HIERARCHICAL

In [None]:
dataset_kwargs = {
    'hierarchical': HIERARCHICAL,
    'dataset_name': 'iu-x-ray',
    'image-size': (256, 256),
    'max_samples': None,
    'norm_by_sample': True,
    'frontal_only': True,
    'shuffle': True,
    'vocab': VOCAB,
}
train_dataloader = prepare_data_report_generation(dataset_type='train', **dataset_kwargs)
val_dataloader = prepare_data_report_generation(dataset_type='val', **dataset_kwargs)
len(train_dataloader.dataset), len(val_dataloader.dataset)

## Show single examples

In [None]:
def get_sample(dataloader, idx, free=True, colorbar=False):
    item = dataloader.dataset[idx]

    images = item.image.unsqueeze(0).to(DEVICE)
    reports = torch.tensor(item.report).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        output = COMPILED_MODEL.model(images, reports, free=free, max_words=100)
        gen_words, gen_att = output

    report_gt = clean_gt_reports(reports)[0]
    report_gen = _clean_gen_reports(gen_words)[0]
    
    print('GROUND TRUTH:')
    print(REPORT_READER.idx_to_text(report_gt))
    print('-'*50)
    print('GENERATED:')
    print(REPORT_READER.idx_to_text(report_gen))

    gen_att = gen_att.squeeze(0).cpu().numpy() # shape: n_words+1, 16, 16
    # assert len(report_gen) == gen_att.shape[0], f'{len(report_gen)} vs {gen_att.shape}'

    plotable_image = tensor_to_range01(item.image).permute(1, 2, 0)
    n_cols = 4
    n_rows = math.ceil((len(report_gen) + 1) / n_cols)

    plt.figure(figsize=(n_cols*7, n_rows*5))

    plt.subplot(n_rows, n_cols, 1)
    plt.imshow(plotable_image)
    plt.axis('off')

    for i in range(len(report_gen)):
        word = report_gen[i]

        plt.subplot(n_rows, n_cols, 2+i)
        plt.title(REPORT_READER.idx_to_text([word]), fontsize=24)
        plt.imshow(gen_att[i])
        if colorbar:
            plt.colorbar()

        plt.axis('off')

In [None]:
get_sample(val_dataloader, 100, colorbar=True)

## Distribution for many samples

In [None]:
def iterate_word_atts(dataloader, max_words=None):
    counter = 0
    for batch in tqdm(dataloader):
        images = batch.images.to(DEVICE)
        reports = batch.reports.to(DEVICE)

        with torch.no_grad():
            output = COMPILED_MODEL.model(images, reports, free=True, max_words=100)
            gen_words, gen_att = output

        reports_gt = clean_gt_reports(reports)
        reports_gen = _clean_gen_reports(gen_words)

        batch_size = gen_words.size(0)

        for i in range(batch_size):
            att = gen_att[i] # shape: n_words, 16, 16
            report = reports_gen[i] # list

            for word, att_map in zip(report, att):
                word = REPORT_READER.idx_to_text([word])
                yield word, att_map
                
                counter += 1
                if max_words is not None and counter >= max_words:
                    return

### Sample words from all the dataset

In [None]:
all_att = []

for word, att in iterate_word_atts(val_dataloader, 10000):
    all_att.append(att)
    
all_att = torch.stack(all_att, dim=0)
all_att.size()

In [None]:
all_att_mean = all_att.mean(dim=0).cpu().numpy()
all_att_std = all_att.std(dim=0).cpu().numpy()

plt.subplot(1, 2, 1)
plt.title('Mean')
plt.imshow(all_att_mean)
plt.colorbar()

plt.subplot(1, 2, 2)
plt.title('STD')
plt.imshow(all_att_std)
plt.colorbar()

### Group by organs

In [None]:
from collections import defaultdict

In [None]:
%run ../datasets/common/sentences2organs/compute.py

In [None]:
WORDS = list(VOCAB)
len(WORDS)

In [None]:
_one_hot, warnings = find_organs_for_sentences(WORDS)
MAIN_ORGAN_BY_WORD = {w:get_main_organ(o, w, warnings) for o, w in zip(_one_hot, WORDS)}
len(MAIN_ORGAN_BY_WORD)

In [None]:
all_att = defaultdict(list)

for word, att in iterate_word_atts(val_dataloader): # 10000
    organ = MAIN_ORGAN_BY_WORD[word]
    all_att[organ].append(att)
    
all_att = {
    k: torch.stack(a, dim=0)
    for k, a in all_att.items()
}
all_att['heart'].size()

In [None]:
n_rows = len(all_att)
n_cols = 2
plt.figure(figsize=(n_cols*7, n_rows*5))

for i_organ, organ in enumerate(MAIN_ORGANS):
    att = all_att[organ]
    
    n_samples = att.size(0)

    plt.subplot(n_rows, n_cols, i_organ * 2 + 1)
    plt.title(f'{organ} (mean, samples={n_samples:,})')
    plt.imshow(att.mean(dim=0).cpu().numpy())
    plt.axis('off')
    plt.colorbar()

    plt.subplot(n_rows, n_cols, i_organ * 2 + 2)
    plt.title(f'{organ} (STD, samples={n_samples:,})')
    plt.imshow(att.std(dim=0).cpu().numpy())
    plt.axis('off')
    plt.colorbar()

### Group by relevant words

In [None]:
from collections import defaultdict

In [None]:
selected_words = [
    'lungs', 'lung', 'heart', 'thorax', 'cardiomegaly', 'airspace',
    'right', 'left', 'bilateral', 'bibasilar',
]

In [None]:
all_att = defaultdict(list)

for word, att in iterate_word_atts(val_dataloader):
    if word in selected_words:
        all_att[word].append(att)
    
all_att = {
    k: torch.stack(a, dim=0)
    for k, a in all_att.items()
}
all_att[selected_words[0]].size()

In [None]:
max_words = 3
words_to_plot = selected_words[:max_words]

n_rows = len(words_to_plot)
n_cols = 2
plt.figure(figsize=(n_cols*7, n_rows*5))

for i_word, word in enumerate(words_to_plot):
    att = all_att[word]
    
    n_samples = att.size(0)
    
    plt.subplot(n_rows, n_cols, i_word * 2 + 1)
    plt.title(f'{word} (mean, samples={n_samples:,})')
    plt.imshow(att.mean(dim=0).cpu().numpy())
    plt.axis('off')
    plt.colorbar()

    plt.subplot(n_rows, n_cols, i_word * 2 + 2)
    plt.title(f'{word} (STD, samples={n_samples:,})')
    plt.imshow(att.std(dim=0).cpu().numpy())
    plt.axis('off')
    plt.colorbar()
    
#     if word == 'thorax':
#         break

In [None]:
word = 'lung'
att = all_att[word]

n_samples = att.size(0)
n_samples = min(n_samples, 20)

n_cols = min(3, n_samples)
n_rows = math.ceil(n_samples / n_cols)
plt.figure(figsize=(7*n_cols, n_rows*5))
plt.suptitle(f'word={word} (samples={n_samples})', fontsize=18)

for i_sample in range(n_samples):
    att_sample = att[i_sample]
    
    plt.subplot(n_rows, n_cols, i_sample + 1)
    plt.title(f'Sample {i_sample}', fontsize=16)
    plt.imshow(att_sample.cpu().numpy())
    plt.axis('off')
    plt.colorbar()

# Analyze sentence attention

## Load stuff

In [None]:
%run ../training/report_generation/hierarchical.py

In [None]:
# COMPILED_MODEL1 = load_stuff_wrapper('0513_200618') # 
# COMPILED_MODEL2 = load_stuff_wrapper('0518_213120') # with supervise-attention
# COMPILED_MODEL_BASE = load_stuff_wrapper('0523_031527')
# COMPILED_MODEL_LRATT = load_stuff_wrapper('0525_232238')
COMPILED_MODEL_LR = load_stuff_wrapper('0524_002837')
COMPILED_MODEL_ISIZE = load_stuff_wrapper('0526_190114')

In [None]:
assert COMPILED_MODEL_LR.reader.vocab == COMPILED_MODEL_ISIZE.reader.vocab

In [None]:
COMPILED_MODEL_OLD = load_stuff_wrapper('0120_140940') # old with supervise-attention'
{
    k: {k2:(v2 if k2 != 'vocab' else len(v2)) for k2, v2 in v.items()} if isinstance(v, dict) else v
    for k, v in COMPILED_MODEL_OLD.metadata.items()
    if k != 'vocab'
}

In [None]:
COMPILED_MODEL_NO_SUPERV = load_stuff_wrapper('0518_225305')

In [None]:
# __att-weights
COMPILED_MODEL_BASE = load_stuff_wrapper('0519_215144')
COMPILED_MODEL_LR = load_stuff_wrapper('0519_205343')

In [None]:
COMPILED_MODEL_LAMBDA = load_stuff_wrapper('0519_233237')
COMPILED_MODEL_LRATT = load_stuff_wrapper('0520_005342')

In [None]:
REPORT_READER = COMPILED_MODEL_LR.reader

In [None]:
dataset_kwargs = {
    'hierarchical': True,
    'dataset_name': 'iu-x-ray',
    'image_size': (256, 256),
    'max_samples': None,
    'norm_by_sample': True,
    'frontal_only': True,
    'shuffle': False,
    'sort_samples': True,
    'vocab': REPORT_READER.vocab,
}
train_dataloader = prepare_data_report_generation(dataset_type='train', **dataset_kwargs)
val_dataloader = prepare_data_report_generation(dataset_type='val', **dataset_kwargs)
len(train_dataloader.dataset), len(val_dataloader.dataset)

In [None]:
kwargs = dataset_kwargs.copy()
kwargs['image_size'] = (512, 512)
train_dataloader_isize = prepare_data_report_generation(dataset_type='train', **kwargs)
val_dataloader_isize = prepare_data_report_generation(dataset_type='val', **kwargs)
len(train_dataloader_isize.dataset), len(val_dataloader_isize.dataset)

In [None]:
kwargs = dataset_kwargs.copy()
kwargs['vocab'] = COMPILED_MODEL_OLD.reader.vocab
kwargs['image_size'] = COMPILED_MODEL_OLD.metadata['dataset_kwargs']['image_size']
train_dataloader_old = prepare_data_report_generation(dataset_type='train', **kwargs)
val_dataloader_old = prepare_data_report_generation(dataset_type='val', **kwargs)
len(train_dataloader_old.dataset), len(val_dataloader_old.dataset)

## Individual samples

In [None]:
%run ../datasets/common/sentences2organs/compute.py
%run ../datasets/common/constants.py

In [None]:
def get_sample(compiled_model, dataloader, idx, free=True, colorbar=False):
    assert compiled_model.reader.vocab == dataloader.dataset.get_vocab()

    def _print_report(report):
        for i, sentence in enumerate(sentence_iterator(report)):
            sentence = compiled_model.reader.idx_to_text(sentence)
            organs_presence = _find_organs_for_sentence(sentence)
            organs = '/'.join(
                organ.replace('ground', '').replace(' lung', '')
                for organ, presence in zip(JSRT_ORGANS, organs_presence)
                if presence
            )
            print(f'{i} [{organs:>10}]: {sentence}')
    
    print(f'Testing run: {compiled_model.run_id}')

    item = dataloader.dataset[idx]

    images = item.image.unsqueeze(0).to(DEVICE)
    reports = torch.tensor(item.report).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        output = compiled_model.model(images, None, free=free,
                                      max_sentences=30, max_words=100)
        words, stops, att_scores, topics = output
        

    report_gt = _flatten_gt_reports(reports)[0]
    report_gen = _flatten_gen_reports(words, stops)[0]
    
    print('GROUND TRUTH:')
    _print_report(report_gt)
    print('-'*20)
    print('GENERATED:')
    _print_report(report_gen)

    sentences = list(sentence_iterator(report_gen))
    att_scores = att_scores.squeeze(0).cpu().numpy() # shape: n_sentences, 16, 16
    # assert len(sentences) == att_scores.shape[0], f'{len(sentences)} vs {att_scores.shape}'

    plotable_image = tensor_to_range01(item.image).permute(1, 2, 0)
    n_cols = 4
    n_rows = math.ceil((len(sentences) + 1) / n_cols)

    plt.figure(figsize=(n_cols*7, n_rows*5))

    plt.subplot(n_rows, n_cols, 1)
    plt.imshow(plotable_image)
    plt.axis('off')

    for i in range(len(sentences)):
        sentence = sentences[i]

        plt.subplot(n_rows, n_cols, 2+i)
        # print(f'Sentence {i}: {compiled_model.reader.idx_to_text(sentence)}')
        plt.title(f'Sentence {i}', fontsize=18)
        plt.imshow(att_scores[i])
        if colorbar:
            plt.colorbar()

        # plt.axis('off')
        
    return output

In [None]:
out = get_sample(COMPILED_MODEL_LR, val_dataloader, 100)

In [None]:
out = get_sample(COMPILED_MODEL_ISIZE, val_dataloader_isize, 100)

In [None]:
out = get_sample(COMPILED_MODEL_OLD, val_dataloader_old, 10)

## Group by properties

### Utils

In [None]:
from itertools import zip_longest

In [None]:
def iter_attention_maps(compiled_model, dataloader, free=True, max_samples=None):
    counter = 0
    pbar = tqdm(total=len(dataloader.dataset) if max_samples is None else max_samples)

    for batch in dataloader:
        images = batch.images.to(DEVICE)
        reports = batch.reports.to(DEVICE)

        with torch.no_grad():
            output = compiled_model.model(images, reports, free=free,
                                          max_sentences=30, max_words=100)
            words, stops, att_scores, topics = output

        reports_gt = _flatten_gt_reports(reports)
        reports_gen = _flatten_gen_reports(words, stops)
        
        for gt_report, gen_report, atts in zip(reports_gt, reports_gen, att_scores):
            gt_sentences = list(sentence_iterator(gt_report))
            gen_sentences = list(sentence_iterator(gen_report))
            
            for i_sentence, (gt_sent, gen_sent, att) in enumerate(zip_longest(
                gt_sentences,
                gen_sentences,
                atts,
                fillvalue=[],
                )):
                gt_sent = REPORT_READER.idx_to_text(gt_sent)
                gen_sent = REPORT_READER.idx_to_text(gen_sent)
                yield gt_sent, gen_sent, att, i_sentence
                
                pbar.update(1)
                counter += 1
                if max_samples is not None and counter >= max_samples:
                    return

### Group by organs

In [None]:
%run ../datasets/common/sentences2organs/compute.py

In [None]:
def get_att_by_organ(compiled_model, dataloader, correct_organ_only=False, **kwargs):
    att_by_organ = defaultdict(list)
    sentences_by_organ = defaultdict(list)

    for gt_sent, _, att, _ in iter_attention_maps(compiled_model, dataloader, **kwargs):
        if not isinstance(att, torch.Tensor):
            continue

        organs_onehot = _find_organs_for_sentence(gt_sent)
        organ = get_main_organ(organs_onehot, gt_sent)
        
        if correct_organ_only:
            gen_organ = get_main_organ(_find_organs_for_sentence(gen_sent), gen_sent)
            if organ != gen_organ:
                continue

        att_by_organ[organ].append(att)
        sentences_by_organ[organ].append(gt_sent)

    att_by_organ = {
        k: torch.stack(v, dim=0)
        for k, v in att_by_organ.items()
    }
    return att_by_organ, sentences_by_organ

In [None]:
att_by_organ_1, _ = get_att_by_organ(COMPILED_MODEL4, val_dataloader, max_samples=None)
att_by_organ_2, _ = get_att_by_organ(COMPILED_MODEL3, val_dataloader, max_samples=None)

In [None]:
organs = [
    organ
    for organ in MAIN_ORGANS
    if organ in att_by_organ_1 or organ in att_by_organ_2
]

n_rows = len(MAIN_ORGANS)
n_cols = 4
plt.figure(figsize=(n_cols*7, n_rows*5))

for i_organ, organ in enumerate(organs):
    def plot_heatmap(heatmap, index, title):
        plt.subplot(n_rows, n_cols, i_organ * n_cols + index)
        plt.title(title, fontsize=18)
        plt.imshow(heatmap.cpu().numpy())
        plt.axis('off')
        plt.colorbar()

    if organ in att_by_organ_1:
        att = att_by_organ_1[organ]
        n_samples = att.size(0)
        plot_heatmap(att.mean(dim=0), 1, f'{organ} (mean, samples={n_samples:,})')
        plot_heatmap(att.std(dim=0), 2, f'{organ} (STD, samples={n_samples:,})')
    
    if organ in att_by_organ_2:
        att = att_by_organ_2[organ]
        n_samples = att.size(0)
        plot_heatmap(att.mean(dim=0), 3, f'supervised: {organ} (mean, samples={n_samples:,})')
        plot_heatmap(att.std(dim=0), 4, f'supervised: {organ} (STD, samples={n_samples:,})')

### Group by position

In [None]:
from ipywidgets import interact
import ipywidgets as widgets

In [None]:
att_by_position = defaultdict(list)

for _, _, att, position in iter_attention_maps(COMPILED_MODEL4, val_dataloader,
                                               max_samples=None):
    if not isinstance(att, torch.Tensor):
        continue
        
    if position >= 6:
        position = '6+'
        
    att_by_position[f'pos {position}'].append(att)
    
att_by_position = {
    k: torch.stack(v, dim=0)
    for k, v in att_by_position.items()
}
att_by_position.keys()

In [None]:
plot_atts_mean_and_std(att_by_position)

In [None]:
def plot_samples_in_slider(i_position):
    position = f'pos {i_position}'
    atts = att_by_position[position]
    def _plot_heatmap(idx):
        heatmap = atts[idx].cpu().numpy()
        
        plt.suptitle(position)
        plt.title(f'Sample {idx}')
        plt.imshow(heatmap)

    interact(_plot_heatmap, idx=widgets.IntSlider(min=0, max=len(atts)-1, step=1, value=0))

In [None]:
plot_samples_in_slider(0)