## Imports

In [None]:
import torch
import os
import json
import matplotlib.pyplot as plt

In [None]:
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [None]:
%run ../iu_xray.py

In [None]:
REPORTS_DIR = os.path.join(DATASET_DIR, 'reports')

## Load data

In [None]:
reports_fname = os.path.join(REPORTS_DIR, 'reports.json')
with open(reports_fname, 'r') as f:
    reports_as_dict = json.load(f)
    reports = list(reports_as_dict.values())
len(reports_as_dict), len(reports)

In [None]:
info_fname = os.path.join(DATASET_DIR, 'info.json')
with open(info_fname, 'r') as f:
    info = json.load(f)
len(info)

In [None]:
with open(info_fname, 'w') as f:
    json.dump(info, f)

In [None]:
info['marks']['rotated_left']

### Rotate images

NOTE: are already rotated!!

In [None]:
rotations = [
    ('left', -90),
    ('right', 90),
    ('bottom', 180),
]

In [None]:
for key, degrees in rotations:
    images_key = f'rotated_{key}'
    for image_name in info['marks'][images_key]:
        filepath = os.path.join(DATASET_DIR, 'images', image_name)
        img = Image.open(filepath).rotate(degrees)
        # img.save(filepath)

## Clean text

### Tokenize and clean

In [None]:
import re
from collections import defaultdict, Counter

In [None]:
# text = 'The previously<BR>described XXXX deformity'
text = """1. low lung volumes
2. exam limited on lateral: view by superimposed soft tissue and bony structures of the arm
3. lungs appear grossly clear . no evidence of pneumonia ."""
re.sub(r'< ?br ?\\?>', ' ', text.lower())

In [None]:
def remove_consecutive_dots(tokens):
    clean_tokens = []
    last_was_dot = False
    for token in tokens:
        is_dot = (token == '.')
        if last_was_dot and is_dot:
            continue

        clean_tokens.append(token)
        last_was_dot = is_dot
            
    return clean_tokens

In [None]:
remove_consecutive_dots(['.', '.', 'asdf', 'hello', '.', 'abc', '.', '.', 'c', '.'])

In [None]:
NUMBER_TOKEN = 'NUMBER'

def text_to_tokens(text):
    text = text.lower()
    # Remove html tags
    text = re.sub(r'(\[)?&amp;[gl]t;(\])?', ' ', text)
    
    # PM or AM token
    text = re.sub(r'\s(a|p)\.?m\.?\s', r' \1m ', text)
    
    # Replace two dots
    text = re.sub(r':', ' . ', text)
    
    # Replace multiple comma/semicolon with simple coma
    text = re.sub(r'(;|,+)', r',', text)
    
    # Replace numbers with decimals by token
    text = re.sub(r'\d+(\.|/)\d+', NUMBER_TOKEN, text)
    
    # Replace break line tag
    text = re.sub(r'< ?br ?\\?>', ' ', text)
    text = re.sub(r'[\[\]<>]', '', text) # Remove brackets [] <>
    text = re.sub(r'(\(|\))', r' \1 ', text) # Give space to parenthesis
    
    text = re.sub(r'\.+', r'.', text) # Replace multiple dots with one dot
    
    # Number as enumerators, like "1. bla bla, 2. bla bla"
    text = re.sub(r'(\W|\A)\d+\.[^\d]', r' . ', text)
    # text = re.sub(r'(\d)\.', r'\1 .', text)
    
    # Add space between text and dot/comma
    text = re.sub(r'([a-zA-Z0-9])(\.|,|/)', r'\1 \2', text)
    text = re.sub(r'(\.|,|/)([a-zA-Z0-9])', r'\1 \2', text)
    
    # Other numbers
    text = re.sub(r'(\W|\A)\d+(a|st|nd|th|rd|\%|mm|xxxx)?', r'\1 {}'.format(NUMBER_TOKEN), text)
    # text = re.sub(r'\A\d+(a|st|nd|th|rd|\%|mm|xxxx)?', r'\1 {}'.format(NUMBER_TOKEN), text)
    
    # Remove apostrophe
    text = re.sub(r'(\w+)\'[st]?', r'\1 ', text) # XXXX't is a typo
    
    # text = re.sub(r'NUMBER\.', 'NUMBER .', text)
    
    tokens = remove_consecutive_dots(text.split())
    if tokens[0] == '.':
        tokens = tokens[1:]
        
    if tokens[-1] != '.':
        tokens.append('.')
    return tokens

