In [1]:
"""Training for Grounded SAM.
"""
import os
import torch
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
import torch
from torchvision import transforms
from PIL import Image
from transformers import SamProcessor

from linear_probe import LinearProbe
from utils import get_bounding_box

from model import myGroundingDino, myBiomedCLIP, mySAM
import dataset_mimic
import dataset_pascal

In [None]:
def train_adaptation(hyperparams):
    """Training script for adaptation loss only (to be combined with segmentation objective).
    """
    # Load hyperparameters
    lr = hyperparams['lr']
    batch_size = hyperparams['batch_size']
    num_epochs = hyperparams['num_epochs']
    num_workers = hyperparams['num_workers']
    device = hyperparams['device']
    save_folder = hyperparams['save_folder']


    # Load data
    mimic_dataloader = dataset_mimic.load_data(batch_size=batch_size, tensor=True)
    print(mimic_dataloader)

    # Load model
    my_groundingdino = myGroundingDino(device=device)
    my_sam = mySAM(device=device)
    my_biomedclip = myBiomedCLIP(device=device)

    # Load optimizer
    groundingdino_params = list(my_groundingdino.model.backbone.parameters()) + list(my_groundingdino.img_linear.parameters()) + list(my_groundingdino.text_linear.parameters())
    sam_params = list(my_sam.model.parameters()) + list(my_sam.img_linear.parameters())
    optimizer = torch.optim.Adam(groundingdino_params + sam_params, lr=lr)

    # Set up training mode
    my_groundingdino.model.backbone.train()
    my_groundingdino.img_linear.train()
    my_groundingdino.text_linear.train()
    my_sam.model.train()
    my_sam.img_linear.train()
    my_biomedclip.model.eval()

    # Training loop
    for epoch in range(num_epochs):
        
        for data in tqdm(mimic_dataloader, desc=f'Training @ epoch {epoch+1} of {num_epochs}'):
            # Load data
            images = data["image"]
            image_paths = data["image_path"]
            report = data["report"]

            # Training step
            optimizer.zero_grad()
            loss = pass
            loss.backward()
            optimizer.step()
    
    # Save model
    my_groundingdino.save(ckpt_folder=save_folder)
    my_sam.save(ckpt_folder=save_folder)