In [None]:
pip install segmentation-models-pytorch


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.1.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [None]:


import os
import glob
import math
import re
import random
import logging
import warnings
from typing import List, Tuple, Dict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm


import segmentation_models_pytorch as smp

from peft import LoraConfig, get_peft_model, PeftModel
from transformers import (
    AutoProcessor,
    LlavaForConditionalGeneration,
    CLIPVisionModel,
    get_linear_schedule_with_warmup,
)
from torch.cuda.amp import GradScaler, autocast


warnings.filterwarnings("ignore")
logging.getLogger("transformers").setLevel(logging.ERROR)
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
torch.backends.cuda.matmul.allow_tf32 = True


class JaccardLoss(nn.Module):

    def __init__(self, smooth=1e-6):
        super(JaccardLoss, self).__init__()
        self.smooth = smooth

    def forward(self, y_pred, y_true):

        y_pred_probs = torch.sigmoid(y_pred)


        y_pred_flat = y_pred_probs.view(-1)
        y_true_flat = y_true.view(-1)


        intersection = (y_pred_flat * y_true_flat).sum()
        total = (y_pred_flat + y_true_flat).sum()
        union = total - intersection


        iou = (intersection + self.smooth) / (union + self.smooth)
        return 1 - iou


class VLM_QASegDataset(Dataset):
    def __init__(self, image_paths: List[str], metadata_df: pd.DataFrame, is_train: bool = True):
        self.image_paths: List[str] = []
        self.mask_paths: List[str] = []
        self.questions: List[str] = []
        self.answers: List[str] = []


        self.image_transform = transforms.Compose([transforms.Resize((336, 336))])

        self.mask_transform = transforms.Compose([
            transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.NEAREST),
            transforms.ToTensor()
        ])

        mdx = metadata_df.set_index("Patient")
        for img_path in image_paths:
            mask_path = img_path.replace(".tif", "_mask.tif")
            if not os.path.exists(mask_path):
                continue

            pid_folder = os.path.basename(os.path.dirname(img_path))
            pid_key = "_".join(pid_folder.split("_")[0:3])
            if pid_key in mdx.index:
                row = mdx.loc[[pid_key]].iloc[0]
                grade = row.get("neoplasm_histologic_grade")
                if pd.notna(grade) and int(grade) in [1, 2]:
                    self.image_paths.append(img_path)
                    self.mask_paths.append(mask_path)
                    q = "What is the histologic grade of the brain tumor in the MRI: one or two?"
                    a = f"The grade of the tumor is {'two' if int(grade) == 2 else 'one'}."
                    self.questions.append(q)
                    self.answers.append(a)

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx: int):

        image = Image.open(self.image_paths[idx]).convert("RGB")
        mask = Image.open(self.mask_paths[idx]).convert("L")


        image = self.image_transform(image)
        mask_tensor = self.mask_transform(mask)

        mask_tensor = (mask_tensor > 0).float()

        return image, mask_tensor, self.questions[idx], self.answers[idx]


def vlm_collate_fn_for_training(batch):
    images, masks, questions, answers = zip(*batch)

    masks_tensor = torch.stack(masks)
    return list(images), masks_tensor, list(questions), list(answers)

def vlm_collate_fn_for_evaluation(batch):
    images, masks, questions, answers = zip(*batch)
    masks_tensor = torch.stack(masks)
    return list(images), masks_tensor, list(questions), list(answers)


def build_training_batch_cpu_main(images, masks, questions, answers, processor: AutoProcessor):
    prompts = [f"USER: <image>\n{q}\nASSISTANT:" for q in questions]
    full_texts = [
        f"USER: <image>\n{q}\nASSISTANT: {a}{processor.tokenizer.eos_token}"
        for q, a in zip(questions, answers)
    ]

    toks_prompt = processor(text=prompts, images=images, return_tensors="pt", padding=True)
    toks_full = processor(text=full_texts, images=images, return_tensors="pt", padding=True)

    labels = toks_full.input_ids.clone()
    prompt_lens = torch.sum(toks_prompt.attention_mask, dim=1)
    for i in range(labels.size(0)):
        labels[i, : prompt_lens[i]] = -100
    labels[labels == processor.tokenizer.pad_token_id] = -100

    batch_cpu = {
        "input_ids": toks_full.input_ids,
        "pixel_values": toks_full.pixel_values,
        "attention_mask": toks_full.attention_mask,
        "labels": labels,
        "seg_masks_gt": masks,
    }
    return batch_cpu


