In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from transformers import BertTokenizer, BertModel

In [None]:
class VQAModel(nn.Module):
    def __init__(self, num_classes=3000):
        super(VQAModel, self).__init__()
        
        # Image encoder (using ResNet)
        self.image_encoder = models.resnet50(pretrained=True)
        self.image_encoder = nn.Sequential(*list(self.image_encoder.children())[:-1])
        
        # BERT for question encoding
        self.question_encoder = BertModel.from_pretrained('bert-base-uncased')
        
        # Fusion and classification layers
        self.fusion = nn.Sequential(
            nn.Linear(2048 + 768, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

In [None]:
def forward(self, image, question_tokens):
        # Image encoding
        img_features = self.image_encoder(image)
        img_features = img_features.view(img_features.size(0), -1)
        
        # Question encoding
        question_features = self.question_encoder(**question_tokens)[1]
        
        # Concatenate features
        combined_features = torch.cat((img_features, question_features), dim=1)
        
        # Final classification
        output = self.fusion(combined_features)
        return output

In [None]:
class VQAPredictor:
    def __init__(self, model_path=None, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.model = VQAModel().to(device)
        if model_path:
            self.model.load_state_dict(torch.load(model_path))
        self.model.eval()
        
        # Initialize tokenizer and image transforms
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.image_transforms = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])
        
        # Load answer vocabulary (simplified version)
        self.idx2ans = {0: "yes", 1: "no", 2: "2", 3: "1", 4: "3", 5: "4",
                       6: "red", 7: "blue", 8: "green", 9: "white"}

In [None]:
def preprocess_image(self, image_path):
        image = Image.open(image_path).convert('RGB')
        image = self.image_transforms(image).unsqueeze(0)
        return image.to(self.device)
    
    def preprocess_question(self, question):
        tokens = self.tokenizer(
            question,
            padding='max_length',
            max_length=128,
            truncation=True,
            return_tensors='pt'
        )
        return {k: v.to(self.device) for k, v in tokens.items()}
    
    def predict(self, image_path, question):
        """
        Make a prediction for a given image and question
        """
        # Preprocess inputs
        image = self.preprocess_image(image_path)
        question_tokens = self.preprocess_question(question)
        
        # Get model prediction
        with torch.no_grad():
            output = self.model(image, question_tokens)
            pred_idx = output.argmax(dim=1).item()
        
        # Convert prediction to answer
        answer = self.idx2ans.get(pred_idx, "I don't know")
        confidence = torch.softmax(output, dim=1)[0][pred_idx].item()
        
        return {
            'answer': answer,
            'confidence': confidence
        }

In [None]:
def train_vqa_model(train_loader, val_loader, num_epochs=10):
    """
    Training function for the VQA model
    """
    model = VQAModel()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        for batch_idx, (images, questions, answers) in enumerate(train_loader):
            images, questions, answers = (
                images.to(device),
                {k: v.to(device) for k, v in questions.items()},
                answers.to(device)
            )
            
            optimizer.zero_grad()
            outputs = model(images, questions)
            loss = criterion(outputs, answers)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            if batch_idx % 100 == 0:
                print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
        
        # Validation phase
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, questions, answers in val_loader:
                images, questions, answers = (
                    images.to(device),
                    {k: v.to(device) for k, v in questions.items()},
                    answers.to(device)
                )
                
                outputs = model(images, questions)
                loss = criterion(outputs, answers)
                val_loss += loss.item()
                
                _, predicted = outputs.max(1)
                total += answers.size(0)
                correct += predicted.eq(answers).sum().item()
        
        print(f'Epoch: {epoch}')
        print(f'Training Loss: {train_loss/len(train_loader):.4f}')
        print(f'Validation Loss: {val_loss/len(val_loader):.4f}')
        print(f'Validation Accuracy: {100.*correct/total:.2f}%')
    
    return model

In [None]:
if __name__ == "__main__":
    # Initialize predictor
    predictor = VQAPredictor()
    
    # Example prediction
    result = predictor.predict(
        image_path="example_image.jpg",
        question="What color is the car?"
    )
    
    print(f"Answer: {result['answer']}")
    print(f"Confidence: {result['confidence']:.2f}")