In [None]:
# vqa_clip_gpt2.py

import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, GPT2LMHeadModel, GPT2Tokenizer
import requests
from io import BytesIO

def load_models():
    # Load CLIP for image+text encoding
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    # Load GPT-2 for text generation
    gpt_model = GPT2LMHeadModel.from_pretrained("gpt2")
    gpt_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    return clip_model, clip_processor, gpt_model, gpt_tokenizer

def preprocess_image(url):
    image = Image.open(BytesIO(requests.get(url).content)).convert("RGB")
    return image

def generate_answer(image, question, clip_model, clip_processor, gpt_model, gpt_tokenizer):
    inputs = clip_processor(text=[question], images=image, return_tensors="pt", padding=True)
    outputs = clip_model(**inputs)

    text_features = outputs.text_embeds
    image_features = outputs.image_embeds

    # Simple fusion (for demo): concatenate image and text embeddings
    context_vector = (image_features + text_features) / 2

    # Prompt GPT-2 with question and force answer generation
    prompt = f"Q: {question}\nA:"
    input_ids = gpt_tokenizer(prompt, return_tensors="pt").input_ids
    output = gpt_model.generate(input_ids, max_new_tokens=50, pad_token_id=gpt_tokenizer.eos_token_id)
    answer = gpt_tokenizer.decode(output[0], skip_special_tokens=True)

    return answer.split("A:")[-1].strip()

def main():
    print("Loading models...")
    clip_model, clip_processor, gpt_model, gpt_tokenizer = load_models()

    # Sample image and question
    image_url = "https://raw.githubusercontent.com/zhoubolei/zhou/images/visual_question.jpg"
    image = preprocess_image(image_url)
    question = "What is the man doing?"

    print("Generating answer...")
    answer = generate_answer(image, question, clip_model, clip_processor, gpt_model, gpt_tokenizer)

    print(f"Q: {question}")
    print(f"A: {answer}")

if __name__ == "__main__":
    main()
