# Imports and setup

## Install and import dependencies

In [4]:
!pip install -r requirements.txt

253.36s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


Defaulting to user installation because normal site-packages is not writeable
Collecting numpy==1.24.1
  Downloading numpy-1.24.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.3/17.3 MB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.24.0
    Uninstalling numpy-1.24.0:
      Successfully uninstalled numpy-1.24.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gradio 4.38.1 requires urllib3~=2.0, but you have urllib3 1.26.20 which is incompatible.[0m[31m
[0mSuccessfully installed numpy-1.24.1


In [5]:
import os
import torch
import monai
from tqdm import tqdm
from statistics import mean
from torch.utils.data import DataLoader
# from torchvision import datasets, transforms
from torch.optim import Adam
import torch.nn as nn
from torch.nn.functional import threshold, normalize
import src.utils as utils

from src.brats_dataset import BratsDataset, collate_fn
from src.brats_processor import Samprocessor, find_slices


from src.segment_anything import build_sam_vit_b, SamPredictor
from src.lora import LoRA_sam
import matplotlib.pyplot as plt
import yaml
import torch.nn.functional as F

print("Setup complete")

Setup complete


## Configure datasets

In [6]:
# Load the config file
with open("./config.yaml", "r") as ymlfile:
   config_file = yaml.load(ymlfile, Loader=yaml.Loader)

# Take dataset path
train_dataset_path = "/home/peter/Documents/Code/samseg/src/testSamples/MEDIUM_Samples"
valid_dataset_path = "/home/peter/Documents/Code/samseg/src/testSamples/MEDIUM_Samples"

# Setup output directories
out_dir = "./train-out"

latest_ckpt_path = os.path.join(out_dir, 'latest_ckpt.pth.tar')
training_loss_path = os.path.join(out_dir, 'training_loss.csv')
backup_ckpts_dir = os.path.join(out_dir, 'backup_ckpts')
if not os.path.exists(backup_ckpts_dir):
    os.makedirs(backup_ckpts_dir)
    os.system(f'chmod a+rwx {backup_ckpts_dir}')

# Load SAM and create LoRA model
sam = build_sam_vit_b(checkpoint=config_file["SAM"]["CHECKPOINT"])
sam_lora = LoRA_sam(sam, config_file["SAM"]["RANK"])
model = sam_lora.sam

# Process the datasets
processor = Samprocessor(model)
train_ds = BratsDataset(train_dataset_path, "train")
valid_ds = BratsDataset(valid_dataset_path, "train")

# Create dataloaders
train_dataloader = DataLoader(train_ds, batch_size=1, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(valid_ds, batch_size=1, shuffle=False, collate_fn=collate_fn)

## Setting hyperparameters

In [7]:
# Initialize optimizer and Loss
optimizer = Adam(model.image_encoder.parameters(), lr=1e-4, weight_decay=0)
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
num_epochs = config_file["TRAIN"]["NUM_EPOCHS"]

loss_functions = [nn.MSELoss(), nn.CrossEntropyLoss()]
loss_weights = [0.4, 0.7]
backup_interval = 5
device = "cuda" if torch.cuda.is_available() else "cpu"

# Training Loop

In [8]:
model.train()
model.to(device)

total_loss = []

for epoch in range(num_epochs):
    epoch_losses = []

    for i, batch in enumerate(tqdm(train_dataloader)):
      torch.cuda.empty_cache()
      # print(f"{batch[0][0]}:")
      
      slice_idx = find_slices((batch[0][2] > 0).float())
      batch_loss = []
      for idx in slice_idx:
        
        outputs = []
        #with torch.no_grad():
        for image in batch[0][1]:
          input = [processor(image, batch[0][2], idx)]
          chunked_outputs = model(batched_input=input,
                        multimask_output=False)
          #chunked_outputs.requires_grad_(True)
          outputs.extend(chunked_outputs)
          del chunked_outputs
        stk_gt, stk_out = utils.stacking_batch(input, outputs)
        stk_out = stk_out.squeeze(1)
        # stk_gt = stk_gt.unsqueeze(1) # We need to get the [B, C, H, W] starting from [H, W]
        
        # Apply loss function to each scan type and calculate the average
        # dumb_loss = []
        # for i in range(stk_out.shape[0]):
        #   loss_one = seg_loss(stk_out[i], stk_gt.float().to(device))
        #   dumb_loss.append(loss_one)
        # loss_avg = torch.mean(torch.stack(dumb_loss))
        # print(f"Loss (avg): {loss_avg}")

        # Loss calculation with map function
        stk_gt = stk_gt.unsqueeze(1) # We need to get the [B, C, H, W] starting from [H, W]
        stk_gt = stk_gt.repeat(4, 1, 1, 1)
        loss = map(seg_loss, stk_out, stk_gt.float().to(device))
        loss_list = list(loss) # Convert map object to list
        loss_avg = torch.mean(torch.stack(loss_list)) # Calculate the average loss per scan type

        # Optimize parameters with the average loss
        optimizer.zero_grad()
        loss_avg.backward()
        optimizer.step()
        

        batch_loss.append(loss_avg)
                
      img_loss = torch.mean(torch.stack(batch_loss))
      # optimizer.zero_grad()
      # scan_loss.requires_grad_(True)
      # scan_loss.backward()
      # # print("GRADIENT:   ", scan_loss.grad)
      # # optimize
      # optimizer.step()
      epoch_losses.append(img_loss.item())

      torch.cuda.empty_cache()

    print(f'EPOCH: {epoch}; Mean training loss: {mean(epoch_losses)}')
    utils.save_tloss_csv(training_loss_path, epoch, mean(epoch_losses))
    print("Saving checkpoint...")
    checkpoint = {
      'epoch': epoch,
      'model_sd': model.state_dict(),
      'optim_sd': optimizer.state_dict(),
      'model': model,
      'loss_functions': loss_functions,
      'loss_weights': loss_weights,
    }
    torch.save(checkpoint, latest_ckpt_path)
    if epoch % backup_interval == 0:
        torch.save(checkpoint, os.path.join(backup_ckpts_dir, f'epoch{epoch}.pth.tar'))
    print('Checkpoint saved successfully.')

# Save the parameters of the model in safetensors format
rank = config_file["SAM"]["RANK"]
sam_lora.save_lora_parameters(f"lora_rank{rank}.safetensors")

  0%|          | 0/3 [00:00<?, ?it/s]

BraTS-MEN-00133-000:
Loss: <map object at 0x7f55e3947370>


  0%|          | 0/3 [02:03<?, ?it/s]


KeyboardInterrupt: 