In [None]:


import os
import glob
import math
import re
import random
import logging
import warnings
from typing import List, Tuple, Dict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm
import cv2

import segmentation_models_pytorch as smp

from peft import LoraConfig, get_peft_model, PeftModel
from transformers import (
    AutoProcessor,
    LlavaForConditionalGeneration,
    CLIPVisionModel,
    get_linear_schedule_with_warmup,
)
from torch.cuda.amp import GradScaler, autocast


warnings.filterwarnings("ignore")
logging.getLogger("transformers").setLevel(logging.ERROR)
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
torch.backends.cuda.matmul.allow_tf32 = True


class JaccardLoss(nn.Module):
    """
    Jaccard/Intersection over Union (IoU) Loss.
    A common loss function for segmentation tasks.
    """
    def __init__(self, smooth=1e-6):
        super(JaccardLoss, self).__init__()
        self.smooth = smooth

    def forward(self, y_pred, y_true):

        y_pred_probs = torch.sigmoid(y_pred)


        y_pred_flat = y_pred_probs.view(-1)
        y_true_flat = y_true.view(-1)


        intersection = (y_pred_flat * y_true_flat).sum()
        total = (y_pred_flat + y_true_flat).sum()
        union = total - intersection


        iou = (intersection + self.smooth) / (union + self.smooth)
        return 1 - iou


class TaskPacingManager:

    def __init__(self, num_tasks: int = 2, T: float = 2.0):
        self.T = T
        self.num_tasks = num_tasks
        self.weights = np.ones(num_tasks)
        self.prev_val_losses = None

    def update_weights(self, current_val_losses: List[float]):
        if self.prev_val_losses is None:
            # First epoch,  initial weights
            self.prev_val_losses = np.array(current_val_losses)
            return

        # 1. Calculate learning velocity (loss ratio) for each task
        # A higher velocity means faster learning (more loss reduction)
        velocities = self.prev_val_losses / (np.array(current_val_losses) + 1e-8)

        # 2.  Invert the velocities.
        # We want to give higher weight to the slower task (lower velocity).
        # Inverting makes the slowest task have the largest value.
        inv_velocities = 1.0 / (velocities + 1e-8)

        # 3. Calculate pacing factors using the inverted velocities
        pacing_factors = inv_velocities / np.sum(inv_velocities)

        # 4. Set new weights for the next epoch using a softmax function
        exp_terms = np.exp(pacing_factors / self.T)
        self.weights = (self.num_tasks * exp_terms / np.sum(exp_terms))*100

        # 5. Update previous losses for the next iteration
        self.prev_val_losses = np.array(current_val_losses)

    def get_weights(self) -> Tuple[float, float]:
        return self.weights[0], self.weights[1]



class VLM_QASegDataset(Dataset):
    def __init__(self, image_paths: List[str], metadata_df: pd.DataFrame, is_train: bool = True):
        self.image_paths: List[str] = []
        self.mask_paths: List[str] = []
        self.questions: List[str] = []
        self.answers: List[str] = []

        self.image_transform = transforms.Compose([transforms.Resize((336, 336))])
        self.mask_transform = transforms.Compose([
            transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.NEAREST),
            transforms.ToTensor()
        ])

        mdx = metadata_df.set_index("Patient")
        for img_path in image_paths:
            mask_path = img_path.replace(".tif", "_mask.tif")
            if not os.path.exists(mask_path):
                continue

            pid_folder = os.path.basename(os.path.dirname(img_path))
            pid_key = "_".join(pid_folder.split("_")[0:3])
            if pid_key in mdx.index:
                row = mdx.loc[[pid_key]].iloc[0]
                grade = row.get("neoplasm_histologic_grade")
                if pd.notna(grade) and int(grade) in [1, 2]:
                    self.image_paths.append(img_path)
                    self.mask_paths.append(mask_path)
                    q = "The tumor is delineated by a yellow border. What is its histologic grade: one or two?"
                    a = f"The grade of the tumor is {'two' if int(grade) == 2 else 'one'}."
                    self.questions.append(q)
                    self.answers.append(a)

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx: int):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        mask = Image.open(self.mask_paths[idx]).convert("L")


        image_np = np.array(image)
        mask_np = (np.array(mask) > 0).astype(np.uint8)
        contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(image_np, contours, -1, (255, 255, 0), thickness=2) # Yellow border
        image = Image.fromarray(image_np)

        image = self.image_transform(image)
        mask_tensor = self.mask_transform(mask)
        mask_tensor = (mask_tensor > 0).float()

        return image, mask_tensor, self.questions[idx], self.answers[idx]


