In [None]:
from datasets import load_dataset
from transformers import ViltConfig

def preprocess_data(example, config):
    answers = example["answers"]
    answer_counts = {}
    for answer_dict in answers:
        answer_text = answer_dict["answer"]
        answer_counts[answer_text] = answer_counts.get(answer_text, 0) + 1

    labels = []
    scores = []
    for answer_text, count in answer_counts.items():
        if answer_text not in config.label2id:
            continue
        labels.append(config.label2id[answer_text])
        score = min(1.0, count / 3)
        scores.append(score)

    example["labels"] = labels
    example["scores"] = scores
    return example

config_name = "dandelin/vilt-b32-finetuned-vqa"
config = ViltConfig.from_pretrained(config_name)

dataset = load_dataset("s076923/vqa-v2-test")
processed_dataset = dataset["test"].map(
    lambda example: preprocess_data(example, config),
    batched=False
)
print(processed_dataset[0])

In [None]:
import torch
from torch.utils.data import Dataset
from transformers import ViltProcessor

class VQADataset(Dataset):
    def __init__(self, dataset, config, processor):
        self.dataset = dataset
        self.config = config
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        data = self.dataset[idx]

        encoding = self.processor(
            images=data["image"],
            text=data["question"],
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        encoding = {k: v[0] for k, v in encoding.items()}

        targets = torch.zeros(len(self.config.id2label))
        targets[data["labels"]] = torch.tensor(data["scores"])
        encoding["labels"] = targets

        return encoding

model_name = "dandelin/vilt-b32-mlm"
processor = ViltProcessor.from_pretrained(model_name)
vqa_dataset = VQADataset(dataset=processed_dataset, config=config, processor=processor)

print(vqa_dataset[0].keys())
print(processor.decode(vqa_dataset[0]["input_ids"]))
labels = torch.nonzero(vqa_dataset[0]["labels"]).squeeze().tolist()
print([config.id2label[label] for label in labels])

In [None]:
from transformers import ViltForQuestionAnswering

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViltForQuestionAnswering.from_pretrained(
    model_name, id2label=config.id2label, label2id=config.label2id
).to(device)
print(model.classifier)

In [None]:
from torch.utils.data import DataLoader

def collate_fn(batch):
    input_ids = [item["input_ids"] for item in batch]
    pixel_values = [item["pixel_values"] for item in batch]
    attention_mask = [item["attention_mask"] for item in batch]
    token_type_ids = [item["token_type_ids"] for item in batch]
    labels = [item["labels"] for item in batch]
    
    encoding = processor.image_processor.pad(pixel_values, return_tensors="pt")

    batch = {
        "input_ids": torch.stack(input_ids),
        "attention_mask": torch.stack(attention_mask),
        "token_type_ids": torch.stack(token_type_ids),
        "pixel_values": encoding["pixel_values"],
        "pixel_mask": encoding["pixel_mask"],
        "labels": torch.stack(labels)
    }
    return batch

dataloader = DataLoader(
    vqa_dataset,
    collate_fn=collate_fn,
    batch_size=4,
    shuffle=False
)

batch = next(iter(dataloader))
for key, value in batch.items():
    print(f"{key}: {value.shape}")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

def unnormalize_image(pixel_values, image_mean, image_std):
    scaled = pixel_values * np.array(image_std)[:, None, None]
    shifted = scaled + np.array(image_mean)[:, None, None]
    uint8_image = (shifted * 255).astype(np.uint8)
    return uint8_image.transpose(1, 2, 0)

batch_idx = 1

image = unnormalize_image(
    pixel_values=batch["pixel_values"][batch_idx].numpy(),
    image_mean=processor.image_processor.image_mean,
    image_std=processor.image_processor.image_std
)
print("Question:", processor.decode(batch["input_ids"][batch_idx]))

labels = torch.nonzero(batch["labels"][batch_idx]).flatten().tolist()
label_names = [config.id2label[label] for label in labels]
print("Possible answers:", label_names)

plt.imshow(Image.fromarray(image))
plt.show()

In [None]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="VQA",
    num_train_epochs=20,
    per_device_train_batch_size=8,
    learning_rate=1e-4,
    weight_decay=0.01,
    logging_strategy="steps",
    logging_steps=20,
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=vqa_dataset,
    data_collator=collate_fn
)

trainer.train()

In [None]:
sample_index = 4
sample = vqa_dataset[sample_index]

print("Sample keys:", sample.keys())
print("Question:", processor.decode(sample["input_ids"]))

sample = {k: v.unsqueeze(0).to(device) for k, v in sample.items()}

model.eval()
with torch.no_grad():
    outputs = model(**sample)

logits = outputs.logits
predicted_probs = torch.sigmoid(logits)
top_probs, top_classes = torch.topk(predicted_probs, 5)

top_probs = top_probs.squeeze().tolist()
top_classes = top_classes.squeeze().tolist()
for prob, class_idx in zip(top_probs, top_classes):
    answer = model.config.id2label[class_idx]
    print(f"Answer: {answer:<7} Probability: {prob:.4f}")

unnormalized_image = unnormalize_image(
    pixel_values=vqa_dataset[sample_index]["pixel_values"].numpy(),
    image_mean=processor.image_processor.image_mean,
    image_std=processor.image_processor.image_std,
)

plt.imshow(Image.fromarray(unnormalized_image))
plt.show()