In [11]:
# Change the following paths to the appropriate paths for your system

args = {
    'model_name_or_path': '/home/mila/s/sarvjeet-singh.ghotra/scratch/git/rl_ft_vlm_tools/pre_trained_models/idefics2-8b',
    'use_quant': True,
    'train_file': '/network/projects/aishwarya_lab/datasets/clevr-math/',
    'eval_batch_size': 1,
    'num_workers': 8,
    'val_idx_file': 'val_idx_800.csv',
    'test_idx_file': 'test_idx_800.csv',
    'eval_output_file': 'outputs/evals/idefics2_eval_800.json',
    'test_output_file': 'outputs/evals/idefics2_test_800.csv',
}

In [12]:
TEMPLATE_TO_ID = {
    'subtraction': 0,
    'addition': 1,
    'adversarial': 2,
    'subtraction-multihop': 3
}

ID_TO_TEMPLATE = {
    0: 'subtraction',
    1: 'addition',
    2: 'adversarial',
    3: 'subtraction-multihop'
}

## Model

In [13]:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(
        args['model_name_or_path'],
        do_image_splitting=False
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [14]:
import torch
from transformers import BitsAndBytesConfig
from transformers import Idefics2ForConditionalGeneration


bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16
            )

model = Idefics2ForConditionalGeneration.from_pretrained(
    args['model_name_or_path'],
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    quantization_config=bnb_config if args['use_quant'] else None,
)

Loading checkpoint shards: 100%|██████████| 7/7 [00:06<00:00,  1.04it/s]


## Eval Code

In [15]:
# Eval code
from tqdm import tqdm
import json


def evaluate_generation(args, model, dataset, dataloader, processor, tokenizer, output_file):
    model.eval()
    predictions = []
    targets = []
    templates = []
    idxs = []
    for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Evaluation Gen Loop'):
        batch = {k: v.to('cuda') for k, v in batch.items()}
        batch_templates = batch['templates']
        batch_idxs = batch['idxs']
        # remove 'templates' from batch
        batch.pop('templates')
        batch.pop('idxs')
        batch.pop('answers')
        output_ = model.generate(
                        **batch,
                        # max_length=args['max_input_length'],
                        # output_scores=True,
                        return_dict_in_generate=True,
                        num_beams=1,
                        use_cache=True,
                        do_sample=False,
                        pad_token_id=tokenizer.pad_token_id,
                        eos_token_id=tokenizer.eos_token_id,
                        max_new_tokens=128,
                    )

        generated_ids = output_.sequences
        # generated_ids = pad_across_processes(generated_ids, dim=1, pad_index=tokenizer.pad_token_id, pad_first=True)
        labels = batch['labels']
        
        generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
        
        # print("=======================")
        # print(generated_texts)
        # print("=======================")
        # print()

        pred_ans = [x.split("Assistant: ")[1][:-1].strip() for x in generated_texts]    # -1 removes "."
        # pred_ans = pred_ans.strip()

        # preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).strip() for g in generated_ids]
        predictions.extend(pred_ans)
        # target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True).strip() for t in labels]
        targets.extend(labels)
        templates.extend(batch_templates)
        idxs.extend(batch_idxs)
        
    # predictions = predictions[:len(dataset)]
    # targets = targets[:len(dataset)]
    # templates = templates[:len(dataset)]

    results = []
    corr = 0
    total = 0
    temp_based_acc = {}
    
    for pred, tar, temp, idx in zip(predictions, targets, templates, idxs):
        tar = str(tar.detach().cpu().numpy())
        cur_res = {
            'pred': pred,
            'target': tar,
            'template': ID_TO_TEMPLATE[int(temp)],
            'idx': int(idx.detach().cpu().numpy())
        }
        results.append(cur_res)

        temp = str(temp.detach().cpu().numpy())
        if temp not in temp_based_acc:
            temp_based_acc[temp] = {}
            temp_based_acc[temp]['corr'] = 0
            temp_based_acc[temp]['total'] = 0
        
        if pred == tar:
            corr += 1
            temp_based_acc[temp]['corr'] += 1
        total += 1
        temp_based_acc[temp]['total'] += 1

    # save first before execute to trace error.
    res_path = output_file
    print(f"Saving results to {res_path}")
    with open(res_path, 'w') as f:
        json.dump(results, f, indent=2)

    # if args['wandb_log']:
    #     table = wandb.Table(dataframe=pd.DataFrame(results))
    #     wandb.log({"predictions": table})

    value_accuracy = corr / len(results) * 100
    print(f"[Eval Info] value_accuracy: {value_accuracy:.5g}%")
    # value_accuracy = torch.FloatTensor([value_accuracy]).to(accelerator.device)

    print(temp_based_acc)

    templates_value_accuracy = {}
    for temp, acc in temp_based_acc.items():
        acc = acc['corr'] / acc['total'] * 100
        temp_name = ID_TO_TEMPLATE[int(temp)]
        print(f"[Eval Info] value_accuracy on {temp_name} category: {acc:.5g}%")
        # templates_value_accuracy[temp_name] = torch.FloatTensor([acc]).to(accelerator.device)
    
    # Metric summary:
    out_stats = {'value_accuracy': value_accuracy}
    for temp_category, acc in templates_value_accuracy.items():
        out_stats[temp_category] = acc

    # model.train()
    return out_stats



