## VLBart PyReft Integration
This is my preliminary try on integrating VLBart with PyReft.
### Instructions
1. Use Pyvene's peterwz-llava branch and PyReft's peterwz-llava branch.
2. Head to pyreft/examples/vlbart/DoRA/image_video_text_understanding, and install packages with the same version as the requirements.txt there. Note that DoRA requires a much less transformers version.
3. Download dataset according to the instructions in pyreft/examples/vlbart/DoRA/image_video_text_understanding/README.md, specifically, go to the google drive link and download processed CLIP features. Put it in pyreft/examples/vlbart/DoRA/datasets/ In this notebook we only process on VQA features.
4. In image_video_text_understanding/download_backbones.py, change the cache directory to your directory storing the models.
5. Try run image_video_text_understanding/VL-T5/scripts/image/dora.sh to see if your DoRA (VLBart model) is installed successfully.
6. Run this notebook.
### Known Issues
1. Directly plugging the DoRA VLBart model here resulted in a 0.20~ VQA performance.
2. The training is fast in first few steps, then become very slow. I suspect that is related to the data loading cache behavior. Batching the dataset loading process, instead of the lazy data loading we are using now with ReftDataloaderDataset, may be a better option.

In [None]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), 'DoRA/image_video_text_understanding/VL-T5/src')))

In [None]:
import vqa_clip_data

In [None]:
vqa_args = {'RefCOCO_BUTD': False,
 'RefCOCO_GT': False,
 'adam_beta1': 0.9,
 'adam_beta2': 0.999,
 'adam_eps': 1e-06,
 'add_adapter_cross_attn': True,
 'add_layer_norm_after_adapter': False,
 'add_layer_norm_before_adapter': False,
 'additional_visual_embedding_layers': 0,
 'answer_normalize': False,
 'backbone': 'facebook/bart-base',
 'batch_size': 512,
 'caption_cocoonly': True,
 'caption_only': False,
 'classifier': False,
 'clip_grad_norm': 5.0,
 'cls_task': 'tinyimagenet',
 'coco_only': False,
 'comment': '',
 'decoder_prompt_len': 0,
 'deepspeed': None,
 'distributed': False,
 'do_lower_case': False,
 'dora_simple': False,
 'downsample': True,
 'dropout': 0.1,
 'dry': False,
 'efficient_unique_hyper_net': False,
 'encoder_prompt_len': 0,
 'epochs': 20,
 'expand_vis_embedding': False,
 'factorized_phm': True,
 'feat_dim': 2048,
 'feature_type': 'RN101',
 'fp16': False,
 'freeze_bn_statistics': False,
 'freeze_ln_statistics': False,
 'from_scratch': False,
 'full_determinism': False,
 'gen_max_length': 20,
 'gpu': 0,
 'gradient_accumulation_steps': 1,
 'ground_upsample': 1,
 'ground_weight': 1,
 'hypercomplex_division': 4,
 'image_size': '(224,224)',
 'individual_vis_layer_norm': True,
 'itm_cocoonly': True,
 'lambda_z': 0.001,
 'load': None,
 'load_lxmert_qa': None,
 'local_rank': 0,
 'log_train_accuracy': False,
 'lora_alpha': 32,
 'lora_dim': 128,
 'lora_settings': True,
 'losses': 'lm,obj,attr,feat',
 'low_rank_rank': 1,
 'lr': 0.01,
 'max_n_boxes': 36,
 'max_text_length': 20,
 'mid_dim': 768,
 'multiGPU': True,
 'multitask_sampling': 'roundrobin',
 'n_boxes': 36,
 'n_ground': 1,
 'n_image_tokens': 4,
 'no_prefix': False,
 'num_beams': 5,
 'num_workers': 4,
 'obj_mask_rate': 0.15,
 'oneddownsample': False,
 'optim': 'adamw',
 'optimizer': 'adamw',
 'oscar_tags': False,
 'output': 'snap/VLBart_multitask/tune+lr1e-2_plzplz2',
 'phm_init_range': 0.01,
 'phm_rank': 1,
 'pos_dim': 4,
 'post_prompt': '',
 'prefix': None,
 'project_name': 'RN101_LMsingle_dora_128_bs300_image224_lora_settings',
 'projected_task_embedding_dim': -1,
 'prompt': 'vqa: ',
 'raw_label': False,
 'reduction_factor': 16,
 'remove_bn_vis_adapter': False,
 'run_name': 'tune+lr1e-2_plzplz2',
 'seed': 9595,
 'share_down_sampler': False,
 'share_up_sampler': False,
 'share_vis_lang_layer_norm': False,
 'shared_phm_rule': True,
 'shared_phm_rule_over_tasks': False,
 'shuffle_boxes': False,
 'single_vqa_prefix': False,
 'sparse_sample': False,
 'submit': False,
 'tasks': 'vqa',
 'test': None,
 'test_answerable': False,
 'test_only': False,
 'testing': False,
 'tokenizer': None,
 'track_z': False,
 'train': 'train',
 'train_topk': -1,
 'unfreeze_batch_norms': False,
 'unfreeze_bias': False,
 'unfreeze_decoder_layer_norms': False,
 'unfreeze_encoder_layer_norms': False,
 'unfreeze_language_model': False,
 'unfreeze_layer_norms': False,
 'unfreeze_lm_head': False,
 'unfreeze_vis_encoder': False,
 'unfreeze_vis_last_layer': False,
 'unique_hyper_net': False,
 'use_adam_for_visual': False,
 'use_adapter': False,
 'use_attn_prefix': False,
 'use_compacter': False,
 'use_data_augmentation': False,
 'use_dora': False,
 'use_hyperformer': False,
 'use_lm_head_adapter': False,
 'use_lora': False,
 'use_lradapter': False,
 'use_separate_optimizer_for_visual': False,
 'use_single_adapter': False,
 'use_single_lora': False,
 'use_single_prompt': False,
 'use_tasks_prompts': True,
 'use_vis_adapter': False,
 'use_vis_layer_norm': True,
 'use_vis_order_embedding': True,
 'use_vision': True,
 'valid': 'valid',
 'valid_batch_size': 512,
 'valid_topk': -1,
 'vis_adapter_type': 'middle-bottleneck',
 'vis_lr': 0.0001,
 'vis_pointer': False,
 'vis_pooling_output': False,
 'vis_reduction_factor': 2,
 'vis_use_transformer': False,
 'vis_weight_decay': 0.01,
 'warmup_ratio': 0.1,
 'weight_decay': 0.01,
 'word_mask_rate': 0.15,
 'world_size': 1}

