In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import os
import sys
import time
import argparse

import torch
import torch.optim as optim
import torchvision.transforms as transforms
from tensorboardX import SummaryWriter

# from dataset import *
from torch.utils.data import DataLoader

sys.path.append("../")
import cfg
import func_2d.function as function
from conf import settings
from func_2d import utils as fn2dutils

# from models.discriminatorlayer import discriminator
from func_2d.dataset import REFUGE

In [12]:
default_args = argparse.Namespace(
    net="sam2",
    encoder="vit_b",
    exp_name="REFUGE_MedSAM2",
    vis=1,
    prompt="bbox",
    prompt_freq=2,
    pretrain="/hpc/mydata/saugat.kandel/sam2_projects/Medical-SAM2/pretrain/MedSAM2_pretrain.pth",
    val_freq=1,
    gpu=True,
    gpu_device=0,
    image_size=1024,
    out_size=1024,
    distributed="none",
    dataset="REFUGE",
    data_path="/hpc/mydata/saugat.kandel/sam2_projects/Medical-SAM2/data/REFUGE",
    sam_ckpt="/hpc/mydata/saugat.kandel/sam2_projects/segment-anything-2/checkpoints/sam2_hiera_small.pt",
    sam_config="sam2_hiera_s",
    video_length=2,
    b=4,
    lr=1e-4,
    weights="0",
    multimask_output=1,
    memory_bank_size=16,
)

In [13]:
def main(args):
    # use bfloat16 for the entire work
    torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

    if torch.cuda.get_device_properties(0).major >= 8:
        # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    GPUdevice = torch.device("cuda", args.gpu_device)

    net = fn2dutils.get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution=args.distributed)

    # optimisation
    optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    """load pretrained model"""

    args.path_helper = fn2dutils.set_log_dir("logs", args.exp_name)
    logger = fn2dutils.create_logger(args.path_helper["log_path"])
    logger.info(args)

    """segmentation data"""
    transform_train = transforms.Compose([
        transforms.Resize((args.image_size, args.image_size)),
        transforms.ToTensor(),
    ])

    transform_test = transforms.Compose([
        transforms.Resize((args.image_size, args.image_size)),
        transforms.ToTensor(),
    ])

    # example of REFUGE dataset
    if args.dataset == "REFUGE":
        """REFUGE data"""
        refuge_train_dataset = REFUGE(args, args.data_path, transform=transform_train, mode="Training")
        refuge_test_dataset = REFUGE(args, args.data_path, transform=transform_test, mode="Test")

        nice_train_loader = DataLoader(
            refuge_train_dataset, batch_size=args.b, shuffle=True, num_workers=2, pin_memory=True
        )
        nice_test_loader = DataLoader(
            refuge_test_dataset, batch_size=args.b, shuffle=False, num_workers=2, pin_memory=True
        )
        """end"""

    """checkpoint path and tensorboard"""
    checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, settings.TIME_NOW)
    # use tensorboard
    if not os.path.exists(settings.LOG_DIR):
        os.mkdir(settings.LOG_DIR)
    writer = SummaryWriter(log_dir=os.path.join(settings.LOG_DIR, args.net, settings.TIME_NOW))

    # create checkpoint folder to save model
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    checkpoint_path = os.path.join(checkpoint_path, "{net}-{epoch}-{type}.pth")

    """begain training"""
    best_tol = 1e4
    best_dice = 0.0

    for epoch in range(settings.EPOCH):
        if epoch == 0:
            tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
            logger.info(f"Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.")

        # training
        net.train()
        time_start = time.time()
        loss = function.train_sam(args, net, optimizer, nice_train_loader, epoch, writer)
        logger.info(f"Train loss: {loss} || @ epoch {epoch}.")
        time_end = time.time()
        print("time_for_training ", time_end - time_start)

        # validation
        net.eval()
        if epoch % args.val_freq == 0 or epoch == settings.EPOCH - 1:
            tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
            logger.info(f"Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.")

            if edice > best_dice:
                best_dice = edice
                torch.save(
                    {"model": net.state_dict(), "parameter": net._parameters},
                    os.path.join(args.path_helper["ckpt_path"], "latest_epoch.pth"),
                )

    writer.close()

In [14]:
main(default_args)

INFO:root:Loaded checkpoint sucessfully
Loaded checkpoint sucessfully
INFO:root:Namespace(net='sam2', encoder='vit_b', exp_name='REFUGE_MedSAM2', vis=1, prompt='bbox', prompt_freq=2, pretrain=None, val_freq=1, gpu=True, gpu_device=0, image_size=1024, out_size=1024, distributed='none', dataset='REFUGE', data_path='/hpc/mydata/saugat.kandel/sam2_projects/Medical-SAM2/data/REFUGE', sam_ckpt='/hpc/mydata/saugat.kandel/sam2_projects/segment-anything-2/checkpoints/sam2_hiera_small.pt', sam_config='sam2_hiera_s', video_length=2, b=4, lr=0.0001, weights='0', multimask_output=1, memory_bank_size=16, path_helper={'prefix': 'logs/REFUGE_MedSAM2_2024_09_02_23_22_22', 'ckpt_path': 'logs/REFUGE_MedSAM2_2024_09_02_23_22_22/Model', 'log_path': 'logs/REFUGE_MedSAM2_2024_09_02_23_22_22/Log', 'sample_path': 'logs/REFUGE_MedSAM2_2024_09_02_23_22_22/Samples'})
Namespace(net='sam2', encoder='vit_b', exp_name='REFUGE_MedSAM2', vis=1, prompt='bbox', prompt_freq=2, pretrain=None, val_freq=1, gpu=True, gpu_devi

KeyboardInterrupt: 

In [11]:
!ls /hpc/mydata/saugat.kandel/sam2_projects/segment-anything-2/Medical-SAM2/data/REFUGE/

ls: cannot access '/hpc/mydata/saugat.kandel/sam2_projects/segment-anything-2/Medical-SAM2/data/REFUGE/': No such file or directory
