# Imports

In [None]:
import torch
from tqdm.notebook import tqdm
from types import MethodType
import matplotlib.pyplot as plt

In [None]:
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [None]:
%run ../utils/__init__.py
config_logging(logging.INFO)

# Functions

In [None]:
from collections import namedtuple

In [None]:
%run ../datasets/__init__.py

In [None]:
BalanceDistribution = namedtuple('BalanceDistribution', [
    'dataloader',
    'n_positives',
])

In [None]:
shorter_label = {
    'No Finding': 'NF',
    'Enlarged Cardiomediastinum': 'Enl Card',
    'Pleural Effusion': 'Pleural-E',
    'Pleural Other': 'Pleural-O',
    'Support Devices': 'Dev',
}

In [None]:
def plot_distributions(dist, n_rows=3, n_cols=5, bins=10):
    """Plots a balance distribution.
    
    Args:
        dataloader -- Dataloader used to calculate distributions
        n_positives -- array/tensor of shape (n_batches, n_diseases+1), with the amount of positives by batch
    """
    plt.figure(figsize=(15, 10))
    
    dataloader = dist.dataloader
    n_positives = dist.n_positives.float()
    
    labels = list(dataloader.dataset.labels) + ['No Finding']
    batch_size = dataloader.batch_size
    
    plt.suptitle(f'BS={batch_size}, SAMPLER={str(dataloader.sampler.__class__.__name__)}')

    for i_label, label_name in enumerate(labels):
        plt.subplot(n_rows, n_cols, i_label + 1)
        
        # TODO: allow plotting n_positives values across epochs,
        # i.e. plt.plot(t, n_positives[:, i_label])
        arr = n_positives[:, i_label]

        vals, _, _ = plt.hist(arr, bins=bins)
        if i_label % n_cols == 0: plt.ylabel('Frequency')
        if i_label // n_cols == n_rows - 1: plt.xlabel('Positive samples in a batch')
        

        mean_value = arr.mean()
        plt.vlines(mean_value, 0, vals.max(), color='red')
        
        label_name = shorter_label.get(label_name, label_name)
        plt.title(f'{label_name}, {mean_value:.1f}, {mean_value/batch_size*100:.1f}%')

In [None]:
def compute_average_balance(create_dataloader=prepare_data_classification, **kwargs):
    """Computes balance of labels in a dataloader.
    
    "Balance" is defined as the average amount of positive labels in a batch, per disease.
    
    Args:
        create_dataloader -- function to create the dataloader
        **kwargs -- passed to the `create_dataloader` function
    """
    dataloader = create_dataloader(**kwargs)

    n_batches = len(dataloader)
    n_labels = len(dataloader.dataset.labels)

    positives_by_label = []
    
    checked_for_monkeypatch = False

    for batch in tqdm(iter(dataloader)):
        labels = batch.labels.sum(dim=0) # shape: n_labels

        if not checked_for_monkeypatch:
            if not (batch.image == -1).all().item():
                print(f'Warning: dataset may be loading images, images={batch.image}')
            checked_for_monkeypath = True
        
        no_finding_count = batch.labels.sum(dim=1) # shape: batch_size
        no_finding_count = (no_finding_count == 0).sum().unsqueeze(0) # shape: 1
        no_finding_count = no_finding_count.type(labels.dtype)
        
        labels = torch.cat((labels, no_finding_count), dim=0) # shape: n_labels+1
        
        positives_by_label.append(labels)

    positives_by_label = torch.stack(positives_by_label, dim=0)
    # shape: n_batches, n_labels+1

    print('Amount of positives by label, in average: ', positives_by_label.float().mean(dim=0).tolist())
    stats = {
        'sampler': str(dataloader.sampler.__class__.__name__),
        'n_samples': len(dataloader.dataset),
        'n_batches': len(dataloader),
        'batch_size': dataloader.batch_size,
    }
    print(' '.join(f'{k}={v}' for k, v in stats.items()))

    return BalanceDistribution(dataloader=dataloader, n_positives=positives_by_label)

# Classification

In [None]:
%run ../datasets/__init__.py
%run ../datasets/common/__init__.py

