# UNet with Resnet 50 Backbone

Train Dataset Size: 

Val Dataset Size: 

Test Dataset Size:

In [None]:
%pip install torch==1.13.1
%pip install timm==0.6.12

In [9]:
from backbones_unet.model.unet import Unet
from backbones_unet.utils.dataset import SemanticSegmentationDataset
from backbones_unet.model.losses import DiceLoss
from backbones_unet.utils.trainer import Trainer
from torchsummaryX import summary
from torch.utils.data import Dataset, DataLoader
from convert_coco_ann_to_mask import convert_coco_to_mask

import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Test Installation
random_tensor = torch.rand((1, 3, 64, 64))
model = Unet(in_channels=3, num_classes=1) # if no backbone specified, will default to Resnet50
print(model.predict(random_tensor))

In [None]:
# Feel free to add more items here
config = {
    "lr"         : 2e-3,
    "epochs"     : 100,
    "batch_size" : 2,  # Increase if your device can handle it
    "num_classes": 1,
    'truncated_normal_mean' : 0,
    'truncated_normal_std' : 0.2,
}

In [None]:
# create a torch.utils.data.Dataset/DataLoader
annotation_json_path = 'example_data/train/images'
train_img_path = 'example_data/train/images'
train_mask_path = 'example_data/train/masks'

val_img_path = 'example_data/val/images'
val_mask_path = 'example_data/val/masks'

## Extract Masks from the COCO annotations (if not already done)

In [None]:
convert_coco_to_mask(input_json=annotation_json_path, image_folder=train_img_path, output_folder=train_mask_path)

In [None]:
train_dataset = SemanticSegmentationDataset(train_img_path, train_mask_path)
val_dataset = SemanticSegmentationDataset(val_img_path, val_mask_path)

train_loader = DataLoader(train_dataset, batch_size=2)
val_loader = DataLoader(val_dataset, batch_size=2)

In [None]:
model = Unet(
    backbone='convnext_base', # backbone network name
    in_channels=3,            # input channels (1 for gray-scale images, 3 for RGB, etc.)
    num_classes=config["num_classes"],            # output channels (number of classes in your dataset)
)

In [None]:
# Define wandb credentials

import wandb
wandb.login(key="49efd84d0e342f343fb91401332234dea4a3ffe2") #API Key is in your wandb account, under settings (wandb.ai/settings)

run = wandb.init(
    name = "Trial_1", ## Wandb creates random run names if you skip this field
    reinit = True, ### Allows reinitalizing runs when you re-run this cell
    # run_id = ### Insert specific run id here if you want to resume a previous run
    # resume = "must" ### You need this to resume previous runs, but comment out reinit = True when using this
    project = "IDL_Project_Segmentation", ### Project should be created in your wandb account
    config = config ### Wandb Config for your run
)

In [None]:
checkpoint_path = ''

In [None]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(params, 1e-4)

trainer = Trainer(
    model,                    # UNet model with pretrained backbone
    criterion=DiceLoss(),     # loss function for model convergence
    optimizer=optimizer,      # optimizer for regularization
    epochs=10                 # number of epochs for model training
)

trainer.fit(train_loader, val_loader)