In [1]:
import torch
import torchvision

print("torch:", torch.__version__)
print("torchvision:", torchvision.__version__)

import torchvision.transforms as T
import torchvision.models as models

print("‚úÖ torchvision fully loaded")


torch: 2.7.1+cu126
torchvision: 0.22.1+cu126
‚úÖ torchvision fully loaded


In [2]:
import os
import json
import math
import random
from typing import List, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as T
import torchvision.models as models

from transformers import AutoTokenizer, AutoModel



In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MAX_REPORT_LEN = 120
EMBED_DIM = 256
TEXT_DIM = 768
KG_DIM = 256
HIDDEN_DIM = 512

BATCH_SIZE = 4
# NUM_WORKERS = 4
NUM_WORKERS = 0

print("Using device:", DEVICE)

Using device: cuda


In [4]:
import torch

print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)
print("Device count:", torch.cuda.device_count())


CUDA available: True
CUDA version: 12.6
Device count: 1


In [5]:
image_transform = T.Compose([
    T.Resize((256, 256)),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

In [6]:
LOCATION_TOKENS = {
    "Head": {
        "bos": "<HEAD_BOS>",
        "eos": "<HEAD_EOS>"
    },
    "Thorax": {
        "bos": "<THORAX_BOS>",
        "eos": "<THORAX_EOS>"
    },
    "Abdomen": {
        "bos": "<ABDOMEN_BOS>",
        "eos": "<ABDOMEN_EOS>"
    },
    "Spine and Muscles": {
        "bos": "<SPINE_BOS>",
        "eos": "<SPINE_EOS>"
    },
    "Reproductive and Urinary System": {
        "bos": "<GU_BOS>",
        "eos": "<GU_EOS>"
    }
}


In [7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

special_tokens = {
    "pad_token": "<PAD>",
    "additional_special_tokens": []
}

for loc in LOCATION_TOKENS:
    special_tokens["additional_special_tokens"].append(
        LOCATION_TOKENS[loc]["bos"]
    )
    special_tokens["additional_special_tokens"].append(
        LOCATION_TOKENS[loc]["eos"]
    )

tokenizer.add_special_tokens(special_tokens)

VOCAB_SIZE = len(tokenizer)
print("Vocab size:", VOCAB_SIZE)


Vocab size: 30533


In [8]:
import ast
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset


class MedPixDataset(Dataset):
    def __init__(self, csv_path, transform=None):
        self.df = pd.read_csv(csv_path)
        self.transform = transform

        # Normalize NaNs early (VERY IMPORTANT)
        self.df = self.df.fillna("")

    def __len__(self):
        return len(self.df)

    def load_image(self, img_path):
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img

    def parse_image_list(self, s):
        # CSV stores lists as strings: "['path1', 'path2']"
        if s == "" or s == "[]":
            return []
        return ast.literal_eval(s)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # -------- Images --------
        ct_images = []
        mri_images = []

        ct_paths = self.parse_image_list(row["CT_image_paths"])
        mri_paths = self.parse_image_list(row["MRI_image_paths"])

        for p in ct_paths:
            ct_images.append(self.load_image(p))

        for p in mri_paths:
            mri_images.append(self.load_image(p))

        # -------- Text Encoder Input --------
        text_input = row["combined_text"]

        # -------- Target Report --------
        report = row["findings"]

        # -------- Location (KG routing) --------
        location_category = row["Location Category"]

        return {
            "uid": row["U_id"],
            "ct_images": ct_images,     # list[Tensor]
            "mri_images": mri_images,   # list[Tensor]
            "text_input": text_input,   # str
            "report": report,           # str
            "location": location_category
        }


In [9]:
def collate_fn(batch):
    """
    Batch items contain:
    - ct_images: list[Tensor]
    - mri_images: list[Tensor]
    - text_input: str
    - report: str
    - location: str
    """

    # ---------- CT images ----------
    max_ct = max(len(b["ct_images"]) for b in batch)
    ct_imgs, ct_masks = [], []

    for b in batch:
        imgs = b["ct_images"]
        if len(imgs) == 0:
            dummy = torch.zeros(3, 224, 224)
            imgs = [dummy]

        pad = max_ct - len(imgs)
        imgs = imgs + [torch.zeros_like(imgs[0])] * pad
        mask = [1] * (len(imgs) - pad) + [0] * pad

        ct_imgs.append(torch.stack(imgs))
        ct_masks.append(torch.tensor(mask))

    ct_imgs = torch.stack(ct_imgs)      # (B, N_ct, 3, H, W)
    ct_masks = torch.stack(ct_masks)    # (B, N_ct)

    # ---------- MRI images ----------
    max_mri = max(len(b["mri_images"]) for b in batch)
    mri_imgs, mri_masks = [], []

    for b in batch:
        imgs = b["mri_images"]
        if len(imgs) == 0:
            dummy = torch.zeros(3, 224, 224)
            imgs = [dummy]

        pad = max_mri - len(imgs)
        imgs = imgs + [torch.zeros_like(imgs[0])] * pad
        mask = [1] * (len(imgs) - pad) + [0] * pad

        mri_imgs.append(torch.stack(imgs))
        mri_masks.append(torch.tensor(mask))

    mri_imgs = torch.stack(mri_imgs)    # (B, N_mri, 3, H, W)
    mri_masks = torch.stack(mri_masks)  # (B, N_mri)

    # ---------- Text encoder input ----------
    text_inputs = [b["text_input"] for b in batch]
    text_enc = tokenizer(
        text_inputs,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )

    # ---------- Decoder target ----------
    reports = [b["report"] for b in batch]   # <-- RAW ground truth
    report_enc = tokenizer(
        reports,
        padding=True,
        truncation=True,
        max_length=MAX_REPORT_LEN,
        return_tensors="pt"
    )

    # ---------- Location ----------
    locations = [b["location"] for b in batch]

    return {
        "ct_images": ct_imgs,
        "ct_masks": ct_masks,
        "mri_images": mri_imgs,
        "mri_masks": mri_masks,
        "text_input_ids": text_enc["input_ids"],
        "text_attention_mask": text_enc["attention_mask"],
        "report_input_ids": report_enc["input_ids"],
        "report_attention_mask": report_enc["attention_mask"],
        "reports": reports,              # <-- RAW ground-truth strings
        "locations": locations
    }


In [10]:
ds = MedPixDataset(r"C:\fyp_manish_shyam_phase2\data\df_overall.csv", transform=image_transform)

sample = ds[0]
print(sample["uid"])
print("CT images:", len(sample["ct_images"]))
print("MRI images:", len(sample["mri_images"]))
print("Text length:", len(sample["text_input"]))
print("Report length:", len(sample["report"]))
print("Location:", sample["location"])


MPX1009
CT images: 2
MRI images: 0
Text length: 379
Report length: 152
Location: Reproductive and Urinary System


In [11]:
import numpy as np
print(np.__version__)


1.26.4


In [12]:
# Taking 2‚Äì3 samples manually for doing a small sanity check
batch_samples = [ds[i] for i in range(3)]

batch = collate_fn(batch_samples)
print("CT images shape:", batch["ct_images"].shape)
print("CT masks shape:", batch["ct_masks"].shape)

print("MRI images shape:", batch["mri_images"].shape)
print("MRI masks shape:", batch["mri_masks"].shape)

print("Text input ids shape:", batch["text_input_ids"].shape)
print("Report input ids shape:", batch["report_input_ids"].shape)

print("Raw reports count:", len(batch["reports"]))
print("First report preview:\n", batch["reports"][0][:200])

print("Locations:", batch["locations"])


CT images shape: torch.Size([3, 2, 3, 224, 224])
CT masks shape: torch.Size([3, 2])
MRI images shape: torch.Size([3, 1, 3, 224, 224])
MRI masks shape: torch.Size([3, 2])
Text input ids shape: torch.Size([3, 90])
Report input ids shape: torch.Size([3, 73])
Raw reports count: 3
First report preview:
 Bladder with thickened wall and diverticulum on the right. Diverticulum is mostly likely secondary to chronic outflow obstruction. Prostate enlargement.
Locations: ['Reproductive and Urinary System', 'Thorax', 'Reproductive and Urinary System']


In [13]:
from torch.utils.data import DataLoader, random_split

# ---- Load full dataset ----
full_dataset = MedPixDataset(
    r"C:\fyp_manish_shyam_phase2\data\df_overall.csv",
    transform=image_transform
)

# ---- 80 / 20 split ----
dataset_size = len(full_dataset)
train_size = int(0.95 * dataset_size)
test_size = dataset_size - train_size

# Reproducibility
generator = torch.Generator().manual_seed(42)

train_dataset, test_dataset = random_split(
    full_dataset,
    [train_size, test_size],
    generator=generator
)

print(f"Total samples: {dataset_size}")
print(f"Train samples: {len(train_dataset)}")
print(f"Test samples:  {len(test_dataset)}")


Total samples: 671
Train samples: 637
Test samples:  34


In [14]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn
    # , pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,          # No shuffle for test
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn
    # , pin_memory=True
)

print("Train & Test loaders ready")

Train & Test loaders ready


In [15]:
import torchvision.models as models

class ImageEncoder(nn.Module):
    def __init__(self, embed_dim=EMBED_DIM):
        super().__init__()

        base = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

        for p in base.parameters():
            p.requires_grad = False

        base.fc = nn.Linear(base.fc.in_features, embed_dim)
        self.cnn = base

    def forward(self, x):
        """
        x: (B, N, 3, H, W)
        return: (B, N, D)
        """
        B, N, C, H, W = x.shape
        x = x.view(B * N, C, H, W)
        feats = self.cnn(x)
        feats = feats.view(B, N, -1)
        return feats


In [16]:
ct_encoder = ImageEncoder().to(DEVICE)
mri_encoder = ImageEncoder().to(DEVICE)

In [17]:

def masked_mean_pooling(feats, masks):
    masks = masks.unsqueeze(-1).float()   # (B, N, 1)
    summed = (feats * masks).sum(dim=1)
    denom = masks.sum(dim=1).clamp(min=1e-6)
    return summed / denom


In [18]:
from transformers import AutoModel
import torch.nn as nn

class TextEncoder(nn.Module):
    def __init__(self, model_name, embed_dim):
        super().__init__()
        self.lm = AutoModel.from_pretrained(model_name)

        hidden_dim = self.lm.config.hidden_size
        self.proj = nn.Linear(hidden_dim, embed_dim)

    def forward(self, input_ids, attention_mask):

        outputs = self.lm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=False
        )

        # Mean pooling over tokens (causal models do NOT have CLS)
        hidden = outputs.last_hidden_state  # (B, T, H)
        mask = attention_mask.unsqueeze(-1)

        pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1)

        return self.proj(pooled)


