# Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import torch
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [None]:
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'
matplotlib.rcParams['figure.figsize'] = (15, 5)

In [None]:
import pandas as pd
pd.options.display.max_columns = None

In [None]:
%run ../utils/nlp.py

In [None]:
%run ../utils/__init__.py
config_logging(logging.INFO)

# Utils

In [None]:
%run ../models/checkpoint/__init__.py
%run ../utils/files.py

In [None]:
def _load_model_wrapper(run_name):
    run_id = RunId(run_name, False, 'rg')
    compiled_model = load_compiled_model_report_generation(run_id)
    
    return run_id, compiled_model

# Load model

In [None]:
DEVICE = torch.device('cuda')

In [None]:
# run_name = '0426_221511' # h-lstm-att
# run_name = '0426_143345' # h-lstm
run_name = '0507_111646'
run_id, compiled_model = _load_model_wrapper(run_name)
compiled_model.model.decoder.return_topics = True
run_id.name

In [None]:
_ = compiled_model.model.eval()

# Load data

In [None]:
%run ../datasets/__init__.py

In [None]:
dataset_kwargs = {
    'dataset_name': 'iu-x-ray',
    'max_samples': None,
    'hierarchical': True,
    'frontal_only': True,
    'image_size': (256, 256),
    'norm_by_sample': True,
    'batch_size': 20,
    'vocab': compiled_model.metadata['dataset_kwargs']['vocab'],
}
train_dataloader = prepare_data_report_generation(dataset_type='train', **dataset_kwargs)
val_dataloader = prepare_data_report_generation(dataset_type='val', **dataset_kwargs)
len(train_dataloader.dataset), len(val_dataloader.dataset)

In [None]:
VOCAB = train_dataloader.dataset.get_vocab()
len(VOCAB)

In [None]:
REPORT_READER = ReportReader(VOCAB)
REPORT_READER

In [None]:
COLORS = ['red', 'green', 'brown', 'blue', 'cyan']
ORGANS = ['heart', 'lungs', 'thorax', 'all', 'neutral']

# Inspect sentence vectors

Plot distributions, write embeddings to TB

## Distribution of number of sentences

How many reports with N sentences are generated?

### In datasets

In [None]:
get_n_sentences_values = lambda dataset: [
    len(list(sentence_iterator(r['tokens_idxs'])))
    for r in dataset.iter_reports_only()
]

In [None]:
val_n_sentences = get_n_sentences_values(val_dataloader.dataset)
len(val_n_sentences)

In [None]:
train_n_sentences = get_n_sentences_values(train_dataloader.dataset)
len(train_n_sentences)

### In predictions

In [None]:
%run ../training/report_generation/hierarchical.py

In [None]:
def compute_sentences_dist(dataloader):
    n_sentences_dist = []

    for batch in tqdm(dataloader):
        images = batch.images.to(DEVICE)
        reports = batch.reports.to(DEVICE)

        with torch.no_grad():
            output = compiled_model.model(images, reports, free=True,
                                          max_words=100, max_sentences=100)
        gen_words, gen_stops, gen_scores, gen_topics = output

        # for report in _flatten_gen_reports(gen_words, gen_stops):
            # Use torch:
        #     tokens, counts = report.unique(return_counts=True)
        #     dot_index, = (tokens == END_OF_SENTENCE_IDX).nonzero(as_tuple=True)
        #     if dot_index.size() == (0,):
        #         # No dot present
        #         n_sentences = 1
        #     else:
        #         dot_index = dot_index.item()
        #         n_sentences = counts[index]

            # Use iterator to count sentences:
            # n_sentences = len(list(sentence_iterator(report)))
            # n_sentences_dist.append(n_sentences)


        # Use stops only
        # Approximation: assumes the 1s appear all first, and the 0s all after
        n_sentences = (gen_stops < 0.5).long().sum(dim=1).tolist()

        n_sentences_dist.extend(n_sentences)

    if len(n_sentences_dist) != len(dataloader.dataset):
        print('Error: array does not match dataset size')
        print(f'arr-len={len(n_sentences_dist)} vs dataset-len={len(dataloader.dataset)}')
    return n_sentences_dist

In [None]:
val_n_sentences_gen = compute_sentences_dist(val_dataloader)
len(val_n_sentences_gen)

In [None]:
train_n_sentences_gen = compute_sentences_dist(train_dataloader)
len(train_n_sentences_gen)

