In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
import re
import shutil
import json
import h5py
import time
import numpy  as np
import nibabel as nib
import SimpleITK as sitk
import matplotlib.pyplot as plt
from scipy.stats import norm
import midasmednet.dataset
from torch.utils.data import Dataset, DataLoader
from midasmednet.landmarks import LandmarkTrainer
from midasmednet.utils.config import Config
from batchgenerators.augmentations.crop_and_pad_augmentations import crop

If you use TorchIO for your research, please cite the following paper:

Pérez-García et al., TorchIO: a Python library for efficient loading,
preprocessing, augmentation and patch-based sampling of medical images
in deep learning
(https://arxiv.org/abs/2003.04696)


In [None]:
class FaceLandmarksDataset(Dataset):

    def __init__(self, csv_file, root_dir, transform=None):
        dim = len(img.shape)
        cropped_patch_size = patch_size - 2*patch_overlap
        n_patches = np.ceil(img_size/cropped_patch_size).astype(int)
        overhead = cropped_patch_size - img_size % cropped_patch_size
        padded_img = np.pad(img, [[patch_overlap, patch_overlap + overhead[0]],
                                  [patch_overlap, patch_overlap + overhead[1]],
                                  [patch_overlap, patch_overlap + overhead[2]]])

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
       
        return sample

In [465]:
def patch_generator(img, patch_size, patch_overlap):
    dim = len(img.shape)
    cropped_patch_size = patch_size - 2*patch_overlap
    n_patches = np.ceil(img_size/cropped_patch_size).astype(int)
    overhead = cropped_patch_size - img_size % cropped_patch_size
    padded_img = np.pad(img, [[patch_overlap, patch_overlap + overhead[0]],
                              [patch_overlap, patch_overlap + overhead[1]],
                              [patch_overlap, patch_overlap + overhead[2]]])
    pos = [np.arange(0, n_patches[k])*cropped_patch_size[k] for k in range(dim)]
    count = -1
    for p0 in pos[0]:
        for p1 in pos[1]:
            for p2 in pos[2]:
                idx = [p0, p1, p2]
                idx_end = idx + patch_size
                count += 1
                yield padded_img[idx[0]:idx_end[0], idx[1]:idx_end[1], idx[2]:idx_end[2]], count

In [466]:
patch_overlap = 1
for p in patch_generator(img, patch_size, patch_overlap):
    print(p)

(array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 1, ..., 2, 4, 3],
        [0, 1, 0, ..., 3, 1, 2],
        ...,
        [0, 3, 4, ..., 2, 0, 3],
        [0, 1, 2, ..., 3, 4, 1],
        [0, 3, 2, ..., 2, 3, 4]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 1, ..., 4, 1, 1],
        [0, 1, 3, ..., 2, 0, 2],
        ...,
        [0, 4, 3, ..., 0, 3, 0],
        [0, 1, 4, ..., 2, 0, 4],
        [0, 0, 0, ..., 0, 0, 3]],

       ...,

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 1, 2, ..., 4, 3, 3],
        [0, 3, 0, ..., 3, 2, 3],
        ...,
        [0, 1, 4, ..., 1, 3, 3],
        [0, 4, 1, ..., 0, 2, 3],
        [0, 1, 0, ..., 1, 0, 2]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 1, 2, ..., 1, 0, 0],
        [0, 0, 1, ..., 1, 1, 1],
        ...,
        [0, 1, 3, ...,

In [464]:
t = time.perf_counter()
patches = [p for p,_ in patch_generator(img, patch_size, 1)]
print(len(patches))
print(time.perf_counter() - t)

cropped_patch_size = patch_size-2*overlap
n = np.ceil(img_size / cropped_patch_size).astype(int)
overhead = cropped_patch_size - img_size % cropped_patch_size
cropped_patches = [p[overlap:-overlap, overlap:-overlap, overlap:-overlap] for p in patches]
nested_patches = [[[cropped_patches[j*n[1]*n[2]+k*n[2]+l] for l in range(n[2])] for k in range(n[1])] for j in range(n[0])]  
out = np.block(nested_patches)[:img_size[0], :img_size[1], :img_size[2]]

print(time.perf_counter() - t)
np.all(out==img)

3
[90 90 90] 1
[52 52 64]
48
0.4396502749295905
0.8213649479439482


True

In [364]:
np.block(cropped_patches)

array([[1, 1, 2, 0, 0, 4, 0, 3, 1, 2, 4, 4, 2, 4, 4, 0, 3, 2, 4, 4, 4, 0,
        4, 3, 3, 1, 0, 1, 3, 0, 1, 3, 0, 0, 4, 0, 1, 2, 1, 0, 3, 2, 0, 1,
        2, 1, 1, 2, 1, 1, 2, 1, 4, 0, 4, 1, 0, 3, 0, 0, 0, 0, 4, 4, 1, 0,
        4, 1, 0, 0, 2, 3],
       [4, 0, 1, 3, 1, 0, 4, 0, 0, 1, 1, 3, 1, 1, 3, 0, 1, 3, 2, 2, 1, 1,
        4, 0, 1, 2, 2, 1, 1, 3, 1, 1, 3, 4, 0, 1, 4, 1, 0, 3, 0, 0, 0, 0,
        4, 4, 1, 0, 4, 1, 0, 0, 2, 3, 0, 3, 4, 0, 1, 1, 3, 2, 4, 0, 0, 2,
        0, 0, 2, 0, 2, 1],
       [1, 1, 2, 4, 3, 4, 2, 4, 2, 4, 2, 4, 4, 2, 4, 0, 1, 0, 0, 4, 4, 2,
        3, 4, 0, 1, 1, 0, 2, 4, 0, 2, 4, 4, 0, 0, 0, 3, 4, 0, 1, 1, 3, 2,
        4, 0, 0, 2, 0, 0, 2, 0, 2, 1, 2, 2, 2, 2, 3, 2, 3, 1, 1, 1, 1, 4,
        1, 1, 4, 0, 0, 0],
       [4, 1, 4, 1, 0, 1, 1, 4, 0, 0, 2, 0, 0, 2, 0, 3, 0, 2, 4, 3, 0, 2,
        2, 1, 2, 2, 2, 3, 1, 1, 3, 1, 1, 1, 0, 3, 2, 2, 2, 2, 3, 2, 3, 1,
        1, 1, 1, 4, 1, 1, 4, 0, 0, 0, 3, 3, 0, 0, 0, 1, 1, 4, 3, 1, 1, 0,
        1, 1, 0, 0, 0, 2],
    

In [385]:
len(img.shape)

3