In [1]:
!pip install -q torch torchvision torchaudio
!pip install -q transformers datasets accelerate Pillow huggingface_hub
!pip install -q bert_score
!pip install -q --no-deps peft

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import os
import json
import torch
import torch.nn as nn
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPVisionModel, AutoImageProcessor
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from huggingface_hub import snapshot_download
import gc
from peft import LoraConfig, get_peft_model, TaskType
import warnings
from tqdm.auto import tqdm
import logging
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import Callback

warnings.filterwarnings('ignore')
os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
logging.getLogger('transformers').setLevel(logging.ERROR)

def apply_lora_to_qwen(
    qwen,
    r=8,
    alpha=16,
    dropout=0.05
):
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=r,
        lora_alpha=alpha,
        lora_dropout=dropout,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj"
        ]
    )

    qwen = get_peft_model(qwen, lora_config)
    qwen.print_trainable_parameters()
    return qwen

def load_qwen_frozen():
    model_name = "Qwen/Qwen1.5-0.5B"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    qwen = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float32,
        device_map="cuda:0",
        load_in_8bit=False,
        low_cpu_mem_usage=True
    )
    qwen.config.pad_token_id = tokenizer.pad_token_id
    for p in qwen.parameters():
        p.requires_grad = False
    qwen.gradient_checkpointing_enable()
    qwen.enable_input_require_grads()
    return tokenizer, qwen

class VqaradDataset(Dataset):
    def __init__(self, image_processor, split='train', hf_repo_id='flaviagiammarino/vqa-rad', max_samples=None):
        self.image_processor = image_processor
        self.dataset = load_dataset(hf_repo_id, split=split, streaming=False)
        self.dataset = list(self.dataset)
        if max_samples is not None:
            self.dataset = self.dataset[:max_samples]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item['image'].convert("RGB")
        image = image.resize((224, 224), Image.LANCZOS)
        processed_output = self.image_processor(image, return_tensors="pt")
        if hasattr(processed_output, 'pixel_values'):
            processed_image = processed_output.pixel_values
        else:
            processed_image = processed_output
        if processed_image.dim() == 4:
            processed_image = processed_image.squeeze(0)
        image.close()
        del image
        return {
            "image": processed_image,
            "question": item['question'],
            "answer": item['answer']
        }

class VisionMLP(nn.Module):
    def __init__(self, vision_encoder, encoder_output_dim, mlp_output_dim, hidden_dim=1280):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.mlp = nn.Sequential(
            nn.Linear(encoder_output_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, mlp_output_dim),
            nn.LayerNorm(mlp_output_dim)
        )
        self._init_weights()

    def _init_weights(self):
        for module in self.mlp.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight, gain=0.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, pixel_values):
        with torch.no_grad():
            outputs = self.vision_encoder(pixel_values, output_hidden_states=False)
            patch_tokens = outputs.last_hidden_state[:, 1:, :]
            if torch.isnan(patch_tokens).any():
                patch_tokens = torch.nan_to_num(patch_tokens, nan=0.0)
        original_dtype = patch_tokens.dtype
        patch_tokens_fp32 = patch_tokens.float()
        mlp_output = self.mlp(patch_tokens_fp32)
        mlp_output = torch.clamp(mlp_output, min=-10.0, max=10.0)
        if torch.isnan(mlp_output).any():
            mlp_output = torch.nan_to_num(mlp_output, nan=0.0)
        return mlp_output.to(original_dtype)

def get_model(encoder_choice, mlp_output_dim, hidden_dim=1280):
    device = "cuda:0"
    model_paths = {
        'standard_clip': 'openai/clip-vit-base-patch32',
        'pubmedclip': 'flaviagiammarino/pubmed-clip-vit-base-patch32',
    }
    vision_encoder = CLIPVisionModel.from_pretrained(model_paths[encoder_choice]).to(device)
    image_processor = AutoImageProcessor.from_pretrained(model_paths[encoder_choice])
    encoder_output_dim = vision_encoder.config.hidden_size
    model = VisionMLP(
        vision_encoder=vision_encoder,
        encoder_output_dim=encoder_output_dim,
        mlp_output_dim=mlp_output_dim,
        hidden_dim=hidden_dim
    ).to(device)
    for param in model.vision_encoder.parameters():
        param.requires_grad = False
    return model, image_processor