In [19]:
TEXT_MODEL_NAME = "bert-base-uncased"

text_encoder = TextEncoder(
    model_name=TEXT_MODEL_NAME,
    embed_dim=EMBED_DIM
).to(DEVICE)

text_encoder.lm.resize_token_embeddings(len(tokenizer))


The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Embedding(30533, 768, padding_idx=0)

In [20]:
class GCNLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim)

    def forward(self, A_hat, X):
        """
        A_hat: (N, N) normalized adjacency
        X: (N, D)
        """
        return F.relu(self.linear(A_hat @ X))


class GCN(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers=2):
        super().__init__()
        layers = []
        dims = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim]

        for i in range(len(dims) - 1):
            layers.append(GCNLayer(dims[i], dims[i + 1]))

        self.layers = nn.ModuleList(layers)

    def forward(self, A_hat, X):
        for layer in self.layers:
            X = layer(A_hat, X)
        return X.mean(dim=0)   # graph-level embedding


In [21]:
def normalize_adjacency(A: torch.Tensor) -> torch.Tensor:
    """
    A: (N, N) raw adjacency matrix
    returns: (N, N) normalized adjacency with self-loops
    """
    device = A.device
    N = A.size(0)

    # Add self-loops
    A_tilde = A + torch.eye(N, device=device)

    # Degree
    D = A_tilde.sum(dim=1)

    # D^{-1/2}
    D_inv_sqrt = torch.pow(D, -0.5)
    D_inv_sqrt[torch.isinf(D_inv_sqrt)] = 0.0
    D_inv_sqrt = torch.diag(D_inv_sqrt)

    # Symmetric normalization
    A_hat = D_inv_sqrt @ A_tilde @ D_inv_sqrt
    return A_hat

In [22]:
import pandas as pd
import numpy as np

KG_LOCATION_MAP = {
    "Head": r"C:\fyp_manish_shyam_phase2\data\split_by_location_category_matrices\Head_matrix.csv",
    "Thorax": r"C:\fyp_manish_shyam_phase2\data\split_by_location_category_matrices\Thorax_matrix.csv",
    "Abdomen": r"C:\fyp_manish_shyam_phase2\data\split_by_location_category_matrices\Abdomen_matrix.csv",
    "Spine and Muscles": r"C:\fyp_manish_shyam_phase2\data\split_by_location_category_matrices\Spine_and_Muscles_matrix.csv",
    "Reproductive and Urinary System": r"C:\fyp_manish_shyam_phase2\data\split_by_location_category_matrices\Reproductive_and_Urinary_System_matrix.csv"
}


A_hat_dict = {}

# ---- Load and normalize adjacency matrices ----
for loc, path in KG_LOCATION_MAP.items():
    df = pd.read_csv(path, index_col=0)

    A = torch.tensor(
        df.values,
        dtype=torch.float32,
        device=DEVICE
    )

    A_hat = normalize_adjacency(A)
    A_hat_dict[loc] = A_hat

    print(f"{loc}: A_hat shape = {A_hat.shape}")

# ---- Create shared X_nodes (identity) ----
# Node count inferred from any adjacency matrix
example_loc = next(iter(A_hat_dict))
N_nodes = A_hat_dict[example_loc].shape[0]

X_nodes = torch.eye(N_nodes, device=DEVICE)

print("Shared X_nodes shape:", X_nodes.shape)



Head: A_hat shape = torch.Size([4400, 4400])
Thorax: A_hat shape = torch.Size([4400, 4400])
Abdomen: A_hat shape = torch.Size([4400, 4400])
Spine and Muscles: A_hat shape = torch.Size([4400, 4400])
Reproductive and Urinary System: A_hat shape = torch.Size([4400, 4400])
Shared X_nodes shape: torch.Size([4400, 4400])


In [23]:
class FeatureFusion(nn.Module):
    def __init__(self, embed_dim=EMBED_DIM, hidden_dim=HIDDEN_DIM):
        super().__init__()
        self.fc = nn.Linear(embed_dim * 4, hidden_dim)
        self.dropout = nn.Dropout(0.3)

    def forward(self, ct_feat, mri_feat, text_feat, kg_feat):
        """
        All inputs: (B, EMBED_DIM)
        Output: (B, HIDDEN_DIM)
        """
        fused = torch.cat(
            [ct_feat, mri_feat, text_feat, kg_feat],
            dim=-1
        )
        fused = self.dropout(fused)
        return self.fc(fused)


In [24]:
fusion = FeatureFusion().to(DEVICE)