In [None]:
from types import SimpleNamespace
args = SimpleNamespace(**vqa_args)

In [None]:
train_loaders = []
vqa_train_loader = vqa_clip_data.get_loader(
    args,
    split='karpathy_train', mode='train', batch_size=args.batch_size,
    distributed=args.distributed, gpu=0,
    workers=args.num_workers,
    topk=args.train_topk,
)
train_loaders.append(vqa_train_loader)

In [None]:
from pyreft.dataset import ReftDataset, ReftDataloaderDataset
from pyreft import (
    ReftTrainerForCausalLM, 
    ReftDataCollator,
    LoreftIntervention,
    TaskType,
    ReftConfig,
    get_reft_model,
)
import torch

In [None]:
class VLBartDataset(ReftDataloaderDataset):
    """
    A ReftClassificationDataset only contains a single text field
    that we tokenize, intervene on a prefix + suffix of, and
    compute subspace settings for. This is intended for classification
    tasks.

    Remember to pass in the input_field and label_field as kwargs.
    """
    def load_dataset(self):
        """Load the dataset (or a portion of it) from HF or a local file."""

        self.task_dataset = self.dataloader.dataset
        self.collate_fn = self.task_dataset.collate_fn
        self.fields_to_pad = ["input_ids", "target_ids"]
        self.pad_mode = "first"

        # select n random examples if specificed
        if self.max_n_example is not None:
            self.task_dataset = torch.utils.data.Subset(self.task_dataset, list(range(self.max_n_example)))

        # save raw_dataset pointer for access raw strings
        self.raw_dataset = self.task_dataset if self.data_split != "train" else None
        return self.task_dataset

    def preprocess(self, kwargs):
        self.input_field = "input_ids"
        self.label_field = "target_ids"

    def tokenize(self, data_item):
        result = {**data_item}
        result["input_length"] += 1
        result["target_length"] += 1
        result["instruction"] = tokenizer.decode(result["input_ids"], skip_special_tokens=True)

        # TODO: whether to add "-1"?
        last_position = len(data_item[self.input_field]) - 1
        return result, last_position

