# Imports

In [None]:
import pandas as pd
import os
from collections import defaultdict

In [None]:
%run ../utils/__init__.py
config_logging(logging.INFO)

In [None]:
%run ../utils/files.py

# Debug light-chexpert cache

## Load Cache file

In [None]:
%run ../metrics/report_generation/labeler_correctness/cache.py

In [None]:
FPATH = os.path.join(LABELER_CACHE_DIR, 'sentences_chexpert.csv')

In [None]:
df = pd.read_csv(FPATH)
print(len(df))
df.head()

In [None]:
sentences = list(df['sentences'])
len(sentences), len(set(sentences))

### Check empty sentences

In [None]:
sum(1 for s in sentences if len(s.split()) == 0)

### Remove trailing dot

In [None]:
repeated_sentences = defaultdict(lambda: 1)
reduced_sentences = set()

for sentence in sentences:
    sentence = sentence.split()
    if sentence[-1] == '.':
        sentence = sentence[:-1]
    sentence = ' '.join(sentence)
    if sentence in reduced_sentences:
        repeated_sentences[sentence] += 1

    reduced_sentences.add(sentence)
    
len(reduced_sentences), len(sentences)

In [None]:
repeated_sentences

### Remove repeated tokens

In [None]:
def remove_duplicated_tokens(tokens):
    return [
        token
        for i, token in enumerate(tokens)
        if i == 0 or token != tokens[i-1]
    ]

In [None]:
remove_duplicated_tokens(['there', 'there', 'is', 'stable', 'there'])

In [None]:
repeated_sentences = defaultdict(lambda: 1)
reduced_sentences = set()

for sentence in sentences:
    sentence = [
        token
        for token in sentence.split()
        if token not in ('END', ',', '.', 'xxxx')
    ]
    
    sentence = remove_duplicated_tokens(sentence)
    
    sentence = ' '.join(sentence)
    if sentence in reduced_sentences:
        repeated_sentences[sentence] += 1

    reduced_sentences.add(sentence)
    
len(reduced_sentences), len(sentences)

In [None]:
sorted(((k, v) for k, v in repeated_sentences.items()), key=lambda x:x[1], reverse=True)

## Clean sentences

### Apply cleaning

In [None]:
%run ../metrics/report_generation/labeler_correctness/light_labeler.py

In [None]:
clean_sentence('there - there / &lt  asdf UNK'.split())

In [None]:
clean_sentences = [
    ' '.join(clean_sentence(sentence.split()))
    for sentence in df['sentences']
]
len(set(clean_sentences)), len(clean_sentences), len(df)

In [None]:
df['clean_sentences'] = clean_sentences
df.head()

### Remove duplicated

In [None]:
cols = [c for c in df.columns if 'sentence' not in c]
len(cols)

In [None]:
unique_df = df.groupby('clean_sentences').first()
unique_df.head()

In [None]:
len(unique_df)

In [None]:
unique_df = unique_df.reset_index(drop=False)
del unique_df['sentences']
unique_df = unique_df.rename(columns={'clean_sentences': 'sentences'})
unique_df.head()

In [None]:
unique_df.to_csv(FPATH, index=False)

# Debug holistic chexpert-labeler

In [None]:
%run ../metrics/report_generation/chexpert.py

In [None]:
FILL_EMPTY = -2
FILL_UNCERTAIN = -1
# dataset_name = 'iu-x-ray'
dataset_name = 'mimic-cxr'

gt_with_labels = _load_gt_df(dataset_name,
                             fill_uncertain=FILL_UNCERTAIN, fill_empty=FILL_EMPTY)
gt_with_labels.head(2)

In [None]:
%run ../metrics/report_generation/chexpert.py

In [None]:
_gt_reports = list(gt_with_labels['Reports'])
some_report = _gt_reports[3]

reports = [
    'Cardiomegaly .',
    _gt_reports[10],
    some_report,
    some_report,
    'no pneumothorax .',
    'the cardiac silhouette and mediastinum size are within normal limits . there is no pulmonary edema . there is no focal consolidation . there are no xxxx of a pleural effusion . there is no evidence of pneumothorax .',
    some_report,
    'heart size is enlarged',
]

In [None]:
labeler = ChexpertLabeler(fill_uncertain=FILL_UNCERTAIN, fill_empty=FILL_EMPTY,
                          caller_id='debugging')