In [25]:
def get_kg_embeddings(locations, gcn, X_nodes, A_hat_dict):
    """
    locations: list[str], length B
    returns: (B, KG_DIM)
    """

    device = X_nodes.device

    # 1. Compute KG embedding ONCE per unique location
    unique_locations = set(locations)
    location_to_embedding = {}

    for loc in unique_locations:
        A_hat = A_hat_dict[loc]              # (N, N)
        kg_emb = gcn(A_hat, X_nodes)         # (KG_DIM,)
        location_to_embedding[loc] = kg_emb

    # 2. Assign embedding to each sample
    kg_embeds = [
        location_to_embedding[loc] for loc in locations
    ]

    return torch.stack(kg_embeds).to(device)   # (B, KG_DIM)


In [26]:
class ReportDecoderLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, locations):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)

        self.lstm = nn.LSTM(
            embed_dim,
            hidden_dim,
            batch_first=True
        )

        # üîë Location-specific output heads
        self.heads = nn.ModuleDict({
            loc: nn.Linear(hidden_dim, vocab_size)
            for loc in locations
        })

    def forward(self, fused_feat, input_ids, locations):
        """
        fused_feat: (B, HIDDEN_DIM)
        input_ids: (B, T)
        locations: list[str] length B
        """

        emb = self.embedding(input_ids)          # (B, T, D)

        h0 = fused_feat.unsqueeze(0)             # (1, B, H)
        c0 = torch.zeros_like(h0)

        out, _ = self.lstm(emb, (h0, c0))        # (B, T, H)

        # üî• Apply correct head PER SAMPLE
        logits = torch.zeros(
            out.size(0), out.size(1), self.heads[locations[0]].out_features,
            device=out.device
        )

        for i, loc in enumerate(locations):
            logits[i] = self.heads[loc](out[i])

        return logits


In [27]:
decoder = ReportDecoderLSTM(
    vocab_size=VOCAB_SIZE,
    embed_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
    locations=list(LOCATION_TOKENS.keys())
).to(DEVICE)


In [28]:
KG_IN_DIM = X_nodes.shape[1]   # number of node features
KG_HIDDEN_DIM = 256            # you can tune this
KG_OUT_DIM = EMBED_DIM         # must match fusion input

gcn = GCN(
    in_dim=KG_IN_DIM,
    hidden_dim=KG_HIDDEN_DIM,
    out_dim=KG_OUT_DIM,
    num_layers=2
).to(DEVICE)

print("GCN initialized")


GCN initialized


In [29]:
FORBIDDEN = {
    "Head": [
        # Thorax
        "lung", "lobe", "bronchus", "pleura", "pulmonary",
        "heart", "cardiac", "pericardium",
        "aorta", "ivc", "svc",

        # Abdomen
        "liver", "hepatic", "spleen", "pancreas",
        "bowel", "colon", "jejunum", "ileum", "duodenum",

        # GU
        "kidney", "renal", "ureter", "bladder",
        "uterus", "ovary", "prostate"
    ],

    "Thorax": [
        # Abdomen
        "liver", "hepatic", "spleen", "pancreas",
        "bowel", "colon", "jejunum", "ileum", "duodenum",

        # GU
        "kidney", "renal", "ureter", "bladder",
        "uterus", "ovary", "prostate"
    ],

    "Abdomen": [
        # Thorax
        "lung", "lobe", "bronchus", "pleura",
        "heart", "cardiac", "pericardium",

        # Neuro / Head
        "brain", "cerebral", "ventricle",
        "orbit", "parotid", "thyroid"
    ],

    "Spine and Muscles": [
        # Thorax
        "lung", "pleura", "heart", "pericardium",

        # Abdomen
        "liver", "hepatic", "spleen", "pancreas",
        "bowel", "colon",

        # GU
        "kidney", "renal", "ureter", "bladder",
        "uterus", "ovary", "prostate"
    ],

    "Reproductive and Urinary System": [
        # Thorax
        "lung", "lobe", "bronchus", "pleura",
        "heart", "pericardium",

        # Neuro / Head
        "brain", "cerebral", "ventricle",
        "spinal cord"
    ]
}


In [30]:
def freeze_module(module):
    for p in module.parameters():
        p.requires_grad = False


In [31]:
# Freeze all encoders and fusion
freeze_module(ct_encoder)
freeze_module(mri_encoder)
freeze_module(text_encoder)
freeze_module(gcn)
freeze_module(fusion)

# Ensure decoder is trainable
for p in decoder.parameters():
    p.requires_grad = True


In [32]:
criterion = nn.CrossEntropyLoss(
    ignore_index=tokenizer.pad_token_id
)

params = [p for p in decoder.parameters() if p.requires_grad]

optimizer = torch.optim.AdamW(
    params,
    lr=3e-4,
    weight_decay=1e-4
)

scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=5,
    gamma=0.5
)

In [33]:
def count_trainable(name, module):
    n = sum(p.numel() for p in module.parameters() if p.requires_grad)
    print(f"{name}: {n:,} trainable params")

count_trainable("CT Encoder", ct_encoder)
count_trainable("MRI Encoder", mri_encoder)
count_trainable("Text Encoder", text_encoder)
count_trainable("GCN", gcn)
count_trainable("Fusion", fusion)
count_trainable("Decoder", decoder)


CT Encoder: 0 trainable params
MRI Encoder: 0 trainable params
Text Encoder: 0 trainable params
GCN: 0 trainable params
Fusion: 0 trainable params
Decoder: 87,710,553 trainable params


In [34]:
from tqdm import tqdm

def train_one_epoch(train_loader):
    decoder.train()
    
    ct_encoder.eval()
    mri_encoder.eval()
    text_encoder.eval()
    gcn.eval()
    fusion.eval()


    total_loss = 0.0

    pbar = tqdm(
        train_loader,
        desc="Training",
        total=len(train_loader),
        leave=True
    )

    for batch in pbar:
        optimizer.zero_grad()

        # =========================
        # Move tensors
        # =========================
        ct_imgs = batch["ct_images"].to(DEVICE)
        ct_masks = batch["ct_masks"].to(DEVICE)

        mri_imgs = batch["mri_images"].to(DEVICE)
        mri_masks = batch["mri_masks"].to(DEVICE)

        text_ids = batch["text_input_ids"].to(DEVICE)
        text_mask = batch["text_attention_mask"].to(DEVICE)

        report_ids = batch["report_input_ids"].to(DEVICE)
        locations = batch["locations"]

        # =========================
        # Encode modalities
        # =========================
        ct_feats = ct_encoder(ct_imgs)
        mri_feats = mri_encoder(mri_imgs)

        ct_pooled = masked_mean_pooling(ct_feats, ct_masks)
        mri_pooled = masked_mean_pooling(mri_feats, mri_masks)

        text_feat = text_encoder(text_ids, text_mask)

        kg_feat = get_kg_embeddings(
            locations, gcn, X_nodes, A_hat_dict
        )

        fused_feat = fusion(
            ct_pooled, mri_pooled, text_feat, kg_feat
        )  # (B, HIDDEN_DIM)

        # ======================================================
        # üîë Inject LOCATION-SPECIFIC BOS (PRESERVED)
        # ======================================================
        B = report_ids.size(0)

        bos_ids = torch.tensor(
            [
                tokenizer.convert_tokens_to_ids(
                    LOCATION_TOKENS[loc]["bos"]
                )
                for loc in locations
            ],
            device=report_ids.device
        ).unsqueeze(1)  # (B, 1)

        # Prepend BOS to reports
        report_ids = torch.cat([bos_ids, report_ids], dim=1)

        # ======================================================
        # üî• Teacher forcing (CORRECT SHIFT)
        # ======================================================
        decoder_inputs = report_ids[:, :-1]   # includes BOS
        targets = report_ids[:, 1:]           # next-token targets

        logits = decoder(
            fused_feat,
            decoder_inputs,
            locations
        )


        loss = criterion(
            logits.reshape(-1, VOCAB_SIZE),
            targets.reshape(-1)
        )

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    scheduler.step()
    return total_loss / len(train_loader)



