In [None]:
%load_ext autoreload
%autoreload 2

import torch
from tqdm import tqdm

# from byol_pretrain.utils import get_loaders_STL10
from torchvision.models import resnet50

import torchvision.transforms as T
from pixel_level_contrastive_learning import PixelCL

# Test with MaskRCNN and STL10 dataset

In [None]:
from karies.models import MaskRCNN
from karies.config import MaskRCNNConfig, ModelConfig, Task, ModelTypes
from byol_pretrain.BYOL_MaskRCNN import BYOL_MaskRCNN_Class, MaskRCNNModelWrapper

base_config: ModelConfig = {
    "name": "BYOL-TEST-DELETE",
    "task": Task.training,
    "optimizer": "adam",
    "learning_rate": 0.0001,
    "weight_decay": 0.0001,
    "batch_size": 8,
    "num_workers": 0,
    "device": "cuda",
    "path_model": "/cluster/group/karies_2022/Simone/karies/karies-models/AAA_BYOL_test/BYOL/saved_pretrained_models/maskrcnn-weekend-test-batch-6/",
    "load_model_name": "pretrain_80_epochs.pth",
    "num_epochs": 100,
    "image_shape": [768, 768],
    "dataset": "dataset_4k",
    "labels_json": "labels_caries.json",
    "histogram_eq": False,
    "visualization_frequency": 1,
    "augmentations": [],
    'fix_random_seed':42
}
config: MaskRCNNConfig = {
    **base_config,
    "model_type": ModelTypes.MaskRCNN,
    "classes": 5,
    "iou_threshold": 0.1,
    "confidence_threshold": 0.1,
    "model_args": {},
    "loss_weights": [1.0, 1.0, 1.0, 1.0, 1.0],
}

m = MaskRCNN(config, load=False)

wrap = MaskRCNNModelWrapper(
    m.model.transform,
    m.model.backbone,
)

In [None]:
l,_ = m.get_data_loaders()

In [None]:
len(l.dataset)

In [None]:
learner = PixelCL(
    wrap,
    image_size = (768, 1024),
    hidden_layer_pixel = 'backbone.body.layer4',  # leads to output of 8x8 feature map for pixel-level learning
    hidden_layer_instance = -1,     # leads to output for instance-level learning
    projection_size = 256,          # size of projection output, 256 was used in the paper
    projection_hidden_size = 2048,  # size of projection hidden dimension, paper used 2048
    moving_average_decay = 0.99,    # exponential moving average decay of target encoder
    ppm_num_layers = 1,             # number of layers for transform function in the pixel propagation module, 1 was optimal
    ppm_gamma = 2,                  # sharpness of the similarity in the pixel propagation module, already at optimal value of 2
    distance_thres = 0.7,           # ideal value is 0.7, as indicated in the paper, which makes the assumption of each feature map's pixel diagonal distance to be 1 (still unclear)
    similarity_temperature = 0.3,   # temperature for the cosine similarity for the pixel contrastive loss
    alpha = 1.,                      # weight of the pixel propagation loss (pixpro) vs pixel CL loss
    use_pixpro = True,               # do pixel pro instead of pixel contrast loss, defaults to pixpro, since it is the best one
    cutout_ratio_range = (0.6, 0.8)  # a random ratio is selected from this range for the random cutout
).cuda()

In [None]:
from torch.cuda.amp import GradScaler

scaler = GradScaler()
opt = torch.optim.Adam(learner.parameters(), lr=1e-4)

torch.cuda.empty_cache()

def sample_batch_images():
    return torch.randn(2, 3, 768, 768).cuda()

for _ in tqdm(range(100000)):
    images = sample_batch_images()
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        loss = learner(images) # if positive pixel pairs is equal to zero, the loss is equal to the instance level loss

    opt.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(opt)
    scaler.update()
    learner.update_moving_average() # update moving average of target encoder


In [None]:
m.model.state_dict()