class MultiModalVQA(nn.Module):
    def __init__(self, vision_mlp_model, qwen, tokenizer):
        super().__init__()
        self.vision = vision_mlp_model
        self.qwen = qwen
        self.tokenizer = tokenizer
        self.device = torch.device("cuda:0")

    def get_base_model(self):
        if hasattr(self.qwen, "get_input_embeddings"):
            return self.qwen
        if hasattr(self.qwen, "base_model"):
            return self.qwen.base_model
        raise RuntimeError("Cannot locate base language model")

    def forward(self, images, questions, answers=None):
        device = self.device
        dtype = self.qwen.dtype
        patch_embeds = self.vision(images)
        if torch.isnan(patch_embeds).any():
            patch_embeds = torch.nan_to_num(patch_embeds, nan=0.0)
        patch_embeds = patch_embeds.to(dtype).to(device)
        B, P, H = patch_embeds.shape
        q_tok = self.tokenizer(
            questions,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=64
        ).to(device)
        

        base_model = self.get_base_model()
        text_embeds = base_model.get_input_embeddings()(q_tok.input_ids)

        
        if torch.isnan(text_embeds).any():
            text_embeds = torch.nan_to_num(text_embeds, nan=0.0)
        text_mask = q_tok.attention_mask.to(device)
        inputs_embeds = torch.cat([patch_embeds, text_embeds], dim=1)
        patch_mask = torch.ones((B, P), device=device, dtype=text_mask.dtype)
        attention_mask = torch.cat([patch_mask, text_mask], dim=1)
        labels = None
        if answers is not None:
            a_tok = self.tokenizer(
                answers,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=32
            ).to(device)
            labels_list = []
            for i in range(B):
                vis_pad = torch.full((P,), -100, device=device, dtype=torch.long)
                q_len = text_mask[i].sum().item()
                q_pad = torch.full((q_len,), -100, device=device, dtype=torch.long)
                ans_len = (a_tok.input_ids[i] != self.tokenizer.pad_token_id).sum().item()
                ans_tokens = a_tok.input_ids[i, :ans_len]
                label_seq = torch.cat([vis_pad, q_pad, ans_tokens], dim=0)
                labels_list.append(label_seq)
            max_len = max(len(l) for l in labels_list)
            labels = torch.full((B, max_len), -100, device=device, dtype=torch.long)
            for i, label_seq in enumerate(labels_list):
                labels[i, :len(label_seq)] = label_seq
            a_embeds = base_model.get_input_embeddings()(a_tok.input_ids).to(dtype).to(device)
            if torch.isnan(a_embeds).any():
                a_embeds = torch.nan_to_num(a_embeds, nan=0.0)
            a_mask = (a_tok.input_ids != self.tokenizer.pad_token_id).long().to(device)
            inputs_embeds = torch.cat([inputs_embeds, a_embeds], dim=1)
            attention_mask = torch.cat([attention_mask, a_mask], dim=1)
            if labels.size(1) != inputs_embeds.size(1):
                if labels.size(1) > inputs_embeds.size(1):
                    pad_len = labels.size(1) - inputs_embeds.size(1)
                    pad_emb = torch.zeros(B, pad_len, H, device=device, dtype=dtype)
                    inputs_embeds = torch.cat([inputs_embeds, pad_emb], dim=1)
                    pad_mask = torch.zeros(B, pad_len, device=device, dtype=attention_mask.dtype)
                    attention_mask = torch.cat([attention_mask, pad_mask], dim=1)
                elif inputs_embeds.size(1) > labels.size(1):
                    pad_len = inputs_embeds.size(1) - labels.size(1)
                    pad_labels = torch.full((B, pad_len), -100, device=device, dtype=torch.long)
                    labels = torch.cat([labels, pad_labels], dim=1)
        if torch.isnan(inputs_embeds).any():
            inputs_embeds = torch.nan_to_num(inputs_embeds, nan=0.0)
        out = self.qwen(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels
        )
        return out