In [35]:
NUM_EPOCHS = 15

for epoch in range(NUM_EPOCHS):
    train_loss = train_one_epoch(train_loader)
    print(
        f"Epoch {epoch+1}/{NUM_EPOCHS} | "
        f"Train Loss: {train_loss:.4f} | "
    )


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:43<00:00,  3.67it/s, loss=7.4073]


Epoch 1/15 | Train Loss: 7.8797 | 


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:43<00:00,  3.66it/s, loss=5.0432]


Epoch 2/15 | Train Loss: 6.1610 | 


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:43<00:00,  3.66it/s, loss=5.5728]


Epoch 3/15 | Train Loss: 5.8749 | 


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:39<00:00,  4.06it/s, loss=5.2977]


Epoch 4/15 | Train Loss: 5.6311 | 


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:40<00:00,  3.98it/s, loss=5.8605]


Epoch 5/15 | Train Loss: 5.4052 | 


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:41<00:00,  3.88it/s, loss=4.7466]


Epoch 6/15 | Train Loss: 5.1918 | 


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:42<00:00,  3.79it/s, loss=5.5792]


Epoch 7/15 | Train Loss: 5.0731 | 


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:41<00:00,  3.83it/s, loss=5.7146]


Epoch 8/15 | Train Loss: 4.9675 | 


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:42<00:00,  3.76it/s, loss=5.3108]


Epoch 9/15 | Train Loss: 4.8507 | 


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:42<00:00,  3.77it/s, loss=4.3164]


Epoch 10/15 | Train Loss: 4.7479 | 


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:42<00:00,  3.73it/s, loss=4.4691]


Epoch 11/15 | Train Loss: 4.6430 | 


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:43<00:00,  3.69it/s, loss=4.7357]


Epoch 12/15 | Train Loss: 4.5953 | 


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:41<00:00,  3.85it/s, loss=5.5242]


Epoch 13/15 | Train Loss: 4.5440 | 


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:42<00:00,  3.74it/s, loss=5.3071]


Epoch 14/15 | Train Loss: 4.4798 | 


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:43<00:00,  3.71it/s, loss=4.1264]

Epoch 15/15 | Train Loss: 4.4291 | 





In [47]:
@torch.no_grad()
def generate_report(
    fused_feat,
    location,
    max_len=MAX_REPORT_LEN,
    min_len=20
):
    decoder.eval()

    bos_id = tokenizer.convert_tokens_to_ids(
        LOCATION_TOKENS[location]["bos"]
    )
    eos_id = tokenizer.convert_tokens_to_ids(
        LOCATION_TOKENS[location]["eos"]
    )

    generated_ids = [bos_id]
    word_count = 0

    for _ in range(max_len * 2):

        input_ids = torch.tensor(
            generated_ids,
            dtype=torch.long,
            device=DEVICE
        ).unsqueeze(0)   # (1, T)

        # üîë CORRECT decoder call
        logits = decoder(
            fused_feat,          # (1, HIDDEN_DIM)
            input_ids,           # (1, T)
            [location]           # list[str] of length 1
        )  # (1, T, vocab)

        step_logits = logits[0, -1]  # (vocab,)

        # üîí Apply forbidden-token masking
        forbidden_words = FORBIDDEN[location]
        forbidden_ids = tokenizer.convert_tokens_to_ids(forbidden_words)
        forbidden_ids = [i for i in forbidden_ids if i != tokenizer.unk_token_id]

        step_logits[forbidden_ids] = -1e9

        next_id = torch.argmax(step_logits).item()

        if next_id == eos_id and word_count < min_len:
            continue

        generated_ids.append(next_id)
        word_count += 1

        if next_id == eos_id or word_count >= max_len:
            break

    return tokenizer.decode(
        generated_ids,
        skip_special_tokens=True
    )


# @torch.no_grad()
# def generate_report(
#     fused_feat,
#     location,
#     max_len=MAX_REPORT_LEN,
#     min_len=20
# ):
#     decoder.eval()

#     bos_id = tokenizer.convert_tokens_to_ids(
#         LOCATION_TOKENS[location]["bos"]
#     )
#     eos_id = tokenizer.convert_tokens_to_ids(
#         LOCATION_TOKENS[location]["eos"]
#     )

#     generated_ids = [bos_id]
#     word_count = 0

#     for _ in range(max_len * 2):

#         input_ids = torch.tensor(
#             generated_ids,
#             dtype=torch.long,
#             device=DEVICE
#         ).unsqueeze(0)   # (1, T)

#         logits = decoder(
#             fused_feat,
#             decoder_inputs,
#             locations
#         )

#         forbidden_words = FORBIDDEN[location]
#         forbidden_ids = tokenizer.convert_tokens_to_ids(forbidden_words)
#         logits[forbidden_ids] = -1e9

#         next_id = torch.argmax(logits[0, -1]).item()

#         if next_id == eos_id and word_count < min_len:
#             continue

#         generated_ids.append(next_id)
#         word_count += 1

#         if next_id == eos_id or word_count >= max_len:
#             break

#     return tokenizer.decode(
#         generated_ids,
#         skip_special_tokens=True
#     )


In [48]:
from tqdm import tqdm

@torch.no_grad()
def run_inference(test_loader):
    ct_encoder.eval()
    mri_encoder.eval()
    text_encoder.eval()
    gcn.eval()
    fusion.eval()
    decoder.eval()   # ‚úÖ GRU decoder only

    generated_reports = []
    ground_truth_reports = []
    locations_all = []

    pbar = tqdm(
        test_loader,
        desc="Generating reports",
        total=len(test_loader),
        leave=True
    )

    for batch in pbar:
        # =========================
        # Move tensors
        # =========================
        ct_imgs = batch["ct_images"].to(DEVICE)
        ct_masks = batch["ct_masks"].to(DEVICE)

        mri_imgs = batch["mri_images"].to(DEVICE)
        mri_masks = batch["mri_masks"].to(DEVICE)

        text_ids = batch["text_input_ids"].to(DEVICE)
        text_mask = batch["text_attention_mask"].to(DEVICE)

        report_ids = batch["report_input_ids"]   # CPU OK
        locations = batch["locations"]

        # =========================
        # Encode modalities
        # =========================
        ct_feats = ct_encoder(ct_imgs)
        mri_feats = mri_encoder(mri_imgs)

        ct_pooled = masked_mean_pooling(ct_feats, ct_masks)
        mri_pooled = masked_mean_pooling(mri_feats, mri_masks)

        text_feat = text_encoder(text_ids, text_mask)

        kg_feat = get_kg_embeddings(
            locations, gcn, X_nodes, A_hat_dict
        )

        fused_feats = fusion(
            ct_pooled, mri_pooled, text_feat, kg_feat
        )   # (B, HIDDEN_DIM)

        # =========================
        # Generate per sample
        # =========================
        for i in range(fused_feats.size(0)):

            gen_report = generate_report(
                fused_feats[i].unsqueeze(0),
                locations[i]
            )

            # ==================================================
            # üîë Align GT format with training (KEEP THIS)
            # ==================================================
            bos_token = LOCATION_TOKENS[locations[i]]["bos"]
            gt_report = bos_token + " " + tokenizer.decode(
                report_ids[i],
                skip_special_tokens=True
            )

            generated_reports.append(gen_report)
            ground_truth_reports.append(gt_report)
            locations_all.append(locations[i])

    return generated_reports, ground_truth_reports, locations_all


In [49]:
import pandas as pd

generated_reports, ground_truth_reports, locations_all = run_inference(test_loader)


Generating reports: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:10<00:00,  1.14s/it]


In [51]:
results_df = pd.DataFrame({
    "location": locations_all,
    "generated_report": generated_reports,
    "ground_truth_report": ground_truth_reports
})

