# Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=2

In [None]:
import os
import torch
import matplotlib.pyplot as plt

In [None]:
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [None]:
DEVICE = torch.device('cuda')

In [None]:
%run ../utils/__init__.py
%run ../models/checkpoint/__init__.py
%run ../utils/nlp.py

In [None]:
config_logging(logging.INFO)

# Load data

In [None]:
%run ../datasets/__init__.py

In [None]:
kwargs = {
    'dataset_name': 'cxr14',
    'dataset_type': 'all',
    'batch_size': 45,
    'image_format': 'L',
    'frontal_only': True,
    'norm_by_sample': True,
    'image_size': (1024, 1024),
}

dataloader = prepare_data_classification(**kwargs)
len(dataloader.dataset)

# Load model

In [None]:
# v0 and v1:
# run_name = '1202_015907_jsrt_scan_lr0.0005_normS_size1024_wce1-6-3-3_aug10_sch-iou-p5-f0.5'

# v2:
run_name = '0412_080944_jsrt_scan_lr0.0005_normS_size1024_wce1-6-3-3_aug5-double_sch-iou-p5-f0.5'
debug = False

In [None]:
compiled_model = load_compiled_model_segmentation(run_name, debug=debug, device=DEVICE)
compiled_model.metadata

# Segment images

And save masks

## Functions

In [None]:
from PIL import Image
from tqdm.notebook import tqdm

In [None]:
%run ../utils/shapes.py

In [None]:
def calculate_output(batch):
    images = batch.image.to(DEVICE)

    with torch.no_grad():
        outputs = compiled_model.model(images).detach().cpu()
        # shape: batch_size, n_labels, height, width

    _, outputs = outputs.max(dim=1)
    # shape: batch_size, height, width

    return outputs

In [None]:
def assertions(mask, image_fname):
    min_value = mask.min()
    if min_value != 0:
        print(f'[{image_fname}] Minimum must be 0, got {min_value}')
    
    max_value = mask.max()
    if max_value != 3:
        print(f'[{image_fname}] Maximum must be 3, got {min_value}')

In [None]:
def get_filepath(masks_folder, image_name):
    filepath = os.path.join(masks_folder, image_name)
    if not filepath.endswith('.png'):
        filepath += '.png'
    return filepath

In [None]:
def clean_image_fname(image_name):
    image_name = image_name.replace('.png', '')
    image_name = image_name.replace('.jpg', '')
    image_name = image_name.replace('/', '-')
    return image_name

## Run segmentation

### Define folders

In [None]:
VERSION = 'v2'
EXIST_OK = True

In [None]:
if dataloader.dataset.__class__.__name__ == 'VinBigDataset':
    FOLDER_NAME = 'organ-masks'
else:
    FOLDER_NAME = 'masks'
FOLDER_NAME

In [None]:
masks_folder = os.path.join(dataloader.dataset.dataset_dir, FOLDER_NAME, VERSION)
os.makedirs(masks_folder, exist_ok=EXIST_OK)

### Remove already calculated

Remove already calculated from the dataset, to avoid loading innecessary images from disk

In [None]:
already_calculated = set(
    clean_image_fname(image_name)
    for image_name in os.listdir(masks_folder)
)
len(already_calculated)

In [None]:
if dataloader.dataset.__class__.__name__ in ('ChexpertDataset',):
    ignore_images = set([
        i.replace('-', '/').replace('.png', '.jpg')
        for i in os.listdir(masks_folder)
    ])
    d = dataloader.dataset.label_index
    d = d.loc[~d['Path'].isin(ignore_images)]
    dataloader.dataset.label_index = d.reset_index(drop=True)
    len(dataloader.dataset)

### Segment!

In [None]:
%%capture output
%%time

state = tqdm(total=len(dataloader.dataset))
errors = []

for batch in dataloader:
    if all(
        clean_image_fname(image_name) in already_calculated
        for image_name in batch.image_fname
    ):
        state.update(len(batch.image_fname))
        continue
    
    outputs = calculate_output(batch).cpu()
    
    for image_name, mask in zip(batch.image_fname, outputs):
        image_name = clean_image_fname(image_name)
        if image_name in already_calculated:
            state.update(1)
            continue
            # raise Exception('Overriding previous mask!')

        mask = mask.to(torch.uint8).numpy()
        # shape: height, width
        
        assertions(mask, image_name)
        
        # Keep only largest shape:
        polygons = calculate_polygons(mask)
        largest_polygons = get_largest_shapes(polygons, name=image_name)
        mask = polygons_to_array(largest_polygons, mask.shape)

        assertions(mask, image_name)
        
        if any(len(coords) == 0 for coords, organ_idx in largest_polygons):
            errors.append(image_name)

        # Save to file
        mask = Image.fromarray(mask, mode='L')
        out_fpath = get_filepath(masks_folder, image_name)
        mask.save(out_fpath)

        state.update(1)