In [None]:
text_to_tokens("3 p.m. message xxxx' l10 l20. there's /11 3. mild clavicle: bilateral")

In [None]:
IGNORE_TOKENS = set(['p.m.', 'pm', 'am'])
token_appearances = Counter()
errors = defaultdict(list)

cleaned_reports_as_dict = dict()

for report in reports:
    filename = report['filename']
    findings = report['findings']
    impression = report['impression']

    n_images = len(report['images'])
    if n_images == 0:
        errors['no-images'].append(filename)
        continue
    
    text = findings
    if findings is None and impression is None:
        errors['text-none'].append(filename)
        continue
    elif findings is None:
        errors['findings-none'].append(filename)
        text = impression
    elif impression is None:
        errors['impression-none'].append(filename)

    # Clean and tokenize text
    tokens = [token for token in text_to_tokens(text) if token not in IGNORE_TOKENS]
    token_appearances += {
        token: 1
        for token in tokens
    }

    cleaned_report = {k: v for k, v in report.items()}
    cleaned_report['clean_text'] = ' '.join(tokens)

    cleaned_reports_as_dict[filename] = cleaned_report

print({k: len(v) for k, v in errors.items()})
print('Different tokens: ', len(token_appearances))
print('Tokens with more than 1 appearance: ',
      len([k for k, v in token_appearances.items() if v > 1]))
len(cleaned_reports_as_dict), len(reports)

In [None]:
sorted([(k, v) for k, v in token_appearances.items() if re.search(':', k)],
       key=lambda x:x[1], reverse=False)

### Review errors

In [None]:
reports_as_dict[errors['no-images'][0]]

### Review specific tokens

In [None]:
reports[0]

In [None]:
found = []

for report in cleaned_reports_as_dict.values():
    name = report['filename']
    findings = report['findings']
    impression = report['impression']
    
    clean = report.get('clean_text', None)
    if not clean:
        try:
            clean = cleaned_reports_as_dict[name]['clean_text']
        except:
            pass
        
    # s = re.search(r'\W\d\b', clean)
    # s = re.search(r'xxxx opacity in the left midlung', clean)
    target = ':'
    s = re.search(target, clean) # or re.search(target, impression or '')
    if s:
        found.append((name, findings, impression, clean)) # s.group(0)

print(len(found))
found

### Save cleaned reports

NOTE: Save after image info below

In [None]:
fname = os.path.join(REPORTS_DIR, 'reports.clean.v2.json')
with open(fname, 'w') as f:
    json.dump(cleaned_reports_as_dict, f)

## Add side to image info (in cleaned reports)

TODO: move this to a script!!!

`side` can be one of (`frontal`, `lateral-left`, `lateral-right`)

In [None]:
REPORTS_JSON_VERSION = 'reports.clean.v2.json'
fname = os.path.join(REPORTS_DIR, REPORTS_JSON_VERSION)
with open(fname, 'r') as f:
    clean_reports = json.load(f)
len(clean_reports)

In [None]:
wrong_images = set(info['marks']['wrong'])
broken_images = set(info['marks']['broken'])

In [None]:
for report_name, report_dict in clean_reports.items():
    new_images_info = []
    for image_info in report_dict['images']:
        image_name = image_info['id']
        image_name = f'{image_name}.png'

        image_info['side'] = info['classification'][image_name]
        image_info['wrong'] = image_name in wrong_images
        image_info['broken'] = image_name in broken_images

        new_images_info.append(image_info)
    
    report_dict['images'] = new_images_info
    clean_reports[report_name] = report_dict
    
len(clean_reports)

In [None]:
fname = os.path.join(REPORTS_DIR, REPORTS_JSON_VERSION)
with open(fname, 'w') as f:
    json.dump(clean_reports, f)

## Save common vocab

In [None]:
%run ../vocab/__init__.py
%run ../iu_xray.py

In [None]:
train_dataset = IUXRayDataset(dataset_type='train', recompute_vocab=True)
len(train_dataset)

In [None]:
vocab = train_dataset.get_vocab()
len(vocab)

In [None]:
prev_vocab = load_vocab('iu_xray')
len(prev_vocab), len(vocab)

In [None]:
idx = 1220
a = [(k, v) for k, v in vocab.items() if v == idx][0]
b = [(k, v) for k, v in prev_vocab.items() if v == idx][0]
a, b

In [None]:
save_vocab('iu_xray', vocab)

## Calculate image normalization

In [None]:
import torch
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

In [None]:
%run ../../utils/images.py

