## Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=1

In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [None]:
DEVICE = torch.device('cuda')
DEVICE

In [None]:
%run ../utils/__init__.py

## Load data

In [None]:
%run ../datasets/__init__.py

### Load seg dataset

In [None]:
kwargs = {
    'dataset_name': 'jsrt',
    'dataset_type': 'test',
    'norm_by_sample': True,
    'image_size': (1024, 1024),
}
jsrt_dataloader = prepare_data_segmentation(**kwargs)
len(jsrt_dataloader.dataset)

In [None]:
jsrt_dataloader.dataset.__class__.__name__

In [None]:
item = jsrt_dataloader.dataset[1]
item.image.size()

In [None]:
plt.imshow(item.image[0], cmap='gray')
plt.axis('off')

### Load IU x-ray

In [None]:
kwargs = {
    'dataset_name': 'iu-x-ray',
    'dataset_type': 'train',
    'batch_size': 10,
    'image_format': 'L',
    'frontal_only': True,
    'norm_by_sample': True,
    'image_size': (1024, 1024),
}

iu_dataloader = prepare_data_classification(**kwargs)
len(iu_dataloader.dataset)

In [None]:
item = iu_dataloader.dataset[1]
item.image.size()

In [None]:
plt.imshow(item.image[0], cmap='gray')
plt.axis('off')

### Load CXR-14

In [None]:
kwargs = {
    'dataset_name': 'cxr14',
    'dataset_type': 'train',
    'batch_size': 10,
    'image_format': 'L',
    'norm_by_sample': True,
    'image_size': (1024, 1024),
}

cxr14_dataloader = prepare_data_classification(**kwargs)
len(cxr14_dataloader.dataset)

### Covid-UC

In [None]:
kwargs = {
    'dataset_name': 'covid-uc',
    'dataset_type': 'train',
    'batch_size': 10,
    'image_format': 'L',
    'frontal_only': True,
    'norm_by_sample': True,
    'image_size': (1024, 1024),
}

coviduc_dataloader = prepare_data_classification(**kwargs)
len(coviduc_dataloader.dataset)

## Load model

In [None]:
%run ../models/checkpoint/__init__.py

In [None]:
run_names = [
#     '1105_180035_jsrt_scan_lr0.0001_normD_size1024',
#     '1106_092037_jsrt_scan_lr0.0001_normD_size1024',
    '1106_180455_jsrt_scan_lr0.0005_normS_size1024_wce1-4-3-3_sch-iou-p5-f0.2',
    '1202_015907_jsrt_scan_lr0.0005_normS_size1024_wce1-6-3-3_aug10_sch-iou-p5-f0.5',
]
debug = False

# run_name = '1106_165046_jsrt_scan_lr0.0001_normS_size1024_wce1-4-3-3'
# run_name = '1106_174749_jsrt_scan_lr0.0001_normS_size1024_wce1-4-3-3'
# run_name = '1106_175002_jsrt_scan_lr0.0001_normS_size1024_wce1-4-3-3'
# debug = True

In [None]:
compiled_models = [
    load_compiled_model_segmentation(run_name, debug=debug, device=DEVICE)
    for run_name in run_names
]

In [None]:
compiled_models[0].metadata

In [None]:
compiled_model = load_compiled_model_segmentation(run_name, debug=debug, device=DEVICE)
compiled_model.metadata

## Examples

### Functions

In [None]:
import re

In [None]:
%run ../utils/nlp.py

In [None]:
def calculate_output(model, item):
    images = item.image.unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        outputs = model(images)

    outputs = outputs.squeeze(0)
    _, outputs = outputs.max(dim=0)
    outputs = outputs.detach().cpu().numpy() # height, width
    
    return outputs

In [None]:
def print_report(dataloader, idx):
    dataset = dataloader.dataset
    item = dataset[idx]
    
    report_reader = ReportReader(dataset.get_vocab())
    
    report = report_reader.idx_to_text(item.report)
    
    print(report)

In [None]:
def shorter_name(run_name):
    group = re.search(r'\d{4}_\d{6}', run_name)
    if group:
        return group[0]
    return run_name

In [None]:
def plot_example(compiled_models, dataloader, sample_idx, titlesize=15):
    dataset = dataloader.dataset
    item = dataset[sample_idx]
    
    if not isinstance(compiled_models, (tuple, list)):
        compiled_models = (compiled_models,)

    outputs = [
        (
            compiled_model.metadata.get('run_name'),
            calculate_output(compiled_model.model, item),
        )
        for compiled_model in compiled_models
    ]
    
    n_cols = 1 + len(outputs)
    
    plt.figure(figsize=(15, 10))
    plt.subplot(1, n_cols, 1)
    dataset_name = dataset.__class__.__name__.lower().strip('dataset')
    title = f'{dataset_name}, sample={sample_idx}'
    if dataset.__class__.__name__ == 'JSRTDataset':
        title += f', {dataset.dataset_type}'
    plt.title(title, fontsize=titlesize)
    plt.imshow(item.image[0], cmap='gray')
    plt.axis('off')

    for index, (run_name, output) in enumerate(outputs):
        print(run_name)
        plt.subplot(1, n_cols, index + 2)
        plt.title(shorter_name(run_name), fontsize=titlesize)
        plt.imshow(output)
        plt.axis('off')

### JSRT examples

In [None]:
plot_example(jsrt_dataloader, 0)

In [None]:
plot_example(jsrt_dataloader, 3)

In [None]:
plot_example(jsrt_dataloader, 20)

In [None]:
plot_example(jsrt_dataloader, 1, titlesize=25)

### IU x-ray dataset

In [None]:
idx = 2
print_report(iu_dataloader, idx)
plot_example(compiled_models, iu_dataloader, idx, titlesize=20)

In [None]:
idx = 300
print_report(iu_dataloader, idx)
plot_example(compiled_models, iu_dataloader, idx, titlesize=20)

In [None]:
idx = -50
print_report(iu_dataloader, idx)
plot_example(compiled_models, iu_dataloader, idx, titlesize=20)

In [None]:
idx = 740
print_report(iu_dataloader, idx)
plot_example(compiled_models, iu_dataloader, idx)

### CXR-14 examples

In [None]:
plot_example(cxr14_dataloader, 0, titlesize=25)

In [None]:
plot_example(cxr14_dataloader, 20, titlesize=25)

In [None]:
plot_example(cxr14_dataloader, 50)

### Covid-UC

In [None]:
plot_example(coviduc_dataloader, 2)

In [None]:
plot_example(coviduc_dataloader, 15)

In [None]:
plot_example(coviduc_dataloader, 6)

In [None]:
coviduc_dataloader.dataset._metadata_df.iloc[1]