In [None]:
# Monkey patch the method, to not load images
def getitem_labelsonly(self, idx):
    row = self.label_index.iloc[idx]
    labels = row[self.labels].to_numpy().astype('int')
    
    return BatchItem(labels=labels)

CXR14Dataset.__getitem__ = getitem_labelsonly

In [None]:
CXR_14_KWARGS = {
    'dataset_name': 'cxr14',
    'dataset_type': 'train',
}

In [None]:
dist_40 = compute_average_balance(batch_size=40, **CXR_14_KWARGS)
plot_distributions(dist_40)

In [None]:
dist = dist_40

batch_size = dist.dataloader.batch_size
n_diseases = len(dist.dataloader.dataset.labels)

d = dist.n_positives
d = d.sum(axis=1) # shape: n_batches
d = d / (batch_size * n_diseases)

d

In [None]:
plt.hist(d)

In [None]:
dist_40_balanced = compute_average_balance(batch_size=40, balanced_sampler=True, **CXR_14_KWARGS)
plot_distributions(dist_40_balanced)

In [None]:
batch = next(iter(dist_40_balanced.dataloader))

In [None]:
batch.labels

In [None]:
dist = dist_40_balanced

batch_size = dist.dataloader.batch_size
n_diseases = len(dist.dataloader.dataset.labels)

d = dist.n_positives
d = d.sum(axis=1) # shape: n_batches
d = d / (batch_size * n_diseases)

(d < 0.5).sum()

In [None]:
# dist_100 = compute_average_balance(batch_size=100, **CXR_14_KWARGS)
plot_distributions(dist_100, 100)

In [None]:
bs = 40
dist_40_os = compute_average_balance(batch_size=bs, oversample=True, **CXR_14_KWARGS)
plot_distributions(dist_40_os, bs)

In [None]:
bs = 40
dist = compute_average_balance(batch_size=bs,
                                     oversample=True,
                                     oversample_label=1,
                                     **CXR_14_KWARGS)
plot_distributions(dist, bs)

# Report generation

In [None]:
from functools import partial

In [None]:
%run ../datasets/__init__.py
%run ../datasets/common/__init__.py
%run ../training/report_generation/flat.py

In [None]:
def getitem_ignoreimages(self, idx):
    report = self.reports[idx]
    filename = report['filename']
    labels = self.labels_by_report[filename]
    
    return BatchItem(labels=labels, report=report['tokens_idxs'], image=torch.tensor(-1))

IUXRayDataset.__getitem__ = getitem_ignoreimages

### Classification-wise (i.e. labels)

In [None]:
IU_KWARGS = {
    'dataset_name': 'iu-x-ray',
    'dataset_type': 'train',
#     'create_dataloader': partial(
#         prepare_data_report_generation,
#         create_dataloader_fn=create_flat_dataloader,
#     ),
}

In [None]:
dist_10 = compute_average_balance(batch_size=10, **IU_KWARGS)
plot_distributions(dist_10)

### Sentences

In [None]:
import os
from collections import Counter, defaultdict
import numpy as np
import pandas as pd

In [None]:
%run ../utils/nlp.py
%run ../datasets/common/constants.py

In [None]:
dataloader = prepare_data_report_generation(create_flat_dataloader,
                                            dataset_name='iu-x-ray',
                                            dataset_type='train',
                                            batch_size=20,
                                            shuffle=True,
                                           )
len(dataloader.dataset)

#### Load stuff to get sentences labels

In [None]:
fpath = os.path.join(dataloader.dataset.reports_dir, 'sentences_with_chexpert_labels.csv')
sentence_labels_df = pd.read_csv(fpath, index_col='sentences')
sentence_labels_df = sentence_labels_df[CHEXPERT_LABELS]
sentence_labels_df.replace(-1, 1, inplace=True)
sentence_labels_df.replace(-2, 0, inplace=True)
sentence_labels_df.head()

In [None]:
sentence_to_labels = sentence_labels_df.transpose().to_dict(orient='list')
sentence_to_labels

In [None]:
report_reader = ReportReader(dataloader.dataset.get_vocab())

