## VLBart PyReft Integration
This is my preliminary try on integrating VLBart with PyReft.
### Instructions
1. Use Pyvene's peterwz-llava branch and PyReft's peterwz-llava branch.
2. Head to pyreft/examples/vlbart/DoRA/image_video_text_understanding, and install packages with the same version as the requirements.txt there. Note that DoRA requires a much less transformers version.
3. Download dataset according to the instructions in pyreft/examples/vlbart/DoRA/image_video_text_understanding/README.md, specifically, go to the google drive link and download processed CLIP features. Put it in pyreft/examples/vlbart/DoRA/datasets/ In this notebook we only process on VQA features.
4. In image_video_text_understanding/download_backbones.py, change the cache directory to your directory storing the models.
5. Try run image_video_text_understanding/VL-T5/scripts/image/dora.sh to see if your DoRA (VLBart model) is installed successfully.
6. Run this notebook.
### Known Issues
1. Directly plugging the DoRA VLBart model here resulted in a 0.20~ VQA performance.
2. The training is fast in first few steps, then become very slow. I suspect that is related to the data loading cache behavior. Batching the dataset loading process, instead of the lazy data loading we are using now with ReftDataloaderDataset, may be a better option.

In [1]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), 'DoRA/image_video_text_understanding/VL-T5/src')))

In [2]:
import vqa_clip_data

