In [1]:
!pip install lightning
!pip install -U transformers

Collecting lightning
  Downloading lightning-2.6.0-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.9/44.9 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<4.0,>=2.1.0->lightning)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<4.0,>=2.1.0->lightning)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<4.0,>=2.1.0->lightning)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<4.0,>=2.1.0->lightning)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<4.0,>=2.1.0->lightning)
  Downloading

In [2]:
from huggingface_hub import snapshot_download

snapshot_download(
    repo_id="huyhoangt2201/lightmedvlm-mimic-phase3-vqa-reduced",
    local_dir="lightmedvlm"
)

Fetching 8 files:   0%|          | 0/8 [00:00<?, ?it/s]

logs/tensorboard/version_0/events.out.tf(…):   0%|          | 0.00/12.0k [00:00<?, ?B/s]

checkpoints/epoch=3-step=2468-loss=0.078(…):   0%|          | 0.00/2.31G [00:00<?, ?B/s]

val_pred.json: 0.00B [00:00, ?B/s]

val_ref.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

hparams.yaml: 0.00B [00:00, ?B/s]

.gitattributes: 0.00B [00:00, ?B/s]

metrics.csv: 0.00B [00:00, ?B/s]

hparams.yaml: 0.00B [00:00, ?B/s]

'/kaggle/working/lightmedvlm'

In [3]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning.pytorch as pl
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import SwinModel
from peft import get_peft_model, LoraConfig, TaskType
import torch.distributed as dist
from transformers import BertTokenizer, AutoImageProcessor
from PIL import Image
import numpy as np

class MLP(nn.Module):
    def __init__(self, in_dim, inter_dim, out_dim):
        super(MLP, self).__init__()
        self.hidden_1 = nn.Linear(in_dim, inter_dim)
        self.act = nn.GELU()
        self.hidden_2 = nn.Linear(inter_dim, out_dim)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, x):
        x = self.act(self.hidden_1(x))
        x = self.dropout(x)
        return self.hidden_2(x)


