In [None]:


import os
import glob
import math
import re
import random
import logging
import warnings
from typing import List, Tuple, Dict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm
import cv2


import segmentation_models_pytorch as smp

from peft import LoraConfig, get_peft_model, PeftModel
from transformers import (
    AutoProcessor,
    LlavaForConditionalGeneration,
    CLIPVisionModel,
    get_linear_schedule_with_warmup,
)
from torch.cuda.amp import GradScaler, autocast


warnings.filterwarnings("ignore")
logging.getLogger("transformers").setLevel(logging.ERROR)
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
torch.backends.cuda.matmul.allow_tf32 = True


class JaccardLoss(nn.Module):

    def __init__(self, smooth=1e-6):
        super(JaccardLoss, self).__init__()
        self.smooth = smooth

    def forward(self, y_pred, y_true):

        y_pred_probs = torch.sigmoid(y_pred)


        y_pred_flat = y_pred_probs.view(-1)
        y_true_flat = y_true.view(-1)


        intersection = (y_pred_flat * y_true_flat).sum()
        total = (y_pred_flat + y_true_flat).sum()
        union = total - intersection


        iou = (intersection + self.smooth) / (union + self.smooth)
        return 1 - iou


class VLM_QASegDataset(Dataset):
    def __init__(self, image_paths: List[str], metadata_df: pd.DataFrame, is_train: bool = True):
        self.image_paths: List[str] = []
        self.mask_paths: List[str] = []
        self.questions: List[str] = []
        self.answers: List[str] = []
        self.is_train = is_train


        if self.is_train:

            self.image_transform = transforms.Compose([
                transforms.Resize((336, 336)),
                transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
            ])
        else:

            self.image_transform = transforms.Compose([
                transforms.Resize((336, 336)),
            ])

        self.mask_transform = transforms.Compose([
            transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.NEAREST),
            transforms.ToTensor()
        ])

        mdx = metadata_df.set_index("Patient")

        for img_path in tqdm(image_paths, desc="Initializing and Filtering Dataset"):
            mask_path = img_path.replace(".tif", "_mask.tif")
            if not os.path.exists(mask_path):
                continue


            try:
                mask_check_img = Image.open(mask_path).convert("L")
                mask_check_np = np.array(mask_check_img)
            except Exception as e:
                print(f"Warning: Could not read {mask_path}. Skipping. Error: {e}")
                continue


            if not np.any(mask_check_np > 0):
                continue


            pid_folder = os.path.basename(os.path.dirname(img_path))
            pid_key = "_".join(pid_folder.split("_")[0:3])
            if pid_key in mdx.index:
                row = mdx.loc[[pid_key]].iloc[0]
                grade = row.get("neoplasm_histologic_grade")
                if pd.notna(grade) and int(grade) in [1, 2]:
                    self.image_paths.append(img_path)
                    self.mask_paths.append(mask_path)
                    q = "First, identify the tumor in the image. Second, what is its histologic grade: one or two?"
                    a = f"The grade of the tumor is {'two' if int(grade) == 2 else 'one'}."
                    self.questions.append(q)
                    self.answers.append(a)

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx: int):

        image = Image.open(self.image_paths[idx]).convert("RGB")
        mask = Image.open(self.mask_paths[idx]).convert("L")


        image_np = np.array(image)
        mask_np = (np.array(mask) > 0).astype(np.uint8)
        contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(image_np, contours, -1, (0, 255, 255), thickness=2)
        image = Image.fromarray(image_np)


        image = self.image_transform(image)
        mask_tensor = self.mask_transform(mask)
        mask_tensor = (mask_tensor > 0).float()

        return image, mask_tensor, self.questions[idx], self.answers[idx]


def vlm_collate_fn_for_training(batch):
    images, masks, questions, answers = zip(*batch)
    masks_tensor = torch.stack(masks)
    return list(images), masks_tensor, list(questions), list(answers)

def vlm_collate_fn_for_evaluation(batch):
    images, masks, questions, answers = zip(*batch)
    masks_tensor = torch.stack(masks)
    return list(images), masks_tensor, list(questions), list(answers)


