In [1]:
import argparse
import os
import sys
import time

import torch
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from tensorboardX import SummaryWriter

# from dataset import *
from torch.utils.data import DataLoader

sys.path.append("../")
import cfg
import func_2d.function as function
from conf import settings
from func_2d import utils as fn2dutils

# from models.discriminatorlayer import discriminator
from func_2d.dataset import REFUGE

import matplotlib.pyplot as plt
from pathlib import Path

In [5]:
default_args = argparse.Namespace(
    model_id="sam2",
    encoder="vit_b",
    exp_name="REFUGE_MedSAM2",
    vis=1,
    prompt="bbox",
    prompt_freq=2,
    val_freq=1,
    gpu=True,
    gpu_device=0,
    image_size=1024,
    out_size=1024,
    distributed="none",
    dataset="REFUGE",
    data_path="/hpc/mydata/saugat.kandel/sam2_projects/Medical-SAM2/data/REFUGE",
    sam_ckpt="/hpc/mydata/saugat.kandel/sam2_projects/Medical-SAM2/pretrain/REFUGE_2d_pretrain.pth",
    sam_config="sam2_hiera_t",
    video_length=2,
    b=16,
    lr=1e-4,
    weights="0",
    multimask_output=1,
    memory_bank_size=32,
)

In [6]:
def main(args):
    # use bfloat16 for the entire work
    torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

    if torch.cuda.get_device_properties(0).major >= 8:
        # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    GPUdevice = torch.device("cuda", args.gpu_device)

    """load pretrained model"""
    print("Loading model...")
    net = fn2dutils.get_network(args, gpu_device=GPUdevice)
    print("Loaded model.")

    # optimisation
    optimizer = optim.Adam(
        net.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False
    )
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    args.path_helper = fn2dutils.set_log_dir("logs", args.exp_name)
    logger = fn2dutils.create_logger(args.path_helper["log_path"])
    logger.info(args)

    """segmentation data"""
    transform_train = transforms.Compose([
        transforms.Resize((args.image_size, args.image_size)),
        transforms.ToTensor(),
    ])

    transform_test = transforms.Compose([
        transforms.Resize((args.image_size, args.image_size)),
        transforms.ToTensor(),
    ])

    # example of REFUGE dataset
    if args.dataset == "REFUGE":
        """REFUGE data"""
        refuge_train_dataset = REFUGE(
            args, args.data_path, transform=transform_train, mode="Training"
        )
        refuge_test_dataset = REFUGE(args, args.data_path, transform=transform_test, mode="Test")

        nice_train_loader = DataLoader(
            refuge_train_dataset, batch_size=args.b, shuffle=True, num_workers=4, pin_memory=True
        )
        nice_test_loader = DataLoader(
            refuge_test_dataset, batch_size=args.b, shuffle=False, num_workers=4, pin_memory=True
        )
        """end"""

    """checkpoint path and tensorboard"""
    checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.model_id, settings.TIME_NOW)
    # use tensorboard
    if not os.path.exists(settings.LOG_DIR):
        os.mkdir(settings.LOG_DIR)
    writer = SummaryWriter(log_dir=os.path.join(settings.LOG_DIR, args.model_id, settings.TIME_NOW))

    # create checkpoint folder to save model
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    checkpoint_path = os.path.join(checkpoint_path, "{net}-{epoch}-{type}.pth")

    """begain training"""
    best_tol = 1e4
    best_dice = 0.0

    for epoch in range(settings.EPOCH):
        if epoch == 0:
            tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
            logger.info(f"Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.")

        # training
        net.train()
        time_start = time.time()
        loss, memory_bank_list = function.train_sam(
            args, net, optimizer, nice_train_loader, epoch, writer
        )
        logger.info(f"Train loss: {loss} || @ epoch {epoch}.")
        time_end = time.time()
        print("time_for_training ", time_end - time_start)

        # validation
        net.eval()
        if epoch % args.val_freq == 0 or epoch == settings.EPOCH - 1:
            tol, (eiou, edice) = function.validation_sam(
                args, nice_test_loader, epoch, net, writer, memory_bank_list=memory_bank_list
            )
            logger.info(f"Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.")

            if edice > best_dice:
                best_dice = edice
                torch.save(
                    {"model": net.state_dict(), "parameter": net._parameters},
                    os.path.join(args.path_helper["ckpt_path"], "latest_epoch.pth"),
                )
                torch.save(
                    {"memory_bank": memory_bank_list},
                    os.path.join(args.path_helper["ckpt_path"], "memory_bank.pth"),
                )
    writer.close()

In [7]:
main(default_args)

Loading model...
Building SAM2 model...


