In [None]:
from cnn.input import (
    Dataloader,
    Dataset,
    get_list_of_patients,
    get_split_deterministic,
    get_training_augmentation,
    get_validation_augmentation,
)

import matplotlib.pyplot as plt
import random

In [None]:
data_path = "spleen_dataset/data/Task09_Spleen_preprocessed/"

In [None]:
patients = get_list_of_patients(data_path)
print(patients)

In [None]:
train_patients, val_patients = get_split_deterministic(
    patients, fold=0, num_splits=5, random_state=12345
)
print(train_patients)
print(val_patients)

In [None]:
patch_size = (128, 128, 1)
batch_size = 32

In [None]:
train_dataset = Dataset(
    data_path=data_path, patients=train_patients, only_non_empty_slices=True
)
train_augmentation = get_training_augmentation(patch_size)
train_dataloader = Dataloader(
    dataset=train_dataset,
    batch_size=batch_size,
    skip_slices=0,
    augmentation=train_augmentation,
)
print(len(train_dataloader))

In [None]:
val_dataset = Dataset(
    data_path=data_path, patients=val_patients, only_non_empty_slices=True
)
val_augmentation = get_validation_augmentation(patch_size)
val_dataloader = Dataloader(
    dataset=val_dataset,
    batch_size=batch_size,
    skip_slices=0,
    augmentation=val_augmentation,
)
print(len(val_dataloader))

In [None]:
id = random.randint(0, len(train_dataset) - 1)
image, label = train_dataset[id]

print(image.shape)
print(label.shape)

fig = plt.figure(figsize=(15, 15))

ax1 = fig.add_subplot(2, 2, 1)
ax1.imshow(image)
ax2 = fig.add_subplot(2, 2, 2)
ax2.imshow(label)

In [None]:
id = random.randint(0, len(train_dataloader) - 1)
images, labels = train_dataloader[id]

image = images[0]
label = labels[0]

fig = plt.figure(figsize=(15, 15))

ax1 = fig.add_subplot(2, 2, 1)
ax1.imshow(image)
ax2 = fig.add_subplot(2, 2, 2)
ax2.imshow(label)

In [None]:
data_path = "prostate_dataset/data/Task05_Prostate_preprocessed/"

In [None]:
patients = get_list_of_patients(data_path)
print(patients)

In [None]:
patch_size = (128, 128, 2)
batch_size = 32

In [None]:
train_patients, val_patients = get_split_deterministic(
    patients, fold=0, num_splits=5, random_state=12345
)
print(train_patients)
print(val_patients)

In [None]:
train_dataset = Dataset(
    data_path=data_path, patients=train_patients, only_non_empty_slices=True
)
train_augmentation = get_training_augmentation(patch_size)
train_dataloader = Dataloader(
    dataset=train_dataset,
    batch_size=batch_size,
    skip_slices=0,
    augmentation=train_augmentation,
)
print(len(train_dataloader))

In [None]:
val_dataset = Dataset(
    data_path=data_path, patients=val_patients, only_non_empty_slices=True
)
val_augmentation = get_validation_augmentation(patch_size)
val_dataloader = Dataloader(
    dataset=val_dataset,
    batch_size=batch_size,
    skip_slices=0,
    augmentation=val_augmentation,
)
print(len(val_dataloader))

In [None]:
id = random.randint(0, len(train_dataset) - 1)
image, label = train_dataset[id]

fig = plt.figure(figsize=(15, 15))

ax1 = fig.add_subplot(1, 3, 1)
ax1.imshow(image[..., 0])
ax2 = fig.add_subplot(1, 3, 2)
ax2.imshow(image[..., 1])
ax3 = fig.add_subplot(1, 3, 3)
ax3.imshow(label)

In [None]:
id = random.randint(0, len(train_dataloader) - 1)
images, labels = train_dataloader[id]

image = images[0]
label = labels[0]

fig = plt.figure(figsize=(15, 15))

ax1 = fig.add_subplot(1, 3, 1)
ax1.imshow(image[..., 0])
ax2 = fig.add_subplot(1, 3, 2)
ax2.imshow(image[..., 1])
ax3 = fig.add_subplot(1, 3, 3)
ax3.imshow(label)

In [None]:
batch_id = random.randint(0, len(val_dataloader) - 1)
images, labels = val_dataloader[batch_id]

sample_id = random.randint(0, len(images) - 1)
image = images[sample_id]
label = labels[sample_id]

fig = plt.figure(figsize=(15, 15))

ax1 = fig.add_subplot(1, 3, 1)
ax1.imshow(image[..., 0])
ax2 = fig.add_subplot(1, 3, 2)
ax2.imshow(image[..., 1])
ax3 = fig.add_subplot(1, 3, 3)
ax3.imshow(label)