## Collator

In [16]:
from torch.utils.data import Dataset
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch
from PIL import Image


class MyDataCollatorFewShots:
    def __init__(self, processor):
        self.processor = processor
        self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
            processor.tokenizer.additional_special_tokens.index("<image>")
        ]
        self.few_shots_txt = []
        self.few_shots_img = [Image.open("outputs/add_2.png").convert("RGB"),
                              Image.open("outputs/sub_1.png").convert("RGB"),
                              Image.open("outputs/adv_1.png").convert("RGB"),
                              Image.open("outputs/sub_multihop_1.png").convert("RGB")]
        # {"type": "image"},
        # {"type": "text", "text": "Final answer = Total small matte blocks + 1 (small matte block).\nHow many small matte blocks are there? 2.\nAdd 1 small matte block to total small matte blocks: 2 + 1 = 3.\nFinal answer: 3. \nExample 2: "}, # Add
        # {"type": "image"},
        # {"type": "text", "text": "Final answer = Total cylinders - Total cyan cylinders.\nFind total cylinders: 2.\nFind cyan cylinders: 0.\nSubtract cyan cylinders from total cylinders: 2 - 0 = 2.\nFinal answer: 2.\n Example 4: "}, # Adv.
        # {"type": "image"},
        # {"type": "text", "text": "Final answer = Total objects - Total matte blocks - Total blue matte objects.\nFind total objects: 7.\nFind total matte blocks: 1.\nFind total blue matte objects: 1.\nSubtract matte blocks and blue matte objects from total objects: 7 - 1 - 1 = 5.\nFinal answer: 5. \nExample 5: "}, # Sub. Multihop

    def __call__(self, examples):
        texts = []
        images = []
        answers = []
        temps = []
        for example in examples:
            image = example["image"]
            question = example["query"]
            answer = example["answer"]
            # messages = [
            #     f"<image>From the image, {question} Answer: ",
            # ]
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "Answer briefly."},
                        {"type": "image"},
                        {"type": "text", "text": {question}},
                    ]
                },
            ]
            #     {
            #         "role": "assistant",
            #         "content": [
            #         {"type": "text", "text": "Solution steps:\nAnswer = Total blocks - Total big cyan blocks.\nHow many blocks are there? 4.\nHow many big blocks are there? 2.\nHow many big blocks are there that are of cyan color? 1.\nSubtract 1 big cyan block from all the blocks: 4 - 1 = 3.\nFinal answer: 3."},
            #         ]
            #     },
            #     {
            #         "role": "user",
            #         "content": [
            #             {"type": "image"},
            #             {"type": "text", "text": question + "\nSolution steps:\nAnswer ="}
            #         ]
            #     },
            # ]
            
            # if example["split"] == "train":
            #     messages.append(
            #         {
            #             "role": "assistant",
            #             "content": [
            #                 {"type": "text", "text": answer}
            #             ]
            #         }
            #     )

            text = self.processor.apply_chat_template(messages, add_generation_prompt=False if example["split"] == "train" else True)
            texts.append(text.strip())  # Orig: texts.append(messages)
            images.append([image])
            answers.append(answer)

            if examples[0]['split'] != 'train':
                temps.append(TEMPLATE_TO_ID[example["template"]])

        batch = self.processor(text=texts, images=images, return_tensors="pt", padding=True)
        ans_tokens = torch.IntTensor(answers)

        labels = batch["input_ids"].clone()
        labels[labels == self.processor.tokenizer.pad_token_id] = self.image_token_id
        batch["labels"] = labels
        batch["answers"] = ans_tokens
        batch['idxs'] = torch.IntTensor([x['idx'] for x in examples])

        if examples[0]['split'] != 'train':
            batch["labels"] = ans_tokens
            batch["templates"] = torch.IntTensor(temps)

        return batch

## Dataset

In [17]:
from torch.utils.data import Dataset
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch


# class MyDataCollator:
#     def __init__(self, processor):
#         self.processor = processor
#         self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
#             processor.tokenizer.additional_special_tokens.index("<image>")
#         ]