In [None]:
image_folder = os.path.join(DATASET_DIR, 'images')

In [None]:
dataset = IUXRayDataset('train')
len(dataset)

In [None]:
train_images = [
    i if i.endswith('.png') else f'{i}.png'
    for i in [r['image_name'] for r in dataset.reports]
]
len(train_images)

In [None]:
mean, std = compute_mean_std(ImageFolderIterator(image_folder, train_images), show=True)
mean, std

### Plot average image

In [None]:
from torchvision import transforms

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

In [None]:
summed = torch.zeros(3, 256, 256)

for image_name in tqdm(image_names):
    fpath = os.path.join(image_folder, image_name)
    image = transform(Image.open(fpath).convert('RGB'))
    summed += image
    
summed /= len(image_names)

In [None]:
average_image = summed.mean(dim=0)
average_image.size()

In [None]:
plt.imshow(average_image, cmap='gray')

## Test `IUXrayDataset` class

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from torchvision import transforms

In [None]:
%run ../iu_xray.py

In [None]:
dataset = IUXRayDataset(dataset_type='all', masks=True, frontal_only=True)
len(dataset), len(dataset.word_to_idx)

In [None]:
item = dataset[0]
image = item.image
labels = item.labels
report = item.report
image.size(), labels.size(), len(report)

In [None]:
item.masks.size()

In [None]:
plt.figure(figsize=(8, 5))

for index, organ in enumerate(JSRT_ORGANS):
    plt.subplot(1, 4, index + 1)
    plt.imshow(item.masks[index])
    plt.title(organ)
    plt.axis('off')

In [None]:
plt.imshow(image.permute(1, 2, 0))

In [None]:
dataset.get_labels_presence_for('Cardiomegaly')

### Review different image shapes

In [None]:
shapes = set()

for idx in range(len(dataset)):
    image, _ = dataset[idx]
    shapes.add(image.numpy().shape)

len(shapes)

In [None]:
shapes

### Load single images

In [None]:
fname = DATASET_DIR + '/images/CXR5_IM-2117-1003002.png'
img = Image.open(fname)
img_tensor = transforms.ToTensor()(img)
img.size, img_tensor.size()

## Inspect tags

In [None]:
from collections import defaultdict

In [None]:
counter = defaultdict(lambda: 0)
for report in reports:
    tags = report['tags_manual']
    for tag in tags:
        counter[tag] += 1

In [None]:
len(reports)

In [None]:
sorted(((k, v) for k, v in counter.items()), key=lambda x:x[1], reverse=True)

## Get sample reports

For LATINX in AI workshop

In [None]:
import matplotlib.pyplot as plt

In [None]:
import numpy as np
from pycocoevalcap.bleu import bleu_scorer
from pycocoevalcap.rouge import rouge

In [None]:
%run ../common.py
%run ../iu_xray.py
%run ../../utils/nlp.py
%run ../../utils/__init__.py

In [None]:
CONSTANT_REPORT = """the heart is normal in size . the mediastinum is unremarkable . 
the lungs are clear .
there is no pneumothorax or pleural effusion . no focal airspace disease .
no pleural effusion or pneumothorax ."""

In [None]:
dataset = IUXRayDataset(dataset_type='all')
report_reader = ReportReader(dataset.get_vocab())
len(dataset)

In [None]:
idx = GT_IDX
item = dataset[idx]
image = arr_to_range(item.image.permute(1, 2, 0))
report_base = report_reader.idx_to_text(item.report)
plt.imshow(image)
plt.axis('off')
print(report_base)

In [None]:
GT_IDX = 7289

In [None]:
target = [
    'the cardiac silhouette is enlarged',
    # 'the lungs are hyper',
    # 'the heart is',
]
not_target = [
    # 'the lungs are clear',
#     'the mediastinum is unremarkable',
#     'the mediastinum is stable',
#     'the mediastinum is normal',
#     'the mediastinum is within normal limits',
]
found = []
found_names = set()
for idx, report in enumerate(dataset.reports):
    filename = report['filename']
    report = report_reader.idx_to_text(report['tokens_idxs'])
    if all(t in report for t in target) and all(t not in report for t in not_target):
        if filename not in found_names:
            found.append((idx, report))
        found_names.add(filename)
len(found)

In [None]:
found[5]

In [None]:
gen = 'the heart is enlarged. the mediastinum is unremarkable . the lungs are hyperinflated with mildly coarsened interstitial markings . '
# the lungs are hyperexpanded
# the lungs are hyperinflated with mildly coarsened interstitial markings
# the lungs are hyperinflated with biapical pleural-parenchymal scarring and upward retraction of the xxxx