In [3]:
batch_size = 128
vqa_args = {'RefCOCO_BUTD': False,
 'RefCOCO_GT': False,
 'adam_beta1': 0.9,
 'adam_beta2': 0.999,
 'adam_eps': 1e-06,
 'add_adapter_cross_attn': True,
 'add_layer_norm_after_adapter': False,
 'add_layer_norm_before_adapter': False,
 'additional_visual_embedding_layers': 0,
 'answer_normalize': False,
 'backbone': 'facebook/bart-base',
 'batch_size': batch_size,
 'caption_cocoonly': True,
 'caption_only': False,
 'classifier': False,
 'clip_grad_norm': 5.0,
 'cls_task': 'tinyimagenet',
 'coco_only': False,
 'comment': '',
 'decoder_prompt_len': 0,
 'deepspeed': None,
 'distributed': False,
 'do_lower_case': False,
 'dora_simple': False,
 'downsample': True,
 'dropout': 0.00,
 'dry': False,
 'efficient_unique_hyper_net': False,
 'encoder_prompt_len': 0,
 'epochs': 20, # 100
 'expand_vis_embedding': False,
 'factorized_phm': True,
 'feat_dim': 2048,
 'feature_type': 'RN101', # RN101
 'fp16': False,
 'freeze_bn_statistics': False,
 'freeze_ln_statistics': False,
 'from_scratch': False,
 'full_determinism': False,
 'gen_max_length': 20,
 'gpu': 0,
 'gradient_accumulation_steps': 1,
 'ground_upsample': 1,
 'ground_weight': 1,
 'hypercomplex_division': 4,
 'image_size': '(224,224)',
 'individual_vis_layer_norm': True,
 'is_wandb': False, # True
 'itm_cocoonly': True,
 'lambda_z': 0.001,
 'load': None,
 'load_lxmert_qa': None,
 'local_rank': 0,
 'log_train_accuracy': False,
 'lora_alpha': 32,
 'lora_dim': 128,
 'lora_settings': True,
 'losses': 'lm,obj,attr,feat',
 'low_rank_rank': 1,
 'lr': 1e-2,
 'lr_scheduler_type': "linear",
 'max_n_boxes': 36,
 'max_n_train_examples': 50000, # 1000
 'max_n_eval_examples': 2000, # 1000
 'max_text_length': 20,
 'mid_dim': 768,
 'multiGPU': True,
 'multitask_sampling': 'roundrobin',
 'n_boxes': 36,
 'n_ground': 1,
 'n_image_tokens': 4,
 'no_prefix': False,
 'num_beams': 5,
 'num_workers': 4,
 'obj_mask_rate': 0.15,
 'oneddownsample': False,
 'optim': 'adamw',
 'optimizer': 'adamw',
 'oscar_tags': False,
 'output': 'snap/VLBart_multitask/tune+lr1e-2_1e-2',
 'phm_init_range': 0.01,
 'phm_rank': 1,
 'pos_dim': 4,
 'position': 'f11+l11',
 'post_prompt': '',
 'prefix': None,
 'project_name': 'RN101_LMsingle_dora_128_bs300_image224_lora_settings',
 'projected_task_embedding_dim': -1,
 'prompt': 'vqa: ',
 'rank': 1,
 'raw_label': False,
 'reduction_factor': 16,
 'remove_bn_vis_adapter': False,
 'run_name': 'tune+lr1e-2_plzplz2',
 'seed': 9595,
 'share_down_sampler': False,
 'share_up_sampler': False,
 'share_vis_lang_layer_norm': False,
 'share_weights': True,
 'shared_phm_rule': True,
 'shared_phm_rule_over_tasks': False,
 'shuffle_boxes': False,
 'single_vqa_prefix': False,
 'sparse_sample': False,
 'submit': False,
 'tasks': 'vqa',
 'test': None,
 'test_answerable': False,
 'test_only': False,
 'testing': False,
 'tokenizer': None,
 'track_z': False,
 'train': 'train',
 'train_topk': -1,
 'unfreeze_batch_norms': False,
 'unfreeze_bias': False,
 'unfreeze_decoder_layer_norms': False,
 'unfreeze_encoder_layer_norms': False,
 'unfreeze_language_model': False,
 'unfreeze_layer_norms': False,
 'unfreeze_lm_head': False,
 'unfreeze_vis_encoder': False,
 'unfreeze_vis_last_layer': False,
 'unique_hyper_net': False,
 'use_adam_for_visual': False,
 'use_adapter': False,
 'use_attn_prefix': False,
 'use_compacter': False,
 'use_data_augmentation': False,
 'use_dora': False,
 'use_hyperformer': False,
 'use_lm_head_adapter': False,
 'use_lora': False,
 'use_lradapter': False,
 'use_separate_optimizer_for_visual': False,
 'use_single_adapter': False,
 'use_single_lora': False,
 'use_single_prompt': False,
 'use_tasks_prompts': True,
 'use_vis_adapter': False,
 'use_vis_layer_norm': True,
 'use_vis_order_embedding': True,
 'use_vision': True,
 'valid': 'valid',
 'valid_batch_size': batch_size,
 'valid_topk': -1,
 'vis_adapter_type': 'middle-bottleneck',
 'vis_lr': 0.00001,
 'vis_pointer': False,
 'vis_pooling_output': False,
 'vis_reduction_factor': 2,
 'vis_use_transformer': False,
 'vis_weight_decay': 0.01,
 'warmup_ratio': 0.00,
 'wandb_proj': "Reft",
 'wandb_name': "peterwz",
 'wandb_dir': "wandb",
 'weight_decay': 0.005,
 'word_mask_rate': 0.15,
 'world_size': 1}

In [4]:
from types import SimpleNamespace
args = SimpleNamespace(**vqa_args)

In [5]:
train_loaders = []
vqa_train_loader = vqa_clip_data.get_loader(
    args,
    split='karpathy_train', mode='train', batch_size=args.batch_size,
    distributed=args.distributed, gpu=0,
    workers=args.num_workers,
    topk=args.train_topk,
)
vqa_val_loader = vqa_clip_data.get_loader(
    args,
    split='karpathy_val', mode='val', batch_size=args.batch_size,
    distributed=args.distributed, gpu=0,
    workers=args.num_workers,
    topk=args.train_topk,
)
train_loaders.append(vqa_train_loader)

Load 605102 data from split(s) karpathy_train.
# Answers: 3129
Data sources:  ['karpathy_train']
Loaded 605102 data from karpathy_train
# all sentences: 605102




Load 26729 data from split(s) karpathy_val.
# Answers: 3129
Data sources:  ['karpathy_val']
Loaded 26729 data from karpathy_val
# all sentences: 26729