In [None]:
print(output)

In [None]:
plt.title(image_name)
plt.imshow(mask)

### Load and plot images by name

Checkout errors in the process

In [None]:
%run ../utils/images.py

In [None]:
new_kwargs = {
    **kwargs,
    'masks': True,
    'masks_version': 'v2',
}
new_dataloader = prepare_data_classification(**new_kwargs)
len(new_dataloader.dataset)

In [None]:
# errors = [
#     '00000003_001.png', '00000013_005.png', '00000010_000.png', '00000007_000.png',
#     '00000005_003.png', '00000001_001.png',
# ]

In [None]:
# TODO: move this to dataset.get_item_by_name() ???
if dataloader.dataset.__class__.__name__ in ('ChexpertDataset',):
    KEY = 'Path'
    image_names = [
        f"{i.replace('-', '/')}.jpg"
        for i in errors
    ]
else:
    KEY = 'FileName'
    image_names = list(errors)
len(image_names)

In [None]:
df = new_dataloader.dataset.label_index
rows = df.loc[df[KEY].isin(image_names)]
indexes = list(rows.index)
indexes

In [None]:
n_rows = len(indexes)
n_cols = 2

plt.figure(figsize=(n_cols*3, n_rows*3))

for counter, index in enumerate(indexes):
    item = new_dataloader.dataset[index]
    
    row_index = counter * n_cols
    
    plt.subplot(n_rows, n_cols, row_index + 1)
    plt.title(item.image_fname)
    plt.imshow(tensor_to_range01(item.image[0]), cmap='gray')
    
    mask = squeeze_masks(item.masks)
    if mask is not None:
        plt.subplot(n_rows, n_cols, row_index + 2)
        plt.imshow(mask)

#     plt.subplot(1, 2, 2)
#     plt.gca().invert_yaxis()
#     for coords, value in largest_polygons2:
#         if len(coords) == 0:
#             continue
#         x_values, y_values = zip(*coords)
#         plt.plot(x_values, y_values)

# Debug load dataset

## CXR14, IU or alike datasets

In [None]:
%run ../utils/common.py
%run ../utils/images.py

In [None]:
%run ../datasets/iu_xray.py

dataset = IUXRayDataset('all', image_size=(1024, 1024), frontal_only=True, masks=True)
len(dataset)

In [None]:
%run ../datasets/cxr14.py

dataset = CXR14Dataset('all', image_size=(1024, 1024), frontal_only=True,
                       masks=True, masks_version='v1')
len(dataset)

In [None]:
item = dataset[200]
item.image.size(), item.masks.size(), item.image_fname

In [None]:
plt.figure(figsize=(15, 5))

plt.subplot(1, 2, 1)
plt.imshow(tensor_to_range01(item.image).permute(1, 2, 0))
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(squeeze_masks(item.masks))
plt.axis('off')

## VinBig alike datasets

In [None]:
import math

In [None]:
%run ../datasets/vinbig.py

dataset = VinBigDataset('train', image_size=(1024, 1024), frontal_only=True,
                        masks=True, fallback_organs=True)
len(dataset)

In [None]:
item = dataset[14]
item.masks.size(), item.labels.shape

In [None]:
n_plots = len(dataset.labels) + 1

n_rows = 3
n_cols = math.ceil(n_plots / n_rows)

plt.figure(figsize=(n_cols*5, n_rows*5))

plt.subplot(n_rows, n_cols, 1)
plt.imshow(tensor_to_range01(item.image).permute(1, 2, 0))
plt.title('CXR')
plt.axis('off')

for disease_i, (disease_name, gt_value) in enumerate(zip(dataset.labels, item.labels)):
    plt.subplot(n_rows, n_cols, disease_i+2)
    plt.imshow(item.masks[disease_i])
    plt.title(f'{disease_name} (gt={gt_value})')
    plt.axis('off')