In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("/cluster/group/karies_2022/Simone/karies/karies-models/AAA_BYOL_test/BYOL/byol_pretrain")

from pathlib import Path
import torch
from byol_pretrain.BYOL import BYOL_Class
from byol_pretrain.toy_model import Model
from byol_pretrain.utils import get_loaders_CIFAR10, get_loaders_STL10

# Toy model

In [None]:
model = Model(3).to("cuda")

batch_size = 50
pretrain_epochs = 100
test = 100 # else None to use all the data
log_dir = "./logs/proj_test/cifar10/DELETE/"
device="cuda"
lin_evaluation_frequency=20


loader, val_loader = get_loaders_CIFAR10(batch_size, test=test)
img_size = (32, 32)

# loader, val_loader = get_loaders_STL10(batch_size, test=test)
# img_size = (96, 96)

byol = BYOL_Class(
    model,
    loader,
    val_loader,
    pretrain_epochs,
    log_dir,
    input_dims=3,
    img_dims=img_size,
    hidden_features=4096,
    device=device,
    lin_evaluation_frequency=lin_evaluation_frequency,
)

byol.pretrain()

## MaskRCNN

In [None]:
from karies.models import MaskRCNN
from karies.config import MaskRCNNConfig, ModelConfig, Task, ModelTypes
from byol_pretrain.BYOL_MaskRCNN import BYOL_MaskRCNN_Class, MaskRCNNModelWrapper

base_config: ModelConfig = {
    "name": "BYOL-TEST-DELETE",
    "task": Task.training,
    "optimizer": "adam",
    "learning_rate": 0.0001,
    "weight_decay": 0.0001,
    "batch_size": 8,
    "num_workers": 0,
    "device": "cuda",
    "path_model": "/cluster/group/karies_2022/Simone/karies/karies-models/AAA_BYOL_test/BYOL/saved_pretrained_models/maskrcnn-weekend-test-batch-6/",
    "load_model_name": "pretrain_80_epochs.pth",
    "num_epochs": 100,
    "image_shape": [768, 1024],
    "dataset": "dataset_test",
    "labels_json": "labels_caries.json",
    "histogram_eq": False,
    "visualization_frequency": 1,
    "augmentations": [],
    'fix_random_seed':42
}

config: MaskRCNNConfig = {
    **base_config,
    "model_type": ModelTypes.MaskRCNN,
    "classes": 5,
    "iou_threshold": 0.1,
    "confidence_threshold": 0.1,
    "model_args": {},
    "loss_weights": [1.0, 1.0, 1.0, 1.0, 1.0],
}

maskrcnn = MaskRCNN(config, load=False)
wrap = MaskRCNNModelWrapper(
    maskrcnn.model.transform,
    maskrcnn.model.backbone,
)

loader, val_loader = maskrcnn.get_data_loaders()

In [None]:
pretrain_epochs = 100
log_dir = Path("./logs/LOAD-DELETE")
device="cuda"
lin_evaluation_frequency=None

byol = BYOL_MaskRCNN_Class(
    wrap,
    loader,
    val_loader,
    pretrain_epochs,
    log_dir,
    input_dims=3,
    img_dims=(700, 700),
    hidden_features=2048,
    device=device,
    lin_evaluation_frequency=lin_evaluation_frequency,
    mixed_precision_enabled=True,

    original_maskrcnn=maskrcnn,
    save_weights_every=5
)


byol.pretrain()

# UNet

In [None]:
from karies.models import U_Net_Model

from karies import ModelConfig
from karies.config import Augmentation, Task, UNetConfig
from karies.config.config_class import ModelTypes

from byol_pretrain.BYOL_UNet import UNetModelWrapper, BYOL_UNet_Class

base_config: ModelConfig = {
    "name": "BYOL-TEST-DELETE",
    "learning_rate": 0.0001,
    "batch_size": 4,
    "num_epochs": 80,
    "histogram_eq": True,
    "path_model": "",
    "task": Task.training,
    "optimizer": "adam",
    "weight_decay": 0.0001,
    "num_workers": 0,
    "device": "cuda",
    "image_shape": [768, 1024],
    "dataset": "dataset_4k",
    "labels_json": "labels_caries.json",
    "visualization_frequency": 1,
    "fix_random_seed": 42,
    "augmentations": [
        {"augmentation": Augmentation.gaussian_blur, "values": [9, 0.9, 1.5]},
        {"augmentation": Augmentation.rotation, "values": (-10, 10)},
        {"augmentation": Augmentation.horizontal_flip},
        {"augmentation": Augmentation.nothing},
        {"augmentation": Augmentation.elastic_transform, "values": [100.0, 10.0]},
        {"augmentation": Augmentation.random_affine, "values": [(-5, 5), (0.9, 1.1), (0, 10)]},
    ],
}

model_config: UNetConfig = {
    **base_config,
    "model_args": {"in_channels": 1, "classes": 1, "decoder_attention_type": None},
    "loss_func": "iou",  # "dice_ce", "gen_dice" or "iou"
    "unet_type": ModelTypes.U_Net, 
}
model_config["model_type"] = model_config["unet_type"]

device="cuda"

unet = U_Net_Model(model_config, load=False)

unet_wrap = UNetModelWrapper(unet.model.encoder)
loader, val_loader = unet.get_data_loaders()

In [None]:
pretrain_epochs = 100
log_dir = Path("./logs/DELETE-UNET")

lin_evaluation_frequency=None

byol = BYOL_UNet_Class(
    unet_wrap,
    loader,
    val_loader,
    pretrain_epochs,
    log_dir,
    input_dims=1,
    img_dims=(650, 650),
    hidden_features=4096,
    device=device,
    lin_evaluation_frequency=lin_evaluation_frequency,
    mixed_precision_enabled=True,

    original_unet=unet,
    save_weights_every=5
)

torch.cuda.empty_cache()
byol.pretrain()