def build_training_batch_cpu_main(images, masks, questions, answers, processor: AutoProcessor):
    prompts = [f"USER: <image>\n{q}\nASSISTANT:" for q in questions]
    full_texts = [
        f"USER: <image>\n{q}\nASSISTANT: {a}{processor.tokenizer.eos_token}"
        for q, a in zip(questions, answers)
    ]

    toks_prompt = processor(text=prompts, images=images, return_tensors="pt", padding=True)
    toks_full = processor(text=full_texts, images=images, return_tensors="pt", padding=True)

    labels = toks_full.input_ids.clone()
    prompt_lens = torch.sum(toks_prompt.attention_mask, dim=1)
    for i in range(labels.size(0)):
        labels[i, : prompt_lens[i]] = -100
    labels[labels == processor.tokenizer.pad_token_id] = -100

    batch_cpu = {
        "input_ids": toks_full.input_ids,
        "pixel_values": toks_full.pixel_values,
        "attention_mask": toks_full.attention_mask,
        "labels": labels,
        "seg_masks_gt": masks,
    }
    return batch_cpu


class LlavaWithSegmentationHead(nn.Module):
    def __init__(self, llava_model):
        super().__init__()
        self.llava = llava_model
        self.vision_tower = self.llava.vision_tower
        self.seg_model = smp.DeepLabV3Plus(
            encoder_name="resnet34",
            encoder_weights=None,
            in_channels=3,
            classes=1,
        )
        smp_encoder_channels = self.seg_model.encoder.out_channels
        self.projection = nn.ModuleList([
            nn.Conv2d(1024, smp_encoder_channels[1], kernel_size=1),
            nn.Conv2d(1024, smp_encoder_channels[2], kernel_size=1),
            nn.Conv2d(1024, smp_encoder_channels[3], kernel_size=1),
            nn.Conv2d(1024, smp_encoder_channels[4], kernel_size=1),
            nn.Conv2d(1024, smp_encoder_channels[5], kernel_size=1),
        ])

    def forward(self, input_ids, pixel_values, attention_mask, labels=None, **kwargs):
        image_features = self.vision_tower(pixel_values, output_hidden_states=True)
        image_features_grid_with_cls = image_features.hidden_states[-1]
        image_features_grid = image_features_grid_with_cls[:, 1:, :]

        batch_size = image_features_grid.shape[0]
        patch_grid_size = int(math.sqrt(image_features_grid.shape[1]))
        hidden_size = image_features_grid.shape[2]
        seg_features = image_features_grid.reshape(batch_size, patch_grid_size, patch_grid_size, hidden_size)
        seg_features = seg_features.permute(0, 3, 1, 2).contiguous()

        projected_features = [proj(seg_features) for proj in self.projection]
        scaled_projected_features = list(projected_features)
        scaled_projected_features[1] = F.interpolate(
            scaled_projected_features[1], scale_factor=4, mode='bilinear', align_corners=False
        )
        decoder_features = [None] + scaled_projected_features
        decoder_output = self.seg_model.decoder(decoder_features)
        seg_logits = self.seg_model.segmentation_head(decoder_output)
        seg_logits = F.interpolate(seg_logits, size=(336, 336), mode='bilinear', align_corners=False)

        vqa_output = self.llava(
            input_ids=input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=True
        )

        return {
            "vqa_loss": vqa_output.loss,
            "vqa_logits": vqa_output.logits,
            "seg_logits": seg_logits.squeeze(1)
        }


def compute_iou(pred_mask, true_mask, threshold=0.5):
    with torch.no_grad():
        pred_mask = (torch.sigmoid(pred_mask) > threshold).float()
        true_mask = true_mask.float()
        intersection = (pred_mask * true_mask).sum(dim=(1, 2))
        union = pred_mask.sum(dim=(1, 2)) + true_mask.sum(dim=(1, 2)) - intersection
        iou = (intersection + 1e-6) / (union + 1e-6)
        return iou.mean().item()