In [None]:
from transformers import BartTokenizer, TrainingArguments
tokenizer = BartTokenizer.from_pretrained(
    args.backbone,
    max_length=args.max_text_length,
    do_lower_case=args.do_lower_case
)

In [None]:
layers = [0,1,2,3,4,5]
position = "f7+l7"

In [None]:
train_dataset = VLBartDataset(
    "vqa", 
    tokenizer, data_split="train", 
    dataloader=vqa_train_loader,
    max_n_example=1000,
    **{"num_interventions": len(layers), "position": position, 
       "share_weights": True, "test_split": "validation"}
)
eval_dataset = VLBartDataset(
    "vqa", 
    tokenizer, data_split="val", 
    dataloader=vqa_train_loader,
    max_n_example=100,
    **{"num_interventions": len(layers), "position": position, 
       "share_weights": True, "test_split": "validation"}
)

In [None]:
from multitask import Trainer
trainer = Trainer(args, vqa_train_loader, None, None, train=True)
# trainer.train()

In [None]:
model = trainer.model

In [None]:
print(model.config)

In [None]:
train_dataset.collate_fn

In [None]:
print(vqa_train_loader.dataset[0].keys())

In [None]:
# from transformers import DataCollatorForSeq2Seq
# data_collator_fn = DataCollatorForSeq2Seq(
#     tokenizer=tokenizer,
#     model=model,
#     label_pad_token_id=-100,
#     padding="longest"
# )
import transformers
def keep_intervention_locations(datum):
    new_data = {}
    new_data["input_ids"] = datum["input_ids"]
    # new_data["instruction"] = datum["instruction"]
    new_data["intervention_locations"] = datum["intervention_locations"]
    new_data["attention_mask"] = datum["attention_mask"]
    return new_data

def custom_collate_fn(data):
    collate_fn_1 = train_dataset.collate_fn
    collate_fn_2 = transformers.DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        label_pad_token_id=-100,
        padding="longest"
    )
    # for item in data:
    #     print(item["input_ids"].shape)
    output_1 = collate_fn_1(data)
    custom_data = [keep_intervention_locations(item) for item in data]
    output_2 = collate_fn_2(custom_data)
    output = output_1
    output["intervention_locations"] = output_2["intervention_locations"]
    # print(output["intervention_locations"].shape)
    print(torch.max(output["intervention_locations"]))
    # Offset image tokens' concatenation
    output["intervention_locations"][:,:,-1] += args.n_boxes
    print(torch.max(output["intervention_locations"]))
    # print(output["intervention_locations"])

    # output["id"] = output_2["id"]
    # output["labels"] = output_2["labels"]
    
    output["attention_mask"] = output_2["attention_mask"]

    ids = []
    instructions = []
    for d in data:
        ids.append(d["id"])
        instructions.append(d["instruction"])
    import numpy as np
    output["id"] = np.array(ids)
    output["instruction"] = instructions
    
    output["logits"] = output["labels"]
    output["labels"] = output["target_ids"]
    # output["instruction"] = tokenizer.batch_decode(output["input_ids"], skip_special_tokens=True)
    # print("Output Keys:", output.keys())
    
    # print("Input IDs:", output["input_ids"], tokenizer.batch_decode(output["input_ids"], skip_special_tokens=True))
    # print("Labels:", output["labels"].shape)
    # labels = [[token for token in sequence if token != -100] for sequence in output["labels"].tolist()]
    # print("Labels:", tokenizer.batch_decode(labels, skip_special_tokens=True))
    # print("Question IDs:", output["question_ids"])
    # print("Answers:", output["answers"])
    # print("All answers:", output["all_answers"])
    # print("Scores:", output["scores"])

    return output

data_collator = ReftDataCollator(data_collator=custom_collate_fn)

In [None]:
rank = 1
dropout=0.05


In [None]:
representations = [{
    "layer": l, "component": "block_output",
    "low_rank_dimension": rank,
    "intervention": LoreftIntervention(
        embed_dim=model.config.d_model, low_rank_dimension=rank,
        dropout=dropout, dtype=torch.float32, act_fn=None, device="cuda",
        add_bias=True
    )
} for l in layers]
task_type=TaskType.CAUSAL_LM

reft_config = ReftConfig(representations=representations)
empty_reft_config = ReftConfig(representations=[])

In [None]:
reft_model = get_reft_model(model, reft_config)
empty_reft_model = get_reft_model(model, empty_reft_config)
empty_reft_model.print_trainable_parameters()
reft_model.print_trainable_parameters()

