In [None]:
!pip install -q git+https://github.com/huggingface/peft.git transformers

In [None]:
import torch
import pandas as pd
from PIL import Image
from transformers import Blip2ForConditionalGeneration, Blip2Processor
from peft import LoraConfig, get_peft_model
from torch.utils.data import Dataset, DataLoader

def load_model():
    model = Blip2ForConditionalGeneration.from_pretrained(
        "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
    )
    processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
    
    lora_config = LoraConfig(
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
        target_modules=["q_proj", "v_proj"],
    )
    
    model = get_peft_model(model, lora_config)
    return model, processor

class Blip2Dataset(Dataset):
    def __init__(self, csv_path, processor, image_dir):
        self.data = pd.read_csv(csv_path)
        self.processor = processor
        self.image_dir = image_dir
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        image_path = f"{self.image_dir}/{row['image']}"
        em, en, label = row['EM'], row['EN'], row['label']
        
        try:
            image = Image.open(image_path).convert("RGB").resize((224, 224))
        except FileNotFoundError:
            return None
        
        answer = "Yes" if label == 1 else "No"
        caption = f"Question: does {em} entail {en}? Yes or no only. Answer: {answer}"
        inputs = self.processor(images=image, text=caption, return_tensors="pt", padding=True)
        
        return {"input_ids": inputs.input_ids.squeeze(0), "pixel_values": inputs.pixel_values.squeeze(0), "label": torch.tensor(label, dtype=torch.long)}

def collate_fn(batch):
    batch = [b for b in batch if b is not None]
    if not batch:
        return None
    input_ids = torch.stack([b["input_ids"] for b in batch])
    pixel_values = torch.stack([b["pixel_values"] for b in batch])
    labels = torch.stack([b["label"] for b in batch])
    return {"input_ids": input_ids, "pixel_values": pixel_values, "labels": labels}

def load_dataset(csv_path="dataset.csv", image_dir="CS4248_project/dataset_following_elco_split"):
    processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
    dataset = Blip2Dataset(csv_path, processor, image_dir)
    return DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

def train(model, train_loader, epochs=3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
    
    for epoch in range(epochs):
        for batch in train_loader:
            if batch is None:
                continue
            input_ids, pixel_values, labels = batch["input_ids"].to(device), batch["pixel_values"].to(device), batch["labels"].to(device)
            
            outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=input_ids)
            loss = outputs.loss
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

def inference(model, processor, test_csv, image_dir):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    model.to(device)
    
    data = pd.read_csv(test_csv)
    y_true, y_pred = [], []
    
    for _, row in data.iterrows():
        image_path = f"{image_dir}/{row['image']}"
        em, en, label = row['EM'], row['EN'], row['label']
        
        try:
            image = Image.open(image_path).convert("RGB").resize((224, 224))
        except FileNotFoundError:
            continue
        
        prompt = f"Question: does {em} entail {en}? Yes or no only. Answer:"
        inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
        
        generated_ids = model.generate(**inputs, max_new_tokens=10)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip().lower()
        pred = 1 if "yes" in generated_text else 0
        
        y_true.append(label)
        y_pred.append(pred)
    
    from sklearn.metrics import accuracy_score, f1_score
    accuracy = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    
    print(f"Accuracy: {accuracy:.4f}")
    print(f"F1 Score: {f1:.4f}")

model, processor = load_model()
train_loader = load_dataset()
train(model, train_loader)
model.save_pretrained("./blip2_lora_finetuned")
processor.save_pretrained("./blip2_lora_finetuned")

inference(model, processor, "CS4248_project/dataset_following_elco_split/generated_img_dataset/test.csv", "CS4248_project/dataset_following_elco_split")