save_path = r"C:\fyp_manish_shyam_phase2\results\multihead\generated_vs_gt_reports_decoder_only_unfrozen.csv"
results_df.to_csv(save_path, index=False)

print(f"Saved results to: {save_path}")
print("Total samples:", len(results_df))


Saved results to: C:\fyp_manish_shyam_phase2\results\multihead\generated_vs_gt_reports_decoder_only_unfrozen.csv
Total samples: 34


In [52]:
# =========================
# STAGE 2: Unfreeze Fusion
# =========================

# Keep encoders frozen
freeze_module(ct_encoder)
freeze_module(mri_encoder)
freeze_module(text_encoder)
freeze_module(gcn)

# Unfreeze fusion
for p in fusion.parameters():
    p.requires_grad = True

# Decoder already trainable
for p in decoder.parameters():
    p.requires_grad = True


In [53]:
print("=== Trainable Parameters Check ===")
count_trainable("CT Encoder", ct_encoder)
count_trainable("MRI Encoder", mri_encoder)
count_trainable("Text Encoder", text_encoder)
count_trainable("GCN", gcn)
count_trainable("Fusion", fusion)
count_trainable("Decoder", decoder)


=== Trainable Parameters Check ===
CT Encoder: 0 trainable params
MRI Encoder: 0 trainable params
Text Encoder: 0 trainable params
GCN: 0 trainable params
Fusion: 524,800 trainable params
Decoder: 87,710,553 trainable params


In [54]:
# =========================
# Optimizer for Stage 2
# =========================

stage2_params = []

stage2_params += [p for p in fusion.parameters() if p.requires_grad]
stage2_params += [p for p in decoder.parameters() if p.requires_grad]

optimizer = torch.optim.AdamW(
    stage2_params,
    lr=1e-4,          # üîë LOWER LR
    weight_decay=1e-4
)

scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=5,
    gamma=0.5
)

print("Stage-2 optimizer ready")


Stage-2 optimizer ready


In [55]:
from tqdm import tqdm

def train_one_epoch_stage2(train_loader):
    fusion.train()
    decoder.train()

    ct_encoder.eval()
    mri_encoder.eval()
    text_encoder.eval()
    gcn.eval()

    total_loss = 0.0

    pbar = tqdm(
        train_loader,
        desc="Stage-2 Training (Fusion + Decoder)",
        total=len(train_loader),
        leave=True
    )

    for batch in pbar:
        optimizer.zero_grad()

        # =========================
        # Move tensors
        # =========================
        ct_imgs = batch["ct_images"].to(DEVICE)
        ct_masks = batch["ct_masks"].to(DEVICE)

        mri_imgs = batch["mri_images"].to(DEVICE)
        mri_masks = batch["mri_masks"].to(DEVICE)

        text_ids = batch["text_input_ids"].to(DEVICE)
        text_mask = batch["text_attention_mask"].to(DEVICE)

        report_ids = batch["report_input_ids"].to(DEVICE)
        locations = batch["locations"]

        # =========================
        # Encode modalities
        # =========================
        with torch.no_grad():
            ct_feats = ct_encoder(ct_imgs)
            mri_feats = mri_encoder(mri_imgs)

            ct_pooled = masked_mean_pooling(ct_feats, ct_masks)
            mri_pooled = masked_mean_pooling(mri_feats, mri_masks)

            text_feat = text_encoder(text_ids, text_mask)

            kg_feat = get_kg_embeddings(
                locations, gcn, X_nodes, A_hat_dict
            )

        fused_feat = fusion(
            ct_pooled, mri_pooled, text_feat, kg_feat
        )

        # =========================
        # Location-specific BOS
        # =========================
        B = report_ids.size(0)

        bos_ids = torch.tensor(
            [
                tokenizer.convert_tokens_to_ids(
                    LOCATION_TOKENS[loc]["bos"]
                )
                for loc in locations
            ],
            device=report_ids.device
        ).unsqueeze(1)

        report_ids = torch.cat([bos_ids, report_ids], dim=1)

        decoder_inputs = report_ids[:, :-1]
        targets = report_ids[:, 1:]
        
        logits = decoder(
            fused_feat,
            decoder_inputs,
            locations
        )

        loss = criterion(
            logits.reshape(-1, VOCAB_SIZE),
            targets.reshape(-1)
        )

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    scheduler.step()
    return total_loss / len(train_loader)


In [56]:
STAGE2_EPOCHS = 10

for epoch in range(STAGE2_EPOCHS):
    loss = train_one_epoch_stage2(train_loader)
    print(
        f"[Stage-2] Epoch {epoch+1}/{STAGE2_EPOCHS} | "
        f"Loss: {loss:.4f}"
    )


Stage-2 Training (Fusion + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:42<00:00,  3.75it/s, loss=3.3447]


[Stage-2] Epoch 1/10 | Loss: 4.4271


Stage-2 Training (Fusion + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:42<00:00,  3.73it/s, loss=3.7776]


[Stage-2] Epoch 2/10 | Loss: 4.3188


Stage-2 Training (Fusion + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:44<00:00,  3.61it/s, loss=3.6472]


[Stage-2] Epoch 3/10 | Loss: 4.2353


Stage-2 Training (Fusion + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:42<00:00,  3.73it/s, loss=4.8699]


[Stage-2] Epoch 4/10 | Loss: 4.1563


Stage-2 Training (Fusion + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:43<00:00,  3.71it/s, loss=4.0666]


[Stage-2] Epoch 5/10 | Loss: 4.0734


Stage-2 Training (Fusion + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:42<00:00,  3.80it/s, loss=4.9297]


[Stage-2] Epoch 6/10 | Loss: 4.0093


Stage-2 Training (Fusion + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:42<00:00,  3.77it/s, loss=2.9259]


[Stage-2] Epoch 7/10 | Loss: 3.9669


Stage-2 Training (Fusion + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:43<00:00,  3.70it/s, loss=3.5885]


[Stage-2] Epoch 8/10 | Loss: 3.9275


Stage-2 Training (Fusion + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:43<00:00,  3.69it/s, loss=3.7411]


[Stage-2] Epoch 9/10 | Loss: 3.8954


Stage-2 Training (Fusion + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:44<00:00,  3.62it/s, loss=3.6983]

[Stage-2] Epoch 10/10 | Loss: 3.8623





In [57]:
import pandas as pd

generated_reports, ground_truth_reports, locations_all = run_inference(test_loader)


Generating reports: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:10<00:00,  1.14s/it]


In [58]:
results_df = pd.DataFrame({
    "location": locations_all,
    "generated_report": generated_reports,
    "ground_truth_report": ground_truth_reports
})

save_path = r"C:\fyp_manish_shyam_phase2\results\multihead\generated_vs_gt_reports_fusion_unfrozen.csv"
results_df.to_csv(save_path, index=False)

print(f"Saved results to: {save_path}")
print("Total samples:", len(results_df))


Saved results to: C:\fyp_manish_shyam_phase2\results\multihead\generated_vs_gt_reports_fusion_unfrozen.csv
Total samples: 34


In [None]:
AIzaSyC0Xq7VNHUQDREmLB7DP97tsxz0SmKYPwQ

In [65]:
# !pip install --quiet google-generativeai
!pip uninstall -y google-generativeai
!pip install -U google-genai


Found existing installation: google-generativeai 0.8.6
Uninstalling google-generativeai-0.8.6:
  Successfully uninstalled google-generativeai-0.8.6
Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting google-genai
  Downloading google_genai-1.56.0-py3-none-any.whl.metadata (53 kB)
Collecting httpx<1.0.0,>=0.28.1 (from google-genai)
  Downloading httpx-0.28.1-py3-none-any.whl.metadata (7.1 kB)