In [6]:
from multitask import Trainer
trainer = Trainer(args, vqa_train_loader, None, None, train=True)
# trainer.train()

Building Model at GPU 0
Model Launching at GPU 0
model.encoder.visual_embedding.feat_embedding.0.weight is trainable...
model.encoder.visual_embedding.feat_embedding.0.bias is trainable...
model.encoder.visual_embedding.feat_embedding.1.weight is trainable...
model.encoder.visual_embedding.feat_embedding.1.bias is trainable...
model.encoder.visual_embedding.absolute_vis_pos_embedding.0.weight is trainable...
model.encoder.visual_embedding.absolute_vis_pos_embedding.0.bias is trainable...
model.encoder.visual_embedding.absolute_vis_pos_embedding.1.weight is trainable...
model.encoder.visual_embedding.absolute_vis_pos_embedding.1.bias is trainable...
model.encoder.visual_embedding.img_order_embedding.weight is trainable...
VLBartMultiTask(
  (model): VLBartModel(
    (shared): Embedding(50465, 768)
    (encoder): JointEncoder(
      (embed_tokens): Embedding(50465, 768)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768, padding_idx=1)
      (layers): ModuleList(
        (



In [7]:
from pyreft.dataset import ReftDataset, ReftDataloaderDataset
from pyreft import (
    ReftTrainerForCausalLM, 
    ReftDataCollator,
    LoreftIntervention,
    TaskType,
    ReftConfig,
    get_reft_model,
)
import torch

In [8]:
from transformers import BartTokenizer, TrainingArguments
tokenizer = trainer.tokenizer

In [9]:
# class VLBartDataset(ReftDataloaderDataset):
class VLBartDataset(ReftDataset):
    """
    A ReftClassificationDataset only contains a single text field
    that we tokenize, intervene on a prefix + suffix of, and
    compute subspace settings for. This is intended for classification
    tasks.

    Remember to pass in the input_field and label_field as kwargs.
    """
    def __init__(
        self, task, 
        tokenizer,
        data_split="train", dataloader=None, 
        max_n_example=None,
        **kwargs,
    ):
        self.dataloader = dataloader
        super(VLBartDataset, self).__init__(task, "", tokenizer, data_split, None, 42, max_n_example,
        **kwargs)

    
    def load_dataset(self):
        """Load the dataset (or a portion of it) from HF or a local file."""

        self.task_dataset = self.dataloader.dataset
        self.collate_fn = self.task_dataset.collate_fn
        self.fields_to_pad = ["input_ids", "target_ids"]
        self.pad_mode = "none"

        # select n random examples if specificed
        if self.max_n_example is not None:
            self.task_dataset = torch.utils.data.Subset(self.task_dataset, list(range(self.max_n_example)))

        # save raw_dataset pointer for access raw strings
        self.raw_dataset = self.task_dataset if self.data_split != "train" else None
        return self.task_dataset

    def preprocess(self, kwargs):
        self.input_field = "input_ids"
        self.label_field = "target_ids"

    def tokenize(self, data_item):
        result = {**data_item}
        # result["input_length"] += 1
        # result["target_length"] += 1
        result["instruction"] = self.task + ": " + result["sent"]
        # print("Instruction", result["instruction"])
        # print("Sent", result["sent"])

        # TODO: whether to add "-1"?
        last_position = len(data_item[self.input_field]) - 1
        return result, last_position

In [10]:
layers = [0,1,2,3,4,5]
position = args.position
from pyreft.dataset import parse_positions
first_pos, last_pos = parse_positions(position)
if "+" in position and not args.share_weights:
    layers += layers

In [11]:
train_dataset = VLBartDataset(
    "vqa", 
    tokenizer, data_split="train", 
    dataloader=vqa_train_loader,
    max_n_example=args.max_n_train_examples,
    **{"num_interventions": len(layers), "position": position, 
       "share_weights": args.share_weights, "test_split": "validation",
      "last_offset": args.n_boxes}
)
eval_dataset = VLBartDataset(
    "vqa", 
    tokenizer, data_split="val", 
    dataloader=vqa_val_loader,
    max_n_example=args.max_n_eval_examples,
    **{"num_interventions": len(layers), "position": position, 
       "share_weights": args.share_weights, "test_split": "validation",
      "last_offset": args.n_boxes}
)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [01:53<00:00, 440.46it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 505.23it/s]


In [12]:
model = trainer.model

In [13]:
print(model.config)

BartConfig {
  "RefCOCO_BUTD": false,
  "RefCOCO_GT": false,
  "_name_or_path": "facebook/bart-base",
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "adam_beta1": 0.9,
  "adam_beta2": 0.999,
  "adam_eps": 1e-06,
  "adapter_config": null,
  "add_adapter_cross_attn": true,
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "add_layer_norm_after_adapter": false,
  "add_layer_norm_before_adapter": false,
  "additional_visual_embedding_layers": 0,
  "answer_normalize": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.0,
  "backbone": "facebook/bart-base",
  "batch_size": 128,
  "bos_token_id": 0,
  "caption_cocoonly": true,
  "caption_only": false,
  "classif_dropout": 0.1,
  "classifier": false,
  "classifier_dropout": 0.0,
  "clip_grad_norm": 5.0,
  "cls_task": "tinyimagenet",
  "coco_only": false,
  "comment": "",
  "d_model": 768,
  "decoder_attention_heads": 12,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layer

In [14]:
train_dataset.collate_fn

<bound method VQAFineTuneDataset.collate_fn of <vqa_clip_data.VQAFineTuneDataset object at 0x7f8899c1f220>>

In [15]:
print(vqa_train_loader.dataset[0].keys())
print(vqa_train_loader.dataset[0]["all_answers"])
print(tokenizer.decode(vqa_train_loader.dataset[0]["target_ids"], skip_special_tokens=True))

dict_keys(['args', 'img_id', 'vis_feats', 'boxes', 'question_id', 'sent', 'input_ids', 'input_length', 'is_topk_optimal', 'label', 'answer', 'score', 'all_answers', 'target_ids', 'target_length'])
['net']
net


In [16]:
# from transformers import DataCollatorForSeq2Seq
# data_collator_fn = DataCollatorForSeq2Seq(
#     tokenizer=tokenizer,
#     model=model,
#     label_pad_token_id=-100,
#     padding="longest"
# )
import transformers
def keep_intervention_locations(datum):
    new_data = {}
    new_data["input_ids"] = datum["input_ids"]
    # new_data["instruction"] = datum["instruction"]
    new_data["intervention_locations"] = datum["intervention_locations"]
    new_data["attention_mask"] = datum["attention_mask"]
    # print(new_data["input_ids"].shape, new_data["attention_mask"])
    return new_data

def custom_collate_fn(data):
    collate_fn_1 = train_dataset.collate_fn
    collate_fn_2 = transformers.DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        label_pad_token_id=-100,
        padding="longest"
    )
    # for item in data:
    #     print(item["input_ids"].shape)
    output_1 = collate_fn_1(data)
    custom_data = [keep_intervention_locations(item) for item in data]
    output_2 = collate_fn_2(custom_data)
    output = output_1
    output["intervention_locations"] = output_2["intervention_locations"]
    # print(output["intervention_locations"].shape)
    # print(torch.max(output["intervention_locations"]))
    # Offset image tokens' concatenation
    # print(output["intervention_locations"])
    # output["intervention_locations"][:,:,-last_pos:] += args.n_boxes
    # output["intervention_locations"] -= 1
    # print(torch.max(output["intervention_locations"]))
    # print(output["intervention_locations"])

    # output["id"] = output_2["id"]
    # output["labels"] = output_2["labels"]
    
    # output["attention_mask"] = output_2["attention_mask"]
    # del output["attention_mask"]

    ids = []
    instructions = []
    for d in data:
        ids.append(d["id"])
        instructions.append(d["instruction"])
    import numpy as np
    output["id"] = np.array(ids)
    output["instruction"] = instructions
    
    output["logits"] = output["labels"]
    output["labels"] = output["target_ids"]
    # output["instruction"] = tokenizer.batch_decode(output["input_ids"], skip_special_tokens=True)
    # print("Output Keys:", output.keys())
    
    # print("Input IDs:", output["input_ids"], tokenizer.batch_decode(output["input_ids"], skip_special_tokens=True))
    # print("Labels:", output["labels"].shape)
    # labels = [[token for token in sequence if token != -100] for sequence in output["labels"].tolist()]
    # print("Labels:", tokenizer.batch_decode(labels, skip_special_tokens=True))
    # print("Question IDs:", output["question_ids"])
    # print("Answers:", output["answers"])
    # print("All answers:", output["all_answers"])
    # print("Scores:", output["scores"])

    return output

data_collator = ReftDataCollator(data_collator=custom_collate_fn)

In [17]:
rank = args.rank


In [18]:
representations = [{
    "layer": l, "component": "block_output",
    "low_rank_dimension": rank,
    "intervention": LoreftIntervention(
        embed_dim=model.config.d_model, low_rank_dimension=rank,
        dropout=args.dropout, dtype=torch.float32, act_fn=None, device="cuda",
        add_bias=True
    )
} for l in layers]
task_type=TaskType.CAUSAL_LM

reft_config = ReftConfig(representations=representations)
empty_reft_config = ReftConfig(representations=[])

In [19]:
reft_model = get_reft_model(model, reft_config)
empty_reft_model = get_reft_model(model, empty_reft_config)
empty_reft_model.print_trainable_parameters()
reft_model.print_trainable_parameters()

trainable intervention params: 0 || trainable model params: 0
model params: 141,156,864 || trainable%: 0.0
trainable intervention params: 9,222 || trainable model params: 0
model params: 141,156,864 || trainable%: 0.006533157324889281


In [20]:
training_args = TrainingArguments(
    output_dir="random",
    run_name="random",
    num_train_epochs=args.epochs,
    per_device_train_batch_size=args.batch_size,
    per_device_eval_batch_size=args.batch_size,
    gradient_accumulation_steps=1,
    evaluation_strategy="no",
    # evaluation_strategy="epoch",
    save_strategy="no",
    metric_for_best_model=None,
    load_best_model_at_end=False,
    logging_strategy="epoch",
    save_total_limit=1, # for GLUE, it will save 2 at max.
    logging_steps=1,
    learning_rate=args.lr,
    # learning_rate=1e-4,
    warmup_ratio=args.warmup_ratio,
    optim="adamw_torch",
    weight_decay=args.weight_decay,
    # lr_scehuler="none",
    lr_scheduler_type=args.lr_scheduler_type,
    report_to="wandb" if args.is_wandb else "none",
    use_cpu=False,
    seed=42,
    # until HF supports ReFT, this remains False! :)
    remove_unused_columns=False
)

In [21]:
from pyvene import IntervenableModel
# from overrides import overrides

class MyTrainer(ReftTrainerForCausalLM):
    # @overrides
    def training_step(self, model, batch):
        # print("My trainer step")
        batch = self._prepare_inputs(batch)

        # print("Batch:", batch.keys())
        device = batch['input_ids'].device

        batch = model.model.vis_forward(batch, device)
        task = batch["task"]

        vis_feats = batch['vis_feats']
        input_ids = batch['input_ids']
        vis_pos = batch['boxes']

        lm_labels = batch["target_ids"].to(device)

        inputs = {**batch}
        inputs["return_dict"] = True
        inputs["reduce_loss"] = False
        inputs["vis_inputs"] = (vis_feats, vis_pos)
        # print(inputs.keys())

        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)

        del inputs
        torch.cuda.empty_cache()

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training

        if self.use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            self.accelerator.backward(loss)

        return loss.detach() / self.args.gradient_accumulation_steps

    def compute_loss(
        self,
        intervenable: IntervenableModel,
        inputs,
        return_outputs=False
    ):
        
        lm_labels = inputs["target_ids"]
        # print("KEYS:", inputs.keys())
        # print("LABELS:", lm_labels)
        # print("SCORES:", inputs["scores"])
        # print("LOCS:", inputs["intervention_locations"])
        # print("INPUT_IDS:", inputs["input_ids"])
        # print("VIS_INPUTS:", inputs["vis_inputs"][0].shape, inputs["vis_inputs"][1].shape)
        
        _, cf_outputs = intervenable(
            {
                "input_ids": inputs["input_ids"],
                # "attention_mask": inputs["attention_mask"],
                "vis_inputs": inputs["vis_inputs"],
                "task": "vqa",
                
            },
            unit_locations={"sources->base": (
                None,
                inputs["intervention_locations"].permute(1, 0, 2).tolist()
            )},
            labels=inputs["target_ids"],
            subspaces=None,
        )
        # return
        loss = cf_outputs.loss
        # print("CF OUTPUTS:", cf_outputs.keys(), len(cf_outputs["loss"]))
        
        
        lm_mask = (lm_labels != -100).float()
        # print("LM MASK:", lm_mask)
        # print("SCORES:", inputs["scores"])
        B, L = lm_labels.size()

        loss = loss.view(B, L) * lm_mask

        loss = loss.sum(dim=1) / lm_mask.sum(dim=1).clamp(min=1)  # B

        loss = loss * inputs["scores"]

        loss = loss.mean()
        return loss
        


In [22]:
trainer = MyTrainer(
    model=reft_model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=None,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None)
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [23]:
print(reft_model.model)

VLBartMultiTask(
  (model): VLBartModel(
    (shared): Embedding(50465, 768)
    (encoder): JointEncoder(
      (embed_tokens): Embedding(50465, 768)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768, padding_idx=1)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine

In [24]:
if args.is_wandb:
    import wandb
    run = wandb.init(
        project=f"{args.wandb_proj}_vqa", 
        entity=args.wandb_name,
        name=args.run_name,
        dir=args.wandb_dir,
    )
    run.summary.update(vars(args))
    n_params = reft_model.count_parameters(include_model=False)
    wandb.log(
        {"train/n_params": n_params})

In [25]:
trainer.train()

{'loss': 0.9145, 'learning_rate': 0.0095, 'epoch': 1.0}
{'loss': 0.8003, 'learning_rate': 0.009000000000000001, 'epoch': 2.0}
{'loss': 0.791, 'learning_rate': 0.0085, 'epoch': 3.0}
{'loss': 0.7841, 'learning_rate': 0.008, 'epoch': 4.0}
{'loss': 0.7779, 'learning_rate': 0.0075, 'epoch': 5.0}
{'loss': 0.7756, 'learning_rate': 0.006999999999999999, 'epoch': 6.0}
{'loss': 0.7706, 'learning_rate': 0.006500000000000001, 'epoch': 7.0}
{'loss': 0.7655, 'learning_rate': 0.006, 'epoch': 8.0}
{'loss': 0.7598, 'learning_rate': 0.0055000000000000005, 'epoch': 9.0}
{'loss': 0.7573, 'learning_rate': 0.005, 'epoch': 10.0}
{'loss': 0.7507, 'learning_rate': 0.0045000000000000005, 'epoch': 11.0}
{'loss': 0.7453, 'learning_rate': 0.004, 'epoch': 12.0}
{'loss': 0.7405, 'learning_rate': 0.0034999999999999996, 'epoch': 13.0}
{'loss': 0.7329, 'learning_rate': 0.003, 'epoch': 14.0}
{'loss': 0.7273, 'learning_rate': 0.0025, 'epoch': 15.0}
{'loss': 0.7196, 'learning_rate': 0.002, 'epoch': 16.0}
{'loss': 0.713, '

TrainOutput(global_step=7820, training_loss=0.7558738923133792, metrics={'train_runtime': 2498.3883, 'train_samples_per_second': 400.258, 'train_steps_per_second': 3.13, 'train_loss': 0.7558738923133792, 'epoch': 20.0})

In [26]:
# tokenizer("tree")

In [27]:
# import pyreft
# reft_model = pyreft.ReftModel.load(
#     "temp-outputs", model
# )
# reft_model.set_device("cuda")

In [28]:
reft_model.model.eval()
for k,v in reft_model.interventions.items():
    _ = v[0].eval()


In [29]:
from compute_metrics import compute_metrics
generations, stats = compute_metrics(
    "vqa", "vqa", reft_model, tokenizer, eval_dataset, eval_dataset,
    '', 'test', 1, # batch_size
    data_collator,
    split=False, greedy_decoding=True, temperature=1.0, top_p=None, top_k=None
)


  5%|█████▍                                                                                                  | 104/2000 [00:04<01:35, 19.90it/s, em=0.279]

vqa: What is this piece of furniture used for?  |  wood  |  sitting


 10%|██████████▌                                                                                              | 202/2000 [00:10<01:41, 17.78it/s, em=0.27]

vqa: How many orange cones are there?  |  1  |  0


 15%|███████████████▊                                                                                        | 303/2000 [00:15<01:34, 17.91it/s, em=0.273]

vqa: Has the skier fallen?  |  no  |  yes


 20%|████████████████████▉                                                                                   | 403/2000 [00:20<01:06, 24.07it/s, em=0.267]

vqa: What color is the bed sheets?  |  white  |  blue


 25%|██████████████████████████▎                                                                             | 505/2000 [00:24<01:04, 23.09it/s, em=0.281]

vqa: What is she wearing on her head?  |  scarf  |  hat


 50%|███████████████████████████████████████████████████▋                                                   | 1003/2000 [00:46<00:39, 25.01it/s, em=0.294]

vqa: The crosswalk sign is indicating what?  |  stop  |  walk


 55%|████████████████████████████████████████████████████████▉                                              | 1105/2000 [00:51<00:40, 22.32it/s, em=0.287]

vqa: How many windows are in the right side of the plane?  |  2  |  40


 60%|█████████████████████████████████████████████████████████████▊                                         | 1201/2000 [00:56<00:35, 22.41it/s, em=0.281]

vqa: What is behind the bike?  |  sand  |  bag


 70%|████████████████████████████████████████████████████████████████████████▎                              | 1403/2000 [01:05<00:24, 24.32it/s, em=0.288]

vqa: What is in the pregnant woman's belly?  |  nothing  |  baby


 75%|█████████████████████████████████████████████████████████████████████████████▌                         | 1505/2000 [01:09<00:20, 23.87it/s, em=0.282]

vqa: How many kids are holding game controllers?  |  1  |  4


 80%|██████████████████████████████████████████████████████████████████████████████████▌                    | 1603/2000 [01:14<00:20, 19.18it/s, em=0.281]

vqa: What condition is the water in?  |  snow  |  wavy


 85%|███████████████████████████████████████████████████████████████████████████████████████▊               | 1704/2000 [01:20<00:14, 20.54it/s, em=0.282]

vqa: Can you see the hook up for the train?  |  no  |  yes


 90%|████████████████████████████████████████████████████████████████████████████████████████████▊          | 1803/2000 [01:25<00:13, 14.66it/s, em=0.278]

vqa: What are the bears sitting on?  |  tree  |  cart


 95%|██████████████████████████████████████████████████████████████████████████████████████████████████     | 1903/2000 [01:32<00:05, 17.31it/s, em=0.277]

vqa: How many colors are represented in this scene?  |  2  |  10


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [01:38<00:00, 20.34it/s, em=0.279]


In [30]:
# eval_dataset[3]["answer"]

In [31]:
# generations[3]

In [32]:
if args.is_wandb:
    wandb.log(stats)
else:
    print(stats)

{'eval/vqa': 0.2785}


In [33]:
# reft_model.save('temp-outputs')

### Next Steps:

1. Speed up data loading [open ended perf problem]
2. Checkup the intervention locations for VL-BART
3. Fine-tuned model's performance on eval/test VQA
4. Fine-tuned model manual validation