def run_evaluation(model, processor, data_loader: DataLoader, device, description="Evaluating"):
    model.eval()
    total_samples, vlm_correct, total_loss_count = 0, 0, 0
    total_vqa_loss_sum, total_seg_loss_sum, total_iou = 0.0, 0.0, 0.0
    seg_loss_fn = JaccardLoss().to(device)

    with torch.no_grad():
        for batch in tqdm(data_loader, desc=description):
            images, masks_gt, questions, answers = batch
            masks_gt = masks_gt.to(device)
            prompts = [f"USER: <image>\n{q}\nASSISTANT:" for q in questions]

            with autocast():
                gen_inputs = processor(text=prompts, images=images, return_tensors="pt", padding=True).to(device)
                generated_ids = model.llava.generate(
                    **gen_inputs, max_new_tokens=20, pad_token_id=processor.tokenizer.pad_token_id
                )
            decoded = processor.batch_decode(generated_ids, skip_special_tokens=True)

            for i in range(len(decoded)):
                pred_span = decoded[i].split("ASSISTANT:")[-1].strip().lower()
                true_span = answers[i].lower()
                want_two = "two" in true_span
                has_one = "one" in pred_span or "1" in pred_span
                has_two = "two" in pred_span or "2" in pred_span
                ok = (want_two and has_two and not has_one) or ((not want_two) and has_one and not has_two)
                if ok:
                    vlm_correct += 1

            batch_cpu = build_training_batch_cpu_main(images, masks_gt.cpu(), questions, answers, processor)
            batch_gpu = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch_cpu.items()}

            with autocast():
                outputs = model(**batch_gpu)
                vqa_loss, seg_logits = outputs["vqa_loss"], outputs["seg_logits"]
                seg_loss = seg_loss_fn(seg_logits, batch_gpu["seg_masks_gt"].squeeze(1))

            if vqa_loss is not None: total_vqa_loss_sum += vqa_loss.item()
            if seg_loss is not None: total_seg_loss_sum += seg_loss.item()
            total_loss_count += 1
            total_iou += compute_iou(seg_logits, batch_gpu["seg_masks_gt"].squeeze(1))
            total_samples += len(answers)

    vlm_acc = (vlm_correct / total_samples) * 100 if total_samples else 0.0
    avg_vqa_loss = total_vqa_loss_sum / total_loss_count if total_loss_count else float("inf")
    avg_seg_loss = total_seg_loss_sum / total_loss_count if total_loss_count else float("inf")
    avg_iou = total_iou / total_loss_count if total_loss_count else 0.0
    ppl = math.exp(avg_vqa_loss) if avg_vqa_loss < 50 else float("inf")

    print(f"\n--- Results for {description} ---")
    print(f"  - VLM Grade Accuracy (QA):      {vlm_acc:.2f}%")
    print(f"  - Perplexity (teacher-forced): {ppl:.4f}")
    print(f"  - Segmentation IoU:             {avg_iou:.4f}")
    print(f"  - Avg Segmentation Loss:        {avg_seg_loss:.4f}")
    print("-" * 40)
    return vlm_acc, avg_iou