class LitMultiModalVQA(pl.LightningModule):
    def __init__(
        self,
        model: MultiModalVQA,
        lr: float,
        stage: str = "mlp"
    ):
        super().__init__()
        self.add_module("model", model)  # Register as submodule
        self.lr = lr
        self.stage = stage
        self.save_hyperparameters(ignore=["model"])

    def forward(self, images, questions, answers=None):
        return self.model(images, questions, answers)

    def training_step(self, batch, batch_idx):
        images = batch["image"]
        questions = batch["question"]
        answers = batch["answer"]

        out = self(images, questions, answers)
        loss = out.loss

        self.log(
            "train_loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            batch_size=images.size(0)
        )
        return loss

    def validation_step(self, batch, batch_idx):
        images = batch["image"]
        questions = batch["question"]
        answers = batch["answer"]

        out = self(images, questions, answers)
        loss = out.loss

        self.log(
            "val_loss",
            loss,
            on_epoch=True,
            prog_bar=True,
            batch_size=images.size(0)
        )
        return loss

    def configure_optimizers(self):
        if self.stage == "mlp":
            params = self.model.vision.mlp.parameters()
        else:
            params = self.model.qwen.parameters()

        optimizer = torch.optim.AdamW(
            params,
            lr=self.lr,
            weight_decay=0.01
        )
        return optimizer

def get_dataloader(dataset_choice, image_processor, batch_size=2, max_samples=None):
    if dataset_choice == 'vqa_rad':
        dataset = VqaradDataset(image_processor=image_processor, split='train', max_samples=max_samples)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=False
    )

def save_checkpoint(model, epoch, checkpoint_dir="checkpoints", stage="mlp"):
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, f"{stage}_epoch_{epoch}.pt")
    
    if stage == "mlp":
        torch.save({
            'epoch': epoch,
            'vision_mlp_state_dict': model.vision.state_dict(),
            'model_state_dict': model.state_dict(),
        }, checkpoint_path)
    else:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
        }, checkpoint_path)
    
    print(f"Checkpoint saved: {checkpoint_path}")

class BertF1Callback(Callback):
    def __init__(self, test_dataloader, device="cuda:0", num_examples=3):
        super().__init__()
        self.test_dataloader = test_dataloader
        self.device = device
        self.num_examples = num_examples

    @torch.no_grad()
    def on_epoch_end(self, trainer, pl_module):
        pl_module.eval()
        model = pl_module.model
        preds, refs, questions = [], [], []

        for batch_idx, batch in enumerate(self.test_dataloader):
            images = batch["image"].to(self.device)
            batch_questions = batch["question"]
            batch_answers = batch["answer"]
            batch_preds = generate_answer(model, images, batch_questions, device=self.device)
            preds.extend(batch_preds)
            refs.extend(batch_answers)
            questions.extend(batch_questions)
            if len(preds) >= self.num_examples:
                break

        print("\nSample predictions (Epoch {}):".format(trainer.current_epoch + 1))
        for i in range(min(self.num_examples, len(preds))):
            print(f"\n  Example {i + 1}:")
            print(f"    Question:   {questions[i]}")
            print(f"    Predicted:  {preds[i]}")
            print(f"    Reference:  {refs[i]}")

        pl_module.train()

class CheckpointEveryEpoch(Callback):
    def __init__(self, stage):
        super().__init__()
        self.stage = stage

    def on_train_epoch_end(self, trainer, pl_module):
        save_checkpoint(pl_module.model, trainer.current_epoch + 1, stage=self.stage)

@torch.no_grad()
def generate_answer(model, images, questions, max_length=32, device="cuda:0"):
    model.eval()
    batch_size = images.size(0)
    results = []

    base_model = model.get_base_model()

    for i in range(batch_size):
        single_image = images[i:i+1]

        single_image = single_image.to(device=device, dtype=next(model.vision.parameters()).dtype)
        if torch.isnan(single_image).any():
            single_image = torch.nan_to_num(single_image, nan=0.0)

        patch_embeds = model.vision(single_image)
        patch_embeds = patch_embeds.to(model.qwen.dtype).to(device)
        if torch.isnan(patch_embeds).any():
            patch_embeds = torch.nan_to_num(patch_embeds, nan=0.0)

        q_tok = model.tokenizer(
            [questions[i]],
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=64
        ).to(device)

        text_embeds = base_model.get_input_embeddings()(q_tok.input_ids).to(model.qwen.dtype)
        inputs_embeds = torch.cat([patch_embeds, text_embeds], dim=1)
        patch_mask = torch.ones((1, patch_embeds.size(1)), device=device)
        attention_mask = torch.cat([patch_mask, q_tok.attention_mask], dim=1)

        out_ids = model.qwen.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            max_new_tokens=7,
            min_new_tokens=1,
            do_sample=False,
            num_beams=1,
            repetition_penalty=1.5,
            no_repeat_ngram_size=2,
            eos_token_id=model.tokenizer.eos_token_id,
            pad_token_id=model.tokenizer.pad_token_id
        )

        text = model.tokenizer.decode(out_ids[0], skip_special_tokens=True)
        if text.startswith(questions[i]):
            text = text[len(questions[i]):].strip()
        text = text.split('.')[0].split('!')[0].split('?')[0].strip()
        if not text:
            text = "unknown"

        results.append(text)

        del patch_embeds, text_embeds, inputs_embeds, attention_mask, out_ids, q_tok
        torch.cuda.empty_cache()

    return results

