In [1]:
from segment_anything import sam_model_registry
from sam_lora_image_encoder import LoRA_Sam
from base_dataset import RadioGalaxyNET
from torchvision.transforms import v2 as T
from torch.nn.modules.loss import CrossEntropyLoss
from utils import DiceLoss
from torchvision.transforms import InterpolationMode
import matplotlib.pyplot as plt
import torch
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm


In [3]:
def show_mask(mask, ax, random_color=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([1, 1, 0, 1])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 

In [3]:
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
rank = 4

sam = sam_model_registry[model_type](
    num_classes = 4,                                                           
    checkpoint=sam_checkpoint,
    image_size = 512)

net = LoRA_Sam(sam[0], 4)
#.cuda()

KeyboardInterrupt: 

In [2]:
train_dir = "RadioGalaxyNET/train"
train_coco = "RadioGalaxyNET/annotations/train.json"
# Note other transformations are not implemented as of this moment.
def get_transform(train):
    transforms = []
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    transforms.append(T.Resize((1024, 1024), interpolation=InterpolationMode.BILINEAR))
    transforms.append(T.ToDtype(torch.float, scale=True))
    transforms.append(T.ToPureTensor())
    return T.Compose(transforms)

# Define the dataset class
train_dataset = RadioGalaxyNET(root=train_dir,
                          annFile=train_coco,
                          transforms=get_transform(train=True)
                          ) 


loading annotations into memory...
Done (t=0.12s)
creating index...
index created!


In [3]:
def collate_fn(batch):
    images = []
    annotations = []
    for img, ann in batch:
        images.append(img)
        annotations.append(ann)
    return images, annotations

# Batch size
train_batch_size = 1

# Define DataLoader for some reason my dataloader can only be = 0, please try different number and let me know how you guys goes. 
trainloader = torch.utils.data.DataLoader(train_dataset,
                                          batch_size=train_batch_size,
                                          shuffle=True,
                                          num_workers=0,
                                          collate_fn=collate_fn)

In [51]:

# Channel 0 is the background 

In [56]:
iter_data = iter(trainloader)
image, label = next(iter_data)



# Create the function that converts the Panoptic mask into Semantic segmentation
def pan_to_sem(label):
    t = torch.empty(1, 5, 1024, 1024).int()
    masks = label['masks']
    masks = (masks != 0).int()
    classes = label['labels']

    for idx in range(len(classes)):
        # Add the object of a given class into the tensor storing the semantic segmentation mask.
        t[:,classes[idx]] += masks[idx]
        # Convert anything above 0 to 1. 
        t[:,classes[idx]] = (t[:,classes[idx]] > 0).int()
    # Create the background class and set it to 0.
    summed_tensor = torch.sum(t, dim=1)
    t[:,0] = (summed_tensor == 0).int()
    return t

(tensor([], dtype=torch.int64), tensor([], dtype=torch.int64), tensor([], dtype=torch.int64), tensor([], dtype=torch.int64))


In [7]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
net.to(device)
net.train()
params = [p for p in net.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

In [8]:
def calc_loss(outputs, low_res_label_batch, ce_loss, dice_loss, dice_weight:float=0.8):
    low_res_logits = outputs['masks']
    loss_ce = ce_loss(low_res_logits, low_res_label_batch[:].long())
    loss_dice = dice_loss(low_res_logits, low_res_label_batch, softmax=True)
    loss = (1 - dice_weight) * loss_ce + dice_weight * loss_dice
    return loss, loss_ce, loss_dice

In [9]:
iter_num = 0
max_epoch = 2
img_size = 1024
len_dataloader = len(trainloader)
iterator = tqdm(range(max_epoch))
multimask_output = True

ce_loss = CrossEntropyLoss()
dice_loss = DiceLoss(4 + 1)
for epoch_num in iterator:
    i = 0   
    for imgs, annotations in trainloader:
        i += 1
        imgs = list(img for img in imgs)
        imgs = torch.stack(imgs, dim=0).to(device)
        masks_list = [annotation["masks"] for annotation in annotations if "masks" in annotation]
        masks_list = torch.stack(masks_list, dim=0).to(device)

        # Forward pass
        outputs = net(imgs, multimask_output, img_size)
        loss, loss_ce, loss_dice = calc_loss(outputs, masks_list , ce_loss, dice_loss)
        optimizer.zero_grad() 
        loss.backward()
        optimizer.step()
        print(f'Iteration: {i}/{len_dataloader}, Loss: {loss}')

  0%|          | 0/2 [00:01<?, ?it/s]


RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size: : [1, 1, 1024, 1024]

In [None]:
torch.cuda.empty_cache()

In [33]:
input = train_dataset[0][0].unsqueeze(0).cuda()
mask_gt = train_dataset[0][1]['masks']
        

In [35]:
#print(imgs.shape)
output = net(input, True, 1024)

In [39]:
output['masks'].shape

torch.Size([1, 5, 1024, 1024])