class LightMedVLMForInference(pl.LightningModule):
    def __init__(
        self,
        vision_model: str = "microsoft/swin-base-patch4-window7-224",
        llm_model: str = "Qwen/Qwen3-0.6B",

        # For training setup
        vis_use_lora: bool = False,
        vis_r: int = 8,
        vis_alpha: int = 16,
        freeze_vm: bool = False,
        llm_use_lora: bool = False,
        llm_r: int = 8,
        llm_alpha: int = 16,
        lora_dropout: float = 0.1,
        low_resource: bool = False,
        max_length: int = 256
    ):
        super().__init__()
        self.vision_model = vision_model
        self.llm_model = llm_model
        self.vis_use_lora = vis_use_lora
        self.vis_r = vis_r
        self.vis_alpha = vis_alpha
        self.freeze_vm = freeze_vm
        self.llm_use_lora = llm_use_lora
        self.llm_r = llm_r
        self.llm_alpha = llm_alpha
        self.lora_dropout = lora_dropout
        self.low_resource = low_resource
        self.max_length = max_length
        
        # Vision encoder setup
        print(f'Loading vision encoder: {self.vision_model}')
        self.visual_encoder = SwinModel.from_pretrained(self.vision_model)
        self.vit_feature_extractor = AutoImageProcessor.from_pretrained(self.vision_model)
        if self.vis_use_lora:
            peft_config_visual = LoraConfig(
                r=self.vis_r,
                lora_alpha=self.vis_alpha,
                target_modules=["query", "value"],
                lora_dropout=self.lora_dropout,
                bias="none",
                modules_to_save=["classifier"],
            )
            self.visual_encoder = get_peft_model(self.visual_encoder, peft_config_visual)
            self.visual_encoder.print_trainable_parameters()
            print('Loading vision encoder with LoRA -- Done')
        elif self.freeze_vm:
            for name, param in self.visual_encoder.named_parameters():
                param.requires_grad = False
            print(f'Loading Frozen vision encoder: {self.vision_model} -- Done')
        else:
            print(f'Loading Trainable vision encoder: {self.vision_model} -- Done')

        # LLM model setup
        print('Loading LLM model')
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.llm_model, 
            use_fast=False,
            trust_remote_code=True
        )
        print(f"BOS token ID: {self.tokenizer.bos_token_id}")
        print(f"EOS token ID: {self.tokenizer.eos_token_id}")
        print(f"PAD token ID: {self.tokenizer.pad_token_id}")
        if self.low_resource:
            self.model = AutoModelForCausalLM.from_pretrained(
                self.llm_model,
                torch_dtype=torch.bfloat16,
                load_in_8bit=True,
                device_map="auto",
                trust_remote_code=True
            )
        else:
            self.model = AutoModelForCausalLM.from_pretrained(
                self.llm_model,
                torch_dtype=torch.bfloat16,
                trust_remote_code=True
            )
        if self.llm_use_lora:
            self.embed_tokens = self.model.get_input_embeddings()
            peft_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM, 
                inference_mode=False, 
                r=self.llm_r, 
                lora_alpha=self.llm_alpha, 
                lora_dropout=self.lora_dropout,
                target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]  
            )
            self.model = get_peft_model(self.model, peft_config)
            self.model.print_trainable_parameters()
            print('Loading LLM LoRA Done')         
        else:
            self.embed_tokens = self.model.get_input_embeddings()
            for name, param in self.model.named_parameters():
                param.requires_grad = False
            print('Loading LLM Done')

        # Projector setup
        self.proj = MLP(
            in_dim=self.visual_encoder.num_features,
            inter_dim=2048,
            out_dim=self.model.config.hidden_size
        )
        self.layer_norm = nn.LayerNorm(self.model.config.hidden_size)

        # System prompt setup
        self.end_sym = "<|im_end|>"
        # System prompt for VQA task
        self.system_prompt = "<|im_start|>system\nYou are a professional radiologist. Please answer the question based on the chest X-ray image and choose from the following two options: [yes, no].<|im_end|>\n"
        
        self.val_step_outputs = []
        self.test_step_outputs = []
        self.val_score = 0.0
        
    def encode_img(self, images):
        image_embeds = []
        for image in images:
            device = image.device
            # Swin transformer
            visual_outputs = self.visual_encoder(image)
            image_embed = visual_outputs['last_hidden_state'].to(device)
            image_embeds.append(image_embed)
            
        image_embeds = torch.stack(image_embeds).mean(0)

        inputs = self.proj(image_embeds)
        atts = torch.ones(inputs.size()[:-1], dtype=torch.long).to(image.device)
        return inputs, atts

    def prompt_wrap(self, img_embeds, atts_img, questions):
        """
        Wrap image embeddings with Qwen-style prompt including the question.
        Format: {system_prompt} <user_start> {question} <image> <user_end> <assistant_start>
        
        Args:
            img_embeds: Image embeddings
            atts_img: Attention mask for images
            questions: List of questions (one per batch item)
        """
        batch_size = img_embeds.shape[0]
        
        # Build prompts for each item in the batch
        wrapped_embeds_list = []
        wrapped_atts_list = []
        
        for i in range(batch_size):
            question = questions[i] if questions[i] is not None else "Describe the following image in detail."
            
            # Construct full prompt with question
            full_prompt = f"{self.system_prompt}<|im_start|>user\n{question} <image><|im_end|>\n<|im_start|>assistant\n"
            
            # Split at image placeholder
            p_before, p_after = full_prompt.split('<image>')
            
            # Tokenize prompt parts
            p_before_tokens = self.tokenizer(
                p_before,
                return_tensors="pt",
                add_special_tokens=False
            ).to(img_embeds.device)
            
            p_after_tokens = self.tokenizer(
                p_after,
                return_tensors="pt",
                add_special_tokens=False
            ).to(img_embeds.device)
            
            # Get embeddings
            with torch.no_grad():
                p_before_embeds = self.embed_tokens(p_before_tokens.input_ids)
                p_after_embeds = self.embed_tokens(p_after_tokens.input_ids)
            
            # Concatenate: [prompt_before] + [image] + [prompt_after]
            wrapped_embeds = torch.cat([
                p_before_embeds,
                img_embeds[i:i+1],
                p_after_embeds
            ], dim=1)
            
            wrapped_embeds_list.append(wrapped_embeds)
            
            # Create attention mask
            wrapped_atts = torch.ones(
                wrapped_embeds.shape[1],
                device=img_embeds.device,
                dtype=atts_img.dtype
            )
            wrapped_atts_list.append(wrapped_atts)
        
        # Find max sequence length in the batch
        max_seq_len = max(embeds.shape[1] for embeds in wrapped_embeds_list)
        
        # Pad all sequences to the same length
        padded_embeds_list = []
        padded_atts_list = []
        
        for embeds, atts in zip(wrapped_embeds_list, wrapped_atts_list):
            seq_len = embeds.shape[1]
            if seq_len < max_seq_len:
                # Pad embeddings with zeros
                padding_size = max_seq_len - seq_len
                padding = torch.zeros(
                    embeds.shape[0], 
                    padding_size, 
                    embeds.shape[2],
                    dtype=embeds.dtype,
                    device=embeds.device
                )
                embeds = torch.cat([embeds, padding], dim=1)
                
                # Pad attention mask with zeros (masked positions)
                atts_padding = torch.zeros(
                    padding_size,
                    dtype=atts.dtype,
                    device=atts.device
                )
                atts = torch.cat([atts, atts_padding], dim=0)
            
            padded_embeds_list.append(embeds)
            padded_atts_list.append(atts)
        
        # Stack all items in the batch
        wrapped_img_embeds = torch.cat(padded_embeds_list, dim=0)
        wrapped_atts_img = torch.stack(padded_atts_list, dim=0)
        
        return wrapped_img_embeds, wrapped_atts_img

    def forward(self, samples):
        image = samples["image"]
        questions = samples.get("question", [None] * len(samples["text"]))
        
        img_embeds, atts_img = self.encode_img(image)
        img_embeds = self.layer_norm(img_embeds)

        # Wrap image with prompt (now includes question)
        img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, questions)

        self.tokenizer.padding_side = "right"
        text = [t + self.end_sym for t in samples["text"]]

        # Tokenize target text
        to_regress_tokens = self.tokenizer(
            text,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.hparams.max_length,
            add_special_tokens=False
        ).to(image[0].device)

        # Create labels: mask prompt+image tokens with -100, keep text tokens
        targets = to_regress_tokens.input_ids.masked_fill(
            to_regress_tokens.input_ids == self.tokenizer.pad_token_id, -100
        )

        # Create empty targets for prompt+image tokens
        empty_targets = (
            torch.ones([atts_img.shape[0], atts_img.shape[1]],
                       dtype=torch.long).to(image[0].device).fill_(-100)
        )
        targets = torch.cat([empty_targets, targets], dim=1)

        # Get text embeddings
        with torch.no_grad():
            to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids)
        
        # Concatenate all embeddings: [prompt+image] + [text]
        inputs_embeds = torch.cat([img_embeds, to_regress_embeds], dim=1)
        attention_mask = torch.cat([atts_img, to_regress_tokens.attention_mask], dim=1)

        # Forward through LLM
        outputs = self.model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            return_dict=True,
            labels=targets,
            use_cache=False
        )
        all_loss = outputs.loss

        return {"loss": all_loss}

    def training_step(self, batch, batch_idx):
        result = self(batch)
        self.log_dict(result, prog_bar=True)
        return result

    def save_checkpoint(self, eval_res):
        current_epoch, global_step = self.trainer.current_epoch, self.trainer.global_step
        param_grad_dic = {
            k: v.requires_grad for (k, v) in self.named_parameters() if v.requires_grad
        }
        state_dict = self.state_dict()
        for k in list(state_dict.keys()):
            if k not in param_grad_dic.keys():
                del state_dict[k]
        
        save_obj = {
            "state_dict": state_dict,
            "hyper_parameters": self.hparams,
            "pytorch-lightning_version": pl.__version__,
            "epoch": current_epoch,
            "global_step": global_step,
        }
        os.makedirs(os.path.join(self.hparams.savedmodel_path, 'checkpoints'), exist_ok=True)
        save_to = os.path.join(
            self.hparams.savedmodel_path, 'checkpoints',
            "checkpoint_epoch{}_step{}_rougle_l{:3}_bleu{:3f}_cider{:3f}.pth".format(
                current_epoch, global_step, eval_res['ROUGE_L'],eval_res['Bleu_4'], eval_res['CIDEr']
            ),
        )
        self.print("Saving checkpoint at step {} to {}.".format(global_step, save_to))
        torch.save(save_obj, save_to)
    
    def decode(self, output_token):
        """Decode output tokens to text."""
        # Remove special tokens at the beginning
        if len(output_token) > 0 and output_token[0] == self.tokenizer.pad_token_id:
            output_token = output_token[1:]
        if len(output_token) > 0 and output_token[0] == self.tokenizer.bos_token_id:
            output_token = output_token[1:]
        
        # Decode to text
        output_text = self.tokenizer.decode(output_token, add_special_tokens=False)
        
        # Split at end symbol and clean up
        output_text = output_text.split(self.end_sym)[0].strip()
        
        # Remove Qwen special tokens
        output_text = output_text.replace('<|im_start|>', '')
        output_text = output_text.replace('<|im_end|>', '')
        output_text = output_text.replace('<|endoftext|>', '')
        output_text = output_text.replace('<unk>', '')
        
        return output_text

    def _parse_image(self, img):
        pixel_values = self.vit_feature_extractor(img, return_tensors="pt").pixel_values
        return pixel_values[0] 
        
    @torch.no_grad()
    def inference(self, image_paths, question=None, beam_size=1, do_sample=False, 
                 min_new_tokens=1, max_new_tokens=100, repetition_penalty=1.0, 
                 length_penalty=1.0, temperature=1.0):
        """
        Generate answer from images and question.
        
        Args:
            image_paths: List of image paths
            question: Question text (optional, defaults to general description)
            beam_size, do_sample, etc.: Generation parameters
        """
        self.eval()

        images = []
        device = next(self.parameters()).device
        for image_path in image_paths:
            with Image.open(image_path) as pil:
                array = np.array(pil, dtype=np.uint8)
                if array.shape[-1] != 3 or len(array.shape) != 3:
                    array = np.array(pil.convert("RGB"), dtype=np.uint8)
                image = self._parse_image(array)
                image = image.unsqueeze(0).to(device)
                images.append(image)

        dtype = self.model.dtype
        
        img_embeds, atts_img = self.encode_img(images)
        img_embeds = self.layer_norm(img_embeds)

        img_embeds = img_embeds.to(dtype)
        
        # Use the question in the prompt
        if question is None:
            question = "Describe the following image in detail."
        questions = [question]
        
        img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, questions)

        outputs = self.model.generate(
            inputs_embeds=img_embeds,
            attention_mask=atts_img,
            num_beams=beam_size,
            do_sample=do_sample,
            min_new_tokens=min_new_tokens,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            temperature=temperature,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
        )
        
        hypo = [self.decode(i) for i in outputs]
        return hypo