def vlm_collate_fn_for_training(batch):
    images, masks, questions, answers = zip(*batch)
    masks_tensor = torch.stack(masks)
    return list(images), masks_tensor, list(questions), list(answers)

def vlm_collate_fn_for_evaluation(batch):
    images, masks, questions, answers = zip(*batch)
    masks_tensor = torch.stack(masks)
    return list(images), masks_tensor, list(questions), list(answers)


def build_training_batch_cpu_main(images, masks, questions, answers, processor: AutoProcessor):
    prompts = [f"USER: <image>\n{q}\nASSISTANT:" for q in questions]
    full_texts = [
        f"USER: <image>\n{q}\nASSISTANT: {a}{processor.tokenizer.eos_token}"
        for q, a in zip(questions, answers)
    ]

    toks_prompt = processor(text=prompts, images=images, return_tensors="pt", padding=True)
    toks_full = processor(text=full_texts, images=images, return_tensors="pt", padding=True)

    labels = toks_full.input_ids.clone()
    prompt_lens = torch.sum(toks_prompt.attention_mask, dim=1)
    for i in range(labels.size(0)):
        labels[i, : prompt_lens[i]] = -100
    labels[labels == processor.tokenizer.pad_token_id] = -100

    return {
        "input_ids": toks_full.input_ids,
        "pixel_values": toks_full.pixel_values,
        "attention_mask": toks_full.attention_mask,
        "labels": labels,
        "seg_masks_gt": masks,
    }


class LlavaWithSegmentationHead(nn.Module):
    def __init__(self, llava_model):
        super().__init__()
        self.llava = llava_model
        self.vision_tower = self.llava.vision_tower

        self.seg_model = smp.DeepLabV3Plus(
            encoder_name="resnet34", encoder_weights=None, in_channels=3, classes=1,
        )
        smp_encoder_channels = self.seg_model.encoder.out_channels
        self.projection = nn.ModuleList([
            nn.Conv2d(1024, smp_encoder_channels[i], kernel_size=1) for i in range(1, 6)
        ])

    def forward(self, input_ids, pixel_values, attention_mask, labels=None, **kwargs):
        image_features = self.vision_tower(pixel_values, output_hidden_states=True)
        image_features_grid_with_cls = image_features.hidden_states[-1]


        image_features_grid = image_features_grid_with_cls[:, 1:, :]
        batch_size, num_patches, hidden_size = image_features_grid.shape
        patch_grid_size = int(math.sqrt(num_patches))
        seg_features = image_features_grid.reshape(
            batch_size, patch_grid_size, patch_grid_size, hidden_size
        ).permute(0, 3, 1, 2).contiguous()

        projected_features = [proj(seg_features) for proj in self.projection]
        scaled_projected_features = list(projected_features)
        scaled_projected_features[1] = F.interpolate(
            scaled_projected_features[1], scale_factor=4, mode='bilinear', align_corners=False
        )
        decoder_features = [None] + scaled_projected_features
        decoder_output = self.seg_model.decoder(decoder_features)
        seg_logits = self.seg_model.segmentation_head(decoder_output)
        seg_logits = F.interpolate(seg_logits, size=(336, 336), mode='bilinear', align_corners=False)


        vqa_output = self.llava(
            input_ids=input_ids, pixel_values=pixel_values,
            attention_mask=attention_mask, labels=labels, return_dict=True
        )

        return {
            "vqa_loss": vqa_output.loss,
            "vqa_logits": vqa_output.logits,
            "seg_logits": seg_logits.squeeze(1)
        }


def compute_iou(pred_mask, true_mask, threshold=0.5):
    with torch.no_grad():
        pred_mask = (torch.sigmoid(pred_mask) > threshold).float()
        true_mask = true_mask.float()
        intersection = (pred_mask * true_mask).sum(dim=(1, 2))
        union = pred_mask.sum(dim=(1, 2)) + true_mask.sum(dim=(1, 2)) - intersection
        iou = (intersection + 1e-6) / (union + 1e-6)
        return iou.mean().item()

