# Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=1

In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [None]:
import pandas as pd
pd.options.display.max_columns = None

# Load model

In [None]:
%run ../models/checkpoint/__init__.py

In [None]:
run_name = '1215_174443_cxr14_resnet-50-v2_lr0.0001_os_Cardiomegaly_normS_size256_sch-roc_auc-p5-f0.1'
debug_run = False

In [None]:
compiled_model = load_compiled_model_classification(run_name, debug=debug_run)
compiled_model.metadata['model_kwargs']

# Load data

In [None]:
%run ../datasets/__init__.py

In [None]:
dataset_kwargs = {
    'dataset_name': 'cxr14',
    'dataset_type': 'all',
    'max_samples': None,
}
dataloader = prepare_data_classification(**dataset_kwargs)
dataset = dataloader.dataset
len(dataset)

# Apply CLAHE normalization to images

In [None]:
import os
import numpy as np
from skimage.exposure import equalize_adapthist
from PIL import Image

## Apply to one sample

In [None]:
idx = 30

In [None]:
image_fname = dataset.label_index['FileName'][idx]
fpath = os.path.join(dataset.image_dir, image_fname)

In [None]:
%%time

image = Image.open(fpath)
image_np = np.array(image)
image_2 = equalize_adapthist(image_np)
image_3 = (image_2 * 255).astype(np.uint8)

# Try with RGB # result --> is the same!!
# image = Image.open(fpath).convert('RGB')
# image_4 = equalize_adapthist(np.array(image))
# image_4 = (image_4 * 255).astype(np.uint8)

image_np.shape, image_2.shape, image_3.shape # , image_4.shape

In [None]:
image_2.dtype

In [None]:
image_2.min(), image_2.max()

In [None]:
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.title('Original image')
plt.imshow(image_np, cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.title('CLAHE (float64)')
plt.imshow(image_2, cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title('CLAHE (uint8)')
plt.imshow(image_3, cmap='gray')
plt.axis('off')

In [None]:
image_np.min(), image_np.max(), image_2.min(), image_2.max(), image_3.min(), image_3.max()

In [None]:
image_np.dtype, image_2.dtype, image_3.dtype

## Apply to the whole dataset

In [None]:
from tqdm.auto import tqdm
from collections import defaultdict

In [None]:
image_fnames = list(dataset.label_index['FileName'])
len(image_fnames)

In [None]:
%%time

clahe_folder = dataset.image_dir + '-clahe'
wrong_images = defaultdict(list)

for image_fname in tqdm(image_fnames):
    fpath = os.path.join(dataset.image_dir, image_fname)

    image = Image.open(fpath).convert('L')
    image = np.array(image)
    if image.ndim != 2:
        wrong_images['n-dim-not-2'].append(image_fname)
        continue
    
    image = equalize_adapthist(image)
    image = (image * 255).astype(np.uint8)
    
    new_fpath = os.path.join(clahe_folder, image_fname)
    
    if os.path.isfile(new_fpath):
        raise Exception(f'Overriding previous file at {new_fpath}')

    image = Image.fromarray(image, mode='L')
    image.save(new_fpath)

In [None]:
image.dtype, image_2.dtype

In [None]:
image.nbytes == image_2.nbytes, image.nbytes, image_2.nbytes

In [None]:
image = Image.open(fpath)
image = np.array(image)
image.dtype

In [None]:
image.nbytes

In [None]:
new_fpath

In [None]:
image.mode

In [None]:
plt.imshow(image)