labeler = CacheLookupLabeler(labeler, gt_with_labels)
labeler = NBatchesLabeler(labeler)
labeler = AvoidDuplicatedLabeler(labeler)

In [None]:
%%time

labels = labeler(reports)
labels.shape, labels, 

## Check example

In [None]:
%run ../utils/files.py
%run ../metrics/report_generation/chexpert.py

In [None]:
run1, run2 = '0604_165400', '0604_165401'
# run1, run2 = '0609_194751', '0609_194752'

run_original = RunId(run1, False, 'rg')
run_copy = RunId(run2, False, 'rg')

In [None]:
def _load_file(run_id, fname):
    fpath = os.path.join(get_results_folder(run_id), fname)
    df = pd.read_csv(fpath)
    sort_by = ['ground_truth']
    if 'image_fname' in df.columns:
        sort_by.append('image_fname')
    df = df.sort_values(sort_by).reset_index(drop=True)
    return df    

def load_unlabeled(run_id):
    return _load_file(run_id, 'outputs-free.csv')

def load_labeled(run_id):
    df = _load_file(run_id, 'outputs-labeled-free.csv')
    
    gt_labels = df[labels_with_suffix('gt')].to_numpy().astype(np.int8)
    gen_labels = df[labels_with_suffix('gen')].to_numpy().astype(np.int8)
    return df, gt_labels, gen_labels

In [None]:
df1 = load_unlabeled(run_original)
df2 = load_unlabeled(run_copy)
(df1 == df2).all()

In [None]:
dfl1, gt1, gen1 = load_labeled(run_original)
dfl2, gt2, gen2 = load_labeled(run_copy)
len(dfl1) == len(dfl2)

In [None]:
(dfl1['generated'] == df1['generated']).all(), (dfl2['generated'] == df2['generated']).all()

In [None]:
(dfl1 == dfl2).all()

In [None]:
(gt1 == gt2).all()

In [None]:
(gen1 == gen2).all()

In [None]:
different_samples = (gen1 != gen2).any(axis=1)

In [None]:
d = dfl1[different_samples]['generated']
d

In [None]:
from collections import Counter

In [None]:
Counter(Counter(d).values())

In [None]:
gen2[different_samples]

In [None]:
gen1[(gen1 != gen2).any(axis=1)]

In [None]:
df1[(df1['generated'] != df2['generated'])]

# Debug MIRQI

In [None]:
run_id = RunId('0121_210044', False, 'rg')

In [None]:
# name = 'outputs-mirqi-free.csv'
name = 'outputs-free.csv'

fpath = os.path.join(get_results_folder(run_id), name)
df = pd.read_csv(fpath)
df.head(2)

In [None]:
# list(df['attributes-gen'])

In [None]:
%run ../metrics/report_generation/mirqi.py

In [None]:
from collections import Counter

In [None]:
Counter(df['dataset_type'])

In [None]:
%%time

attributes = _call_mirqi_to_reports(df['ground_truth'])
attributes.shape

In [None]:
len(df), len(df['filename'])

In [None]:
df.head()

In [None]:
attributes

In [None]:
df_small = df[:20]
# df_small = df_small[['filename', 'epoch', 'dataset_type', ]]
df_small.head()

In [None]:
df2 = apply_mirqi_to_df(df_small, run_id.short_name, batches=2)

In [None]:
df2.head()

In [None]:
%run -n ../eval_report_generation_mirqi.py

In [None]:
run_id = RunId('0428_133018', True, 'rg')

In [None]:
%%time

evaluate_run(run_id,
             override=False,
             max_samples=20,
             free=True)

In [None]:
# gt = _attributes_to_list(df2['attributes-gt'])
gen = _attributes_to_list(df2['attributes-gen'])
len(gt), len(gen)

In [None]:
scores_v1 = MIRQI(gt, gen)
scores_v2 = MIRQI_v2(gt, gen)
scores_v2

In [None]:
df3 = df2.assign(**scores_v1)
df3 = df3.assign(**scores_v2)
df3.head()

In [None]:
metrics = _calculate_metrics_dict(df3)
metrics

In [None]:
reports = [
    'no pneumothorax .',
    'cardiomegaly .',
    'no active disease .',
]
labels = apply_mirqi_to_reports(reports)
labels.shape