class LlavaWithSegmentationHead(nn.Module):
    def __init__(self, llava_model):
        super().__init__()
        self.llava = llava_model

        self.vision_tower = self.llava.vision_tower

        self.seg_model = smp.DeepLabV3Plus(
            encoder_name="resnet34",
            encoder_weights=None, # We are not using the encoder directly
            in_channels=3,
            classes=1,
        )

        # Get the expected channel sizes from the smp encoder.
        # For resnet34, this is typically a tuple of 6: (3, 64, 64, 128, 256, 512)
        smp_encoder_channels = self.seg_model.encoder.out_channels

        # We need to project CLIP's single 1024-dim output to the channel sizes
        # of all 5 feature stages of the original ResNet encoder.
        # We skip smp_encoder_channels[0] as it's the input image channels.
        self.projection = nn.ModuleList([
            nn.Conv2d(1024, smp_encoder_channels[1], kernel_size=1), # Stage 1 -> 64 channels
            nn.Conv2d(1024, smp_encoder_channels[2], kernel_size=1), # Stage 2 -> 64 channels
            nn.Conv2d(1024, smp_encoder_channels[3], kernel_size=1), # Stage 3 -> 128 channels
            nn.Conv2d(1024, smp_encoder_channels[4], kernel_size=1), # Stage 4 -> 256 channels
            nn.Conv2d(1024, smp_encoder_channels[5], kernel_size=1), # Stage 5 -> 512 channels
        ])


    def forward(self, input_ids, pixel_values, attention_mask, labels=None, **kwargs):
        image_features = self.vision_tower(pixel_values, output_hidden_states=True)
        image_features_grid_with_cls = image_features.hidden_states[-1]


        image_features_grid = image_features_grid_with_cls[:, 1:, :]

        batch_size = image_features_grid.shape[0]
        patch_grid_size = int(math.sqrt(image_features_grid.shape[1]))
        hidden_size = image_features_grid.shape[2]
        seg_features = image_features_grid.reshape(batch_size, patch_grid_size, patch_grid_size, hidden_size)
        seg_features = seg_features.permute(0, 3, 1, 2).contiguous()

        # Project the single feature map to the multiple resolutions the decoder expects
        projected_features = [proj(seg_features) for proj in self.projection]

        scaled_projected_features = list(projected_features)


        scaled_projected_features[1] = F.interpolate(
            scaled_projected_features[1],
            scale_factor=4,
            mode='bilinear',
            align_corners=False
        )


        decoder_features = [None] + scaled_projected_features


        decoder_output = self.seg_model.decoder(decoder_features)


        seg_logits = self.seg_model.segmentation_head(decoder_output)

        seg_logits = F.interpolate(seg_logits, size=(336, 336), mode='bilinear', align_corners=False)


        vqa_output = self.llava(
            input_ids=input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=True
        )

        return {
            "vqa_loss": vqa_output.loss,
            "vqa_logits": vqa_output.logits,
            "seg_logits": seg_logits.squeeze(1)
        }

def compute_iou(pred_mask, true_mask, threshold=0.5):
    with torch.no_grad():
        pred_mask = (torch.sigmoid(pred_mask) > threshold).float()
        true_mask = true_mask.float()

        intersection = (pred_mask * true_mask).sum(dim=(1, 2))
        union = pred_mask.sum(dim=(1, 2)) + true_mask.sum(dim=(1, 2)) - intersection

        iou = (intersection + 1e-6) / (union + 1e-6)
        return iou.mean().item()