def run_evaluation(model, processor, data_loader: DataLoader, device, description="Evaluating"):
    model.eval()
    total_samples, vlm_correct, total_vqa_loss_sum, total_seg_loss_sum, total_loss_count, total_iou = 0, 0, 0.0, 0.0, 0, 0.0
    seg_loss_fn = JaccardLoss().to(device)

    with torch.no_grad():
        for batch in tqdm(data_loader, desc=description):
            images, masks_gt, questions, answers = batch
            masks_gt = masks_gt.to(device)

            # VQA Generation Accuracy
            prompts = [f"USER: <image>\n{q}\nASSISTANT:" for q in questions]
            with autocast():
                gen_inputs = processor(text=prompts, images=images, return_tensors="pt", padding=True).to(device)
                generated_ids = model.llava.generate(**gen_inputs, max_new_tokens=20, pad_token_id=processor.tokenizer.pad_token_id)
            decoded = processor.batch_decode(generated_ids, skip_special_tokens=True)

            for i in range(len(decoded)):
                pred_span = decoded[i].split("ASSISTANT:")[-1].strip().lower()
                true_span = answers[i].lower()
                want_two = "two" in true_span
                has_one = "one" in pred_span or "1" in pred_span
                has_two = "two" in pred_span or "2" in pred_span
                if (want_two and has_two and not has_one) or ((not want_two) and has_one and not has_two):
                    vlm_correct += 1


            batch_cpu = build_training_batch_cpu_main(images, masks_gt.cpu(), questions, answers, processor)
            batch_gpu = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch_cpu.items()}
            with autocast():
                outputs = model(**batch_gpu)
                vqa_loss, seg_logits = outputs["vqa_loss"], outputs["seg_logits"]
                seg_loss = seg_loss_fn(seg_logits, batch_gpu["seg_masks_gt"].squeeze(1))

            if vqa_loss is not None: total_vqa_loss_sum += vqa_loss.item()
            if seg_loss is not None: total_seg_loss_sum += seg_loss.item()
            total_loss_count += 1
            total_iou += compute_iou(seg_logits, batch_gpu["seg_masks_gt"].squeeze(1))
            total_samples += len(answers)

    vlm_acc = (vlm_correct / total_samples) * 100 if total_samples else 0.0
    avg_vqa_loss = total_vqa_loss_sum / total_loss_count if total_loss_count else float("inf")
    avg_seg_loss = total_seg_loss_sum / total_loss_count if total_loss_count else float("inf")
    avg_iou = total_iou / total_loss_count if total_loss_count else 0.0

    print(f"\n--- Results for {description} ---")
    print(f"  - VLM Grade Accuracy (QA): {vlm_acc:.2f}%")
    print(f"  - Segmentation IoU:        {avg_iou:.4f}")
    print(f"  - Avg VQA Loss:            {avg_vqa_loss:.4f}")
    print(f"  - Avg Segmentation Loss:   {avg_seg_loss:.4f}")
    print("-" * 40)
    return vlm_acc, avg_iou, avg_vqa_loss, avg_seg_loss