### Plot dataset and predictions

In [None]:
def plot_n_sentences_distribution(n_sentences, split, gt=True, max_value=20):
    if max_value is not None:
        kwargs = {
            'bins': max_value,
            'range': (0, max_value),
        }
    else:
        kwargs = { 'bins': 10 }
    
    
    title = f'{"GT" if gt else "GEN"}-{split}'
    plt.title(f'N sentences per report ({title})', fontsize=20)
    plt.hist(n_sentences, align='mid', **kwargs)
    plt.xlabel('N sentences', fontsize=15)
    plt.ylabel('Number of reports', fontsize=15)
    
    if max_value is not None:
        outliers = [
            val
            for val in n_sentences
            if val > max_value
        ]
        n_outliers = len(outliers)
        if n_outliers > 0:
            min_o = min(outliers)
            max_o = max(outliers)
            print(f'{n_outliers} outliers found in {title}, from {min_o} to {max_o}')

In [None]:
MAX_VALUE = 15

n_rows = 2
n_cols = 2
plt.figure(figsize=(n_cols*7, n_rows*5))

plt.subplot(n_rows, n_cols, 1)
plot_n_sentences_distribution(train_n_sentences, 'train', gt=True, max_value=MAX_VALUE)

plt.subplot(n_rows, n_cols, 2)
plot_n_sentences_distribution(val_n_sentences, 'val', gt=True, max_value=MAX_VALUE)

plt.subplot(n_rows, n_cols, 3)
plot_n_sentences_distribution(train_n_sentences_dist, 'train', gt=False, max_value=MAX_VALUE)

plt.subplot(n_rows, n_cols, 4)
plot_n_sentences_distribution(val_n_sentences_dist, 'val', gt=False, max_value=MAX_VALUE)

plt.tight_layout()

## Vector topics

Plot and analyze in TB

### Compute vectors and save to file

TODO: wrap this in a function?

#### Compute vectors

In [None]:
def load_sentences_and_topic_vectors_(dataloader, all_sentences, all_vectors,
                                      all_metadata,
                                      max_amount=None):
    n_sentences_added = 0
    split = dataloader.dataset.dataset_type
    
    for batch in tqdm(dataloader):
        images = batch.images.to(DEVICE)
        reports = batch.reports.to(DEVICE)

        with torch.no_grad():
            output = compiled_model.model(images, reports, free=True,
                                          max_words=100, max_sentences=100)
        gen_words, gen_stops, _, gen_topics = output
        gen_words = gen_words.argmax(-1) # shape: bs, n_sentences, n_words
        gen_stops = (gen_stops > 0.5).type(torch.uint8) # shape: bs, n_sentences

        for report, stops, topics in zip(gen_words, gen_stops, gen_topics):
            for i_sentence, (sentence, should_stop, topic) in enumerate(zip(report, stops, topics)):
                if should_stop:
                    break
                dot_positions, = (sentence == END_OF_SENTENCE_IDX).nonzero(as_tuple=True)
                if len(dot_positions) == 0:
                    first_dot = len(sentence)
                else:
                    first_dot = dot_positions[0].item() + 1
                sentence = sentence[:first_dot].tolist()
                sentence = REPORT_READER.idx_to_text(sentence)

                all_sentences.append(sentence)
                all_vectors.append(topic)
                all_metadata.append((sentence, i_sentence, split))
                
                n_sentences_added += 1

        if max_amount is not None and \
            n_sentences_added >= _MAX_SENTENCES_COLLECTION:
            print(f'Stopped at {n_sentences_added}')
            break
        
    return

In [None]:
len(train_dataloader.dataset), len(val_dataloader.dataset)

In [None]:
ALL_SENTENCES = []
ALL_VECTORS = []
ALL_METADATA = []

In [None]:
load_sentences_and_topic_vectors_(train_dataloader,
                                  ALL_SENTENCES, ALL_VECTORS, ALL_METADATA)
len(ALL_SENTENCES), len(ALL_VECTORS), len(ALL_METADATA)

In [None]:
load_sentences_and_topic_vectors_(val_dataloader,
                                  ALL_SENTENCES, ALL_VECTORS, ALL_METADATA)
len(ALL_SENTENCES), len(ALL_VECTORS), len(ALL_METADATA)

In [None]:
ALL_VECTORS = torch.stack(ALL_VECTORS, dim=0)
ALL_VECTORS.size()

#### Create dataframe