def run_evaluation(model, processor, data_loader: DataLoader, device, description="Evaluating"):
    model.eval()
    total_samples = 0
    vlm_correct = 0
    total_vqa_loss_sum = 0.0
    total_seg_loss_sum = 0.0
    total_loss_count = 0
    total_iou = 0.0

    seg_loss_fn = JaccardLoss().to(device)

    with torch.no_grad():
        for batch in tqdm(data_loader, desc=description):
            images, masks_gt, questions, answers = batch
            masks_gt = masks_gt.to(device)

            # VQA Generation Accuracy
            prompts = [f"USER: <image>\n{q}\nASSISTANT:" for q in questions]
            with autocast():
                gen_inputs = processor(text=prompts, images=images, return_tensors="pt", padding=True).to(device)
                generated_ids = model.llava.generate(
                    **gen_inputs,
                    max_new_tokens=20,
                    pad_token_id=processor.tokenizer.pad_token_id,
                )
            decoded = processor.batch_decode(generated_ids, skip_special_tokens=True)

            for i in range(len(decoded)):
                pred_span = decoded[i].split("ASSISTANT:")[-1].strip().lower()
                true_span = answers[i].lower()
                want_two = "two" in true_span
                has_one = "one" in pred_span or "1" in pred_span
                has_two = "two" in pred_span or "2" in pred_span
                ok = (want_two and has_two and not has_one) or ((not want_two) and has_one and not has_two)
                if ok:
                    vlm_correct += 1

            # VQA Loss + Segmentation Loss + IoU
            batch_cpu = build_training_batch_cpu_main(images, masks_gt.cpu(), questions, answers, processor)
            batch_gpu = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch_cpu.items()}

            with autocast():
                outputs = model(**batch_gpu)
                vqa_loss = outputs["vqa_loss"]
                seg_logits = outputs["seg_logits"]
                seg_loss = seg_loss_fn(seg_logits, batch_gpu["seg_masks_gt"].squeeze(1))

            if vqa_loss is not None:
                total_vqa_loss_sum += vqa_loss.item()
            if seg_loss is not None:
                total_seg_loss_sum += seg_loss.item()
            total_loss_count += 1
            total_iou += compute_iou(seg_logits, batch_gpu["seg_masks_gt"].squeeze(1))
            total_samples += len(answers)

    vlm_acc = (vlm_correct / total_samples) * 100 if total_samples else 0.0
    avg_vqa_loss = total_vqa_loss_sum / total_loss_count if total_loss_count else float("inf")
    avg_seg_loss = total_seg_loss_sum / total_loss_count if total_loss_count else float("inf")
    avg_iou = total_iou / total_loss_count if total_loss_count else 0.0
    ppl = math.exp(avg_vqa_loss) if avg_vqa_loss < 50 else float("inf")

    print(f"\n--- Results for {description} ---")
    print(f"  - VLM Grade Accuracy (QA):      {vlm_acc:.2f}%")
    print(f"  - Perplexity (teacher-forced):  {ppl:.4f}")
    print(f"  - Segmentation IoU:             {avg_iou:.4f}")
    print(f"  - Avg Segmentation Loss:        {avg_seg_loss:.4f}")
    print("-" * 40)
    return vlm_acc, avg_iou


