In [None]:
%load_ext autoreload
%autoreload 2

import torch
from tqdm import tqdm

import torchvision.transforms as T
from pixel_level_contrastive_learning import PixelCL

In [None]:
from karies.models import MaskRCNN
from karies.config import MaskRCNNConfig, ModelConfig, Task, ModelTypes
from byol_pretrain.BYOL_MaskRCNN import BYOL_MaskRCNN_Class, MaskRCNNModelWrapper

base_config: ModelConfig = {
    "name": "BYOL-TEST-DELETE",
    "task": Task.training,
    "optimizer": "adam",
    "learning_rate": 0.0001,
    "weight_decay": 0.0001,
    "batch_size": 8,
    "num_workers": 0,
    "device": "cuda",
    "path_model": "/cluster/group/karies_2022/Simone/karies/karies-models/AAA_BYOL_test/BYOL/saved_pretrained_models/maskrcnn-weekend-test-batch-6/",
    "load_model_name": "pretrain_80_epochs.pth",
    "num_epochs": 100,
    "image_shape": [500, 500],
    "dataset": "dataset_4k",
    "labels_json": "labels_caries.json",
    "histogram_eq": False,
    "visualization_frequency": 1,
    "augmentations": [],
    'fix_random_seed':42
}
config: MaskRCNNConfig = {
    **base_config,
    "model_type": ModelTypes.MaskRCNN,
    "classes": 5,
    "iou_threshold": 0.1,
    "confidence_threshold": 0.1,
    "model_args": {},
    "loss_weights": [1.0, 1.0, 1.0, 1.0, 1.0],
}

m = MaskRCNN(config, load=False)

wrap = MaskRCNNModelWrapper(
    m.model.transform,
    m.model.backbone,
)

In [None]:
learner = PixelCL(
    wrap,
    image_size = (500, 500),
    hidden_layer_pixel = 'backbone.body.layer4',  # leads to output of 8x8 feature map for pixel-level learning
    hidden_layer_instance = -1,     # leads to output for instance-level learning
    projection_size = 256,          # size of projection output, 256 was used in the paper
    projection_hidden_size = 2048,  # size of projection hidden dimension, paper used 2048
    moving_average_decay = 0.99,    # exponential moving average decay of target encoder
    ppm_num_layers = 1,             # number of layers for transform function in the pixel propagation module, 1 was optimal
    ppm_gamma = 2,                  # sharpness of the similarity in the pixel propagation module, already at optimal value of 2
    distance_thres = 0.7,           # ideal value is 0.7, as indicated in the paper, which makes the assumption of each feature map's pixel diagonal distance to be 1 (still unclear)
    similarity_temperature = 0.3,   # temperature for the cosine similarity for the pixel contrastive loss
    alpha = 1.,                      # weight of the pixel propagation loss (pixpro) vs pixel CL loss
    use_pixpro = True,               # do pixel pro instead of pixel contrast loss, defaults to pixpro, since it is the best one
    cutout_ratio_range = (0.6, 0.8)  # a random ratio is selected from this range for the random cutout
).cuda()

In [None]:
from torch.cuda.amp import GradScaler

mixed_precision = True

scaler = GradScaler(enabled=mixed_precision)
opt = torch.optim.Adam(learner.parameters(), lr=1e-4)

torch.cuda.empty_cache()

def sample_batch_images():
    return torch.randn(2, 3, 500, 500).cuda()

for _ in tqdm(range(50)):
    images = sample_batch_images()
    with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=mixed_precision):
        loss = learner(images) # if positive pixel pairs is equal to zero, the loss is equal to the instance level loss

    opt.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(opt)
    scaler.update()
    learner.update_moving_average() # update moving average of target encoder


In [None]:
import math
import torch
import torch.nn.functional as F

def default(val, def_val):
    return def_val if val is None else val

def cutout_and_resize(image, coordinates, output_size = None, mode = 'nearest'):
    shape = image.shape
    output_size = default(output_size, shape[2:])
    (y0, y1), (x0, x1) = coordinates
    cutout_image = image[:, :, y0:y1, x0:x1]
    return F.interpolate(cutout_image, size = output_size, mode = mode)

image_h, image_w = 50, 50
proj_image_h, proj_image_w = 25, 25

# store all the x and y coordinates for the original pictures
coordinates = torch.meshgrid(
            torch.arange(image_h),
            torch.arange(image_w)
        )
coordinates = torch.stack(coordinates).unsqueeze(0).float()
# normalize them so they represent the relative positions of the pixels inside the image
coordinates /= math.sqrt(image_h ** 2 + image_w ** 2)

# get where the relative positions of the pixels of the image would be in the augmentations
coordinates[:, 0] *= proj_image_h
coordinates[:, 1] *= proj_image_w

# cut the area of interest of the original images, since they have been cut in a smaller window
# proj_coors_one = cutout_and_resize(coordinates, cutout_coordinates_one, output_size = proj_image_shape, mode = self.coord_cutout_interpolate_mode)
# proj_coors_two = cutout_and_resize(coordinates, cutout_coordinates_two, output_size = proj_image_shape, mode = self.coord_cutout_interpolate_mode)

coordinates[0, 1]

In [None]:
import torch
import torch.nn as nn
# Initialize some example data
proj_coors_one = torch.randn((3, 2))
proj_coors_two = torch.randn((3, 2))
# Show initial shapes
print("Initial shapes:")
print("proj_coors_one: ", proj_coors_one.shape)
print("proj_coors_two: ", proj_coors_two.shape)
# PairwiseDistance object
pdist = nn.PairwiseDistance(p=2)
num_pixels = proj_coors_one.shape[0]
# Expand and reshape
proj_coors_one_expanded = proj_coors_one[:, None].expand(num_pixels, num_pixels, -1).reshape(num_pixels * num_pixels, 2)
proj_coors_two_expanded = proj_coors_two[None, :].expand(num_pixels, num_pixels, -1).reshape(num_pixels * num_pixels, 2)
# Show expanded/reshaped shapes
print("\nTransformed shapes:")
print("proj_coors_one_expanded: ", proj_coors_one_expanded.shape)
print("proj_coors_two_expanded: ", proj_coors_two_expanded.shape)
# Calculate distances
distance_matrix = pdist(proj_coors_one_expanded, proj_coors_two_expanded)
# Show output shape
print("\nOutput shape:")
print("distance_matrix: ", distance_matrix.shape)


In [None]:
distance_matrix