def discover_lora_targets(llava_model, include_vision: bool = True) -> List[str]:
    text_keys = {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"}
    projector_keys = {"multi_modal_projector"}
    vision_keys = {"q_proj", "k_proj", "v_proj", "out_proj"}
    target_modules: set[str] = set()
    for name, module in llava_model.named_modules():
        if any(k in name for k in text_keys) and "language_model" in name:
            target_modules.add(name.split(".")[-1])
        if any(k in name for k in projector_keys):
            if hasattr(module, "weight") and getattr(module, "weight", None) is not None:
                target_modules.add(name.split(".")[-1])
        if include_vision and ("vision_tower" in name) and any(k in name for k in vision_keys):
            target_modules.add(name.split(".")[-1])
    return sorted(list(target_modules))


if __name__ == "__main__":
    config = {
        "device": "cuda:0" if torch.cuda.is_available() else "cpu",
        "base_path": "/home/ealam/Downloads/LGG dataset Cameron/lgg-mri-segmentation/kaggle_3m",
        "local_llava_path": "/home/ealam/Desktop/llava-1.5-7b-local",
        "save_path": "./llava-lora-multitask-delineated-augmented-filtered",
        "csv_path": "/home/ealam/Downloads/LGG dataset Cameron/lgg-mri-segmentation/kaggle_3m/data.csv",
        "learning_rate": 1e-5,
        "batch_size": 2,
        "num_epochs": 25,
        "early_stopping_patience": 5,
        "seed": 42,
        "include_vision_lora": True,
        "seg_loss_weight": 0.25,
        "num_workers": 4,
        "grad_clip_val": 1.0,
    }


    torch.manual_seed(config["seed"])
    np.random.seed(config["seed"])
    random.seed(config["seed"])


    print("Step 1: Gathering and splitting data...")
    all_image_paths = [p.replace("_mask.tif", ".tif") for p in glob.glob(os.path.join(config["base_path"], "*", "*_mask.tif"))]
    all_image_paths = [p for p in all_image_paths if os.path.exists(p)]
    usable_paths, _ = train_test_split(all_image_paths, test_size=0.01, random_state=config["seed"])
    train_val_paths, test_paths = train_test_split(usable_paths, test_size=0.20, random_state=config["seed"])
    train_paths, val_paths = train_test_split(train_val_paths, test_size=0.20, random_state=config["seed"])


    print("\nStep 2: Setting up multi-task model and processor...")
    DEVICE = config["device"]
    base_model = LlavaForConditionalGeneration.from_pretrained(
        config["local_llava_path"], torch_dtype=torch.float16, low_cpu_mem_usage=True
    )
    processor = AutoProcessor.from_pretrained(config["local_llava_path"])
    if processor.tokenizer.pad_token is None:
        processor.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        base_model.resize_token_embeddings(len(processor.tokenizer))

    target_modules = discover_lora_targets(base_model, include_vision=config["include_vision_lora"])
    print("LoRA target modules:", target_modules)
    lora_cfg = LoraConfig(
        r=32, lora_alpha=64, target_modules=target_modules,
        lora_dropout=0.05, bias="none", task_type="CAUSAL_LM",
    )
    peft_model = get_peft_model(base_model, lora_cfg)
    multitask_model = LlavaWithSegmentationHead(peft_model).to(DEVICE)
    peft_model.print_trainable_parameters()


    print("\nStep 3: Preparing DataLoaders...")
    metadata_df = pd.read_csv(config["csv_path"])
    train_ds = VLM_QASegDataset(train_paths, metadata_df, is_train=True)
    val_ds = VLM_QASegDataset(val_paths, metadata_df, is_train=False)
    test_ds = VLM_QASegDataset(test_paths, metadata_df, is_train=False)

    print(f"\nData splits after filtering: {len(train_ds)} training, {len(val_ds)} validation, and {len(test_ds)} test samples.")


    train_loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True, num_workers=config["num_workers"], collate_fn=vlm_collate_fn_for_training, drop_last=True)


    val_loader = DataLoader(val_ds, batch_size=config["batch_size"], shuffle=False, num_workers=config["num_workers"], collate_fn=vlm_collate_fn_for_evaluation)
    test_loader = DataLoader(test_ds, batch_size=config["batch_size"], shuffle=False, num_workers=config["num_workers"], collate_fn=vlm_collate_fn_for_evaluation)


    print("\nStep 4: Starting multi-task fine-tuning...")
    trainable_params = [p for p in multitask_model.parameters() if p.requires_grad]
    optimizer = AdamW(trainable_params, lr=config["learning_rate"])
    scaler = GradScaler()
    seg_loss_fn = JaccardLoss().to(DEVICE)

    num_training_steps = len(train_loader) * config["num_epochs"]
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=int(0.1 * num_training_steps), num_training_steps=num_training_steps
    )

    best_val_metric, patience = 0.0, 0
    for epoch in range(config["num_epochs"]):
        multitask_model.train()
        total_loss = 0.0
        for images, masks, questions, answers in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
            batch_cpu = build_training_batch_cpu_main(images, masks, questions, answers, processor)
            batch_gpu = {k: v.to(DEVICE) if torch.is_tensor(v) else v for k, v in batch_cpu.items()}
            optimizer.zero_grad(set_to_none=True)

            with autocast():
                outputs = multitask_model(**batch_gpu)
                vqa_loss = outputs["vqa_loss"]
                seg_logits = outputs["seg_logits"]
                seg_loss = seg_loss_fn(seg_logits, batch_gpu["seg_masks_gt"].squeeze(1))
                combined_loss = vqa_loss + config["seg_loss_weight"] * seg_loss

            scaler.scale(combined_loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(trainable_params, config["grad_clip_val"])
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            total_loss += combined_loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"\nEpoch {epoch+1} Avg Combined Loss -> {avg_loss:.4f}")

        val_acc, val_iou = run_evaluation(multitask_model, processor, val_loader, DEVICE, "Validation Set Eval")
        current_metric = val_acc + (val_iou * 100)

        if current_metric > best_val_metric:
            print(f"  -> New best validation metric ({current_metric:.2f}). Saving model...")
            best_val_metric = current_metric
            patience = 0
            save_dir = config["save_path"]
            os.makedirs(save_dir, exist_ok=True)
            torch.save(multitask_model.seg_model.state_dict(), os.path.join(save_dir, "seg_model.pth"))
            torch.save(multitask_model.projection.state_dict(), os.path.join(save_dir, "projection.pth"))
            multitask_model.llava.save_pretrained(os.path.join(save_dir, "llava_lora"))
            processor.save_pretrained(os.path.join(save_dir, "processor"))
        else:
            patience += 1
            print(f"  -> No improvement for {patience} epoch(s).")
            if patience >= config["early_stopping_patience"]:
                print("\n--- Early stopping triggered. ---")
                break
        print("=" * 80)


    print("\nStep 5: Loading best model for final evaluation...")
    save_path = config["save_path"]
    if os.path.exists(os.path.join(save_path, "seg_model.pth")):
        final_base_model = LlavaForConditionalGeneration.from_pretrained(config["local_llava_path"], torch_dtype=torch.float16)
        final_peft_model = PeftModel.from_pretrained(final_base_model, os.path.join(save_path, "llava_lora"))
        final_multitask_model = LlavaWithSegmentationHead(final_peft_model).to(DEVICE)
        final_multitask_model.seg_model.load_state_dict(torch.load(os.path.join(save_path, "seg_model.pth")))
        final_multitask_model.projection.load_state_dict(torch.load(os.path.join(save_path, "projection.pth")))
        final_processor = AutoProcessor.from_pretrained(os.path.join(save_path, "processor"))

        run_evaluation(final_multitask_model, final_processor, test_loader, DEVICE, "Final Test Evaluation")
    else:
        print("No model was saved.")