With sentences and metadata (topic vectors are added later)

In [None]:
SENTENCES_DF = pd.DataFrame(ALL_METADATA, columns=['sentence', 'position', 'split'])
print(len(SENTENCES_DF))
SENTENCES_DF.head()

In [None]:
Counter(SENTENCES_DF['position'])

#### Add organs per sentence

In [None]:
%run ../datasets/common/sentences2organs/compute.py
%run ../datasets/common/constants.py

In [None]:
organs_onehot, warnings = find_organs_for_sentences(ALL_SENTENCES)
neutral_sentences = set(warnings['all-empty'])
len(organs_onehot), len(neutral_sentences)

In [None]:
if 'heart' not in SENTENCES_DF.columns:
    SENTENCES_DF = pd.concat([
        SENTENCES_DF,
        pd.DataFrame(organs_onehot, columns=JSRT_ORGANS)], axis=1)
    assert len(SENTENCES_DF) == len(ALL_SENTENCES)
SENTENCES_DF.head()

In [None]:
SENTENCES_DF['organ'] = [
    get_main_organ(one_hot, sentence, warnings)
    for sentence, one_hot in zip(ALL_SENTENCES, organs_onehot)
]
SENTENCES_DF.head()

In [None]:
Counter(SENTENCES_DF['organ'])

#### Add diseases per sentence

In [None]:
CACHE_FPATH = os.path.join(WORKSPACE_DIR, 'cache', 'labeler', 'sentences_chexpert.csv')

In [None]:
cache_df = pd.read_csv(CACHE_FPATH)
cache_df = cache_df.loc[cache_df['sentences'].isin(set(ALL_SENTENCES))]
print(len(cache_df))
cache_df.head()

In [None]:
if 'No Finding' not in SENTENCES_DF.columns:
    SENTENCES_DF = SENTENCES_DF.merge(cache_df, left_on='sentence', right_on='sentences', how='left')
    SENTENCES_DF.fillna(-3, inplace=True)
    assert len(SENTENCES_DF) == len(ALL_SENTENCES)
    SENTENCES_DF = SENTENCES_DF.astype({d: 'int8' for d in CHEXPERT_DISEASES})
    del SENTENCES_DF['sentences']

print(len(SENTENCES_DF))
SENTENCES_DF.head()

In [None]:
Counter(SENTENCES_DF['No Finding'])

In [None]:
del cache_df

#### Add topic vectors

In [None]:
len(SENTENCES_DF), ALL_VECTORS.size()

In [None]:
columns = [f'emb{i}' for i in range(ALL_VECTORS.size(1))]
if 'emb0' not in SENTENCES_DF.columns:
    SENTENCES_DF = pd.concat([
            SENTENCES_DF,
            pd.DataFrame(ALL_VECTORS.cpu().numpy(), columns=columns)], axis=1)
    assert len(SENTENCES_DF) == len(ALL_SENTENCES)
    print('Concatenated')

print(len(SENTENCES_DF))
SENTENCES_DF.head()

#### Save sentences to file

In [None]:
fpath = os.path.join(get_results_folder(run_id), 'sentence_vectors.csv')
folder = os.path.dirname(fpath)
os.makedirs(folder, exist_ok=True)

In [None]:
SENTENCES_DF.to_csv(fpath, index=False)

### Load pre-computed sentences and embeddings

In [None]:
fpath = os.path.join(get_results_folder(run_id), 'sentence_vectors.csv')

In [None]:
SENTENCES_DF = pd.read_csv(fpath)
print(len(SENTENCES_DF))
SENTENCES_DF.head()

### Position analysis

#### Plot basic position distribution

How many sentences are generated in X position?

In [None]:
def plot_positions_histogram(positions, upper_group=10,
                             titlesize=18, labelsize=15,
                             title='Distribution of sentence positions',
                             barcolor=None,
                            ):
    """Plots an histogram of a positions array.
    
    Args:
        positions -- array with numbers indicating sentences positions
        upper_group -- positions larger or equal to this will be grouped in one bin
    """
    max_position = max(positions)
    if upper_group >= max_position:
        bins = range(0, max_position)
        last_one_grouped = False
    else:
        bins = list(range(0, upper_group + 1)) + [max_position]
        last_one_grouped = True
    hist, bins = np.histogram(positions, bins=bins)
    bins = bins[:-1]

    plt.title(title, fontsize=titlesize)
    plt.xlabel('Sentence position', fontsize=labelsize)
    plt.ylabel('N sentences', fontsize=labelsize)
    plt.bar(bins, hist, color=barcolor)

    xlabels = list(str(i) for i in range(len(bins)))
    if last_one_grouped:
        xlabels[-1] = f'{xlabels[-1]}+'
    _ = plt.xticks(range(len(xlabels)), xlabels)