def discover_lora_targets(llava_model, include_vision: bool = True) -> List[str]:
    text_keys = {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"}
    projector_keys = {"multi_modal_projector"}
    vision_keys = {"q_proj", "k_proj", "v_proj", "out_proj"}  # CLIP-style

    target_modules: set[str] = set()

    for name, module in llava_model.named_modules():

        if any(k in name for k in text_keys) and "language_model" in name:
            target_modules.add(name.split(".")[-1])

        if any(k in name for k in projector_keys):
            if hasattr(module, "weight") and getattr(module, "weight", None) is not None:
                target_modules.add(name.split(".")[-1])

        if include_vision and ("vision_tower" in name) and any(k in name for k in vision_keys):
            target_modules.add(name.split(".")[-1])

    return sorted(list(target_modules))


if __name__ == "__main__":
    config = {
        "device": "cuda:2" if torch.cuda.is_available() else "cpu",
        "base_path": "/home/ealam/Downloads/LGG dataset Cameron/lgg-mri-segmentation/kaggle_3m",
        "local_llava_path": "/home/ealam/Desktop/llava-1.5-7b-local",
        "save_path": "./llava-lora-multitask",
        "csv_path": "/home/ealam/Downloads/LGG dataset Cameron/lgg-mri-segmentation/kaggle_3m/data.csv",
        "learning_rate": 1e-5,
        "batch_size": 2,
        "num_epochs": 25,
        "early_stopping_patience": 5,
        "seed": 42,
        "include_vision_lora": True,
        "seg_loss_weight": 0.5,
        "num_workers": 0,
        "grad_clip_val": 1.0,
    }

    # Seeds
    torch.manual_seed(config["seed"])
    np.random.seed(config["seed"])
    random.seed(config["seed"])

    # Data Gathering
    print("Step 1: Gathering and splitting data...")
    all_image_paths = [p.replace("_mask.tif", ".tif") for p in glob.glob(os.path.join(config["base_path"], "*", "*_mask.tif"))]
    all_image_paths = [p for p in all_image_paths if os.path.exists(p)]
    print(f"Found {len(all_image_paths)} total images.")


    usable_paths, unused_paths = train_test_split(all_image_paths, test_size=0.01, random_state=config["seed"])
    print(f"Setting aside {len(unused_paths)} images. Using the remaining {len(usable_paths)} for this experiment.")

    train_val_paths, test_paths = train_test_split(usable_paths, test_size=0.20, random_state=config["seed"])
    train_paths, val_paths = train_test_split(train_val_paths, test_size=0.20, random_state=config["seed"])
    print(f"Splitting usable data into {len(train_paths)} training, {len(val_paths)} validation, and {len(test_paths)} test samples.")


    # Model & Processor
    print("\nStep 2: Setting up multi-task model and processor...")
    DEVICE = config["device"]
    base_model = LlavaForConditionalGeneration.from_pretrained(
        config["local_llava_path"], torch_dtype=torch.float16, low_cpu_mem_usage=True
    )
    processor = AutoProcessor.from_pretrained(config["local_llava_path"])
    if processor.tokenizer.pad_token is None:
        processor.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        base_model.resize_token_embeddings(len(processor.tokenizer))

    # LoRA Setup
    target_modules = discover_lora_targets(base_model, include_vision=config["include_vision_lora"])
    print("LoRA target modules:", target_modules)
    lora_cfg = LoraConfig(
        r=32, lora_alpha=64,
        target_modules=target_modules,
        lora_dropout=0.05, bias="none", task_type="CAUSAL_LM",
    )

    peft_model = get_peft_model(base_model, lora_cfg)

    #  Full Custom Model
    multitask_model = LlavaWithSegmentationHead(peft_model).to(DEVICE)
    peft_model.print_trainable_parameters()


    print("\nStep 3: Preparing DataLoaders...")
    metadata_df = pd.read_csv(config["csv_path"])
    train_ds = VLM_QASegDataset(train_paths, metadata_df, is_train=True)
    val_ds = VLM_QASegDataset(val_paths, metadata_df, is_train=False)
    test_ds = VLM_QASegDataset(test_paths, metadata_df, is_train=False)

    train_loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True, num_workers=config["num_workers"], collate_fn=vlm_collate_fn_for_training)
    val_loader = DataLoader(val_ds, batch_size=config["batch_size"], shuffle=False, num_workers=config["num_workers"], collate_fn=vlm_collate_fn_for_evaluation)
    test_loader = DataLoader(test_ds, batch_size=config["batch_size"], shuffle=False, num_workers=config["num_workers"], collate_fn=vlm_collate_fn_for_evaluation)


    print("\nStep 4: Starting multi-task fine-tuning...")
    trainable_params = [p for p in multitask_model.parameters() if p.requires_grad]
    optimizer = AdamW(trainable_params, lr=config["learning_rate"])
    scaler = GradScaler()
    seg_loss_fn = JaccardLoss().to(DEVICE)


    num_training_steps = len(train_loader) * config["num_epochs"]
    num_warmup_steps = int(0.1 * num_training_steps)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
    )

    best_val_metric = 0.0
    patience = 0

    for epoch in range(config["num_epochs"]):
        multitask_model.train()
        total_loss = 0.0

        for images, masks, questions, answers in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
            batch_cpu = build_training_batch_cpu_main(images, masks, questions, answers, processor)
            batch_gpu = {k: v.to(DEVICE) if torch.is_tensor(v) else v for k, v in batch_cpu.items()}
            optimizer.zero_grad(set_to_none=True)

            with autocast():
                outputs = multitask_model(**batch_gpu)
                vqa_loss = outputs["vqa_loss"]
                seg_logits = outputs["seg_logits"]
                seg_loss = seg_loss_fn(seg_logits, batch_gpu["seg_masks_gt"].squeeze(1))

                combined_loss = vqa_loss + config["seg_loss_weight"] * seg_loss

            scaler.scale(combined_loss).backward()


            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(trainable_params, config["grad_clip_val"])

            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            total_loss += combined_loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"\nEpoch {epoch+1} Avg Combined Loss -> {avg_loss:.4f}")


        val_acc, val_iou = run_evaluation(multitask_model, processor, val_loader, DEVICE, "Validation Set Eval")

        current_metric = val_acc + (val_iou * 100)

        if current_metric > best_val_metric:
            print(f"  -> New best validation metric ({current_metric:.2f}). Saving model...")
            best_val_metric = current_metric
            patience = 0

            save_dir = config["save_path"]
            os.makedirs(save_dir, exist_ok=True)

            torch.save(multitask_model.seg_model.state_dict(), os.path.join(save_dir, "seg_model.pth"))
            torch.save(multitask_model.projection.state_dict(), os.path.join(save_dir, "projection.pth"))
            multitask_model.llava.save_pretrained(os.path.join(save_dir, "llava_lora"))
            processor.save_pretrained(os.path.join(save_dir, "processor"))
        else:
            patience += 1
            print(f"  -> No improvement for {patience} epoch(s).")
            if patience >= config["early_stopping_patience"]:
                print("\n--- Early stopping triggered. ---")
                break
        print("=" * 80)

    # Final Evaluation
    print("\nStep 5: Loading best model for final evaluation...")
    save_path = config["save_path"]
    if os.path.exists(os.path.join(save_path, "seg_model.pth")):
        final_base_model = LlavaForConditionalGeneration.from_pretrained(config["local_llava_path"], torch_dtype=torch.float16)
        final_peft_model = PeftModel.from_pretrained(final_base_model, os.path.join(save_path, "llava_lora"))
        final_multitask_model = LlavaWithSegmentationHead(final_peft_model).to(DEVICE)


        final_multitask_model.seg_model.load_state_dict(torch.load(os.path.join(save_path, "seg_model.pth")))
        final_multitask_model.projection.load_state_dict(torch.load(os.path.join(save_path, "projection.pth")))
        final_processor = AutoProcessor.from_pretrained(os.path.join(save_path, "processor"))

        run_evaluation(final_multitask_model, final_processor, test_loader, DEVICE, "Final Test Evaluation")
    else:
        print("No model was saved.")


