# Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'
matplotlib.rcParams['figure.figsize'] = (15, 5)

In [None]:
import pandas as pd
pd.options.display.max_columns = None

In [None]:
%run ../utils/__init__.py
config_logging(logging.INFO)

# Fn to load model

In [None]:
%run ../models/checkpoint/__init__.py
%run ../utils/files.py

In [None]:
def _load_model_wrapper(run_name):
    run_id = RunId(run_name, False, 'rg')
    compiled_model = load_compiled_model_report_generation(run_id)
    
    return run_id, compiled_model

# Load data

In [None]:
%run ../datasets/__init__.py

In [None]:
dataset_kwargs = {
    'dataset_name': 'iu-x-ray',
    'dataset_type': 'train',
    'max_samples': None,
    'hierarchical': True,
    'frontal_only': True,
    'image_size': (256, 256),
    'norm_by_sample': True,
    'batch_size': 20,
}
dataloader = prepare_data_report_generation(**dataset_kwargs)
dataset = dataloader.dataset
len(dataset)

# Inspect word embeddings

Write embeddings to TB

## Utils

In [None]:
def plot_emb_distribution(emb, title='Embeddings distribution'):
    if emb.ndim > 1:
        emb = emb.flatten()
    plt.hist(emb, bins=50)
    plt.title(title, fontsize=20)
    plt.ylabel('Count')
    plt.xlabel('Embedding value')

## Select a subsample of words

In [None]:
from collections import Counter

In [None]:
vocab = dataset.get_vocab()
idx_to_word = {v:k for k, v in vocab.items()}
len(vocab)

In [None]:
token_appearances = Counter()
for report in dataset.iter_reports_only():
    for token in report['tokens_idxs']:
        token_appearances[token] += 1
len(token_appearances)

### K words with most appearances

In [None]:
top_k_words = lambda k: tuple(zip(*sorted(
    token_appearances.items(),
    key=lambda x:x[1],
    reverse=True,
)[:k]))[0]

In [None]:
word_tokens = top_k_words(800)
len(word_tokens)

In [None]:
words = [idx_to_word[token] for token in word_tokens]
len(words), words[:10]

In [None]:
EMBEDDING_NAME = str(len(words))

### All words

In [None]:
words = [w for w in idx_to_word.values() if w not in ('PAD', 'START', 'END', 'UNK')]
len(words), words[:10]

In [None]:
word_tokens = [vocab[word] for word in words]
len(word_tokens)

In [None]:
EMBEDDING_NAME = 'all'

## Get organs and diseases

In [None]:
%run ../datasets/common/sentences2organs/compute.py
%run ../datasets/common/constants.py

In [None]:
organs_onehot, warnings = find_organs_for_sentences(words)
neutral_words = set(warnings['all-empty'])
len(organs_onehot), len(neutral_words)

In [None]:
header = ['word', 'organ', *JSRT_ORGANS]
metadata = [
    (word, get_main_organ(one_hot, word, warnings), *one_hot)
    for word, one_hot in zip(words, organs_onehot)
]
len(header), len(metadata)

### Filter only non-neutral organs

In [None]:
metadata = [
    t
    for t in metadata
    if t[1] != 'neutral'
]
len(metadata)

In [None]:
words = tuple(zip(*metadata))[0]
len(words)

In [None]:
word_tokens = [vocab[tup[0]] for tup in metadata]
len(word_tokens)

In [None]:
EMBEDDING_NAME = 'non-neutral'

## Write from model to TB

In [None]:
%run ../tensorboard/__init__.py

In [None]:
word_tokens = torch.tensor(word_tokens, device='cuda')
word_tokens.size()

In [None]:
def calculate_embeddings(compiled_model):
    with torch.no_grad():
        embeddings = compiled_model.model.decoder.word_embeddings(word_tokens)
    
    embeddings = embeddings.cpu().numpy()
    assert embeddings.shape == (len(word_tokens), 100)
    
    return embeddings

In [None]:
def write_embeddings_from_compiled_model(run_id, compiled_model, embeddings):
    assert compiled_model.metadata['run_id'] == run_id.to_dict()

    writer = SummaryWriter(get_tb_large_log_folder(run_id))

    writer.add_embedding(
        embeddings,
        metadata=metadata,
        metadata_header=header,
        global_step=compiled_model.get_current_epoch(),
        tag=f'word_embeddings_{len(embeddings)}',
    )

    writer.close()