#     def __call__(self, examples):
#         texts = []
#         images = []
#         answers = []
#         temps = []
#         for example in examples:
#             image = example["image"]
#             question = example["query"]
#             answer = example["answer"]
#             messages = [
#                 {
#                     "role": "user",
#                     "content": [
#                         {"type": "text", "text": "Answer briefly."},
#                         {"type": "image"},
#                         {"type": "text", "text": question}
#                     ]
#                 }
#             ]
#             if example["split"] == "train":
#                 messages.append(
#                     {
#                         "role": "assistant",
#                         "content": [
#                             {"type": "text", "text": answer}
#                         ]
#                     }
#                 )
#             text = self.processor.apply_chat_template(messages, add_generation_prompt=False if example["split"] == "train" else True)
#             texts.append(text.strip())
#             images.append([image])
#             answers.append(answer)

#             if examples[0]['split'] != 'train':
#                 temps.append(TEMPLATE_TO_ID[example["template"]])

#         batch = self.processor(text=texts, images=images, return_tensors="pt", padding=True)
#         ans_tokens = torch.IntTensor(answers)

#         labels = batch["input_ids"].clone()
#         labels[labels == self.processor.tokenizer.pad_token_id] = self.image_token_id
#         batch["labels"] = labels
#         batch["answers"] = ans_tokens
#         batch['idxs'] = torch.IntTensor([x['idx'] for x in examples])

#         if examples[0]['split'] != 'train':
#             batch["labels"] = ans_tokens
#             batch["templates"] = torch.IntTensor(temps)

#         return batch


class ClverMathDataset(Dataset):
    def __init__(self, data_dir, split, idx_file=None):
        self.data_dir = data_dir
        self.dataset = load_dataset('dali-does/clevr-math',
                               cache_dir=data_dir,
                               split=split)
        self.split = split.lower()
        if idx_file is not None:
            with open(idx_file, 'r') as f:
                self.indices = f.read().splitlines()
        else:
            self.indices = list(range(len(self.dataset)))

        self.indices = [int(x) for x in self.indices]
        self.len = len(self.indices)

        # self.len = len(self.dataset)
        # if self.split != 'train':
        #     # self.dataset = self.dataset[:100]
        #     self.len = 2400

    def __getitem__(self, idx):
        idx = self.indices[idx]
        ques = self.dataset[idx]['question']
        label = self.dataset[idx]['label']
        img = self.dataset[idx]['image'].convert('RGB')
        template = self.dataset[idx]['template']
        dp = {
            'query': ques,
            'image': img,
            'answer': label,
            'split': self.split,
            'idx': idx,
        }
        if self.split != 'train':
            dp['template'] = template
        return dp

    def __len__(self):
        return self.len


def prepare_datasets_and_data_loaders_idefics2(args, processor):
    data_collator = MyDataCollatorFewShots(processor)

    val_dataset = ClverMathDataset(args['train_file'], 'validation', args['val_idx_file'])
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=args['eval_batch_size'],
        collate_fn=data_collator,
        num_workers=args['num_workers'],
        pin_memory=True,
        shuffle=False,
        drop_last=False
    )

    test_dataset = ClverMathDataset(args['train_file'], 'test', args['test_idx_file'])
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=args['eval_batch_size'],
        collate_fn=data_collator,
        num_workers=args['num_workers'],
        pin_memory=True,
        shuffle=False,
        drop_last=False
    )

    return val_dataloader, test_dataloader



## Rest

In [18]:
val_dl, test_dl = prepare_datasets_and_data_loaders_idefics2(args, processor)

In [19]:
evaluate_generation(args, model, None, val_dl, processor, processor.tokenizer, args['eval_output_file'])

Evaluation Gen Loop:   0%|          | 0/801 [00:00<?, ?it/s]
No chat template is defined for this tokenizer - using the default template for the LlamaTokenizerFast class. If the default is not appropriate for your model, please set `tokenizer.chat_template` to an appropriate template. See https://huggingface.co/docs/transformers/main/chat_templating for more information.


No chat template is defined for this tokenizer - using the default template for the LlamaTokenizerFast class. If the default is not appropriate for your model, please set `tokenizer.chat_template` to an appropriate template. See https://huggingface.co/docs/transformers/main/chat_templating for more information.


No chat template is defined for this tokenizer - using the default template for the LlamaTokenizerFast class. If the default is not appropriate for your model, please set `tokenizer.chat_template` to an appropriate template. See https://huggingface.co/docs/transformers/main/chat_templating for more informati

Saving results to outputs/evals/idefics2_eval_800.json
[Eval Info] value_accuracy: 94.632%
{'2': {'corr': 86, 'total': 89}, '1': {'corr': 270, 'total': 282}, '3': {'corr': 80, 'total': 92}, '0': {'corr': 322, 'total': 338}}
[Eval Info] value_accuracy on adversarial category: 96.629%
[Eval Info] value_accuracy on addition category: 95.745%
[Eval Info] value_accuracy on subtraction-multihop category: 86.957%
[Eval Info] value_accuracy on subtraction category: 95.266%