#### Evaluate in dataloader

In [None]:
different_sentences_per_batch = []
errors = defaultdict(list)

labels_by_batch = []

for batch in tqdm(iter(dataloader)):
    sentences_counter = Counter()
    batch_labels = np.zeros(len(CHEXPERT_LABELS) + 1)
    
    for report in batch.reports:
        for sentence in sentence_iterator(report):
            
            sentence = report_reader.idx_to_text(sentence)
            
            # Count sentences
            sentences_counter[sentence] += 1
            
            # Count labels
            labels = sentence_to_labels.get(sentence, None)
            if labels is None:
                errors['no-labels-found'].append(sentence)
                continue
            no_finding = int(all(l == 0 for l in labels[1:-1]))
            labels = np.array(labels + [no_finding]) # shape: n_diseases + 1
            
            batch_labels += labels

    # Accumulate labels
    labels_by_batch.append(batch_labels)

    # Count sentences
    n_sentences_in_batch = len(sentences_counter)
    different_sentences_per_batch.append(n_sentences_in_batch)

labels_by_batch = np.array(labels_by_batch) # shape: n_batches, (n_diseases+1)

# Move NF to the first label
labels_by_batch[:,0] = labels_by_batch[:,-1]
labels_by_batch = np.delete(labels_by_batch, -1, 1) # shape: n_batches, n_diseases

n_errors = {k:len(v) for k, v in errors.items()}

n_errors, np.mean(different_sentences_per_batch), labels_by_batch.shape

In [None]:
bins = 15

n_rows = 3
n_cols = 5

# fig, ax = plt.subplots(figsize=(15,10), sharex=True)
plt.figure(figsize=(15, 10))

for i_label, label_name in enumerate(CHEXPERT_LABELS):
    subplot_i = i_label + 1
    plt.subplot(n_rows, n_cols, subplot_i)
        
    plt.hist(labels_by_batch[:, i_label], bins=bins)
    plt.title(label_name)
    if i_label % n_cols == 0:
        plt.ylabel('Frequency')
    
    if i_label // n_cols == n_rows - 1:
        plt.xlabel('Number of positives')

In [None]:
labels_by_batch[:, 1:].sum(axis=1)

#### Manual inspection

In [None]:
dataloader = prepare_data_report_generation(create_flat_dataloader,
                                            dataset_name='iu-x-ray',
                                            dataset_type='train',
                                            batch_size=20,
                                            shuffle=True,
                                           )
len(dataloader.dataset)

In [None]:
d = iter(dataloader)

In [None]:
reports = [
    report_reader.idx_to_text(r)
    for r in next(d).reports
]
reports

In [None]:
t = range(len(different_sentences_per_batch))
plt.plot(t, different_sentences_per_batch)

plt.xlabel('Batch i')
plt.ylabel('Different sentences')

TODO:
* for each sentence, search its labels in sentences_with_chexpert_labels.csv
* Make a plot of labels seen through the batches (i.e. labels!=NF vs batch_i)

* Same can be done by report (use dataset.labels_by_report)

In [None]:
n_batches = len(dataloader)
n_labels = len(dataloader.dataset.labels)

positives_by_label = []

checked_for_monkeypatch = False

for batch in tqdm(iter(dataloader)):
    labels = batch.labels.sum(dim=0) # shape: n_labels

    if not checked_for_monkeypatch:
        if not (batch.image == -1).all().item():
            print(f'Warning: dataset may be loading images, images={batch.images}')
        checked_for_monkeypath = True

    no_finding_count = batch.labels.sum(dim=1) # shape: batch_size
    no_finding_count = (no_finding_count == 0).sum().unsqueeze(0) # shape: 1
    no_finding_count = no_finding_count.type(labels.dtype)

    labels = torch.cat((labels, no_finding_count), dim=0) # shape: n_labels+1

    positives_by_label.append(labels)

positives_by_label = torch.stack(positives_by_label, dim=0)
# shape: n_batches, n_labels+1

print('Amount of positives by label, in average: ', positives_by_label.float().mean(dim=0))
print('Batch size: ', dataloader.batch_size)
