In [2]:
import torch
from pathlib import Path
from tqdm import tqdm
from PIL import Image
from transformers import BlipForQuestionAnswering, BlipProcessor
from torch.utils.data import DataLoader, Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
MAX_LENGTH=20
BATCH_SIZE=8
NUM_EPOCHS=20
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
BASE_DIR=Path('D:\\project\\student_resource 3')
%cd {BASE_DIR}

cuda
D:\project\student_resource 3


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [3]:
# model_name = "Salesforce/blip-vqa-base"
# model = BlipForQuestionAnswering.from_pretrained(model_name)
# processor = BlipProcessor.from_pretrained(model_name)

In [4]:
save_path = f"{BASE_DIR}/models/finetuned_blip_model"
model = BlipForQuestionAnswering.from_pretrained(save_path)
processor = BlipProcessor.from_pretrained(save_path)

In [5]:
class VQADataset(Dataset):
    def __init__(self, dataframe, processor, max_length=MAX_LENGTH):
        self.dataframe = dataframe
        self.processor = processor
        self.max_length = max_length

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        image = Image.open(row["image_path"]).convert("RGB")
        question = row["question"]
        answer = row["answer"]

        inputs = self.processor(
            text=question,
            images=image,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        inputs["labels"] = self.processor.tokenizer(
            answer,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        ).input_ids

        for key in inputs:
            inputs[key] = inputs[key].squeeze(0)

        return inputs

In [6]:
import pandas as pd

df = pd.read_csv(f"{BASE_DIR}/data/train_vqa.csv")
train_dataset = VQADataset(df, processor)

In [7]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [8]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
model = model.to(device)

In [9]:
import os

ckpt_dir = "./checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)

In [10]:
for epoch in range(21, 21 + NUM_EPOCHS):
    model.train()
    total_loss = 0
    for batch in tqdm(train_dataloader):
        input_ids = batch["input_ids"].to(device)
        pixel_values = batch["pixel_values"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(
            input_ids=input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            labels=labels,
        )
        loss = outputs.loss
        total_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}, Loss: {avg_loss}")
    checkpoint_path = os.path.join(ckpt_dir, f"epoch_{epoch+1}.pth")
    torch.save(
        {
            "epoch": epoch + 1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "loss": avg_loss,
        },
        checkpoint_path,
    )
    print(f"Checkpoint saved at {checkpoint_path}")
    print()

  0%|          | 0/2000 [00:00<?, ?it/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
100%|██████████| 2000/2000 [29:56<00:00,  1.11it/s]


Epoch 22/20, Loss: 0.3875201988518238
Checkpoint saved at ./checkpoints\epoch_22.pth



100%|██████████| 2000/2000 [25:56<00:00,  1.28it/s]


Epoch 23/20, Loss: 0.382313860565424
Checkpoint saved at ./checkpoints\epoch_23.pth



100%|██████████| 2000/2000 [28:45<00:00,  1.16it/s]


Epoch 24/20, Loss: 0.3881159516274929
Checkpoint saved at ./checkpoints\epoch_24.pth



100%|██████████| 2000/2000 [31:00<00:00,  1.08it/s]


Epoch 25/20, Loss: 0.38054754415154457
Checkpoint saved at ./checkpoints\epoch_25.pth



100%|██████████| 2000/2000 [25:34<00:00,  1.30it/s]


Epoch 26/20, Loss: 0.3765271860510111
Checkpoint saved at ./checkpoints\epoch_26.pth



100%|██████████| 2000/2000 [26:39<00:00,  1.25it/s]


Epoch 27/20, Loss: 0.3742283678650856
Checkpoint saved at ./checkpoints\epoch_27.pth



100%|██████████| 2000/2000 [25:36<00:00,  1.30it/s]


Epoch 28/20, Loss: 0.37472163151204585
Checkpoint saved at ./checkpoints\epoch_28.pth



100%|██████████| 2000/2000 [25:50<00:00,  1.29it/s]


Epoch 29/20, Loss: 0.3701839331239462
Checkpoint saved at ./checkpoints\epoch_29.pth



100%|██████████| 2000/2000 [26:55<00:00,  1.24it/s]


Epoch 30/20, Loss: 0.36726969315111635
Checkpoint saved at ./checkpoints\epoch_30.pth



100%|██████████| 2000/2000 [28:13<00:00,  1.18it/s]


Epoch 31/20, Loss: 0.3650258024483919
Checkpoint saved at ./checkpoints\epoch_31.pth



100%|██████████| 2000/2000 [26:26<00:00,  1.26it/s]


Epoch 32/20, Loss: 0.3625302555710077
Checkpoint saved at ./checkpoints\epoch_32.pth



100%|██████████| 2000/2000 [25:58<00:00,  1.28it/s]


Epoch 33/20, Loss: 0.8288284711241722
Checkpoint saved at ./checkpoints\epoch_33.pth



100%|██████████| 2000/2000 [26:43<00:00,  1.25it/s]


Epoch 34/20, Loss: 1.6551950470805168
Checkpoint saved at ./checkpoints\epoch_34.pth



100%|██████████| 2000/2000 [25:39<00:00,  1.30it/s]


Epoch 35/20, Loss: 1.6526521760821342
Checkpoint saved at ./checkpoints\epoch_35.pth



100%|██████████| 2000/2000 [25:52<00:00,  1.29it/s]


Epoch 36/20, Loss: 1.6341559911370278
Checkpoint saved at ./checkpoints\epoch_36.pth



100%|██████████| 2000/2000 [25:37<00:00,  1.30it/s]


Epoch 37/20, Loss: 1.6266684315800666
Checkpoint saved at ./checkpoints\epoch_37.pth



100%|██████████| 2000/2000 [25:30<00:00,  1.31it/s]


Epoch 38/20, Loss: 1.6194599456191063
Checkpoint saved at ./checkpoints\epoch_38.pth



100%|██████████| 2000/2000 [25:27<00:00,  1.31it/s]


Epoch 39/20, Loss: 1.6156916198134423
Checkpoint saved at ./checkpoints\epoch_39.pth



100%|██████████| 2000/2000 [25:28<00:00,  1.31it/s]


Epoch 40/20, Loss: 1.6138958581089973
Checkpoint saved at ./checkpoints\epoch_40.pth



100%|██████████| 2000/2000 [25:29<00:00,  1.31it/s]


Epoch 41/20, Loss: 1.6139339218735695
Checkpoint saved at ./checkpoints\epoch_41.pth



In [11]:
save_path = f"{BASE_DIR}/models/finetuned_blip_model"
model.save_pretrained(save_path)
processor.save_pretrained(save_path)

[]

# Testing

In [5]:
class VQATestDataset(Dataset):
    def __init__(self, dataframe, processor, max_length=MAX_LENGTH):
        self.dataframe = dataframe
        self.processor = processor
        self.max_length = max_length

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        image = Image.open(row["image_path"]).convert("RGB")
        question = row["question"]
        answer = row["answer"]  # Ground-truth answer for evaluation

        inputs = self.processor(
            text=question,
            images=image,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        for key in inputs:
            inputs[key] = inputs[key].squeeze(0)

        return {
            "pixel_values": inputs["pixel_values"],
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "question": question,
            "answer": answer,  # For evaluation purposes
        }

In [None]:
import pandas as pd

test_df = pd.read_csv(f"{BASE_DIR}/data/test_vqa.csv")

In [8]:
test_dataset = VQATestDataset(test_df, processor)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [9]:
from evaluate import load

metric = load("sacrebleu")

In [None]:
model = BlipForQuestionAnswering.from_pretrained(save_path)
processor = BlipProcessor.from_pretrained(save_path)

In [20]:
model = model.to(device)

In [None]:
model.load_state_dict(
    torch.load(f"{BASE_DIR}/checkpoints/epoch_31.pth")["model_state_dict"]
)

  model.load_state_dict(torch.load(f"{BASE_DIR}/checkpoints/epoch_31.pth")['model_state_dict'])


<All keys matched successfully>

In [None]:
model = model.to(device)

In [14]:
model.eval()

predictions = []
references = []

with torch.no_grad():
    for batch in tqdm(test_dataloader):
        pixel_values = batch["pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        # Generate predictions
        outputs = model.generate(
            input_ids=input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            max_length=MAX_LENGTH,  # Ensure this matches your model configuration
            num_beams=5,  # Beam search for better quality answers
        )

        # Decode predictions and ground-truth answers
        predicted_answer = processor.tokenizer.batch_decode(
            outputs, skip_special_tokens=True
        )
        ground_truth_answer = batch["answer"]

        predictions.append(predicted_answer)
        references.append(ground_truth_answer)  # List of lists for BLEU metric

        metric.add_batch(predictions=predicted_answer, references=ground_truth_answer)

        # print(f"Q: {batch['question'][0]}")
        # print(f"Predicted: {predicted_answer}")
        # print(f"Ground Truth: {ground_truth_answer}\n")

# Evaluate the BLEU score
results = metric.compute()
print(f"BLEU Score: {results['score']}")

  0%|          | 0/500 [00:00<?, ?it/s]

100%|██████████| 500/500 [06:41<00:00,  1.25it/s]


BLEU Score: 0.10159281170372051


In [15]:
import numpy as np
from itertools import chain

In [17]:
results_df = pd.DataFrame()
results_df["labels"] = list(chain.from_iterable(references))
results_df["predictions"] = list(chain.from_iterable(predictions))
results_df.head()

Unnamed: 0,labels,predictions
0,6.5 kilogram,200. 0 millilitre
1,96.0 watt,10. 0 centimetre
2,230.0 millimetre,5. 0 volt
3,10.0 centimetre,12. 0 centimetre
4,60.0 watt,"[ 220. 0, 240. 0 ] volt"


In [18]:
results_df.to_csv(f"{BASE_DIR}/results/finetuned_blip.csv", index=False)