@torch.no_grad()
def evaluate_bert_f1_and_examples(model, dataloader, device="cuda:0", num_examples=3):
    model = model.to(device)  # Ensure model is on the correct device
    model.eval()
    preds, refs, questions = [], [], []

    for batch_idx, batch in enumerate(dataloader):
        images = batch["image"].to(device)
        batch_questions = batch["question"]
        batch_answers = batch["answer"]

        batch_preds = generate_answer(model, images, batch_questions, device=device)
        preds.extend(batch_preds)
        refs.extend(batch_answers)
        questions.extend(batch_questions)

        if len(preds) >= num_examples:
            break

    print("\nSample predictions:")
    for i in range(min(num_examples, len(preds))):
        print(f"\n  Example {i + 1}:")
        print(f"    Question:   {questions[i]}")
        print(f"    Predicted:  {preds[i]}")
        print(f"    Reference:  {refs[i]}")

    print(f"\nSkipping BERTScore F1 (library unavailable)")
    return 0.0

if __name__ == "__main__":
    gc.collect()
    torch.cuda.empty_cache()
    print("Loading models...")
    tokenizer, qwen = load_qwen_frozen()
    vision_model, image_processor = get_model(
        encoder_choice="pubmedclip",
        mlp_output_dim=qwen.config.hidden_size,
        hidden_dim=2048
    )
    model = MultiModalVQA(
        vision_mlp_model=vision_model,
        qwen=qwen,
        tokenizer=tokenizer
    )

    train_loader = get_dataloader("vqa_rad", image_processor, batch_size=16)
    test_loader = get_dataloader("vqa_rad", image_processor, batch_size=1, max_samples=20)

    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        mode="min",
        save_top_k=1,
        filename="best-{epoch}-{val_loss:.4f}",
        save_weights_only=False
    )

    # ----------------- MLP Stage -----------------
    print("\nStarting MLP training...")
    lit_model = LitMultiModalVQA(
        model=model,
        lr=3e-5,
        stage="mlp"
    )

    trainer = pl.Trainer(
        accelerator="gpu",
        devices=1,
        precision="16-mixed",
        max_epochs=3,
        callbacks=[checkpoint_callback, BertF1Callback(test_loader), CheckpointEveryEpoch("mlp")],
        log_every_n_steps=10,
    )

    trainer.fit(
        lit_model,
        train_dataloaders=train_loader,
        val_dataloaders=test_loader,
    )
    model_mlp = lit_model.model
    model_mlp.to("cuda:0")
    evaluate_bert_f1_and_examples(model_mlp, test_loader, device="cuda:0", num_examples=3)

    # ----------------- LoRA Stage -----------------
    print("\nApplying LoRA to Qwen...")
    model.qwen = apply_lora_to_qwen(model.qwen)

    lit_model_lora = LitMultiModalVQA(
        model=model,
        lr=1e-4,
        stage="lora"
    )

    trainer = pl.Trainer(
        accelerator="gpu",
        devices=1,
        precision="16-mixed",
        max_epochs=2,
        callbacks=[checkpoint_callback, BertF1Callback(test_loader), CheckpointEveryEpoch("lora")],
        log_every_n_steps=10,
    )

    trainer.fit(
        lit_model_lora,
        train_dataloaders=train_loader,
        val_dataloaders=test_loader,
    )

    model_lora = lit_model_lora.model
    model_lora.to("cuda:0")
    evaluate_bert_f1_and_examples(model_lora, test_loader, device="cuda:0", num_examples=3)


    print("\n" + "="*80)
    print("Loading checkpoint and making predictions...")
    print("="*80)
    
    del model_lora, lit_model_lora, trainer
    gc.collect()
    torch.cuda.empty_cache()
    
    tokenizer_new, qwen_new = load_qwen_frozen()
    vision_model_new, image_processor_new = get_model(
        encoder_choice="pubmedclip",
        mlp_output_dim=qwen_new.config.hidden_size,
        hidden_dim=2048
    )
    model_new = MultiModalVQA(vision_model_new, qwen_new, tokenizer_new)
    model_new.qwen = apply_lora_to_qwen(model_new.qwen)
    
    checkpoint_path = "checkpoints/lora_epoch_1.pt"
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location="cuda:0")
        model_new.load_state_dict(checkpoint['model_state_dict'])
        print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
    
    model_new.to("cuda:0").eval()
    test_loader_inference = get_dataloader("vqa_rad", image_processor_new, batch_size=1, max_samples=10)
    
    all_preds, all_refs, all_questions = [], [], []
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader_inference):
            images = batch["image"].to("cuda:0")
            questions = batch["question"]
            answers = batch["answer"]
            
            predictions = generate_answer(model_new, images, questions, device="cuda:0")
            all_preds.extend(predictions)
            all_refs.extend(answers)
            all_questions.extend(questions)
            
            print(f"\nExample {batch_idx + 1}:")
            print(f"  Question:  {questions[0]}")
            print(f"  Predicted: {predictions[0]}")
            print(f"  Reference: {answers[0]}")
            print(f"  Match:     {'✓' if predictions[0].lower().strip() == answers[0].lower().strip() else '✗'}")
    
    exact_matches = sum(1 for p, r in zip(all_preds, all_refs) if p.lower().strip() == r.lower().strip())
    print(f"\nExact match: {exact_matches}/{len(all_preds)} ({exact_matches/len(all_preds)*100:.2f}%)")
    
    with open("predictions_output.json", 'w') as f:
        json.dump([{"question": q, "predicted": p, "reference": r} 
                   for q, p, r in zip(all_questions, all_preds, all_refs)], f, indent=2)