2025-12-12 02:24:25.591792: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765506266.001422      47 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765506266.119632      47 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

In [4]:
from lightning.fabric.utilities.data import AttributeDict
torch.serialization.add_safe_globals([AttributeDict])

ckpt_file="lightmedvlm/checkpoints/epoch=3-step=2468-loss=0.0781.ckpt"   # Absoluate path to .pth file
args = {
    "vision_model":"microsoft/swin-base-patch4-window7-224",
    "llm_model":"Qwen/Qwen3-0.6B"
}
model = LightMedVLMForInference.load_from_checkpoint(ckpt_file,strict=False, **args)
model = model.to("cuda")
model.eval()

Loading vision encoder: microsoft/swin-base-patch4-window7-224


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/352M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/255 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loading Trainable vision encoder: microsoft/swin-base-patch4-window7-224 -- Done
Loading LLM model


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

BOS token ID: None
EOS token ID: 151645
PAD token ID: 151643


config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

trainable params: 2,293,760 || all params: 598,343,680 || trainable%: 0.3834
Loading LLM LoRA Done


LightMedVLMForInference(
  (visual_encoder): SwinModel(
    (embeddings): SwinEmbeddings(
      (patch_embeddings): SwinPatchEmbeddings(
        (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): SwinEncoder(
      (layers): ModuleList(
        (0): SwinStage(
          (blocks): ModuleList(
            (0): SwinLayer(
              (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
              (attention): SwinAttention(
                (self): SwinSelfAttention(
                  (query): Linear(in_features=128, out_features=128, bias=True)
                  (key): Linear(in_features=128, out_features=128, bias=True)
                  (value): Linear(in_features=128, out_features=128, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
                )
                (output): SwinSelf

In [12]:
model.inference(image_paths=["/kaggle/input/mimic-700-images/iu_images/iu_images/CXR3030_IM-1405/0.png"], question="Does the cardiomediastinal silhouette appear normal in the chest X-ray?")  # Inference one report at a time

['Yes, the cardiomediastinal silhouette appears normal in the chest X-ray.']

## Infer IU

In [6]:
from tqdm import tqdm
TEST_ANNOT = "/kaggle/input/mimic-700-images/test/test/vqa/iuxray_test.jsonl"
TEST_IMG_DIR = "/kaggle/input/mimic-700-images/iu_images/iu_images"

# Load JSONL
with open(TEST_ANNOT, "r") as f:
    lines = f.readlines()

print("Số lượng test samples:", len(lines))

results = []

for line in tqdm(lines):
    item = json.loads(line)

    img_path = os.path.join(TEST_IMG_DIR, item["image"])

    # --- FIX CLEAN QUESTION ---
    raw_question = item["question"]
    question = (
        raw_question
        .replace("\n<image>", "")
        .replace("<image>", "")
        .strip()
    )

    gt_answer = item["answer"]

    # --- INFER ---
    pred = model.inference(
        image_paths=[img_path],
        question=question
    )[0]

    # --- SAVE ---
    results.append({
        "id": item["question_id"],
        "image": item["image"],
        "question": question,
        "gt_answer": gt_answer,
        "pred_answer": pred
    })

OUT_FILE = "iuxray_vqa_predictions.json"

with open(OUT_FILE, "w") as f:
    json.dump(results, f, indent=4, ensure_ascii=False)

Số lượng test samples: 2573


  0%|          | 0/2573 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
100%|██████████| 2573/2573 [59:09<00:00,  1.38s/it]  


## Test MIMIC


In [5]:
from tqdm import tqdm
TEST_ANNOT = "/kaggle/input/mimic-700-images/test/test/vqa/mimic_test.jsonl"
TEST_IMG_DIR = "/kaggle/input/mimic-700-images/mimic_cxr_selected_224/mimic_cxr_selected_224"

with open(TEST_ANNOT, "r") as f:
    lines = f.readlines()

print("Số lượng test samples:", len(lines))

results = []

for line in tqdm(lines):
    item = json.loads(line)

    img_path = os.path.join(TEST_IMG_DIR, item["image"])

    raw_question = item["question"]
    question = (
        raw_question
        .replace("\n<image>", "")
        .replace("<image>", "")
        .strip()
    )

    gt_answer = item["answer"]

    pred = model.inference(
        image_paths=[img_path],
        question=question
    )[0]

    results.append({
        "id": item["question_id"],
        "image": item["image"],
        "question": question,
        "gt_answer": gt_answer,
        "pred_answer": pred
    })

OUT_FILE = "mimic_vqa_predictions.json"

with open(OUT_FILE, "w") as f:
    json.dump(results, f, indent=4, ensure_ascii=False)

Số lượng test samples: 3470


  0%|          | 0/3470 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
100%|██████████| 3470/3470 [1:08:51<00:00,  1.19s/it]