{'value_accuracy': 94.63171036204744}

## Fewshots iterations

In [None]:
val_dl, test_dl = prepare_datasets_and_data_loaders_idefics2(args, processor)
example = val_dl.dataset[1]
orig_query = example['query']
print(orig_query)
# example['query'] = "From the image, subtract all blue balls and all gray blocks. How many balls are left in the image?"
# example['query'] = "From the image, subtract all cylinders. How many objects are left?"
example['query'] = "How many cylinders are there in the image?"
print("Example: ", example)
few_shot_example = {
    # 'query': "From the image, subtract all balls. How many objects are left? Answer: 6.",
    'query': "How many spheres are there in the image? Answer: 3.",
    'image': Image.open("outputs/tmp.png").convert("RGB"),
}

In [None]:
texts = []
images = []
answers = []

image_token_id = processor.tokenizer.additional_special_tokens_ids[
            processor.tokenizer.additional_special_tokens.index("<image>")
        ]

image = example["image"]
question = example["query"]
answer = example["answer"]
messages = [
    f"<image>{few_shot_example['query']}<image>{question} Answer:",
]
# messages = [
#     {
#         "role": "user",
#         "content": [
#             {"type": "text", "text": "Think step by step."},
#             {"type": "image"},
#             {"type": "text", "text": "Subtract all big cyan blocks. How many blocks are left?"},
#         ]
#     },
#     {
#         "role": "assistant",
#         "content": [
#         {"type": "text", "text": "Solution steps:\nAnswer = Total blocks - Total big cyan blocks.\nHow many blocks are there? 4.\nHow many big blocks are there? 2.\nHow many big blocks are there that are of cyan color? 1.\nSubtract 1 big cyan block from all the blocks: 4 - 1 = 3.\nFinal answer: 3."},
#         ]
#     },
#     {
#         "role": "user",
#         "content": [
#             {"type": "image"},
#             {"type": "text", "text": question + "\nSolution steps:\nAnswer ="}
#         ]
#     },
# ]

# if example["split"] == "train":
#     messages.append(
#         {
#             "role": "assistant",
#             "content": [
#                 {"type": "text", "text": answer}
#             ]
#         }
#     )

# text = self.processor.apply_chat_template(messages, add_generation_prompt=False if example["split"] == "train" else True)
texts.extend(messages)  # Orig: texts.append(messages)
images.append([few_shot_example['image'], image])
answers.append(answer)

# if examples[0]['split'] != 'train':
#     temps.append(TEMPLATE_TO_ID[example["template"]])

batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
ans_tokens = torch.IntTensor(answers)

labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = image_token_id
# batch["labels"] = labels
# batch["answers"] = ans_tokens
# batch['idxs'] = torch.IntTensor([x['idx'] for x in examples])

# if examples[0]['split'] != 'train':
#     batch["labels"] = ans_tokens
#     batch["templates"] = torch.IntTensor(temps)

In [None]:
output_ = model.generate(
                        **batch,
                        # max_length=args['max_input_length'],
                        # output_scores=True,
                        return_dict_in_generate=True,
                        num_beams=1,
                        use_cache=True,
                        do_sample=False,
                        pad_token_id=processor.tokenizer.pad_token_id,
                        eos_token_id=processor.tokenizer.eos_token_id,
                        max_new_tokens=64,
                    )

generated_ids = output_.sequences
# generated_ids = pad_across_processes(generated_ids, dim=1, pad_index=tokenizer.pad_token_id, pad_first=True)
# labels = batch['labels']

generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)

print("=======================")
print(generated_texts)
print("=======================")
print()

In [None]:
image.show()

## Debugging

In [13]:
val_len = 119202
test_len = 7955
sample_size = 800

idx = list(range(val_len))
import random
random.shuffle(idx)

val_idx = idx[:sample_size]
test_idx = idx[:sample_size]

# save idices to a csv file
import pandas as pd
val_df = pd.DataFrame(val_idx)
test_df = pd.DataFrame(test_idx)
val_df.to_csv(f'val_idx_{sample_size}.csv', index=False)
test_df.to_csv(f'test_idx_{sample_size}.csv', index=False)


In [None]:
# 800
# [Eval Info] value_accuracy: 95.63%
# {'2': {'corr': 86, 'total': 89}, '1': {'corr': 275, 'total': 282}, '3': {'corr': 79, 'total': 92}, '0': {'corr': 326, 'total': 338}}
# [Eval Info] value_accuracy on adversarial category: 96.629%
# [Eval Info] value_accuracy on addition category: 97.518%
# [Eval Info] value_accuracy on subtraction-multihop category: 85.87%
# [Eval Info] value_accuracy on subtraction category: 96.45%