In [None]:
training_args = TrainingArguments(
    output_dir="random",
    run_name="random",
    num_train_epochs=100,
    per_device_train_batch_size=args.batch_size,
    per_device_eval_batch_size=args.batch_size,
    gradient_accumulation_steps=1,
    evaluation_strategy="no",
    # evaluation_strategy="epoch",
    save_strategy="no",
    metric_for_best_model=None,
    load_best_model_at_end=False,
    logging_strategy="steps",
    save_total_limit=1, # for GLUE, it will save 2 at max.
    logging_steps=1,
    learning_rate=4e-2,
    warmup_ratio=0.1,
    optim="adamw_torch",
    weight_decay=0.01,
    report_to="none",
    use_cpu=False,
    seed=42,
    # until HF supports ReFT, this remains False! :)
    remove_unused_columns=False
)

In [None]:
from pyvene import IntervenableModel
# from overrides import overrides

class MyTrainer(ReftTrainerForCausalLM):
    # @overrides
    def training_step(self, model, batch):
        # print("My trainer step")
        batch = self._prepare_inputs(batch)

        # print("Batch:", batch.keys())
        device = batch['input_ids'].device

        batch = model.model.vis_forward(batch, device)
        task = batch["task"]

        vis_feats = batch['vis_feats']
        input_ids = batch['input_ids']
        vis_pos = batch['boxes']

        lm_labels = batch["target_ids"].to(device)

        inputs = {**batch}
        inputs["return_dict"] = True
        inputs["reduce_loss"] = False
        inputs["vis_inputs"] = (vis_feats, vis_pos)
        # print(inputs.keys())

        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)

        del inputs
        torch.cuda.empty_cache()

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training

        if self.use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            self.accelerator.backward(loss)

        return loss.detach() / self.args.gradient_accumulation_steps

    def compute_loss(
        self,
        intervenable: IntervenableModel,
        inputs,
        return_outputs=False
    ):
        
        lm_labels = inputs["target_ids"]
        # print("KEYS:", inputs.keys())
        # print("LABELS:", lm_labels)
        # print("SCORES:", inputs["scores"])
        _, cf_outputs = intervenable(
            {
                "input_ids": inputs["input_ids"],
                "attention_mask": inputs["attention_mask"],
                "vis_inputs": inputs["vis_inputs"],
                "task": "vqa",
                
            },
            unit_locations={"sources->base": (
                None,
                inputs["intervention_locations"].permute(1, 0, 2).tolist()
            )},
            labels=inputs["labels"],
            subspaces=inputs["subspaces"].permute(1, 0, 2).tolist() if "subspaces" in inputs else None
        )
        # return
        loss = (cf_outputs.loss, cf_outputs) if return_outputs else cf_outputs.loss
        if isinstance(loss, tuple):
            loss = loss[0]
        lm_mask = (lm_labels != -100).float()
        B, L = lm_labels.size()

        loss = loss.view(B, L) * lm_mask

        loss = loss.sum(dim=1) / lm_mask.sum(dim=1).clamp(min=1)  # B

        loss = loss * inputs["scores"]

        loss = loss.mean()
        return loss
        


In [None]:
trainer = MyTrainer(
    model=reft_model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=None,
)

In [None]:
print(reft_model.model)

In [None]:
trainer.train()

In [None]:
# tokenizer("tree")

In [None]:
# import pyreft
# reft_model = pyreft.ReftModel.load(
#     "temp-outputs", model
# )
# reft_model.set_device("cuda")

In [None]:
reft_model.model.eval()
for k,v in reft_model.interventions.items():
    _ = v[0].eval()


In [None]:
from compute_metrics import compute_metrics
generations, stats = compute_metrics(
    "vqa", "vqa", reft_model, tokenizer, train_dataset, train_dataset,
    '', 'test', 64, # batch_size
    data_collator,
    split=False, greedy_decoding=True, temperature=1.0, top_p=None, top_k=None
)


In [None]:
# eval_dataset[3]["answer"]

In [None]:
# generations[3]

In [None]:
stats

In [None]:
# reft_model.save('temp-outputs')

### Next Steps:

1. Speed up data loading [open ended perf problem]
2. Checkup the intervention locations for VL-BART
3. Fine-tuned model's performance on eval/test VQA
4. Fine-tuned model manual validation