Step 1: Gathering and splitting data...
Found 3929 total images.
Setting aside 40 images. Using the remaining 3889 for this experiment.
Splitting usable data into 2488 training, 623 validation, and 778 test samples.

Step 2: Setting up multi-task model and processor...


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

LoRA target modules: ['down_proj', 'gate_proj', 'k_proj', 'linear_1', 'linear_2', 'o_proj', 'out_proj', 'q_proj', 'up_proj', 'v_proj']
trainable params: 86,671,360 || all params: 7,150,098,432 || trainable%: 1.2122

Step 3: Preparing DataLoaders...

Step 4: Starting multi-task fine-tuning...


Training Epoch 1: 100%|█████████████████████| 1216/1216 [12:32<00:00,  1.62it/s]



Epoch 1 Avg Combined Loss -> 0.8633


Validation Set Eval: 100%|████████████████████| 304/304 [03:51<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      48.27%
  - Perplexity (teacher-forced):  1.0770
  - Segmentation IoU:             0.0390
  - Avg Segmentation Loss:        0.9755
----------------------------------------
  -> New best validation metric (52.17). Saving model...


Training Epoch 2: 100%|█████████████████████| 1216/1216 [12:28<00:00,  1.62it/s]



Epoch 2 Avg Combined Loss -> 0.5485


Validation Set Eval: 100%|████████████████████| 304/304 [03:50<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      64.58%
  - Perplexity (teacher-forced):  1.0635
  - Segmentation IoU:             0.5402
  - Avg Segmentation Loss:        0.9424
----------------------------------------
  -> New best validation metric (118.60). Saving model...


Training Epoch 3: 100%|█████████████████████| 1216/1216 [12:22<00:00,  1.64it/s]



Epoch 3 Avg Combined Loss -> 0.5234


Validation Set Eval: 100%|████████████████████| 304/304 [03:50<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      74.30%
  - Perplexity (teacher-forced):  1.0612
  - Segmentation IoU:             0.6498
  - Avg Segmentation Loss:        0.9030
----------------------------------------
  -> New best validation metric (139.28). Saving model...


Training Epoch 4: 100%|█████████████████████| 1216/1216 [12:12<00:00,  1.66it/s]



Epoch 4 Avg Combined Loss -> 0.4869


Validation Set Eval: 100%|████████████████████| 304/304 [03:50<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      79.90%
  - Perplexity (teacher-forced):  1.0712
  - Segmentation IoU:             0.7190
  - Avg Segmentation Loss:        0.8555
----------------------------------------
  -> New best validation metric (151.80). Saving model...


Training Epoch 5: 100%|█████████████████████| 1216/1216 [12:04<00:00,  1.68it/s]



Epoch 5 Avg Combined Loss -> 0.4437


Validation Set Eval: 100%|████████████████████| 304/304 [03:50<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      81.05%
  - Perplexity (teacher-forced):  1.0779
  - Segmentation IoU:             0.7814
  - Avg Segmentation Loss:        0.7919
----------------------------------------
  -> New best validation metric (159.19). Saving model...


Training Epoch 6: 100%|█████████████████████| 1216/1216 [11:56<00:00,  1.70it/s]



Epoch 6 Avg Combined Loss -> 0.4096


Validation Set Eval: 100%|████████████████████| 304/304 [03:52<00:00,  1.31it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      81.55%
  - Perplexity (teacher-forced):  1.0769
  - Segmentation IoU:             0.7983
  - Avg Segmentation Loss:        0.7397
----------------------------------------
  -> New best validation metric (161.38). Saving model...


Training Epoch 7: 100%|█████████████████████| 1216/1216 [11:43<00:00,  1.73it/s]



Epoch 7 Avg Combined Loss -> 0.3791


Validation Set Eval: 100%|████████████████████| 304/304 [03:49<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      86.00%
  - Perplexity (teacher-forced):  1.0633
  - Segmentation IoU:             0.8577
  - Avg Segmentation Loss:        0.7061
----------------------------------------
  -> New best validation metric (171.77). Saving model...


Training Epoch 8: 100%|█████████████████████| 1216/1216 [11:31<00:00,  1.76it/s]



Epoch 8 Avg Combined Loss -> 0.3508


Validation Set Eval: 100%|████████████████████| 304/304 [03:49<00:00,  1.33it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      84.51%
  - Perplexity (teacher-forced):  1.0862
  - Segmentation IoU:             0.8745
  - Avg Segmentation Loss:        0.6773
----------------------------------------
  -> New best validation metric (171.96). Saving model...


Training Epoch 9: 100%|█████████████████████| 1216/1216 [11:29<00:00,  1.76it/s]



Epoch 9 Avg Combined Loss -> 0.3418


Validation Set Eval: 100%|████████████████████| 304/304 [03:50<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      84.18%
  - Perplexity (teacher-forced):  1.0883
  - Segmentation IoU:             0.8627
  - Avg Segmentation Loss:        0.6557
----------------------------------------
  -> No improvement for 1 epoch(s).


Training Epoch 10: 100%|████████████████████| 1216/1216 [11:31<00:00,  1.76it/s]



Epoch 10 Avg Combined Loss -> 0.3242


Validation Set Eval: 100%|████████████████████| 304/304 [03:51<00:00,  1.31it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      85.34%
  - Perplexity (teacher-forced):  1.0995
  - Segmentation IoU:             0.8805
  - Avg Segmentation Loss:        0.6440
----------------------------------------
  -> New best validation metric (173.39). Saving model...


Training Epoch 11: 100%|████████████████████| 1216/1216 [11:36<00:00,  1.75it/s]



Epoch 11 Avg Combined Loss -> 0.3201


Validation Set Eval: 100%|████████████████████| 304/304 [03:52<00:00,  1.31it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      83.86%
  - Perplexity (teacher-forced):  1.1450
  - Segmentation IoU:             0.8800
  - Avg Segmentation Loss:        0.6309
----------------------------------------
  -> No improvement for 1 epoch(s).


Training Epoch 12: 100%|████████████████████| 1216/1216 [11:24<00:00,  1.78it/s]



Epoch 12 Avg Combined Loss -> 0.3120


Validation Set Eval: 100%|████████████████████| 304/304 [03:51<00:00,  1.31it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      84.35%
  - Perplexity (teacher-forced):  1.1407
  - Segmentation IoU:             0.8799
  - Avg Segmentation Loss:        0.6272
----------------------------------------
  -> No improvement for 2 epoch(s).


Training Epoch 13: 100%|████████████████████| 1216/1216 [11:28<00:00,  1.76it/s]



Epoch 13 Avg Combined Loss -> 0.3048


Validation Set Eval: 100%|████████████████████| 304/304 [03:52<00:00,  1.31it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      75.78%
  - Perplexity (teacher-forced):  1.2223
  - Segmentation IoU:             0.8813
  - Avg Segmentation Loss:        0.6179
----------------------------------------
  -> No improvement for 3 epoch(s).


Training Epoch 14: 100%|████████████████████| 1216/1216 [11:24<00:00,  1.78it/s]



Epoch 14 Avg Combined Loss -> 0.2942


Validation Set Eval: 100%|████████████████████| 304/304 [03:51<00:00,  1.31it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      84.51%
  - Perplexity (teacher-forced):  1.1587
  - Segmentation IoU:             0.8831
  - Avg Segmentation Loss:        0.6143
----------------------------------------
  -> No improvement for 4 epoch(s).


Training Epoch 15: 100%|████████████████████| 1216/1216 [11:25<00:00,  1.77it/s]



Epoch 15 Avg Combined Loss -> 0.2934


Validation Set Eval: 100%|████████████████████| 304/304 [03:52<00:00,  1.31it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      85.50%
  - Perplexity (teacher-forced):  1.1445
  - Segmentation IoU:             0.8851
  - Avg Segmentation Loss:        0.6152
----------------------------------------
  -> New best validation metric (174.01). Saving model...


Training Epoch 16: 100%|████████████████████| 1216/1216 [11:23<00:00,  1.78it/s]



Epoch 16 Avg Combined Loss -> 0.2905


Validation Set Eval: 100%|████████████████████| 304/304 [03:52<00:00,  1.31it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      85.17%
  - Perplexity (teacher-forced):  1.1197
  - Segmentation IoU:             0.8850
  - Avg Segmentation Loss:        0.6128
----------------------------------------
  -> No improvement for 1 epoch(s).


Training Epoch 17: 100%|████████████████████| 1216/1216 [11:22<00:00,  1.78it/s]



Epoch 17 Avg Combined Loss -> 0.2900


Validation Set Eval: 100%|████████████████████| 304/304 [03:52<00:00,  1.31it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      85.50%
  - Perplexity (teacher-forced):  1.1394
  - Segmentation IoU:             0.8800
  - Avg Segmentation Loss:        0.6089
----------------------------------------
  -> No improvement for 2 epoch(s).


Training Epoch 18: 100%|████████████████████| 1216/1216 [11:21<00:00,  1.78it/s]



Epoch 18 Avg Combined Loss -> 0.2864


Validation Set Eval: 100%|████████████████████| 304/304 [03:52<00:00,  1.31it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      84.68%
  - Perplexity (teacher-forced):  1.1550
  - Segmentation IoU:             0.8854
  - Avg Segmentation Loss:        0.6116
----------------------------------------
  -> No improvement for 3 epoch(s).


Training Epoch 19: 100%|████████████████████| 1216/1216 [11:17<00:00,  1.79it/s]



Epoch 19 Avg Combined Loss -> 0.2842


Validation Set Eval: 100%|████████████████████| 304/304 [03:51<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      84.35%
  - Perplexity (teacher-forced):  1.1595
  - Segmentation IoU:             0.8866
  - Avg Segmentation Loss:        0.6079
----------------------------------------
  -> No improvement for 4 epoch(s).


Training Epoch 20: 100%|████████████████████| 1216/1216 [11:14<00:00,  1.80it/s]



Epoch 20 Avg Combined Loss -> 0.2839


Validation Set Eval: 100%|████████████████████| 304/304 [03:50<00:00,  1.32it/s]


--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      85.17%
  - Perplexity (teacher-forced):  1.1632
  - Segmentation IoU:             0.8808
  - Avg Segmentation Loss:        0.6078
----------------------------------------
  -> No improvement for 5 epoch(s).

--- Early stopping triggered. ---

Step 5: Loading best model for final evaluation...





Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Final Test Evaluation: 100%|██████████████████| 382/382 [04:49<00:00,  1.32it/s]


--- Results for Final Test Evaluation ---
  - VLM Grade Accuracy (QA):      84.27%
  - Perplexity (teacher-forced):  1.1521
  - Segmentation IoU:             0.8658
  - Avg Segmentation Loss:        0.5733
----------------------------------------



