In [None]:
import sys
sys.path.append("../..")

In [None]:
import os
import json
import numpy as np
import torch
import torchvision
import random
from PIL import Image
import matplotlib.pyplot as plt

from tqdm import tqdm
import time

In [None]:
root = "./datasets/lidc/LIDC_IDRI/test"
scans = os.listdir(root)

def get_image(scan_name):
    global root
    path_to_image = os.path.join(root, scan_name, "image.npy")
    return np.load(path_to_image).astype(np.float32)

def view_image(image):
    plt.imshow(np.array(image), cmap="gray")
    plt.colorbar()
    plt.show()


In [None]:
from dinov2.data import transforms
from dinov2.data.augmentations import DataAugmentationDINO

class CfgCrops:
    def __init__(self):
        self.global_crops_scale = (0.32, 1.0)
        self.local_crops_number = 8
        self.local_crops_scale = (0.05, 0.32)
        self.global_crops_size = 224
        self.local_crops_size = 96
class CfgNorm:
    def __init__(self):
        self.mean = 0.124
        self.std = 0.121
class CfgAugments:
    def __init__(self):
        self.crops = CfgCrops()
        self.norm = CfgNorm()
        self.global_1 = ["rotation_0.8_90", "crop", "contrast_0.8_0.4", "brightness_0.8_0.4", "blur_1"]
        self.global_2 = ["crop", "contrast_0.8_0.4", "brightness_0.8_0.4", "solarize_0.2_0.5_1", "noise_0.5_0.02_1", "blur_0.1"]
        self.local = ["crop", "contrast_0.8_0.4", "brightness_0.8_0.4", "blur_0.5"]

augments = DataAugmentationDINO(CfgAugments())

In [None]:
loaded_image = get_image(scans[0])[0]
loaded_image.dtype

In [None]:
image = Image.fromarray(loaded_image, 'F')
view_image(image)

In [None]:
global_crop_1 = augments.global_transfo1(image)
global_crop_2 = augments.global_transfo2(image)
local_crops = [augments.local_transfo(image) for _ in range(augments.local_crops_number)]

In [None]:
view_image(global_crop_1[0])

In [None]:
view_image(global_crop_2[0])

In [None]:
view_image(local_crops[1][0])

In [None]:
class RandomGaussianBlur:
    def __init__(self, p):
        self.p = float(p)
        self.transform = torchvision.transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 2.0))
    
    def __call__(self, img):
        if random.random() < self.p:
            return self.transform(img)
        return img
blur = RandomGaussianBlur(1.0)


In [None]:
random.random()

In [None]:
local_crops[0][0].shape

In [None]:
blurred_image = blur(local_crops[0])
view_image(blurred_image[0])

In [None]:
torch.equal(blurred_image, local_crops[0][0])

In [None]:
local_crops[0]

In [None]:
blur(local_crops[0][0])