In [1]:
%load_ext autoreload
%autoreload 2

import torch
from tqdm import tqdm
import torchvision.transforms as T
from pixel_level_contrastive_learning import PixelCL
from karies.models import MaskRCNN
from pixpro_utils import MaskRCNNModelWrapper
import sys

sys.path.append("/cluster/group/karies_2022/Simone/karies/karies-models/")
from configs.MaskRCNN.default_model_setting_config import model_config as config

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config["name"] = "PIXPRO-TEST-DELETE"
config["batch_size"] = 2
config["device"] = "cuda"
config["dataset"] = "dataset_test"
config["augmentations"] = []
config["fix_random_seed"] = 42
config["image_shape"] = [768, 1024]

m = MaskRCNN(config, load=False)

wrap = MaskRCNNModelWrapper(
    m.model.transform,
    m.model.backbone,
)

In [5]:
from pixpro_utils import PretrainingDataset
from karies.config import DatasetPartition

dl = PretrainingDataset(
    config,
    DatasetPartition.train,
    crop_size=112
).get_data_loaders()

x = next(iter(dl))

In [None]:
learner = PixelCL(
    wrap,
    image_size = (500, 500),
    hidden_layer_pixel = 'backbone.body.layer4',  # leads to output of 8x8 feature map for pixel-level learning
    hidden_layer_instance = -1,     # leads to output for instance-level learning
    projection_size = 256,          # size of projection output, 256 was used in the paper
    projection_hidden_size = 2048,  # size of projection hidden dimension, paper used 2048
    moving_average_decay = 0.99,    # exponential moving average decay of target encoder
    ppm_num_layers = 1,             # number of layers for transform function in the pixel propagation module, 1 was optimal
    ppm_gamma = 2,                  # sharpness of the similarity in the pixel propagation module, already at optimal value of 2
    distance_thres = 0.7,           # ideal value is 0.7, as indicated in the paper, which makes the assumption of each feature map's pixel diagonal distance to be 1 (still unclear)
    similarity_temperature = 0.3,   # temperature for the cosine similarity for the pixel contrastive loss
    alpha = 1.,                      # weight of the pixel propagation loss (pixpro) vs pixel CL loss
    use_pixpro = True,               # do pixel pro instead of pixel contrast loss, defaults to pixpro, since it is the best one
    cutout_ratio_range = (0.6, 0.8)  # a random ratio is selected from this range for the random cutout
).cuda()

In [None]:
from torch.cuda.amp import GradScaler

mixed_precision = True

scaler = GradScaler(enabled=mixed_precision)
opt = torch.optim.Adam(learner.parameters(), lr=1e-4)

torch.cuda.empty_cache()

def sample_batch_images():
    return torch.randn(2, 3, 500, 500).cuda()

for _ in tqdm(range(50)):
    images = sample_batch_images()
    with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=mixed_precision):
        loss = learner(images) # if positive pixel pairs is equal to zero, the loss is equal to the instance level loss

    opt.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(opt)
    scaler.update()
    learner.update_moving_average() # update moving average of target encoder
