# Imports

In [None]:
import os
import torch
import matplotlib.pyplot as plt

In [None]:
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [None]:
%env CUDA_VISIBLE_DEVICES=1

In [None]:
!echo $CUDA_VISIBLE_DEVICES

In [None]:
DEVICE = torch.device('cuda')

In [None]:
%run ../utils/__init__.py
%run ../datasets/__init__.py
%run ../models/checkpoint/__init__.py
%run ../utils/nlp.py

# Load data

In [None]:
kwargs = {
    'dataset_name': 'cxr14', # iu-x-ray
    'dataset_type': 'all',
    'batch_size': 25,
    'image_format': 'L',
    'frontal_only': True,
    'norm_by_sample': True,
    'image_size': (1024, 1024),
}

iu_dataloader = prepare_data_classification(**kwargs)
len(iu_dataloader.dataset)

# Load model

In [None]:
run_name = '1202_015907_jsrt_scan_lr0.0005_normS_size1024_wce1-6-3-3_aug10_sch-iou-p5-f0.5'
debug = False

In [None]:
compiled_model = load_compiled_model_segmentation(run_name, debug=debug, device=DEVICE)
compiled_model.metadata

# Segment images

And save masks

In [None]:
from PIL import Image
from tqdm.notebook import tqdm

In [None]:
# from medai.datasets.iu_xray import DATASET_DIR
from medai.datasets.cxr14 import DATASET_DIR
# DATASET_DIR

In [None]:
def calculate_output(batch):
    images = batch.image.to(DEVICE)

    with torch.no_grad():
        outputs = compiled_model.model(images).detach().cpu()
        # shape: batch_size, n_labels, height, width

    _, outputs = outputs.max(dim=1)
    # shape: batch_size, height, width

    return outputs

In [None]:
masks_folder = os.path.join(DATASET_DIR, 'masks')
os.makedirs(masks_folder, exist_ok=True)

In [None]:
def assertions(mask):
    min_value = mask.min()
    assert min_value == 0, f'Minimum must be zero, got {min_value}'
    
    max_value = mask.max()
    assert max_value == 3, f'Maximum must be three, got {min_value}'

In [None]:
def get_filepath(image_name):
    filepath = os.path.join(masks_folder, image_name)
    if not filepath.endswith('.png'):
        filepath += '.png'
    return filepath

In [None]:
state = tqdm(total=len(iu_dataloader.dataset))

for batch in iu_dataloader:
    outputs = calculate_output(batch).cpu()
    
    for image_name, mask in zip(batch.image_fname, outputs):
        mask = mask.to(torch.uint8).numpy()
        # shape: height, width
        
        assertions(mask)

        mask = Image.fromarray(mask, mode='L')
        mask.save(get_filepath(image_name))

        state.update(1)

In [None]:
plt.imshow(mask)

# Debug load dataset

In [None]:
%run ../utils/common.py

In [None]:
%run ../datasets/iu_xray.py

dataset = IUXRayDataset('all', image_size=(1024, 1024), frontal_only=True, masks=True)
len(dataset)

In [None]:
%run ../datasets/cxr14.py

dataset = CXR14Dataset('all', image_size=(1024, 1024), frontal_only=True, masks=True)
len(dataset)

In [None]:
def squeeze_masks(masks):
    n_dim = len(item.masks.size())
    if n_dim == 2:
        return masks
    
    n_organs = item.masks.size(0)
    multiplier = torch.arange(0, n_organs).unsqueeze(-1).unsqueeze(-1)
    return (multiplier * item.masks).sum(dim=0)

In [None]:
item = dataset[102]
item.image.size(), item.masks.size()

In [None]:
plt.figure(figsize=(15, 5))

plt.subplot(1, 2, 1)
plt.imshow(tensor_to_range01(item.image).permute(1, 2, 0))
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(squeeze_masks(item.masks))
plt.axis('off')