2025-12-21 16:41:03.683487: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766335263.872443      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766335263.927572      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766335264.370470      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1766335264.370502      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1766335264.370505      55 computation_placer.cc:177] computation placer alr

Loading models...


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/661 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.24G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001-eb8844602202be(…):   0%|          | 0.00/24.2M [00:00<?, ?B/s]

data/test-00000-of-00001-e5bc3d208bb4dee(…):   0%|          | 0.00/10.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1793 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/451 [00:00<?, ? examples/s]

INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores



Starting MLP training...


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Output()

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=3` reached.


INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)



Sample predictions:

  Example 1:
    Question:   how was this image taken?
    Predicted:  mri-ct, 1
    Reference:  mri

  Example 2:
    Question:   are the lungs affected?
    Predicted:  no, lung affected yes
l
    Reference:  no

  Example 3:
    Question:   are the lungs normal appearing?
    Predicted:  no, yes
    Reference:  no

Skipping BERTScore F1 (library unavailable)

Applying LoRA to Qwen...
trainable params: 3,784,704 || all params: 467,772,416 || trainable%: 0.8091


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Output()

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached.



Sample predictions:

  Example 1:
    Question:   how was this image taken?
    Predicted:  air with air of the and,
    Reference:  mri

  Example 2:
    Question:   where is the mass?
    Predicted:  bilateral lobe the left side
    Reference:  pineal region

  Example 3:
    Question:   which organ system is abnormal in this image?
    Predicted:  card liver and the ofth kidney
    Reference:  cardiovascular

Skipping BERTScore F1 (library unavailable)

Loading checkpoint and making predictions...
trainable params: 3,784,704 || all params: 467,772,416 || trainable%: 0.8091
Loaded checkpoint from epoch 1

Example 1:
  Question:  is this image in the transverse plane?
  Predicted: no,yes
  Reference: yes
  Match:     ✗

Example 2:
  Question:  how was this image taken?
  Predicted: axialicaliacalip
  Reference: mri
  Match:     ✗

Example 3:
  Question:  what is the location of the mass?
  Predicted: right side of the brain and right
  Reference: pineal region
  Match:     ✗

Example

In [3]:
!pip install -q aiogram

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m698.4/698.4 kB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ypy-websocket 0.8.4 requires aiofiles<23,>=22.1.0, but you have aiofiles 25.1.0 which is incompatible.
gradio 5.49.1 requires aiofiles<25.0,>=22.0, but you have aiofiles 25.1.0 which is incompatible.
gradio 5.49.1 requires pydantic<2.12,>=2.0, but you have pydantic 2.12.5 which is incompatible.[0m[31m
[0m

In [None]:
import asyncio
from aiogram import Bot, Dispatcher, types
from aiogram.filters import CommandStart
from aiogram.types import FSInputFile
import nest_asyncio


def load_model():
    gc.collect()
    torch.cuda.empty_cache()
    tokenizer, qwen = load_qwen_frozen()
    vision_model, image_processor = get_model(
        encoder_choice="pubmedclip",
        mlp_output_dim=qwen.config.hidden_size,
        hidden_dim=2048
    )
    model = MultiModalVQA(
        vision_mlp_model=vision_model,
        qwen=qwen,
        tokenizer=tokenizer
    )
    model.qwen = apply_lora_to_qwen(model.qwen)
    
    checkpoint_path = "checkpoints/lora_epoch_1.pt"
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location="cuda:0")
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
    else:
        raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}. Please train the model first.")
    
    model.to("cuda:0").eval()
    return model, image_processor

# Telegram Bot Code with aiogram
TOKEN = '6539903139:AAF3IRqIXQuoIOQZs7iMis83PuYp9q7eTQA'  

model, image_processor = load_model()

bot = Bot(token=TOKEN)
dp = Dispatcher()

@dp.message(CommandStart())
async def start(message: types.Message):
    await message.reply('Send me an image with a caption as the question, and I\'ll answer using the VQA model!')

@dp.message()
async def handle_photo(message: types.Message):
    print("Received message from user:", message.from_user.id)
    print("Message has photo:", bool(message.photo))
    print("Message has caption:", bool(message.caption))
    if message.photo and message.caption:
        try:
            print("Starting photo processing...")
            # Download the photo
            photo = message.photo[-1]  # Get the highest resolution
            print("Selected photo size:", photo.width, "x", photo.height)
            file = await bot.get_file(photo.file_id)
            print("Got file info:", file.file_id, file.file_path)
            photo_path = 'temp_image.jpg'
            await bot.download_file(file.file_path, photo_path)
            print("Downloaded photo to:", photo_path)
            
            # Process the image
            print("Opening image...")
            image = Image.open(photo_path).convert("RGB")
            print("Resizing image...")
            image = image.resize((224, 224), Image.LANCZOS)
            print("Processing image with image_processor...")
            processed_image = image_processor(image, return_tensors="pt").pixel_values.to("cuda:0")
            print("Processed image shape:", processed_image.shape)
            
            # Get the question (caption)
            question = message.caption
            print("Question:", question)
            
            # Generate answer
            print("Generating answer...")
            prediction = generate_answer(model, processed_image, [question])[0]
            print("Generated prediction:", prediction)
            
            # Reply
            print("Sending reply...")
            await message.reply(f"Question: {question}\nAnswer: {prediction}")
            print("Reply sent.")
            
            # Clean up
            print("Cleaning up...")
            os.remove(photo_path)
            del processed_image
            torch.cuda.empty_cache()
            print("Cleanup complete.")
        except Exception as e:
            print("Error during processing:", str(e))
            await message.reply(f"Error processing photo: {str(e)}")
    else:
        print("Message does not have photo or caption. Sending reminder.")
        await message.reply('Please send an image with a caption as the question.')

async def main():
    await dp.start_polling(bot)

if __name__ == '__main__':
    nest_asyncio.apply()
    asyncio.run(main())

trainable params: 3,784,704 || all params: 467,772,416 || trainable%: 0.8091
Loaded checkpoint from epoch 1
Received message from user: 5410612788
Message has photo: True
Message has caption: True
Starting photo processing...
Selected photo size: 514 x 626
Got file info: AgACAgIAAxkBAAOBaUgb0TyG5lt7ywn5p3w0Gz2fEZcAAqgMaxsMSkBKwcLn6qu5bksBAAMCAAN4AAM2BA photos/file_0.jpg
Downloaded photo to: temp_image.jpg
Opening image...
Resizing image...
Processing image with image_processor...
Processed image shape: torch.Size([1, 3, 224, 224])
Question: hello how are you
Generating answer...
Generated prediction: axial andyes,no
Sending reply...
Reply sent.
Cleaning up...
Cleanup complete.
Received message from user: 5410612788
Message has photo: True
Message has caption: True
Starting photo processing...
Selected photo size: 800 x 757
Got file info: AgACAgIAAxkBAAOOaUglOqHpDGe6ZEaHfrfteO16cX4AAhkNaxsMSkBKfQPsH6joXBgBAAMCAAN4AAM2BA photos/file_1.jpg
Downloaded photo to: temp_image.jpg
Opening image