Collecting tenacity<9.2.0,>=8.2.3 (from google-genai)
  Downloading tenacity-9.1.2-py3-none-any.whl.metadata (1.2 kB)
Collecting distro<2,>=1.7.0 (from google-genai)
  Downloading distro-1.9.0-py3-none-any.whl.metadata (6.8 kB)
Downloading google_genai-1.56.0-py3-none-any.whl (426 kB)
Downloading distro-1.9.0-py3-none-any.whl (20 kB)
Downloading httpx-0.28.1-py3-none-any.whl (73 kB)
Downloading tenacity-9.1.2-py3-none-any.whl (28 kB)
Installing collected packages: tenacity, distro, httpx, g

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
fief-client 0.20.0 requires httpx<0.28.0,>=0.21.3, but you have httpx 0.28.1 which is incompatible.

[notice] A new release of pip is available: 25.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [66]:
# import os
# import google.generativeai as genai

# # Option 1: Direct (NOT recommended for shared systems)
# GEMINI_API_KEY = "AIzaSyC0Xq7VNHUQDREmLB7DP97tsxz0SmKYPwQ"

# # Option 2: Environment variable (recommended)
# # os.environ["GEMINI_API_KEY"] = "YOUR_API_KEY_HERE"

# genai.configure(api_key=GEMINI_API_KEY)


import os
from google import genai

# Recommended even on shared systems (Jupyter-safe)
os.environ["GEMINI_API_KEY"] = "AIzaSyC0Xq7VNHUQDREmLB7DP97tsxz0SmKYPwQ"

client = genai.Client()


In [79]:
# model = genai.GenerativeModel(
#     model_name="gemini-1.5-pro",
#     generation_config={
#         "temperature": 0.1,        # LOW creativity
#         "top_p": 0.9,
#         "top_k": 40,
#         "max_output_tokens": 512
#     }
# )

MODEL_NAME = "models/gemini-2.5-flash"


In [90]:
# SYSTEM_PROMPT = """
# You are a clinical radiology report refinement assistant.

# STRICT RULES:
# 1. DO NOT add new medical findings.
# 2. DO NOT infer diseases, abnormalities, or diagnoses.
# 3. DO NOT speculate or guess missing information.
# 4. DO NOT introduce anatomy or observations not explicitly present.
# 5. DO NOT change clinical meaning.
# 6. Preserve negations exactly (e.g., "no evidence of").
# 7. If the input is vague or incomplete, keep it vague.
# 8. If unsure, return the original text unchanged.

# ALLOWED:
# - Improve grammar and clarity
# - Improve clinical phrasing
# - Remove redundancy
# - Improve sentence flow
# - Standardize terminology

# If any rule conflicts, prioritize safety and factual consistency.
# """


# SYSTEM_PROMPT = """
# You are a clinical radiology report refinement assistant.

# STRICT RULES (MUST FOLLOW):
# 1. DO NOT add new findings, diseases, abnormalities, or diagnoses.
# 2. DO NOT infer, speculate, or assume missing information.
# 3. DO NOT introduce new anatomy or observations not explicitly present.
# 4. DO NOT change the clinical meaning or certainty of any statement.
# 5. Preserve all negations exactly (e.g., ‚Äúno evidence of‚Äù, ‚Äúabsence of‚Äù).
# 6. If information is missing or unclear, keep it unchanged and do not elaborate.
# 7. If refinement would violate any rule, return the original text verbatim.

# LENGTH AND COMPLETENESS CONSTRAINTS:
# 8. The refined report MUST contain at least 50 words.
# 9. If the original report is shorter, you may:
#    - Rephrase sentences
#    - Combine or split sentences
#    - Improve clinical phrasing
#    - Reiterate existing information in a clearer medical style
#    BUT you must NOT introduce new facts.
# 10. The final sentence MUST be grammatically complete and end with proper punctuation.
# 11. Do NOT leave any sentence unfinished or abruptly terminated.

# ALLOWED OPERATIONS:
# - Grammar correction
# - Clinical phrasing improvement
# - Redundancy removal or controlled repetition
# - Sentence flow and coherence improvement
# - Terminology standardization

# OUTPUT FORMAT:
# - Return ONLY the refined report text.
# - Do NOT add headings, explanations, bullet points, or commentary.
# """

SYSTEM_PROMPT = """
You are a clinical radiology report refinement assistant.

STRICT SAFETY RULES (MUST FOLLOW):
1. DO NOT add any findings, diseases, abnormalities, diagnoses, or anatomy
   that are NOT explicitly present in the provided inputs.
2. DO NOT infer, speculate, or assume missing information.
3. DO NOT change the clinical meaning, certainty, or negation of any statement.
4. Preserve all negations exactly (e.g., ‚Äúno evidence of‚Äù, ‚Äúabsence of‚Äù).
5. If any refinement would violate these rules, return the original text verbatim.

GROUND TRUTH USAGE RULE:
6. If REFERENCE FINDINGS / GROUND TRUTH is provided, you MAY ONLY use
   information explicitly stated in the ground truth to:
   - Expand sentences
   - Rephrase content
   - Improve clarity and structure
7. DO NOT introduce content that is not present in either the generated report
   or the ground truth.

LENGTH AND COMPLETENESS CONSTRAINTS (MANDATORY):
8. The refined report MUST contain AT LEAST 50 words.
9. If the generated report is shorter than 50 words, you MUST increase length
   by rephrasing, restructuring, or elaborating ONLY USING GROUND TRUTH content.
10. If no ground truth is provided and the report is under 50 words,
    you may carefully rephrase or restate existing content WITHOUT adding facts.
11. The final sentence MUST be grammatically complete and end with proper punctuation.
12. DO NOT leave any sentence unfinished or abruptly terminated.

ALLOWED OPERATIONS:
- Grammar correction
- Clinical phrasing improvement
- Controlled repetition for clarity
- Sentence flow and coherence improvement
- Terminology standardization

OUTPUT FORMAT:
- Return ONLY the refined report text.
- Do NOT add headings, bullet points, explanations, or commentary.
"""



In [91]:
# def refine_medical_report(
#     generated_report: str,
#     reference_findings: str = None
# ):
#     """
#     generated_report: model-generated report to refine
#     reference_findings: OPTIONAL factual findings or labels (if available)
#     """

#     prompt = f"""
# {SYSTEM_PROMPT}

# ORIGINAL GENERATED REPORT:
# \"\"\"
# {generated_report}
# \"\"\"
# """

#     if reference_findings:
#         prompt += f"""
# REFERENCE FINDINGS (GROUND TRUTH ‚Äì DO NOT EXCEED):
# \"\"\"
# {reference_findings}
# \"\"\"
# """

#     prompt += """
# TASK:
# Refine the ORIGINAL GENERATED REPORT while strictly following the rules.
# Return ONLY the refined report text.
# """

#     response = model.generate_content(prompt)
#     return response.text.strip()


def refine_medical_report(
    generated_report: str,
    reference_findings: str = None
):
    prompt = f"""
{SYSTEM_PROMPT}

ORIGINAL GENERATED REPORT:
\"\"\"
{generated_report}
\"\"\"
"""

    if reference_findings:
        prompt += f"""
REFERENCE FINDINGS (GROUND TRUTH ‚Äî DO NOT EXCEED):
\"\"\"
{reference_findings}
\"\"\"
"""

    prompt += "\nRefine the report while strictly obeying the rules."

    response = client.models.generate_content(
        model=MODEL_NAME,
        contents=prompt,
        config={
            "temperature": 0.1,   # üîë low creativity
            "max_output_tokens": 512,
        }
    )

    return response.text.strip()


In [92]:

refined = refine_medical_report(generated_reports[19])
print(refined)


There is T2 hyperintense signal intensity noted within the right frontal lobe. Additionally, T


In [93]:
ground_truth_reports[19]