In [None]:
def measure_bleu_rouge(gen, gt):
    scorer = bleu_scorer.BleuScorer(n=4)
    scorer += (gen, [gt])
    bleu_1_4, _ = scorer.compute_score()
    
    scorer = rouge.Rouge()
    rouge_score = scorer.calc_score([gen], [gt])
    
    print('BLEU 1-4: ', bleu_1_4)
    print('BLEU: ', np.mean(bleu_1_4))
    print('ROUGE-L: ', rouge_score)

In [None]:
report_1 = """the heart is normal in size . the mediastinum is unremarkable . 
the lungs are clear ."""
report_2 = """the heart is normal . the mediastinum is otherwise unremarkable . 
lungs are both clear ."""
measure_bleu_rouge(report_1, report_2)

In [None]:
report = report_reader.idx_to_text(dataset[GT_IDX].report)
report

In [None]:
gt = """the cardiac silhouette is enlarged .
the lungs are hyperexpanded with flattening of the bilateral hemidiaphragms .
no pneumothorax or pleural effusion ."""
# the lungs are hyperinflated with mildly coarsened interstitial markings .
# with flattening of the bilateral hemidiaphragms 

In [None]:
gen = """the cardiac silhouette is normal in size .
the lungs are clear .
no pneumothorax or pleural effusion ."""

In [None]:
measure_bleu_rouge(gen, gt)

In [None]:
gt = "the cardiac silhouette is enlarged . the lungs are hyperexpanded with flattening of the bilateral hemidiaphragms . no pneumothorax or pleural effusion ."
gen = "the cardiac silhouette is normal in size and configuration . the lungs are clear . no pneumothorax or pleural effusion ."
measure_bleu_rouge(gen, gt)

In [None]:
measure_bleu_rouge(gen, gt)

## Check no-findings vs labels==0

In [None]:
from collections import defaultdict

In [None]:
chexpert_path = os.path.join(REPORTS_DIR, 'reports_with_chexpert_labels.csv')
mirqi_path = os.path.join(REPORTS_DIR, 'reports_with_mirqi_labels.csv')

In [None]:
chexpert_df = pd.read_csv(chexpert_path, index_col=0)
chexpert_df.replace(-1, 1, inplace=True)
chexpert_df.replace(-2, 0, inplace=True)
chexpert_df.head()

In [None]:
mirqi_df = pd.read_csv(mirqi_path, index_col=0)
mirqi_df.drop(columns=['attributes-gen', 'MIRQI-r', 'MIRQI-p', 'MIRQI-f'], inplace=True)
mirqi_df.rename(columns={'attributes-gt': 'attributes'}, inplace=True)
mirqi_df.replace(-1, 1, inplace=True)
mirqi_df.replace(-2, 0, inplace=True)
mirqi_df.head()

In [None]:
base_columns = set(['filename', 'Reports', 'attributes'])
MIRQI_LABELS = [c for c in mirqi_df.columns if c not in base_columns]

In [None]:
len(chexpert_df), len(mirqi_df)

In [None]:
df = chexpert_df.merge(mirqi_df, on='filename', suffixes=['_chx', '_mirqi'])
print(len(df))
df.head()

In [None]:
reports_by_condition = defaultdict(set)

for index, row in chexpert_df.iterrows():
    filename = row['filename']
    report = row['Reports']
    labels = row[CHEXPERT_LABELS]

    tup = (index, filename, report)

    no_findings = labels['No Finding']
    
    if no_findings == 1:
        reports_by_condition['no-findings-1'].add(tup)
        if any(l != 0 for l in labels[1:-1]):
            # Exclude no-findings and support-devices
            reports_by_condition['inconsistent'].add(tup)
    else:
        if not any(l != 0 for l in labels[1:-1]):
            reports_by_condition['no-findings-absent'].add(tup)
    
    if all(l != 1 for l in labels):
        reports_by_condition['no-1s'].add(tup)
    
[(k, len(v)) for k, v in reports_by_condition.items()]

In [None]:
l = list(reports_by_condition['no-findings-absent'])
l[:5]

In [None]:
mirqi_df.loc[mirqi_df['filename'] == '256.xml'][MIRQI_LABELS]

In [None]:
l = list(reports_by_condition['no-1s'])
l[:10]

In [None]:
l = list(reports_by_condition['no-findings-1'])
l[:10]