def discover_lora_targets(llava_model, include_vision: bool = True) -> List[str]:
    text_keys = {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"}
    projector_keys = {"multi_modal_projector"}
    vision_keys = {"q_proj", "k_proj", "v_proj", "out_proj"}  # CLIP-style

    target_modules: set[str] = set()

    for name, module in llava_model.named_modules():

        if any(k in name for k in text_keys) and "language_model" in name:
            target_modules.add(name.split(".")[-1])

        if any(k in name for k in projector_keys):
            if hasattr(module, "weight") and getattr(module, "weight", None) is not None:
                target_modules.add(name.split(".")[-1])

        if include_vision and ("vision_tower" in name) and any(k in name for k in vision_keys):
            target_modules.add(name.split(".")[-1])

    return sorted(list(target_modules))


if __name__ == "__main__":
    config = {
        "device": "cuda:0" if torch.cuda.is_available() else "cpu",
        "base_path": "/home/ealam/Downloads/LGG dataset Cameron/lgg-mri-segmentation/kaggle_3m",
        "local_llava_path": "/home/ealam/Desktop/llava-1.5-7b-local",
        "save_path": "./llava-lora-multitask-delineated-dwa-tp",
        "csv_path": "/home/ealam/Downloads/LGG dataset Cameron/lgg-mri-segmentation/kaggle_3m/data.csv",
        "learning_rate": 1e-5,
        "batch_size": 2,
        "num_epochs": 25,
        "early_stopping_patience": 5,
        "seed": 42,
        "include_vision_lora": True,
        "dwa_temperature": 2.0,
        "num_workers": 4,
        "grad_clip_val": 1.0,
    }

    # Seeds
    torch.manual_seed(config["seed"])
    np.random.seed(config["seed"])
    random.seed(config["seed"])

    # Data
    print("Step 1: Gathering and splitting data...")
    all_image_paths = [p.replace("_mask.tif", ".tif") for p in glob.glob(os.path.join(config["base_path"], "*", "*_mask.tif"))]
    usable_paths, _ = train_test_split(all_image_paths, test_size=0.01, random_state=config["seed"])
    train_val_paths, test_paths = train_test_split(usable_paths, test_size=0.20, random_state=config["seed"])
    train_paths, val_paths = train_test_split(train_val_paths, test_size=0.20, random_state=config["seed"])
    print(f"Data split: {len(train_paths)} train, {len(val_paths)} val, {len(test_paths)} test.")

    # Model & Processor
    print("\nStep 2: Setting up multi-task model and processor...")
    DEVICE = config["device"]
    base_model = LlavaForConditionalGeneration.from_pretrained(
        config["local_llava_path"], torch_dtype=torch.float16, low_cpu_mem_usage=True
    )
    processor = AutoProcessor.from_pretrained(config["local_llava_path"])
    if processor.tokenizer.pad_token is None:
        processor.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        base_model.resize_token_embeddings(len(processor.tokenizer))

    # LoRA Setup
    target_modules = discover_lora_targets(base_model, include_vision=config["include_vision_lora"])
    print("LoRA target modules:", target_modules)
    lora_cfg = LoraConfig(r=32, lora_alpha=64, target_modules=target_modules, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
    peft_model = get_peft_model(base_model, lora_cfg)
    multitask_model = LlavaWithSegmentationHead(peft_model).to(DEVICE)
    peft_model.print_trainable_parameters()

    # DataLoaders
    print("\nStep 3: Preparing DataLoaders...")
    metadata_df = pd.read_csv(config["csv_path"])
    train_ds = VLM_QASegDataset(train_paths, metadata_df, is_train=True)
    val_ds = VLM_QASegDataset(val_paths, metadata_df, is_train=False)
    test_ds = VLM_QASegDataset(test_paths, metadata_df, is_train=False)
    train_loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True, num_workers=config["num_workers"], collate_fn=vlm_collate_fn_for_training)
    val_loader = DataLoader(val_ds, batch_size=config["batch_size"], shuffle=False, num_workers=config["num_workers"], collate_fn=vlm_collate_fn_for_evaluation)
    test_loader = DataLoader(test_ds, batch_size=config["batch_size"], shuffle=False, num_workers=config["num_workers"], collate_fn=vlm_collate_fn_for_evaluation)

    # Training
    print("\nStep 4: Starting multi-task fine-tuning with DWA-TP loss...")
    task_pacer = TaskPacingManager(T=config["dwa_temperature"])
    trainable_params = [p for p in multitask_model.parameters() if p.requires_grad]
    optimizer = AdamW(trainable_params, lr=config["learning_rate"])
    scaler = GradScaler()
    seg_loss_fn = JaccardLoss().to(DEVICE)
    num_training_steps = len(train_loader) * config["num_epochs"]
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(0.1 * num_training_steps), num_training_steps=num_training_steps)

    best_val_metric, patience = 0.0, 0

    for epoch in range(config["num_epochs"]):
        multitask_model.train()
        total_loss = 0.0
        w_vqa, w_seg = task_pacer.get_weights()
        print(f"--- Epoch {epoch+1} ---")
        print(f"Starting with weights -> VQA: {w_vqa:.4f}, Seg: {w_seg:.4f}")

        for images, masks, questions, answers in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
            batch_cpu = build_training_batch_cpu_main(images, masks, questions, answers, processor)
            batch_gpu = {k: v.to(DEVICE) if torch.is_tensor(v) else v for k, v in batch_cpu.items()}
            optimizer.zero_grad(set_to_none=True)

            with autocast():
                outputs = multitask_model(**batch_gpu)
                vqa_loss = outputs["vqa_loss"]
                seg_logits = outputs["seg_logits"]
                seg_loss = seg_loss_fn(seg_logits, batch_gpu["seg_masks_gt"].squeeze(1))
                combined_loss = (w_vqa * vqa_loss) + (w_seg * seg_loss)

            scaler.scale(combined_loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(trainable_params, config["grad_clip_val"])
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            total_loss += combined_loss.item()

        print(f"\nEpoch {epoch+1} Avg Combined Training Loss -> {total_loss / len(train_loader):.4f}")

        # Validation and Weight Update
        val_acc, val_iou, val_vqa_loss, val_seg_loss = run_evaluation(multitask_model, processor, val_loader, DEVICE, "Validation Set Eval")
        task_pacer.update_weights([val_vqa_loss, val_seg_loss])

        current_metric = val_acc + (val_iou * 100)
        if current_metric > best_val_metric:
            print(f"  -> New best validation metric ({current_metric:.2f}). Saving model...")
            best_val_metric = current_metric
            patience = 0
            save_dir = config["save_path"]
            os.makedirs(save_dir, exist_ok=True)
            torch.save(multitask_model.seg_model.state_dict(), os.path.join(save_dir, "seg_model.pth"))
            torch.save(multitask_model.projection.state_dict(), os.path.join(save_dir, "projection.pth"))
            multitask_model.llava.save_pretrained(os.path.join(save_dir, "llava_lora"))
            processor.save_pretrained(os.path.join(save_dir, "processor"))
        else:
            patience += 1
            print(f"  -> No improvement for {patience} epoch(s).")
            if patience >= config["early_stopping_patience"]:
                print("\n--- Early stopping triggered. ---")
                break
        print("=" * 80)

    # Final Evaluation
    print("\nStep 5: Loading best model for final evaluation...")
    save_path = config["save_path"]
    if os.path.exists(os.path.join(save_path, "seg_model.pth")):
        final_base_model = LlavaForConditionalGeneration.from_pretrained(config["local_llava_path"], torch_dtype=torch.float16)
        final_peft_model = PeftModel.from_pretrained(final_base_model, os.path.join(save_path, "llava_lora"))
        final_multitask_model = LlavaWithSegmentationHead(final_peft_model).to(DEVICE)

        final_multitask_model.seg_model.load_state_dict(torch.load(os.path.join(save_path, "seg_model.pth")))
        final_multitask_model.projection.load_state_dict(torch.load(os.path.join(save_path, "projection.pth")))
        final_processor = AutoProcessor.from_pretrained(os.path.join(save_path, "processor"))

        run_evaluation(final_multitask_model, final_processor, test_loader, DEVICE, "Final Test Evaluation")
    else:
        print("No model was saved.")





Step 1: Gathering and splitting data...
Data split: 2488 train, 623 val, 778 test.

Step 2: Setting up multi-task model and processor...


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

LoRA target modules: ['down_proj', 'gate_proj', 'k_proj', 'linear_1', 'linear_2', 'o_proj', 'out_proj', 'q_proj', 'up_proj', 'v_proj']
trainable params: 86,671,360 || all params: 7,150,098,432 || trainable%: 1.2122

Step 3: Preparing DataLoaders...

Step 4: Starting multi-task fine-tuning with DWA-TP loss...
--- Epoch 1 ---
Starting with weights -> VQA: 1.0000, Seg: 1.0000


Training Epoch 1: 100%|█████████████████████| 1216/1216 [12:27<00:00,  1.63it/s]



Epoch 1 Avg Combined Training Loss -> 1.3090


Validation Set Eval: 100%|████████████████████| 304/304 [03:52<00:00,  1.31it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 48.11%
  - Segmentation IoU:        0.0550
  - Avg VQA Loss:            0.0717
  - Avg Segmentation Loss:   0.9734
----------------------------------------
  -> New best validation metric (53.61). Saving model...
--- Epoch 2 ---
Starting with weights -> VQA: 1.0000, Seg: 1.0000


Training Epoch 2: 100%|█████████████████████| 1216/1216 [12:24<00:00,  1.63it/s]



Epoch 2 Avg Combined Training Loss -> 1.0165


Validation Set Eval: 100%|████████████████████| 304/304 [03:50<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 52.55%
  - Segmentation IoU:        0.7625
  - Avg VQA Loss:            0.0696
  - Avg Segmentation Loss:   0.9134
----------------------------------------
  -> New best validation metric (128.80). Saving model...
--- Epoch 3 ---
Starting with weights -> VQA: 100.4278, Seg: 99.5722


Training Epoch 3: 100%|█████████████████████| 1216/1216 [12:23<00:00,  1.64it/s]



Epoch 3 Avg Combined Training Loss -> 94.8619


Validation Set Eval: 100%|████████████████████| 304/304 [03:51<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 65.40%
  - Segmentation IoU:        0.8548
  - Avg VQA Loss:            0.0693
  - Avg Segmentation Loss:   0.8345
----------------------------------------
  -> New best validation metric (150.89). Saving model...
--- Epoch 4 ---
Starting with weights -> VQA: 101.0608, Seg: 98.9392


Training Epoch 4: 100%|█████████████████████| 1216/1216 [12:16<00:00,  1.65it/s]



Epoch 4 Avg Combined Training Loss -> 88.4436


Validation Set Eval: 100%|████████████████████| 304/304 [03:50<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 68.70%
  - Segmentation IoU:        0.8726
  - Avg VQA Loss:            0.0773
  - Avg Segmentation Loss:   0.7880
----------------------------------------
  -> New best validation metric (155.96). Saving model...
--- Epoch 5 ---
Starting with weights -> VQA: 102.0826, Seg: 97.9174


Training Epoch 5: 100%|█████████████████████| 1216/1216 [12:04<00:00,  1.68it/s]



Epoch 5 Avg Combined Training Loss -> 82.1020


Validation Set Eval: 100%|████████████████████| 304/304 [03:50<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 73.48%
  - Segmentation IoU:        0.8971
  - Avg VQA Loss:            0.1423
  - Avg Segmentation Loss:   0.7300
----------------------------------------
  -> New best validation metric (163.18). Saving model...
--- Epoch 6 ---
Starting with weights -> VQA: 108.2411, Seg: 91.7589


Training Epoch 6: 100%|█████████████████████| 1216/1216 [11:45<00:00,  1.72it/s]



Epoch 6 Avg Combined Training Loss -> 71.1762


Validation Set Eval: 100%|████████████████████| 304/304 [03:48<00:00,  1.33it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 74.96%
  - Segmentation IoU:        0.9140
  - Avg VQA Loss:            0.1291
  - Avg Segmentation Loss:   0.6856
----------------------------------------
  -> New best validation metric (166.36). Saving model...
--- Epoch 7 ---
Starting with weights -> VQA: 99.5733, Seg: 100.4267


Training Epoch 7: 100%|█████████████████████| 1216/1216 [11:45<00:00,  1.72it/s]



Epoch 7 Avg Combined Training Loss -> 71.3534


Validation Set Eval: 100%|████████████████████| 304/304 [03:49<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 78.58%
  - Segmentation IoU:        0.9295
  - Avg VQA Loss:            0.1294
  - Avg Segmentation Loss:   0.6336
----------------------------------------
  -> New best validation metric (171.53). Saving model...
--- Epoch 8 ---
Starting with weights -> VQA: 101.0116, Seg: 98.9884


Training Epoch 8: 100%|█████████████████████| 1216/1216 [11:45<00:00,  1.72it/s]



Epoch 8 Avg Combined Training Loss -> 66.0589


Validation Set Eval: 100%|████████████████████| 304/304 [03:51<00:00,  1.31it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 79.74%
  - Segmentation IoU:        0.9344
  - Avg VQA Loss:            0.1383
  - Avg Segmentation Loss:   0.5954
----------------------------------------
  -> New best validation metric (173.18). Saving model...
--- Epoch 9 ---
Starting with weights -> VQA: 101.6101, Seg: 98.3899


Training Epoch 9: 100%|█████████████████████| 1216/1216 [11:38<00:00,  1.74it/s]



Epoch 9 Avg Combined Training Loss -> 63.3935


Validation Set Eval: 100%|████████████████████| 304/304 [03:51<00:00,  1.31it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 79.74%
  - Segmentation IoU:        0.9396
  - Avg VQA Loss:            0.1127
  - Avg Segmentation Loss:   0.5799
----------------------------------------
  -> New best validation metric (173.69). Saving model...
--- Epoch 10 ---
Starting with weights -> VQA: 97.7795, Seg: 102.2205


Training Epoch 10: 100%|████████████████████| 1216/1216 [11:29<00:00,  1.76it/s]



Epoch 10 Avg Combined Training Loss -> 61.2625


Validation Set Eval: 100%|████████████████████| 304/304 [03:50<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 79.24%
  - Segmentation IoU:        0.9432
  - Avg VQA Loss:            0.1171
  - Avg Segmentation Loss:   0.5634
----------------------------------------
  -> No improvement for 1 epoch(s).
--- Epoch 11 ---
Starting with weights -> VQA: 100.8315, Seg: 99.1685


Training Epoch 11: 100%|████████████████████| 1216/1216 [11:28<00:00,  1.77it/s]



Epoch 11 Avg Combined Training Loss -> 58.2641


Validation Set Eval: 100%|████████████████████| 304/304 [03:50<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 81.55%
  - Segmentation IoU:        0.9444
  - Avg VQA Loss:            0.1539
  - Avg Segmentation Loss:   0.5512
----------------------------------------
  -> New best validation metric (175.99). Saving model...
--- Epoch 12 ---
Starting with weights -> VQA: 103.6618, Seg: 96.3382


Training Epoch 12: 100%|████████████████████| 1216/1216 [11:28<00:00,  1.77it/s]



Epoch 12 Avg Combined Training Loss -> 55.3920


Validation Set Eval: 100%|████████████████████| 304/304 [03:49<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 79.90%
  - Segmentation IoU:        0.9415
  - Avg VQA Loss:            0.1676
  - Avg Segmentation Loss:   0.5469
----------------------------------------
  -> No improvement for 1 epoch(s).
--- Epoch 13 ---
Starting with weights -> VQA: 101.1620, Seg: 98.8380


Training Epoch 13: 100%|████████████████████| 1216/1216 [11:18<00:00,  1.79it/s]



Epoch 13 Avg Combined Training Loss -> 55.1151


Validation Set Eval: 100%|████████████████████| 304/304 [03:49<00:00,  1.33it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 84.02%
  - Segmentation IoU:        0.9418
  - Avg VQA Loss:            0.1319
  - Avg Segmentation Loss:   0.5412
----------------------------------------
  -> New best validation metric (178.20). Saving model...
--- Epoch 14 ---
Starting with weights -> VQA: 97.1525, Seg: 102.8475


Training Epoch 14: 100%|████████████████████| 1216/1216 [11:18<00:00,  1.79it/s]



Epoch 14 Avg Combined Training Loss -> 55.5868


Validation Set Eval: 100%|████████████████████| 304/304 [03:48<00:00,  1.33it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 82.87%
  - Segmentation IoU:        0.9495
  - Avg VQA Loss:            0.1417
  - Avg Segmentation Loss:   0.5266
----------------------------------------
  -> No improvement for 1 epoch(s).
--- Epoch 15 ---
Starting with weights -> VQA: 101.2371, Seg: 98.7629


Training Epoch 15: 100%|████████████████████| 1216/1216 [11:16<00:00,  1.80it/s]



Epoch 15 Avg Combined Training Loss -> 53.0654


Validation Set Eval: 100%|████████████████████| 304/304 [03:50<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 84.51%
  - Segmentation IoU:        0.9461
  - Avg VQA Loss:            0.1308
  - Avg Segmentation Loss:   0.5294
----------------------------------------
  -> New best validation metric (179.12). Saving model...
--- Epoch 16 ---
Starting with weights -> VQA: 98.9301, Seg: 101.0699


Training Epoch 16: 100%|████████████████████| 1216/1216 [11:11<00:00,  1.81it/s]



Epoch 16 Avg Combined Training Loss -> 53.4818


Validation Set Eval: 100%|████████████████████| 304/304 [03:51<00:00,  1.31it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 84.51%
  - Segmentation IoU:        0.9458
  - Avg VQA Loss:            0.1489
  - Avg Segmentation Loss:   0.5276
----------------------------------------
  -> No improvement for 1 epoch(s).
--- Epoch 17 ---
Starting with weights -> VQA: 101.6650, Seg: 98.3350


Training Epoch 17: 100%|████████████████████| 1216/1216 [11:15<00:00,  1.80it/s]



Epoch 17 Avg Combined Training Loss -> 52.2585


Validation Set Eval: 100%|████████████████████| 304/304 [04:12<00:00,  1.20it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 84.02%
  - Segmentation IoU:        0.9470
  - Avg VQA Loss:            0.1136
  - Avg Segmentation Loss:   0.5239
----------------------------------------
  -> No improvement for 2 epoch(s).
--- Epoch 18 ---
Starting with weights -> VQA: 96.7201, Seg: 103.2799


Training Epoch 18: 100%|████████████████████| 1216/1216 [11:34<00:00,  1.75it/s]



Epoch 18 Avg Combined Training Loss -> 54.2568


Validation Set Eval: 100%|████████████████████| 304/304 [03:53<00:00,  1.30it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 85.83%
  - Segmentation IoU:        0.9472
  - Avg VQA Loss:            0.1586
  - Avg Segmentation Loss:   0.5251
----------------------------------------
  -> New best validation metric (180.55). Saving model...
--- Epoch 19 ---
Starting with weights -> VQA: 104.1034, Seg: 95.8966


Training Epoch 19: 100%|████████████████████| 1216/1216 [11:12<00:00,  1.81it/s]



Epoch 19 Avg Combined Training Loss -> 50.1848


Validation Set Eval: 100%|████████████████████| 304/304 [03:50<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 85.01%
  - Segmentation IoU:        0.9501
  - Avg VQA Loss:            0.1608
  - Avg Segmentation Loss:   0.5188
----------------------------------------
  -> No improvement for 1 epoch(s).
--- Epoch 20 ---
Starting with weights -> VQA: 100.3239, Seg: 99.6761


Training Epoch 20: 100%|████████████████████| 1216/1216 [11:11<00:00,  1.81it/s]



Epoch 20 Avg Combined Training Loss -> 52.2466


Validation Set Eval: 100%|████████████████████| 304/304 [03:50<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 84.51%
  - Segmentation IoU:        0.9470
  - Avg VQA Loss:            0.1830
  - Avg Segmentation Loss:   0.5225
----------------------------------------
  -> No improvement for 2 epoch(s).
--- Epoch 21 ---
Starting with weights -> VQA: 101.5275, Seg: 98.4725


Training Epoch 21: 100%|████████████████████| 1216/1216 [11:09<00:00,  1.82it/s]



Epoch 21 Avg Combined Training Loss -> 50.6505


Validation Set Eval: 100%|████████████████████| 304/304 [03:50<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 84.02%
  - Segmentation IoU:        0.9460
  - Avg VQA Loss:            0.2132
  - Avg Segmentation Loss:   0.5224
----------------------------------------
  -> No improvement for 3 epoch(s).
--- Epoch 22 ---
Starting with weights -> VQA: 101.9062, Seg: 98.0938


Training Epoch 22: 100%|████████████████████| 1216/1216 [11:07<00:00,  1.82it/s]



Epoch 22 Avg Combined Training Loss -> 49.8845


Validation Set Eval: 100%|████████████████████| 304/304 [03:48<00:00,  1.33it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 85.01%
  - Segmentation IoU:        0.9469
  - Avg VQA Loss:            0.1906
  - Avg Segmentation Loss:   0.5217
----------------------------------------
  -> No improvement for 4 epoch(s).
--- Epoch 23 ---
Starting with weights -> VQA: 98.6186, Seg: 101.3814


Training Epoch 23: 100%|████████████████████| 1216/1216 [11:07<00:00,  1.82it/s]



Epoch 23 Avg Combined Training Loss -> 51.2884


Validation Set Eval: 100%|████████████████████| 304/304 [03:50<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 86.66%
  - Segmentation IoU:        0.9513
  - Avg VQA Loss:            0.1657
  - Avg Segmentation Loss:   0.5159
----------------------------------------
  -> New best validation metric (181.79). Saving model...
--- Epoch 24 ---
Starting with weights -> VQA: 98.3895, Seg: 101.6105


Training Epoch 24: 100%|████████████████████| 1216/1216 [11:19<00:00,  1.79it/s]



Epoch 24 Avg Combined Training Loss -> 51.1916


Validation Set Eval: 100%|████████████████████| 304/304 [03:53<00:00,  1.30it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 86.66%
  - Segmentation IoU:        0.9470
  - Avg VQA Loss:            0.1814
  - Avg Segmentation Loss:   0.5212
----------------------------------------
  -> No improvement for 1 epoch(s).
--- Epoch 25 ---
Starting with weights -> VQA: 101.0095, Seg: 98.9905


Training Epoch 25: 100%|████████████████████| 1216/1216 [11:15<00:00,  1.80it/s]



Epoch 25 Avg Combined Training Loss -> 51.4321


Validation Set Eval: 100%|████████████████████| 304/304 [03:49<00:00,  1.32it/s]


--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA): 86.49%
  - Segmentation IoU:        0.9489
  - Avg VQA Loss:            0.1792
  - Avg Segmentation Loss:   0.5187
----------------------------------------
  -> No improvement for 2 epoch(s).

Step 5: Loading best model for final evaluation...





Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Final Test Evaluation: 100%|██████████████████| 382/382 [04:49<00:00,  1.32it/s]


--- Results for Final Test Evaluation ---
  - VLM Grade Accuracy (QA): 83.36%
  - Segmentation IoU:        0.9429
  - Avg VQA Loss:            0.2025
  - Avg Segmentation Loss:   0.4687
----------------------------------------



