In [2]:
# multi threaded 
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from scipy import ndimage
# config

overlap = 20
pre_patch_size = [96, 96, 96]
patch_size = [64, 64, 64]
# initialize 3d image 256*256*256, filled with random values
image = np.random.rand(256, 256, 256)
intensity_thres = 20
random_thres = 0.7


In [None]:
# augmentation functions

# input: channel (high res, low res), H, W, D
def translation(image):
    # randomly translates the 3D patch by a random value within limits of pre_patch_size and conforms to patch_size
    shift_x = np.random.randint(0, pre_patch_size[0] - patch_size[0] + 1)
    shift_y = np.random.randint(0, pre_patch_size[1] - patch_size[1] + 1)
    shift_z = np.random.randint(0, pre_patch_size[2] - patch_size[2] + 1)

    return image[shift_x:shift_x + patch_size[0], shift_y:shift_y + patch_size[1], shift_z:shift_z + patch_size[2]]

def rotation(image):
    # randomly rotates the image by 90 degrees or -90 degrees along x, y, or z axes
    chosen_axis = np.random.choice[(0,1), (0,2), (1,2)]
    angle = np.random.choice([90, -90, 0, 180])
    return ndimage.rotate(image, angle, axes=chosen_axis, reshape=False)

def flip(image):
    # randomly choose an axis to flip
    if np.random.rand() < 0.5:
        axis = np.random.choice([0, 1, 2])
        image = np.flip(image, axis=axis)
    return image

# downsampling image: maybe from 0.6 to 1


In [3]:
class Image:
    def __init__(self, image, augmentations=[]):
        """Initializes the Image object and converts the image to 8-bit."""
        self.high_res_patch = []
        self.low_res_patch = []
        self.image = image
        self.augmentations = []

        self.convert_to_8bit()
        self.get_patch()
        
    def convert_to_8bit(self):
        """Converts the image to 8-bit format (0-255)."""
        img_min = np.min(self.image)
        img_max = np.max(self.image)
        self.image = ((self.image - img_min) / (img_max - img_min) * 255).astype(np.uint8)

    def get_patch(self):
        """Extracts overlapping patches from the image while filtering based on intensity."""
        for i in range(0, self.image.shape[0], pre_patch_size[0] - overlap):
            for j in range(0, self.image.shape[1], pre_patch_size[1] - overlap):
                for k in range(0, self.image.shape[2], pre_patch_size[2] - overlap):
                    patch = self.image[i:i + pre_patch_size[0], j:j + pre_patch_size[1], k:k + pre_patch_size[2]]
                    avg_intensity = np.mean(patch)
                    if self.filter_patch(avg_intensity):
                        self.high_res_patch.append(patch)

                        # apply augmentations
                        # for aug in self.augmentations:
                        #     patch = aug(patch)
                        # downsampling after augmentation
                        # self.low_res_patch.append()

    def filter_patch(self, avg_intensity):
        """Filters patches based on intensity and a random threshold for low-intensity patches."""
        if avg_intensity > intensity_thres:
            return True
        else:
            return np.random.rand() > random_thres
        
    def save_patches(self, image):
        # include filename

    # def downsample(self, image):
    #     """Downsamples the image by 1/3 and upsamples it back to original size."""
    #     image = ndimage.zoom(image, 1/3)
    #     image = ndimage.zoom(image, 3)
    #     return image


In [None]:
# when dataloader 
# implement them ourselves: translation, rotation (90 degrees), flipping/mirroring, small brighteness/contrast

In [None]:
# dataset
# have augmentation inside dataset
# input: DCM => output: npy


class CustomImageDataset(Dataset):
    def __init__(self, image, augmentations):
        self.processed_image = Image(image, augmentations)
        self.high_res_patch = self.processed_image.high_res_patch
        self.low_res_patch = self.processed_image.low_res_patch

    def __len__(self):
        return len(self.high_res_patch)

    def __getitem__(self, idx):
        return self.high_res_patch[idx], self.low_res_patch[idx]

### Notes

- augmentation libraries: monai, kornia, volumnetations; https://github.com/kornia/kornia