INFO:root:Loaded checkpoint sucessfully
Loaded checkpoint sucessfully
INFO:root:Namespace(model_id='sam2', encoder='vit_b', exp_name='REFUGE_MedSAM2', vis=1, prompt='bbox', prompt_freq=2, val_freq=1, gpu=True, gpu_device=0, image_size=1024, out_size=1024, distributed='none', dataset='REFUGE', data_path='/hpc/mydata/saugat.kandel/sam2_projects/Medical-SAM2/data/REFUGE', sam_ckpt='/hpc/mydata/saugat.kandel/sam2_projects/Medical-SAM2/pretrain/REFUGE_2d_pretrain.pth', sam_config='sam2_hiera_t', video_length=2, b=16, lr=0.0001, weights='0', multimask_output=1, memory_bank_size=32, path_helper={'prefix': 'logs/REFUGE_MedSAM2_2024_09_18_14_18_10', 'ckpt_path': 'logs/REFUGE_MedSAM2_2024_09_18_14_18_10/Model', 'log_path': 'logs/REFUGE_MedSAM2_2024_09_18_14_18_10/Log', 'sample_path': 'logs/REFUGE_MedSAM2_2024_09_18_14_18_10/Samples'})
Namespace(model_id='sam2', encoder='vit_b', exp_name='REFUGE_MedSAM2', vis=1, prompt='bbox', prompt_freq=2, val_freq=1, gpu=True, gpu_device=0, image_size=1024, ou

Loaded model.


INFO:root:Total score: 0.10136875510215759, IOU: 0.7800176978687383, DICE: 0.8726226878166199 || @ epoch 0.
Total score: 0.10136875510215759, IOU: 0.7800176978687383, DICE: 0.8726226878166199 || @ epoch 0.
Total score: 0.10136875510215759, IOU: 0.7800176978687383, DICE: 0.8726226878166199 || @ epoch 0.
Epoch 0: 100%|██████████| 25/25 [00:41<00:00,  1.67s/img, loss (batch)=0.0565]
INFO:root:Train loss: 0.05358761891722679 || @ epoch 0.
Train loss: 0.05358761891722679 || @ epoch 0.
Train loss: 0.05358761891722679 || @ epoch 0.


time_for_training  41.649803161621094


INFO:root:Total score: 0.11239167302846909, IOU: 0.751289737579541, DICE: 0.8533355379104615 || @ epoch 0.
Total score: 0.11239167302846909, IOU: 0.751289737579541, DICE: 0.8533355379104615 || @ epoch 0.
Total score: 0.11239167302846909, IOU: 0.751289737579541, DICE: 0.8533355379104615 || @ epoch 0.
Epoch 1: 100%|██████████| 25/25 [00:40<00:00,  1.62s/img, loss (batch)=0.0534]
INFO:root:Train loss: 0.0524336439371109 || @ epoch 1.
Train loss: 0.0524336439371109 || @ epoch 1.
Train loss: 0.0524336439371109 || @ epoch 1.


time_for_training  40.460312366485596


INFO:root:Total score: 0.10884302854537964, IOU: 0.7605328927788314, DICE: 0.8593486881256104 || @ epoch 1.
Total score: 0.10884302854537964, IOU: 0.7605328927788314, DICE: 0.8593486881256104 || @ epoch 1.
Total score: 0.10884302854537964, IOU: 0.7605328927788314, DICE: 0.8593486881256104 || @ epoch 1.
Epoch 2: 100%|██████████| 25/25 [00:40<00:00,  1.62s/img, loss (batch)=0.0594]
INFO:root:Train loss: 0.052664262652397154 || @ epoch 2.
Train loss: 0.052664262652397154 || @ epoch 2.
Train loss: 0.052664262652397154 || @ epoch 2.


time_for_training  40.5927209854126


INFO:root:Total score: 0.10304282605648041, IOU: 0.7711213848346032, DICE: 0.8666448521614075 || @ epoch 2.
Total score: 0.10304282605648041, IOU: 0.7711213848346032, DICE: 0.8666448521614075 || @ epoch 2.
Total score: 0.10304282605648041, IOU: 0.7711213848346032, DICE: 0.8666448521614075 || @ epoch 2.
Epoch 3: 100%|██████████| 25/25 [00:40<00:00,  1.61s/img, loss (batch)=0.053] 
INFO:root:Train loss: 0.05112998634576797 || @ epoch 3.
Train loss: 0.05112998634576797 || @ epoch 3.
Train loss: 0.05112998634576797 || @ epoch 3.


time_for_training  40.16026186943054


INFO:root:Total score: 0.11173125356435776, IOU: 0.7544308090512163, DICE: 0.8553388071060181 || @ epoch 3.
Total score: 0.11173125356435776, IOU: 0.7544308090512163, DICE: 0.8553388071060181 || @ epoch 3.
Total score: 0.11173125356435776, IOU: 0.7544308090512163, DICE: 0.8553388071060181 || @ epoch 3.
Epoch 4: 100%|██████████| 25/25 [00:40<00:00,  1.61s/img, loss (batch)=0.0498]
INFO:root:Train loss: 0.05161822482943535 || @ epoch 4.
Train loss: 0.05161822482943535 || @ epoch 4.
Train loss: 0.05161822482943535 || @ epoch 4.


time_for_training  40.31901502609253


INFO:root:Total score: 0.11592018604278564, IOU: 0.7448775979055567, DICE: 0.8486284637451171 || @ epoch 4.
Total score: 0.11592018604278564, IOU: 0.7448775979055567, DICE: 0.8486284637451171 || @ epoch 4.
Total score: 0.11592018604278564, IOU: 0.7448775979055567, DICE: 0.8486284637451171 || @ epoch 4.
Epoch 5: 100%|██████████| 25/25 [00:40<00:00,  1.61s/img, loss (batch)=0.0489]
INFO:root:Train loss: 0.05195574343204498 || @ epoch 5.
Train loss: 0.05195574343204498 || @ epoch 5.
Train loss: 0.05195574343204498 || @ epoch 5.


time_for_training  40.14305782318115


INFO:root:Total score: 0.12372041493654251, IOU: 0.7535404752708869, DICE: 0.8541388702392578 || @ epoch 5.
Total score: 0.12372041493654251, IOU: 0.7535404752708869, DICE: 0.8541388702392578 || @ epoch 5.
Total score: 0.12372041493654251, IOU: 0.7535404752708869, DICE: 0.8541388702392578 || @ epoch 5.
Epoch 6: 100%|██████████| 25/25 [00:40<00:00,  1.61s/img, loss (batch)=0.0446]
INFO:root:Train loss: 0.051642042398452756 || @ epoch 6.
Train loss: 0.051642042398452756 || @ epoch 6.
Train loss: 0.051642042398452756 || @ epoch 6.


time_for_training  40.2776095867157


INFO:root:Total score: 0.1072559803724289, IOU: 0.7622483017891633, DICE: 0.8608567428588867 || @ epoch 6.
Total score: 0.1072559803724289, IOU: 0.7622483017891633, DICE: 0.8608567428588867 || @ epoch 6.
Total score: 0.1072559803724289, IOU: 0.7622483017891633, DICE: 0.8608567428588867 || @ epoch 6.
Epoch 7: 100%|██████████| 25/25 [00:40<00:00,  1.63s/img, loss (batch)=0.0464]
INFO:root:Train loss: 0.05052577704191208 || @ epoch 7.
Train loss: 0.05052577704191208 || @ epoch 7.
Train loss: 0.05052577704191208 || @ epoch 7.


time_for_training  40.820916175842285


INFO:root:Total score: 0.10586125403642654, IOU: 0.7656886798479491, DICE: 0.8631748056411743 || @ epoch 7.
Total score: 0.10586125403642654, IOU: 0.7656886798479491, DICE: 0.8631748056411743 || @ epoch 7.
Total score: 0.10586125403642654, IOU: 0.7656886798479491, DICE: 0.8631748056411743 || @ epoch 7.
Epoch 8: 100%|██████████| 25/25 [00:40<00:00,  1.61s/img, loss (batch)=0.0548]
INFO:root:Train loss: 0.051807394176721575 || @ epoch 8.
Train loss: 0.051807394176721575 || @ epoch 8.
Train loss: 0.051807394176721575 || @ epoch 8.


time_for_training  40.232962131500244


INFO:root:Total score: 0.11739873886108398, IOU: 0.7501241022234061, DICE: 0.8527576541900634 || @ epoch 8.
Total score: 0.11739873886108398, IOU: 0.7501241022234061, DICE: 0.8527576541900634 || @ epoch 8.
Total score: 0.11739873886108398, IOU: 0.7501241022234061, DICE: 0.8527576541900634 || @ epoch 8.
Epoch 9: 100%|██████████| 25/25 [00:40<00:00,  1.61s/img, loss (batch)=0.057] 
INFO:root:Train loss: 0.052034389972686765 || @ epoch 9.
Train loss: 0.052034389972686765 || @ epoch 9.
Train loss: 0.052034389972686765 || @ epoch 9.


time_for_training  40.358137369155884


INFO:root:Total score: 0.11921778321266174, IOU: 0.7491427643984896, DICE: 0.8515798544883728 || @ epoch 9.
Total score: 0.11921778321266174, IOU: 0.7491427643984896, DICE: 0.8515798544883728 || @ epoch 9.
Total score: 0.11921778321266174, IOU: 0.7491427643984896, DICE: 0.8515798544883728 || @ epoch 9.
Epoch 10: 100%|██████████| 25/25 [00:40<00:00,  1.60s/img, loss (batch)=0.0474]
INFO:root:Train loss: 0.051230194568634035 || @ epoch 10.
Train loss: 0.051230194568634035 || @ epoch 10.
Train loss: 0.051230194568634035 || @ epoch 10.


time_for_training  40.02678632736206


INFO:root:Total score: 0.11097963154315948, IOU: 0.7574772726214838, DICE: 0.8575374865531922 || @ epoch 10.
Total score: 0.11097963154315948, IOU: 0.7574772726214838, DICE: 0.8575374865531922 || @ epoch 10.
Total score: 0.11097963154315948, IOU: 0.7574772726214838, DICE: 0.8575374865531922 || @ epoch 10.
Epoch 11:   0%|          | 0/25 [00:00<?, ?img/s]