'<HEAD_BOS> ‚Ä¢ heterogeneous signal lesion - mixed iso and hyperintense on t1wi ‚Ä¢ hyperintense on t2wi - in the region of the right lenticular nucleus, anterior limb of the right internal capsule, and external capsule ‚Ä¢ hypointense rim on t1wi that ‚Äú blooms ‚Äù on t2wi ‚Ä¢ no mass effect ‚Ä¢ minimal enhancement seen post - gadolinium'

In [78]:
from google import genai
import os

client = genai.Client()

models = client.models.list()

for m in models:
    print(m.name)


models/embedding-gecko-001
models/gemini-2.5-flash
models/gemini-2.5-pro
models/gemini-2.0-flash-exp
models/gemini-2.0-flash
models/gemini-2.0-flash-001
models/gemini-2.0-flash-exp-image-generation
models/gemini-2.0-flash-lite-001
models/gemini-2.0-flash-lite
models/gemini-2.0-flash-lite-preview-02-05
models/gemini-2.0-flash-lite-preview
models/gemini-exp-1206
models/gemini-2.5-flash-preview-tts
models/gemini-2.5-pro-preview-tts
models/gemma-3-1b-it
models/gemma-3-4b-it
models/gemma-3-12b-it
models/gemma-3-27b-it
models/gemma-3n-e4b-it
models/gemma-3n-e2b-it
models/gemini-flash-latest
models/gemini-flash-lite-latest
models/gemini-pro-latest
models/gemini-2.5-flash-lite
models/gemini-2.5-flash-image-preview
models/gemini-2.5-flash-image
models/gemini-2.5-flash-preview-09-2025
models/gemini-2.5-flash-lite-preview-09-2025
models/gemini-3-pro-preview
models/gemini-3-flash-preview
models/gemini-3-pro-image-preview
models/nano-banana-pro-preview
models/gemini-robotics-er-1.5-preview
models/g

In [61]:
# =========================
# STAGE 3: Unfreeze GCN
# =========================

# Keep encoders frozen
freeze_module(ct_encoder)
freeze_module(mri_encoder)
freeze_module(text_encoder)

# Unfreeze GCN
for p in gcn.parameters():
    p.requires_grad = True

# Fusion + Decoder remain trainable
for p in fusion.parameters():
    p.requires_grad = True

for p in decoder.parameters():
    p.requires_grad = True


In [62]:
print("=== Stage-3 Trainable Params Check ===")
count_trainable("CT Encoder", ct_encoder)
count_trainable("MRI Encoder", mri_encoder)
count_trainable("Text Encoder", text_encoder)
count_trainable("GCN", gcn)
count_trainable("Fusion", fusion)
count_trainable("Decoder", decoder)


=== Stage-3 Trainable Params Check ===
CT Encoder: 0 trainable params
MRI Encoder: 0 trainable params
Text Encoder: 0 trainable params
GCN: 1,192,448 trainable params
Fusion: 524,800 trainable params
Decoder: 25,056,837 trainable params


In [63]:
# =========================
# Optimizer for Stage 3
# =========================

stage3_params = []

stage3_params += [p for p in gcn.parameters() if p.requires_grad]
stage3_params += [p for p in fusion.parameters() if p.requires_grad]
stage3_params += [p for p in decoder.parameters() if p.requires_grad]

optimizer = torch.optim.AdamW(
    stage3_params,
    lr=5e-5,          # üîë LOWER LR (KG is sensitive)
    weight_decay=1e-4
)

scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=5,
    gamma=0.5
)

print("Stage-3 optimizer ready")


Stage-3 optimizer ready


In [64]:
from tqdm import tqdm

def train_one_epoch_stage3(train_loader):
    gcn.train()
    fusion.train()
    decoder.train()

    ct_encoder.eval()
    mri_encoder.eval()
    text_encoder.eval()

    total_loss = 0.0

    pbar = tqdm(
        train_loader,
        desc="Stage-3 Training (GCN + Fusion + Decoder)",
        total=len(train_loader),
        leave=True
    )

    for batch in pbar:
        optimizer.zero_grad()

        # =========================
        # Move tensors
        # =========================
        ct_imgs = batch["ct_images"].to(DEVICE)
        ct_masks = batch["ct_masks"].to(DEVICE)

        mri_imgs = batch["mri_images"].to(DEVICE)
        mri_masks = batch["mri_masks"].to(DEVICE)

        text_ids = batch["text_input_ids"].to(DEVICE)
        text_mask = batch["text_attention_mask"].to(DEVICE)

        report_ids = batch["report_input_ids"].to(DEVICE)
        locations = batch["locations"]

        # =========================
        # Encode frozen modalities
        # =========================
        with torch.no_grad():
            ct_feats = ct_encoder(ct_imgs)
            mri_feats = mri_encoder(mri_imgs)

            ct_pooled = masked_mean_pooling(ct_feats, ct_masks)
            mri_pooled = masked_mean_pooling(mri_feats, mri_masks)

            text_feat = text_encoder(text_ids, text_mask)

        # =========================
        # GCN now TRAINABLE
        # =========================
        kg_feat = get_kg_embeddings(
            locations, gcn, X_nodes, A_hat_dict
        )

        fused_feat = fusion(
            ct_pooled, mri_pooled, text_feat, kg_feat
        )

        # =========================
        # Location BOS + decoding
        # =========================
        bos_ids = torch.tensor(
            [
                tokenizer.convert_tokens_to_ids(
                    LOCATION_TOKENS[loc]["bos"]
                )
                for loc in locations
            ],
            device=report_ids.device
        ).unsqueeze(1)

        report_ids = torch.cat([bos_ids, report_ids], dim=1)

        decoder_inputs = report_ids[:, :-1]
        targets = report_ids[:, 1:]

        logits = decoder(
            fused_feat,
            decoder_inputs,
            locations
        )

        loss = criterion(
            logits.reshape(-1, VOCAB_SIZE),
            targets.reshape(-1)
        )

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    scheduler.step()
    return total_loss / len(train_loader)


In [65]:
STAGE3_EPOCHS = 10

for epoch in range(STAGE3_EPOCHS):
    loss = train_one_epoch_stage3(train_loader)
    print(
        f"[Stage-3] Epoch {epoch+1}/{STAGE3_EPOCHS} | "
        f"Loss: {loss:.4f}"
    )


Stage-3 Training (GCN + Fusion + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:44<00:00,  3.61it/s, loss=3.0517]


[Stage-3] Epoch 1/10 | Loss: 3.4898


Stage-3 Training (GCN + Fusion + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:42<00:00,  3.78it/s, loss=3.4704]


[Stage-3] Epoch 2/10 | Loss: 3.4488


Stage-3 Training (GCN + Fusion + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:42<00:00,  3.73it/s, loss=2.4937]


[Stage-3] Epoch 3/10 | Loss: 3.4347


Stage-3 Training (GCN + Fusion + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:42<00:00,  3.79it/s, loss=3.5297]


[Stage-3] Epoch 4/10 | Loss: 3.4031


Stage-3 Training (GCN + Fusion + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:38<00:00,  4.10it/s, loss=3.3280]


[Stage-3] Epoch 5/10 | Loss: 3.3705


Stage-3 Training (GCN + Fusion + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:41<00:00,  3.88it/s, loss=2.7290]


[Stage-3] Epoch 6/10 | Loss: 3.3389


Stage-3 Training (GCN + Fusion + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:40<00:00,  3.94it/s, loss=2.7925]


[Stage-3] Epoch 7/10 | Loss: 3.3392


Stage-3 Training (GCN + Fusion + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:43<00:00,  3.69it/s, loss=3.3145]


[Stage-3] Epoch 8/10 | Loss: 3.3207


Stage-3 Training (GCN + Fusion + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:43<00:00,  3.72it/s, loss=3.2163]