### Calculate for four base models

In [None]:
run_id, compiled_model = _load_model_wrapper('0426_143345')
emb_h = calculate_embeddings(compiled_model)
# write_embeddings_from_compiled_model(run_id, compiled_model, emb_h)
emb_h.shape

In [None]:
run_id, compiled_model = _load_model_wrapper('0426_221511')
emb_h_att = calculate_embeddings(compiled_model)
# write_embeddings_from_compiled_model(run_id, compiled_model, emb_h_att)
emb_h_att.shape

In [None]:
run_id, compiled_model = _load_model_wrapper('0417_132754') # 0501_201357
emb_flat = calculate_embeddings(compiled_model)
# write_embeddings_from_compiled_model(run_id, compiled_model, emb_flat)
emb_flat.shape

In [None]:
run_id, compiled_model = _load_model_wrapper('0418_102603')  # 0501_212955
emb_flat_att = calculate_embeddings(compiled_model)
# write_embeddings_from_compiled_model(run_id, compiled_model, emb_flat_att)
emb_flat_att.shape

### Plot distributions

In [None]:
n_rows = 2
n_cols = 2

plt.figure(figsize=(n_cols * 7, n_rows * 5))

plt.subplot(n_rows, n_cols, 1)
plot_emb_distribution(emb_h, 'h-lstm')

plt.subplot(n_rows, n_cols, 2)
plot_emb_distribution(emb_h_att, 'h-lstm-att')

plt.subplot(n_rows, n_cols, 3)
plot_emb_distribution(emb_flat, 'lstm')

plt.subplot(n_rows, n_cols, 4)
plot_emb_distribution(emb_flat_att, 'lstm-att')

## Write from random layer

In [None]:
emb_layer = nn.Embedding(len(vocab), 100, 0).to('cuda')
with torch.no_grad():
    embeddings = emb_layer(word_tokens)
    
embeddings.size()

In [None]:
embeddings = embeddings.cpu().numpy()
embeddings.shape

In [None]:
run_id = RunId('random_word_embedding', False, 'rg')
writer = SummaryWriter(get_tb_large_log_folder(run_id))

In [None]:
writer.add_embedding(
    embeddings,
    metadata=metadata,
    metadata_header=header,
    global_step=0,
    tag=f'word_embeddings_{len(words)}',
)

In [None]:
writer.close()

## Glove

In [None]:
import torchtext

In [None]:
glove = torchtext.vocab.GloVe(name='6B', dim=100)

In [None]:
missing_words = [word for word in words if word not in glove.stoi]
len(missing_words)

In [None]:
embeddings = glove.get_vecs_by_tokens(words)
embeddings.size()

In [None]:
embeddings = embeddings.cpu().numpy()
embeddings.shape

In [None]:
run_id = RunId('glove', False, 'rg')
writer = SummaryWriter(get_tb_large_log_folder(run_id))

writer.add_embedding(
    embeddings,
    metadata=metadata,
    metadata_header=header,
    global_step=0,
    tag=f'word_embeddings_{len(words)}',
)

writer.close()

In [None]:
flat_embeddings = embeddings.flatten()
flat_embeddings.shape

In [None]:
plot_emb_distribution(flat_embeddings)

## Load rad-glove

In [None]:
%run ../models/report_generation/word_embedding.py

In [None]:
radglove = RadGlove()
len(radglove)

In [None]:
radglove['number']

### Write rad-glove to TB

In [None]:
missing_words = [word for word in words if word not in radglove]
len(missing_words)

In [None]:
embeddings = torch.stack([
    radglove[token] if token in radglove else torch.zeros(100)
    for token in words
], dim=0)
embeddings.size()

In [None]:
embeddings = embeddings.cpu().numpy()
embeddings.shape

In [None]:
run_id = RunId('radglove', False, 'rg')
writer = SummaryWriter(get_tb_large_log_folder(run_id))

writer.add_embedding(
    embeddings,
    metadata=metadata,
    metadata_header=header,
    global_step=0,
    tag=f'word_embeddings_{EMBEDDING_NAME}',
)

writer.close()

In [None]:
plot_emb_distribution(embeddings, 'radglove')