Step 1: Gathering and splitting data...

Step 2: Setting up multi-task model and processor...


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

LoRA target modules: ['down_proj', 'gate_proj', 'k_proj', 'linear_1', 'linear_2', 'o_proj', 'out_proj', 'q_proj', 'up_proj', 'v_proj']
trainable params: 86,671,360 || all params: 7,150,098,432 || trainable%: 1.2122

Step 3: Preparing DataLoaders...


Initializing and Filtering Dataset: 100%|█| 2488/2488 [00:01<00:00, 1926.00it/s]
Initializing and Filtering Dataset: 100%|███| 623/623 [00:00<00:00, 1793.23it/s]
Initializing and Filtering Dataset: 100%|███| 778/778 [00:00<00:00, 1965.01it/s]



Data splits after filtering: 833 training, 214 validation, and 283 test samples.

Step 4: Starting multi-task fine-tuning...


Training Epoch 1: 100%|███████████████████████| 416/416 [04:17<00:00,  1.62it/s]



Epoch 1 Avg Combined Loss -> 0.8729


Validation Set Eval: 100%|████████████████████| 107/107 [01:21<00:00,  1.31it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      47.20%
  - Perplexity (teacher-forced): 1.0813
  - Segmentation IoU:             0.1256
  - Avg Segmentation Loss:        0.9323
----------------------------------------
  -> New best validation metric (59.76). Saving model...


Training Epoch 2: 100%|███████████████████████| 416/416 [04:16<00:00,  1.62it/s]



Epoch 2 Avg Combined Loss -> 0.2969


Validation Set Eval: 100%|████████████████████| 107/107 [01:21<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      52.34%
  - Perplexity (teacher-forced): 1.0714
  - Segmentation IoU:             0.2860
  - Avg Segmentation Loss:        0.8813
----------------------------------------
  -> New best validation metric (80.94). Saving model...


Training Epoch 3: 100%|███████████████████████| 416/416 [04:16<00:00,  1.62it/s]



Epoch 3 Avg Combined Loss -> 0.2711


Validation Set Eval: 100%|████████████████████| 107/107 [01:21<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      69.16%
  - Perplexity (teacher-forced): 1.0627
  - Segmentation IoU:             0.5370
  - Avg Segmentation Loss:        0.7794
----------------------------------------
  -> New best validation metric (122.85). Saving model...


Training Epoch 4: 100%|███████████████████████| 416/416 [04:14<00:00,  1.64it/s]



Epoch 4 Avg Combined Loss -> 0.2472


Validation Set Eval: 100%|████████████████████| 107/107 [01:20<00:00,  1.33it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      70.09%
  - Perplexity (teacher-forced): 1.0677
  - Segmentation IoU:             0.6167
  - Avg Segmentation Loss:        0.7027
----------------------------------------
  -> New best validation metric (131.77). Saving model...


Training Epoch 5: 100%|███████████████████████| 416/416 [04:12<00:00,  1.65it/s]



Epoch 5 Avg Combined Loss -> 0.2153


Validation Set Eval: 100%|████████████████████| 107/107 [01:20<00:00,  1.33it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      75.70%
  - Perplexity (teacher-forced): 1.0593
  - Segmentation IoU:             0.6722
  - Avg Segmentation Loss:        0.6364
----------------------------------------
  -> New best validation metric (142.92). Saving model...


Training Epoch 6: 100%|███████████████████████| 416/416 [04:08<00:00,  1.67it/s]



Epoch 6 Avg Combined Loss -> 0.1930


Validation Set Eval: 100%|████████████████████| 107/107 [01:21<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      78.97%
  - Perplexity (teacher-forced): 1.0774
  - Segmentation IoU:             0.7188
  - Avg Segmentation Loss:        0.5657
----------------------------------------
  -> New best validation metric (150.85). Saving model...


Training Epoch 7: 100%|███████████████████████| 416/416 [04:05<00:00,  1.70it/s]



Epoch 7 Avg Combined Loss -> 0.1709


Validation Set Eval: 100%|████████████████████| 107/107 [01:20<00:00,  1.33it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      83.18%
  - Perplexity (teacher-forced): 1.0721
  - Segmentation IoU:             0.7397
  - Avg Segmentation Loss:        0.5091
----------------------------------------
  -> New best validation metric (157.15). Saving model...


Training Epoch 8: 100%|███████████████████████| 416/416 [04:03<00:00,  1.71it/s]



Epoch 8 Avg Combined Loss -> 0.1549


Validation Set Eval: 100%|████████████████████| 107/107 [01:20<00:00,  1.33it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      78.50%
  - Perplexity (teacher-forced): 1.0952
  - Segmentation IoU:             0.7629
  - Avg Segmentation Loss:        0.4512
----------------------------------------
  -> No improvement for 1 epoch(s).


Training Epoch 9: 100%|███████████████████████| 416/416 [04:02<00:00,  1.71it/s]



Epoch 9 Avg Combined Loss -> 0.1408


Validation Set Eval: 100%|████████████████████| 107/107 [01:21<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      72.43%
  - Perplexity (teacher-forced): 1.1660
  - Segmentation IoU:             0.7722
  - Avg Segmentation Loss:        0.4149
----------------------------------------
  -> No improvement for 2 epoch(s).


Training Epoch 10: 100%|██████████████████████| 416/416 [04:00<00:00,  1.73it/s]



Epoch 10 Avg Combined Loss -> 0.1255


Validation Set Eval: 100%|████████████████████| 107/107 [01:21<00:00,  1.31it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      85.51%
  - Perplexity (teacher-forced): 1.0798
  - Segmentation IoU:             0.7810
  - Avg Segmentation Loss:        0.3727
----------------------------------------
  -> New best validation metric (163.62). Saving model...


Training Epoch 11: 100%|██████████████████████| 416/416 [03:58<00:00,  1.74it/s]



Epoch 11 Avg Combined Loss -> 0.1072


Validation Set Eval: 100%|████████████████████| 107/107 [01:21<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      77.57%
  - Perplexity (teacher-forced): 1.1266
  - Segmentation IoU:             0.7913
  - Avg Segmentation Loss:        0.3367
----------------------------------------
  -> No improvement for 1 epoch(s).


Training Epoch 12: 100%|██████████████████████| 416/416 [03:59<00:00,  1.73it/s]



Epoch 12 Avg Combined Loss -> 0.1057


Validation Set Eval: 100%|████████████████████| 107/107 [01:21<00:00,  1.31it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      86.45%
  - Perplexity (teacher-forced): 1.0923
  - Segmentation IoU:             0.7984
  - Avg Segmentation Loss:        0.3125
----------------------------------------
  -> New best validation metric (166.29). Saving model...


Training Epoch 13: 100%|██████████████████████| 416/416 [03:56<00:00,  1.76it/s]



Epoch 13 Avg Combined Loss -> 0.0951


Validation Set Eval: 100%|████████████████████| 107/107 [01:21<00:00,  1.31it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      85.51%
  - Perplexity (teacher-forced): 1.1089
  - Segmentation IoU:             0.8055
  - Avg Segmentation Loss:        0.2885
----------------------------------------
  -> No improvement for 1 epoch(s).


Training Epoch 14: 100%|██████████████████████| 416/416 [03:56<00:00,  1.76it/s]



Epoch 14 Avg Combined Loss -> 0.0916


Validation Set Eval: 100%|████████████████████| 107/107 [01:21<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      84.11%
  - Perplexity (teacher-forced): 1.1103
  - Segmentation IoU:             0.8070
  - Avg Segmentation Loss:        0.2770
----------------------------------------
  -> No improvement for 2 epoch(s).


Training Epoch 15: 100%|██████████████████████| 416/416 [03:56<00:00,  1.76it/s]



Epoch 15 Avg Combined Loss -> 0.0844


Validation Set Eval: 100%|████████████████████| 107/107 [01:20<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      86.92%
  - Perplexity (teacher-forced): 1.1128
  - Segmentation IoU:             0.8130
  - Avg Segmentation Loss:        0.2565
----------------------------------------
  -> New best validation metric (168.22). Saving model...


Training Epoch 16: 100%|██████████████████████| 416/416 [03:55<00:00,  1.76it/s]



Epoch 16 Avg Combined Loss -> 0.0797


Validation Set Eval: 100%|████████████████████| 107/107 [01:21<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      88.32%
  - Perplexity (teacher-forced): 1.1051
  - Segmentation IoU:             0.8187
  - Avg Segmentation Loss:        0.2468
----------------------------------------
  -> New best validation metric (170.18). Saving model...


Training Epoch 17: 100%|██████████████████████| 416/416 [03:54<00:00,  1.77it/s]



Epoch 17 Avg Combined Loss -> 0.0759


Validation Set Eval: 100%|████████████████████| 107/107 [01:20<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      76.17%
  - Perplexity (teacher-forced): 1.1944
  - Segmentation IoU:             0.8212
  - Avg Segmentation Loss:        0.2373
----------------------------------------
  -> No improvement for 1 epoch(s).


Training Epoch 18: 100%|██████████████████████| 416/416 [03:55<00:00,  1.77it/s]



Epoch 18 Avg Combined Loss -> 0.0698


Validation Set Eval: 100%|████████████████████| 107/107 [01:20<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      84.58%
  - Perplexity (teacher-forced): 1.1218
  - Segmentation IoU:             0.8248
  - Avg Segmentation Loss:        0.2294
----------------------------------------
  -> No improvement for 2 epoch(s).


Training Epoch 19: 100%|██████████████████████| 416/416 [03:55<00:00,  1.77it/s]



Epoch 19 Avg Combined Loss -> 0.0702


Validation Set Eval: 100%|████████████████████| 107/107 [01:21<00:00,  1.31it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      85.51%
  - Perplexity (teacher-forced): 1.1273
  - Segmentation IoU:             0.8289
  - Avg Segmentation Loss:        0.2195
----------------------------------------
  -> No improvement for 3 epoch(s).


Training Epoch 20: 100%|██████████████████████| 416/416 [03:55<00:00,  1.77it/s]



Epoch 20 Avg Combined Loss -> 0.0652


Validation Set Eval: 100%|████████████████████| 107/107 [01:21<00:00,  1.32it/s]



--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      85.51%
  - Perplexity (teacher-forced): 1.1243
  - Segmentation IoU:             0.8295
  - Avg Segmentation Loss:        0.2158
----------------------------------------
  -> No improvement for 4 epoch(s).


Training Epoch 21: 100%|██████████████████████| 416/416 [03:55<00:00,  1.77it/s]



Epoch 21 Avg Combined Loss -> 0.0652


Validation Set Eval: 100%|████████████████████| 107/107 [01:21<00:00,  1.32it/s]


--- Results for Validation Set Eval ---
  - VLM Grade Accuracy (QA):      85.98%
  - Perplexity (teacher-forced): 1.1099
  - Segmentation IoU:             0.8334
  - Avg Segmentation Loss:        0.2091
----------------------------------------
  -> No improvement for 5 epoch(s).

--- Early stopping triggered. ---

Step 5: Loading best model for final evaluation...





Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Final Test Evaluation: 100%|██████████████████| 142/142 [01:47<00:00,  1.31it/s]


--- Results for Final Test Evaluation ---
  - VLM Grade Accuracy (QA):      84.45%
  - Perplexity (teacher-forced): 1.1266
  - Segmentation IoU:             0.8025
  - Avg Segmentation Loss:        0.2650
----------------------------------------