In [None]:
positions = SENTENCES_DF['position']
len(set(positions)) # , Counter(positions)

In [None]:
plt.figure(figsize=(8, 5))

plot_positions_histogram(positions, 11)

#### Plot organs distribution by sentence position

How many sentences about each organ are generated in position X?

In [None]:
plot_positions = [0, 1, 2, 3, 4, 5, (6,200)]

n_rows = len(plot_positions)
n_cols = 2

plt.figure(figsize=(n_cols * 5, n_rows * 3))

for i_split, split in enumerate(('train', 'val')):
    sub_df = SENTENCES_DF.loc[SENTENCES_DF['split'] == split]
    
    for i_position, position in enumerate(plot_positions):
        if isinstance(position, tuple):
            lower, _ = position
            actual_upper = max(sub_df['position'])
            position = (lower, actual_upper)
            condition = (sub_df['position'] >= lower) & (sub_df['position'] <= actual_upper)
        else:
            condition = sub_df['position'] == position
        rows = sub_df.loc[condition]
        
        counter = Counter(rows['organ'])
        amounts = [counter[o] for o in ORGANS]
        # organs, amounts = zip(*sorted(.items()))
        
        plt_index = i_position * n_cols + i_split + 1
        plt.subplot(n_rows, n_cols, plt_index)
        plt.bar(ORGANS, amounts, color=COLORS)
        plt.title(f'Organs in {split} position={position}', fontsize=16)
        plt.ylabel('Number of sentences', fontsize=15)
        
plt.tight_layout()

#### Plot position distribution by organ

Given organ X, in what positions is X described?

In [None]:
n_rows = len(ORGANS)
n_cols = 2

plt.figure(figsize=(n_cols * 5, n_rows * 3))

for i_split, split in enumerate(('train', 'val')):
    sub_df = SENTENCES_DF.loc[SENTENCES_DF['split'] == split]
    
    for i_organ, (organ, color) in enumerate(zip(ORGANS, COLORS)):
        rows = sub_df.loc[sub_df['organ'] == organ]
        
        positions = rows['position']
        
        plt_index = i_organ * n_cols + i_split + 1
        plt.subplot(n_rows, n_cols, plt_index)
        title = f'Sentence positions for {organ} ({split})'
        plot_positions_histogram(positions, 10,
                                 title=title, barcolor=color,
                                 labelsize=15,
                                )

plt.tight_layout()

### Write to TB

In [None]:
%run ../tensorboard/__init__.py

In [None]:
run_id

In [None]:
def write_tb(split, sample=None, dryrun=True):
    df = SENTENCES_DF.loc[SENTENCES_DF['split'] == split]

    if sample is not None:
        df = df.sample(sample)
    
    emb_cols = [c for c in SENTENCES_DF.columns if 'emb' in c]
    embeddings = df[emb_cols].to_numpy()
    assert embeddings.shape == (len(df), 100), f'Got {embeddings.shape}'
    
    # Group larger position values into one bin
    group_greater_than = 8
    replace_with = f'{group_greater_than}+'
    df.replace(
        {'position': {k:replace_with for k in range(group_greater_than, 200)}},
        inplace=True,
    )
    
    header = ['position', *JSRT_ORGANS, 'organ', *CHEXPERT_DISEASES]
    metadata = df[header].to_numpy()
    metadata = [tuple(map(str, x)) for x in metadata]
    
    tag = f'sentence_embeddings_{split}_{len(embeddings)}'
    if dryrun:
        print(f'Would write: {len(embeddings):,} vectors, tag={tag}')
        return
    
    writer = SummaryWriter(get_tb_large_log_folder(run_id))
    
    writer.add_embedding(
        embeddings,
        metadata=metadata,
        metadata_header=header,
        global_step=compiled_model.get_current_epoch(),
        tag=tag,
    )
    print(f'Written {len(embeddings):,} vectors, tag={tag}')

    writer.close()

In [None]:
write_tb('train', sample=2000, dryrun=False)

In [None]:
write_tb('val', dryrun=False)