## Imports

In [None]:
import torch
import os
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms

In [None]:
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [None]:
%run ../jsrt.py
%run ../../utils/images.py

## Load sample image

In [None]:
from torchvision import transforms

In [None]:
images_dir = os.path.join(DATASET_DIR, 'images')
image_fpath = os.path.join(images_dir, 'JPCLN026.png')

In [None]:
tf = transforms.ToTensor()

In [None]:
image = Image.open(image_fpath).convert('L')
image = tf(image)
image.size()

In [None]:
image.mean(), image.std()

In [None]:
plt.imshow(image[0], cmap='gray')

## Split

In [None]:
import os
import random

In [None]:
%run ../jsrt.py
%run ../../utils/__init__.py

In [None]:
images_dir = os.path.join(DATASET_DIR, 'images')

metadata_fpath = os.path.join(DATASET_DIR, 'jsrt_metadata.csv')
metadata = pd.read_csv(metadata_fpath)
images = list(metadata['study_id'])
len(images), len(set(images))

In [None]:
fold1 = [
    name
    for name in images
    if int(name[-5]) % 2 != 0
]
fold2 = [
    name for name in images
    if name not in train
]
len(fold1), len(fold2)

In [None]:
random.shuffle(fold2)

In [None]:
n_val = len(fold2) // 2

train = fold1
val = fold2[:n_val]
test = fold2[n_val:]
len(train), len(val), len(test)

In [None]:
for split, images in zip(['train', 'val', 'test'], [train, val, test]):
    filepath = os.path.join(DATASET_DIR, 'splits', split)
    filepath = f'{filepath}.txt'
    write_list_to_txt(images, filepath)

## Test Dataset class

In [None]:
%run ../jsrt.py
%run ../../utils/__init__.py

In [None]:
dataset = JSRTDataset('train', image_size=(1024, 1024))
len(dataset)

In [None]:
item = dataset[0]

In [None]:
item.image.size()

In [None]:
item.masks.size()

In [None]:
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.imshow(tensor_to_range01(item.image[0]), cmap='gray')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(item.masks)
plt.axis('off')

In [None]:
item.masks

In [None]:
isinstance(dataset.transform.transforms[0], transforms.Resize)

In [None]:
isinstance(dataset.transform.transforms[1], transforms.ToTensor)

In [None]:
dataset.transform.transforms

## Debug augmentation

### Check pil-tensor bug

In [None]:
%run ../tools/augmentation.py

In [None]:
dataset = JSRTDataset('train', image_size=(1024, 1024))
len(dataset)

In [None]:
aug = Augmentator(dataset, dont_shuffle=True)
len(aug)

In [None]:
plot_augmented_samples(aug, 0, masks=False)

### Debug augmenting seg masks

In [None]:
%run ../common.py
%run ../../utils/__init__.py

In [None]:
%run ../tools/augmentation.py

In [None]:
dataset = JSRTDataset('train', image_size=(1024, 1024))
len(dataset)

In [None]:
aug = Augmentator(dataset, dont_shuffle=True, seg_mask=True)
len(aug)

In [None]:
plot_augmented_samples(aug, 10)

In [None]:
masks1 = aug[10].masks
masks2 = aug[11].masks
print(masks1.min(), masks1.max(), masks1.type(), masks1.size())
print(masks2.min(), masks2.max(), masks2.type(), masks2.size())

## Compute mean and std

In [None]:
import os
import pandas as pd

In [None]:
%run ../jsrt.py
%run ../../utils/images.py

In [None]:
images_dir = os.path.join(DATASET_DIR, 'images')

# metadata_fpath = os.path.join(DATASET_DIR, 'jsrt_metadata.csv')
# metadata = pd.read_csv(metadata_fpath)

# image_names = metadata['study_id']

In [None]:
fpath = os.path.join(DATASET_DIR, 'splits', 'train.txt')
with open(fpath) as f:
    image_names = [l.strip() for l in f.readlines()]
len(image_names)

In [None]:
iterator = ImageFolderIterator(images_dir, image_names, image_format='L')
mean, std = compute_mean_std(iterator, show=True, n_channels=1)
mean, std

## Check image sizes

In [None]:
import matplotlib.pyplot as plt

In [None]:
# idx = 2
# image_name = dataset.images_names[idx]

sizes = set()

for image_name in dataset.images_names:
    image_fpath = os.path.join(dataset.images_dir, image_name)

    image = Image.open(image_fpath).convert('RGB')
    sizes.add(image.size)
    
sizes

In [None]:
plt.imshow(image)

All images are 2048 x 2048

## Compute class balance

In [None]:
from collections import Counter
from tqdm.notebook import tqdm

In [None]:
images_dir = os.path.join(DATASET_DIR, 'images')

image_names = os.listdir(images_dir)
len(image_names)

In [None]:
dataset = JSRTDataset('all')
len(dataset)

In [None]:
n_labels = len(dataset.seg_labels)

In [None]:
counts = Counter()

for image_name in tqdm(image_names):
    masks = dataset.get_masks(image_name)
    
    counts += Counter(t.item() for t in masks.view(-1))
    
counts

In [None]:
height, width = masks.size()
total = height * width * len(image_names)

for i_label in range(n_labels):
    amount = counts[i_label] / total * 100
    print(f'Label: {i_label}, {amount:.1f}%')