[Stage-3] Epoch 9/10 | Loss: 3.3245


Stage-3 Training (GCN + Fusion + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:44<00:00,  3.63it/s, loss=3.3888]

[Stage-3] Epoch 10/10 | Loss: 3.3063





In [66]:
import pandas as pd

generated_reports, ground_truth_reports, locations_all = run_inference(test_loader)


Generating reports: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:09<00:00,  1.02s/it]


In [67]:
results_df = pd.DataFrame({
    "location": locations_all,
    "generated_report": generated_reports,
    "ground_truth_report": ground_truth_reports
})

save_path = r"C:\fyp_manish_shyam_phase2\results\generated_vs_gt_reports_GCN_unfrozen.csv"
results_df.to_csv(save_path, index=False)

print(f"Saved results to: {save_path}")
print("Total samples:", len(results_df))


Saved results to: C:\fyp_manish_shyam_phase2\results\generated_vs_gt_reports_GCN_unfrozen.csv
Total samples: 34


In [68]:
# =========================
# Helper: Unfreeze ResNet layer4 only
# =========================

def unfreeze_resnet_layer4(resnet_model):
    # Freeze everything
    for p in resnet_model.parameters():
        p.requires_grad = False

    # Unfreeze layer4
    for p in resnet_model.layer4.parameters():
        p.requires_grad = True

    # Keep FC trainable (already replaced)
    for p in resnet_model.fc.parameters():
        p.requires_grad = True


In [69]:
# =========================
# STAGE 4: Partial Image Unfreeze
# =========================

unfreeze_resnet_layer4(ct_encoder.cnn)
unfreeze_resnet_layer4(mri_encoder.cnn)

# Text encoder stays frozen
freeze_module(text_encoder)

# GCN + Fusion + Decoder remain trainable
for p in gcn.parameters():
    p.requires_grad = True

for p in fusion.parameters():
    p.requires_grad = True

for p in decoder.parameters():
    p.requires_grad = True


In [70]:
print("=== Stage-4 Trainable Params Check ===")
count_trainable("CT Encoder", ct_encoder)
count_trainable("MRI Encoder", mri_encoder)
count_trainable("Text Encoder", text_encoder)
count_trainable("GCN", gcn)
count_trainable("Fusion", fusion)
count_trainable("Decoder", decoder)


=== Stage-4 Trainable Params Check ===
CT Encoder: 8,525,056 trainable params
MRI Encoder: 8,525,056 trainable params
Text Encoder: 0 trainable params
GCN: 1,192,448 trainable params
Fusion: 524,800 trainable params
Decoder: 25,056,837 trainable params


In [73]:
# =========================
# Optimizer for Stage 4
# =========================

stage4_params = []

stage4_params += [p for p in ct_encoder.parameters() if p.requires_grad]
stage4_params += [p for p in mri_encoder.parameters() if p.requires_grad]
stage4_params += [p for p in gcn.parameters() if p.requires_grad]
stage4_params += [p for p in fusion.parameters() if p.requires_grad]
stage4_params += [p for p in decoder.parameters() if p.requires_grad]

optimizer = torch.optim.AdamW(
    stage4_params,
    lr=1e-5,          # üîë VERY IMPORTANT: small LR
    weight_decay=1e-4
)

scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=5,
    gamma=0.5
)

print("Stage-4 optimizer ready")


Stage-4 optimizer ready


In [75]:
from tqdm import tqdm

def train_one_epoch_stage4(train_loader):
    ct_encoder.train()
    mri_encoder.train()
    gcn.train()
    fusion.train()
    decoder.train()

    text_encoder.eval()   # still frozen

    total_loss = 0.0

    pbar = tqdm(
        train_loader,
        desc="Stage-4 Training (Visual + KG + Decoder)",
        total=len(train_loader),
        leave=True
    )

    for batch in pbar:
        optimizer.zero_grad()

        # =========================
        # Move tensors
        # =========================
        ct_imgs = batch["ct_images"].to(DEVICE)
        ct_masks = batch["ct_masks"].to(DEVICE)

        mri_imgs = batch["mri_images"].to(DEVICE)
        mri_masks = batch["mri_masks"].to(DEVICE)

        text_ids = batch["text_input_ids"].to(DEVICE)
        text_mask = batch["text_attention_mask"].to(DEVICE)

        report_ids = batch["report_input_ids"].to(DEVICE)
        locations = batch["locations"]

        # =========================
        # Encode modalities
        # =========================
        ct_feats = ct_encoder(ct_imgs)
        mri_feats = mri_encoder(mri_imgs)

        ct_pooled = masked_mean_pooling(ct_feats, ct_masks)
        mri_pooled = masked_mean_pooling(mri_feats, mri_masks)

        with torch.no_grad():
            text_feat = text_encoder(text_ids, text_mask)

        kg_feat = get_kg_embeddings(
            locations, gcn, X_nodes, A_hat_dict
        )

        fused_feat = fusion(
            ct_pooled, mri_pooled, text_feat, kg_feat
        )

        # =========================
        # Location BOS + decoding
        # =========================
        bos_ids = torch.tensor(
            [
                tokenizer.convert_tokens_to_ids(
                    LOCATION_TOKENS[loc]["bos"]
                )
                for loc in locations
            ],
            device=report_ids.device
        ).unsqueeze(1)

        report_ids = torch.cat([bos_ids, report_ids], dim=1)

        decoder_inputs = report_ids[:, :-1]
        targets = report_ids[:, 1:]

        logits = decoder(
            fused_feat,
            decoder_inputs,
            locations
        )

        loss = criterion(
            logits.reshape(-1, VOCAB_SIZE),
            targets.reshape(-1)
        )

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    scheduler.step()
    return total_loss / len(train_loader)


In [76]:
STAGE4_EPOCHS = 6

for epoch in range(STAGE4_EPOCHS):
    loss = train_one_epoch_stage4(train_loader)
    print(
        f"[Stage-4] Epoch {epoch+1}/{STAGE4_EPOCHS} | "
        f"Loss: {loss:.4f}"
    )


Stage-4 Training (Visual + KG + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:43<00:00,  3.71it/s, loss=4.4528]


[Stage-4] Epoch 1/6 | Loss: 3.3147


Stage-4 Training (Visual + KG + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:42<00:00,  3.75it/s, loss=3.2478]


[Stage-4] Epoch 2/6 | Loss: 3.2968


Stage-4 Training (Visual + KG + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:42<00:00,  3.76it/s, loss=3.9971]


[Stage-4] Epoch 3/6 | Loss: 3.2943


Stage-4 Training (Visual + KG + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:43<00:00,  3.67it/s, loss=3.4107]


[Stage-4] Epoch 4/6 | Loss: 3.2878


Stage-4 Training (Visual + KG + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:44<00:00,  3.60it/s, loss=2.2344]


[Stage-4] Epoch 5/6 | Loss: 3.2694


Stage-4 Training (Visual + KG + Decoder): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:45<00:00,  3.52it/s, loss=2.0949]

[Stage-4] Epoch 6/6 | Loss: 3.2679





In [77]:
import pandas as pd

generated_reports, ground_truth_reports, locations_all = run_inference(test_loader)


Generating reports: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:09<00:00,  1.01s/it]


In [78]:
results_df = pd.DataFrame({
    "location": locations_all,
    "generated_report": generated_reports,
    "ground_truth_report": ground_truth_reports
})

save_path = r"C:\fyp_manish_shyam_phase2\results\generated_vs_gt_reports_encoder_unfrozen.csv"
results_df.to_csv(save_path, index=False)

print(f"Saved results to: {save_path}")
print("Total samples:", len(results_df))


Saved results to: C:\fyp_manish_shyam_phase2\results\generated_vs_gt_reports_encoder_unfrozen.csv
Total samples: 34
