# Initial configuration

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch, os
print(torch.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
from travel.constants import DATA_CACHE_DIR, MODEL_CACHE_DIR
os.environ['HF_HOME'] = MODEL_CACHE_DIR

1.13.0+cu117
cuda


In [3]:
!nvidia-smi

Thu Feb  8 08:28:31 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.06              Driver Version: 545.23.06    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A40                     On  | 00000000:23:00.0 Off |                    0 |
|  0%   28C    P8              26W / 300W |      7MiB / 46068MiB |      0%   E. Process |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

# Utils

Need to move these into external code files.

In [None]:
from data.captaincook4d.constants import ANNOTATIONS_DIR, VIDEO_DIR, ERROR_CATEGORIES

from travel.data import MistakeDetectionExample
from travel.data.utils import generate_float_series
from travel.data.utils.image import BoundingBox, BoundingBoxCluster, get_preprocessed_image, draw_entity_boxes_on_image
from travel.data.utils.video import get_video, extract_frames
from travel.data.utils.text import simple_present_to_imperative
from travel.model.vqg import VQGOutputs
from travel.model.vqa import VQAOutputs, VQAResponse
from travel.model import MistakeDetectionEvaluator, HeuristicMistakeDetectionEvaluator, \
                         MistakeDetectionScorer, FrameVQAMistakeDetectionScorer

# CaptainCook4D Preliminary Empirical Experiments

## Gather Data

In [42]:
from travel.data.captaincook4d import CaptainCook4DDataset

dataset = CaptainCook4DDataset(debug_n_examples_per_class=20)

  0%|          | 1/335 [00:00<04:49,  1.15it/s]

Error examples: 1
Success examples: 6


  1%|          | 2/335 [00:01<03:46,  1.47it/s]

Error examples: 3
Success examples: 11


  1%|          | 3/335 [00:02<04:01,  1.38it/s]

Error examples: 8
Success examples: 13


  1%|          | 4/335 [00:04<07:09,  1.30s/it]

Error examples: 8
Success examples: 24


  1%|▏         | 5/335 [00:04<05:11,  1.06it/s]

Error examples: 11
Success examples: 24


  2%|▏         | 6/335 [00:05<04:23,  1.25it/s]

Error examples: 12
Success examples: 31


  2%|▏         | 7/335 [00:06<04:24,  1.24it/s]

Error examples: 17
Success examples: 35


  2%|▏         | 8/335 [00:06<04:24,  1.24it/s]

Error examples: 17
Success examples: 41


  3%|▎         | 9/335 [00:07<04:41,  1.16it/s]

Error examples: 21
Success examples: 47


  3%|▎         | 10/335 [00:09<05:19,  1.02it/s]

Error examples: 21
Success examples: 58


  3%|▎         | 11/335 [00:09<05:07,  1.05it/s]

Error examples: 25
Success examples: 67


  4%|▎         | 12/335 [00:11<05:58,  1.11s/it]

Error examples: 30
Success examples: 73


  4%|▍         | 13/335 [00:12<05:51,  1.09s/it]

Error examples: 34
Success examples: 80


  4%|▍         | 14/335 [00:13<05:16,  1.02it/s]

Error examples: 37
Success examples: 80


  4%|▍         | 15/335 [00:36<41:05,  7.70s/it]

Error examples: 39
Success examples: 87


  5%|▍         | 16/335 [00:42<38:36,  7.26s/it]

Error examples: 39
Success examples: 96


  5%|▌         | 17/335 [00:57<49:42,  9.38s/it]

Error examples: 41
Success examples: 100


  5%|▌         | 18/335 [01:08<53:34, 10.14s/it]

Error examples: 45
Success examples: 105


  6%|▌         | 19/335 [01:25<1:04:15, 12.20s/it]

Error examples: 53
Success examples: 106


  6%|▌         | 20/335 [01:46<1:16:52, 14.64s/it]

Error examples: 53
Success examples: 117


  6%|▋         | 21/335 [02:00<1:16:06, 14.54s/it]

Error examples: 53
Success examples: 128


  7%|▋         | 22/335 [02:15<1:17:10, 14.79s/it]

Error examples: 53
Success examples: 139


  7%|▋         | 23/335 [02:21<1:03:05, 12.13s/it]

Error examples: 58
Success examples: 144


  7%|▋         | 24/335 [02:37<1:08:12, 13.16s/it]

Error examples: 58
Success examples: 159


  7%|▋         | 25/335 [02:49<1:05:35, 12.70s/it]

Error examples: 63
Success examples: 161


  8%|▊         | 26/335 [02:59<1:01:36, 11.96s/it]

Error examples: 63
Success examples: 174


  8%|▊         | 27/335 [03:30<1:30:43, 17.68s/it]

Error examples: 63
Success examples: 193


  8%|▊         | 28/335 [03:42<1:21:59, 16.02s/it]

Error examples: 63
Success examples: 204


  9%|▊         | 29/335 [04:05<1:32:38, 18.16s/it]

Error examples: 63
Success examples: 224


  9%|▉         | 30/335 [04:22<1:29:50, 17.68s/it]

Error examples: 63
Success examples: 230


  9%|▉         | 31/335 [04:49<1:44:27, 20.62s/it]

Error examples: 63
Success examples: 241


 10%|▉         | 32/335 [05:00<1:29:38, 17.75s/it]

Error examples: 63
Success examples: 255


 10%|▉         | 33/335 [05:20<1:32:17, 18.34s/it]

Error examples: 67
Success examples: 258


 10%|█         | 34/335 [05:46<1:44:20, 20.80s/it]

Error examples: 67
Success examples: 269


 10%|█         | 35/335 [06:12<1:51:21, 22.27s/it]

Error examples: 67
Success examples: 292


 11%|█         | 36/335 [06:23<1:33:10, 18.70s/it]

Error examples: 70
Success examples: 300


 11%|█         | 37/335 [06:38<1:28:46, 17.87s/it]

Error examples: 74
Success examples: 303


 11%|█▏        | 38/335 [06:50<1:19:14, 16.01s/it]

Error examples: 75
Success examples: 305


 12%|█▏        | 39/335 [07:09<1:23:45, 16.98s/it]

Error examples: 84
Success examples: 310


 12%|█▏        | 40/335 [07:32<1:31:38, 18.64s/it]

Error examples: 84
Success examples: 321


 12%|█▏        | 41/335 [07:58<1:42:03, 20.83s/it]

Error examples: 85
Success examples: 337


 13%|█▎        | 42/335 [08:17<1:39:25, 20.36s/it]

Error examples: 90
Success examples: 338


 13%|█▎        | 43/335 [08:30<1:28:50, 18.25s/it]

Error examples: 90
Success examples: 349


 13%|█▎        | 44/335 [08:45<1:22:54, 17.09s/it]

Error examples: 90
Success examples: 356


 13%|█▎        | 45/335 [09:01<1:20:45, 16.71s/it]

Error examples: 91
Success examples: 374


 14%|█▎        | 46/335 [09:45<1:59:48, 24.87s/it]

Error examples: 92
Success examples: 376


 14%|█▍        | 47/335 [10:06<1:53:51, 23.72s/it]

Error examples: 92
Success examples: 387


 14%|█▍        | 48/335 [10:22<1:42:57, 21.52s/it]

Error examples: 94
Success examples: 393


 15%|█▍        | 49/335 [10:50<1:52:25, 23.59s/it]

Error examples: 94
Success examples: 410


 15%|█▍        | 50/335 [11:17<1:55:51, 24.39s/it]

Error examples: 94
Success examples: 429


 15%|█▌        | 51/335 [11:33<1:44:15, 22.03s/it]

Error examples: 94
Success examples: 441


 16%|█▌        | 52/335 [11:48<1:33:12, 19.76s/it]

Error examples: 99
Success examples: 444


 16%|█▌        | 52/335 [12:17<1:06:53, 14.18s/it]

Collected at least 20 positive and negative examples!





## Step 1: VQG with LLaMA for Recipe Steps

Load model:

In [12]:
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM

LM_NAME = "meta-llama/Llama-2-7b-hf"
lm = pipeline("text-generation", 
                 model=LM_NAME, 
                 token="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", 
                 model_kwargs={"load_in_8bit": True})
# lm = lm.to(device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [13]:
lm.model.generation_config.top_p = None
lm.model.generation_config.temperature = None
lm.model.generation_config.do_sample = False
lm.tokenizer.pad_token_id = lm.model.config.eos_token_id
lm.tokenizer.padding_side = "left"

Load recipe steps:

In [14]:
from travel.data.captaincook4d.constants import RECIPE_STEPS

Generate success verification questions:

In [15]:
import torch
from transformers.pipelines.pt_utils import KeyDataset
import os
from dataclasses import dataclass, field
from dataclasses_json import dataclass_json
from enum import Enum
import pickle
import json
                     
USE_VQG_CACHE = False

VQG_FILENAME = os.path.join(DATA_CACHE_DIR, "vqg_outputs.json")
if not USE_VQG_CACHE or not os.path.exists(VQG_FILENAME):

    # NOTE: this process makes an assumption that a single recipe step is all we need to determine the step requirements;
    # there are some exceptions to this assumption, e.g., "cook the pan until the colour is darkened" where the target
    # object is unclear just from a single step
    
    # TODO: source recipe steps from elsewhere; find a comprehensive set that covers variety of state changes?
    task_graph1 = os.path.join(ANNOTATIONS_DIR, "task_graphs/capresebruschetta.json")
    procedure_id1 = 345
    step_idx = 4
    example1 = 'The recipe step is "Spoon the mixture from the bowl onto the bread". To visually verify that this step is complete, what are 2 questions we could ask about an image of a target object and their expected answers?\n' \
               'Target object: bread\n' \
               '1. Is there mixture on the bread? Yes\n' \
               '2. Is there any bread without mixture on top of it? No' \

    example2 = 'The recipe step is "Roll the tortilla into a thin, log shape about 1 inch thick. Make sure no filling leaks out.". To visually verify that this step is complete, what are 2 questions we could ask about an image of a target object and their expected answers?\n' \
               'Target object: tortilla\n' \
               '1. Is the tortilla in a thin log shape? Yes\n' \
               '2. Is there any filling leaking out of the tortilla? No'

    example3 = 'The recipe step is "Fold the coffee filter into quarters". To visually verify that this step is complete, what are 2 questions we could ask about an image of a target object and their expected answers?\n' \
               'Target object: coffee filter\n' \
               '1. Is the coffee filter in a quarter circle? Yes\n' \
               '2. Is the coffee filter folded? Yes' \

    prompts = []
    with torch.no_grad():
        for step_id, step in RECIPE_STEPS.items():
            test = f'The recipe step is "{step}". To visually verify that this step is complete, what are 2 questions we could ask about an image of a target object and their expected answers?\n'
            prompt = "\n\n".join([example1, example2, example3, test])
            prompts.append({"step_id": step_id, "step": step, "prompt": prompt})

    vqg_outputs = {}
    prompt_idx = 0
    for out in tqdm(lm(KeyDataset(prompts, "prompt"), 
                     batch_size=32, 
                     max_new_tokens=64, 
                     return_full_text=False, 
                     truncation="do_not_truncate"),
                   total=len(prompts)):
        inp = prompts[prompt_idx]

        step_id = int(inp['step_id'])
        step = inp['step']

        text = out[0]['generated_text']
        
        # print("===========================================================================")
        # print(text)
        # Hack: sometimes output from LLaMA 2 starts with Љ and whitespace characters, and sometimes this character replaces the first "T" in "Target object:"
        text_fixed = text.replace("Љ", "").strip() 
        if not text_fixed.startswith("Target object:") and ":" in text_fixed:
            text_fixed = "Target object: " + ":".join(text_fixed.split(":")[1:]).strip()
        
        # Parse reported target object and questions and answers
        try:
            target_object = text_fixed.split("\n")[0].split("Target object: ")[1].strip()
            # print(target_object)
            # state_description = text_fixed.split("\n")[1].replace("Expected state: After this step, ", "").strip()
            questions_answers = [(q_a.split("?")[0].strip() + "?", q_a.split("?")[1].strip()) for q_a in text_fixed.split("\n")[1:3]] # NOTE: only extract k=2 questions and answers; can adjust this as needed later
            questions = [q[2:].strip() for q, _ in questions_answers]          
            answers = [a for _, a in questions_answers]
            output = VQGOutputs(step_id,
                                step,
                                target_object,
                                # state_description,
                                questions,
                                answers)
        except:
            print(text)
            print('=====')
            print(text_fixed)
            raise

        vqg_outputs[step_id] = output
        prompt_idx += 1

        # # Early stopping for debugging
        # if prompt_idx >= 20:
        #     break
else:
    vqg_outputs = pickle.load(open(VQG_FILENAME, "rb"))
    
#     vqg_outputs_json = json.load(open(os.path.join(DATA_CACHE_DIR, "vqg_outputs.json"), "r"))
    
#     with open(os.path.join(DATA_CACHE_DIR, "vqg_outputs.json"), "r") as f:
#         vqg_outputs_json = json.load(f)
#         vqg_outputs = {k: VQGOutputs.from_dict(v) for k, v in vqg_outputs_json.items()}
          

100%|██████████| 350/350 [03:04<00:00,  1.90it/s]


#### Print and save outputs

In [None]:
for step_id, output in vqg_outputs.items():
    print(RECIPE_STEPS[step_id])
    pprint(output)
    print('===================')

In [16]:
import json
from dataclasses import asdict
import pickle

# TODO: cache in a more interpretable way later
pickle.dump(vqg_outputs, open(VQG_FILENAME, "wb"))

# Sort the dict
# vqg_keys = sorted(vqg_outputs.keys())
# vqg_outputs_new = {}
# for key in vqg_keys:
#     vqg_outputs_new[key] = vqg_outputs[key]

# with open(os.path.join(DATA_CACHE_DIR, "vqg_outputs.json"), "w") as f:
#     for key in vqg_outputs_new:
#         vqg_outputs_new[key] = VQGOutputs.to_dict(vqg_outputs_new[key])
#     json.dump(vqg_outputs_new, f, indent=4)

## Step 2: VQA with LLaVA

Load models and dataset:

In [17]:
del lm

In [18]:
# Setup code grabbed from docs: https://huggingface.co/docs/transformers/model_doc/llava#transformers.LlavaForConditionalGeneration
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration

MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
vlm_processor = AutoProcessor.from_pretrained(MODEL_NAME)
vlm = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir=DATA_CACHE_DIR, load_in_8bit=True)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [19]:
import requests
from PIL import Image
import torch

from transformers import OwlViTProcessor, OwlViTForObjectDetection
from transformers import Owlv2Processor, Owlv2ForObjectDetection

DETECTOR_NAME = "google/owlv2-base-patch16"
detector_processor = Owlv2Processor.from_pretrained(DETECTOR_NAME)
detector = Owlv2ForObjectDetection.from_pretrained(DETECTOR_NAME, load_in_8bit=True)

In [44]:
examples = success_examples[:100] + error_examples[:100]
# examples = examples[:1] # Just for debug purposes

### VQG->VQA

First, use OWL to filter out frames where target object isn't visible:

In [21]:
from tqdm import tqdm
import dataclasses
from collections import defaultdict

SHOW_OUTPUTS = False

with torch.no_grad():
    
    all_frames = [frame for example in examples for frame in example.frames]
    all_texts = [f"a photo of {'an' if vqg_outputs[example.procedure_id].target_object[0] in ['a','e','i','o','u'] else 'a'} {vqg_outputs[example.procedure_id].target_object}" for example in examples for frame in example.frames]
    
    batch_size = 8
    all_results = []
    all_padded_images = []
    for i in tqdm(range(0, len(all_frames), batch_size), desc="detecting objects"):
        # Prepare the batch
        batch_frames = all_frames[i:i+batch_size]
        batch_texts = all_texts[i:i+batch_size]
        
        inputs = detector_processor(text=batch_texts, images=batch_frames, return_tensors="pt").to(device)
        outputs = detector(**inputs)
        inputs = inputs.to("cpu")  
        
        padded_images = [get_preprocessed_image(inputs.pixel_values[j].detach().to('cpu')) for j in range(len(batch_frames))]

        # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
        target_sizes = torch.Tensor([pi.size[::-1] for pi in padded_images])
        # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
        results = detector_processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=0.2)
        all_results += results
        all_padded_images += padded_images
        
    texts = all_texts
    results = all_results
    padded_images = all_padded_images

    example_frame_idx = 0
    filtered_examples = []  
    filtered_out_frames = 0
    for example in tqdm(examples, desc="filtering frames"):        
        step_id = example.procedure_id
        vqg_output = vqg_outputs[step_id]
        target_object = vqg_output.target_object

        filtered_frames = []
        filtered_frame_times = []
        for i in range(example_frame_idx, example_frame_idx + len(example.frames)):
            
            text = texts[i]
            boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]

            if len(boxes) > 0:
                filtered_frames.append(example.frames[i - example_frame_idx])
                filtered_frame_times.append(example.frame_times[i - example_frame_idx])
            else:
                filtered_out_frames += 1
            
            if SHOW_OUTPUTS:
                entities = defaultdict(list)
                for box, score, label in zip(boxes, scores, labels):
                    box = BoundingBox(*[round(j, 2) for j in box.tolist()], float(score.detach().cpu()))

                    # normalize bounding box coordinates
                    box_norm = box.normalize(padded_images[example_frame_idx + i].width, padded_images[example_frame_idx + i].height)
                    # print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")

                    entities[text[label]].append(box_norm)

                entities = [(entity, None, BoundingBoxCluster(boxes).get_merged_boxes()) for entity, boxes in entities.items()]
                plt.figure()
                draw_entity_boxes_on_image(padded_images[example_frame_idx + i].resize((padded_images[example_frame_idx + i].width*2, padded_images[example_frame_idx + i].height*2)), entities, show=True)
        
        new_example = dataclasses.replace(example)
        new_example.frames = filtered_frames
        new_example.frame_times = filtered_frame_times
        filtered_examples.append(new_example)
        example_frame_idx += len(example.frames)
        
print(f"{filtered_out_frames} / {len(padded_images)} frames filtered out for not containing target object")

detecting objects: 100%|██████████| 119/119 [05:26<00:00,  2.74s/it]
filtering frames: 100%|██████████| 68/68 [00:00<00:00, 24360.49it/s]

504 / 945 frames filtered out for not containing target object





Ask success verification questions per frame:

(doesn't try to focus VLM on target object yet)

In [45]:
import random
from pprint import pprint
import torch

prompt_template = "USER: <image>\n{question} (yes/no) ASSISTANT: "
NO_ID = vlm_processor.tokenizer("No", add_special_tokens=False)['input_ids'][0]
YES_ID = vlm_processor.tokenizer("Yes", add_special_tokens=False)['input_ids'][0]
RESPONSE_TOKEN_IDS = {
    VQAResponse["No"]: NO_ID, 
    VQAResponse["Yes"]: YES_ID
}

vqa_outputs = []
with torch.no_grad():
    # for example in tqdm(filtered_examples):
    for example in tqdm(examples):
        example_vqa_outputs = []
        
        step_id = example.procedure_id
        
        questions = vqg_outputs[step_id].questions
        prompts = [prompt_template.format(question=question.strip()) for question in questions]
        expected_answers = vqg_outputs[step_id].answers
                           
        for frame in example.frames:
            frame_vqa_outputs = []
            
            for prompt, expected_answer in zip(prompts, expected_answers):
                inputs = vlm_processor(text=prompt, images=frame, return_tensors="pt").to(device)

                # Forward pass
                logits = vlm(**inputs).logits[0] # (seq. length, vocab size)
                logits = logits[-1].detach().cpu() # just logits for last input (next generation)
                
                frame_vqa_outputs.append(
                    VQAOutputs(
                        step_id,
                        frame,
                        prompt,
                        expected_answer,
                        RESPONSE_TOKEN_IDS,
                        logits,        
                    )
                )
            example_vqa_outputs.append(frame_vqa_outputs)

        vqa_outputs.append(example_vqa_outputs)

100%|██████████| 200/200 [30:53<00:00,  9.27s/it] 


### Baseline: Direct VQA

In [39]:
import random
from pprint import pprint
import torch

prompt_template = 'USER: <image>\nThe current goal is to "{step}". Did the person successfully do this? (yes/no) ASSISTANT: '
NO_ID = vlm_processor.tokenizer("No", add_special_tokens=False)['input_ids'][0]
YES_ID = vlm_processor.tokenizer("Yes", add_special_tokens=False)['input_ids'][0]
RESPONSE_TOKEN_IDS = {
    VQAResponse["No"]: NO_ID, 
    VQAResponse["Yes"]: YES_ID
}

vqa_outputs = []
with torch.no_grad():
    for example in tqdm(examples):
        this_vqa_outputs = []
        
        step_id = example.procedure_id
        step = example.procedure_description
        
        prompt = prompt_template.format(step=step)
        expected_answer = VQAResponse["Yes"]
        
        for frame in example.frames:
            inputs = vlm_processor(text=prompt, images=frame, return_tensors="pt").to(device)

            # Forward pass
            logits = vlm(**inputs).logits[0] # (seq length, vocab size)
            logits = logits[-1].detach().cpu()

            this_vqa_outputs.append(
                [VQAOutputs(
                    step_id,
                    frame,
                    prompt,
                    expected_answer,
                    RESPONSE_TOKEN_IDS,
                    logits,        
                )]
            )        
            
        vqa_outputs.append(this_vqa_outputs)

100%|██████████| 68/68 [06:01<00:00,  5.32s/it]


## Step 3: Evaluate VQA Outputs

In [53]:
from collections import Counter

evaluator = HeuristicMistakeDetectionEvaluator(examples, vqa_outputs)

metrics = evaluator.get_mistake_detection_metrics()
print("Metrics:")
pprint(metrics)

VQAResponse.Yes VQAResponse.Yes
VQAResponse.Yes VQAResponse.No
VQAResponse.Yes VQAResponse.Yes
VQAResponse.Yes VQAResponse.No
VQAResponse.No VQAResponse.Yes
VQAResponse.Yes VQAResponse.No
VQAResponse.Yes VQAResponse.Yes
VQAResponse.Yes VQAResponse.No
VQAResponse.No VQAResponse.Yes
VQAResponse.Yes VQAResponse.No
VQAResponse.Yes VQAResponse.Yes
VQAResponse.Yes VQAResponse.No
VQAResponse.Yes VQAResponse.Yes
VQAResponse.Yes VQAResponse.No
VQAResponse.Yes VQAResponse.Yes
VQAResponse.Yes VQAResponse.No
VQAResponse.Yes VQAResponse.Yes
VQAResponse.Yes VQAResponse.No
VQAResponse.Yes VQAResponse.Yes
VQAResponse.Yes VQAResponse.No
VQAResponse.Yes VQAResponse.Yes
VQAResponse.Yes VQAResponse.No
VQAResponse.No VQAResponse.Yes
VQAResponse.Yes VQAResponse.No
VQAResponse.Yes VQAResponse.Yes
VQAResponse.Yes VQAResponse.No
[True, True, True, True, True, True, True, True, True, True, True, True, True]
VQAResponse.Yes VQAResponse.Yes
VQAResponse.Yes VQAResponse.No
VQAResponse.Yes VQAResponse.Yes
VQARespons

In [47]:
pprint(vqa_outputs)

[[[VQAOutputs(procedure_id=17,
              frame=<PIL.Image.Image image mode=RGB size=640x360 at 0x154DDD9673A0>,
              prompt='USER: <image>\n'
                     'Is there an onion in the image? (yes/no) ASSISTANT: ',
              expected_answer=<VQAResponse.Yes: 1>,
              response_token_ids={<VQAResponse.Yes: 1>: 3869,
                                  <VQAResponse.No: 0>: 1939},
              logits=tensor([-2.4805e+00, -1.5420e+00,  9.6484e+00,  ..., -4.8065e-03,
        -7.5798e-03,  6.8998e-04]),
              answer_probs={<VQAResponse.Yes: 1>: 0.6959583,
                            <VQAResponse.No: 0>: 0.30404168},
              predicted_answer=<VQAResponse.Yes: 1>),
   VQAOutputs(procedure_id=17,
              frame=<PIL.Image.Image image mode=RGB size=640x360 at 0x154DDD9673A0>,
              prompt='USER: <image>\n'
                     'Is the onion peehren? (yes/no) ASSISTANT: ',
              expected_answer=<VQAResponse.No: 0>,
              respo

# RLHF/DPO Preliminary Experiments

## Testing VQA Scorer

Simple test of VQA scorer:

In [14]:
del lm

NameError: name 'lm' is not defined

In [38]:
del scorer

In [23]:
import dataclasses

MODEL_NAME = "llava-hf/llava-1.5-7b-hf"

# Select last frame from each CaptainCook4D video clip
examples = success_examples[:20] + error_examples[:20]
filtered_examples = []
for example in examples:
    filtered_frames = example.frames[-1]
    new_example = dataclasses.replace(example)
    new_example.frames = [example.frames[-1]]
    new_example.frame_times = [example.frames[-1]]
    filtered_examples.append(new_example)
examples = filtered_examples

scorer = FrameVQAMistakeDetectionScorer(MODEL_NAME)
logits_errors = scorer(examples, vqg_outputs)
# vqa_outputs = scorer(examples, vqg_outputs, return_vqa_outputs=True)

pprint(examples[14])
pprint(vqg_outputs[examples[14].procedure_id])

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

running VQA: 100%|██████████| 80/80 [00:29<00:00,  2.69it/s]

MistakeDetectionExample(video_id='29_7',
                        procedure_id=342,
                        frames=[<PIL.Image.Image image mode=RGB size=640x360 at 0x14F508543B50>],
                        frame_times=[<PIL.Image.Image image mode=RGB size=640x360 at 0x14F508543B50>],
                        procedure_description='1/8 cup shredded mozzarella to '
                                              'a bowl',
                        mistake=False,
                        mistake_type=None,
                        mistake_description=None)
VQGOutputs(procedure_id=342,
           procedure_description='1/8 cup shredded mozzarella to a bowl',
           target_object='bowl',
           questions=['Is there mozzarella in the bowl?',
                      'Is there any cheese in the bowl?'],
           answers_str=['Yes', 'Yes'],
           answers=[<VQAResponse.Yes: 1>, <VQAResponse.Yes: 1>])





In [21]:
print(logits_errors.shape)
pprint(logits_errors)

torch.Size([40, 2, 2])
tensor([[[-4.1604e-05,  0.0000e+00],
         [ 1.0000e+00,  1.0000e+00]],

        [[-1.1921e-07, -2.3842e-07],
         [ 1.0000e+00,  1.0000e+00]],

        [[-7.2718e-06, -3.5763e-07],
         [-1.4067e-05, -9.5367e-07]],

        [[-2.3842e-07, -2.3842e-07],
         [ 1.0000e+00,  1.0000e+00]],

        [[-1.1921e-07, -3.5763e-07],
         [-1.1921e-07, -3.5763e-07]],

        [[-3.5763e-07, -1.1921e-07],
         [-2.3842e-07, -1.1921e-07]],

        [[-5.9605e-07, -2.3842e-07],
         [-1.1921e-07, -1.1921e-07]],

        [[-1.1921e-07, -5.9605e-07],
         [ 1.0000e+00,  1.0000e+00]],

        [[ 0.0000e+00, -1.1921e-07],
         [ 0.0000e+00, -1.1921e-07]],

        [[ 0.0000e+00, -1.1921e-07],
         [ 1.0000e+00,  1.0000e+00]],

        [[-1.1921e-07, -1.1921e-07],
         [-1.1921e-07, -2.3842e-07]],

        [[-3.5763e-07, -1.1921e-07],
         [-2.3842e-07, -1.1921e-07]],

        [[-8.3447e-07, -2.3842e-07],
         [-5.9605e-07, -2.38

## Data Collection from Ego4D

Load LLaMA to propose modifications to Ego4D narrations:

In [1]:
from transformers import pipeline

LM_NAME = "meta-llama/Llama-2-7b-hf"
lm = pipeline("text-generation", 
                 model=LM_NAME, 
                 token="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", 
                 model_kwargs={"load_in_8bit": True})
pprint(lm.model.generation_config)
lm.tokenizer.pad_token_id = lm.model.config.eos_token_id
lm.tokenizer.padding_side = "left"
NEWLINE_TOKEN_ID = lm.tokenizer("\n", add_special_tokens=False)['input_ids'][1]
lm.model.generation_config.eos_token_id = NEWLINE_TOKEN_ID

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x14b1ed35fa60>>
Traceback (most recent call last):
  File "/home/sstorks/.cache/pypoetry/virtualenvs/travel-l8Q4DA9E-py3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


RuntimeError: KeyboardInterrupt: <MESSAGE UNAVAILABLE DUE TO ANOTHER EXCEPTION>

MESSAGE UNAVAILABLE DUE TO EXCEPTION: KeyboardInterrupt: <EMPTY MESSAGE>

In [None]:
object_change_template = '"{sentence}": {old_noun} -> {new_noun}\n'
object_change_examples = [
    object_change_template.format(sentence="Pour the water into a bowl", old_noun="bowl", new_noun="vase"),
    object_change_template.format(sentence="Slice the tomatoes into small cubes", old_noun="tomatoes", new_noun="peppers"),
    object_change_template.format(sentence="Place the box on the shelf", old_noun="box", new_noun="bucket"),
    object_change_template.format(sentence="Mix the ingredients using a spoon", old_noun="spoon", new_noun="fork"),
    object_change_template.format(sentence="Unscrew the bolt on the bicycle", old_noun="bicycle", new_noun="motorcycle"),
]
object_change_prompt_template = "Propose a plausible different noun to replace a noun in each of the below sentences. The new noun should not be a hypernym or hyponym of the original noun.\n"
object_change_prompt_template += "\n".join(object_change_examples) + '\n"{sentence}":'

action_change_template = '"{sentence}": {new_sentence}\n'
action_change_examples = [
    action_change_template.format(sentence="Pour the water into a bowl", new_sentence="Pour the water out of the bowl"),
    action_change_template.format(sentence="Slice the tomatoes into small cubes", new_sentence="Add the tomatoes into a blender"),
    action_change_template.format(sentence="Place the box on the shelf", new_sentence="Take the box from the shelf"),
    action_change_template.format(sentence="Mix the ingredients using a spoon", new_sentence="Use the spoon to scoop the ingredients onto the tortilla"),
    action_change_template.format(sentence="Unscrew the bolt on the bicycle", new_sentence="Screw the bolt into the bicycle"),
]
action_change_prompt_template = "Modify each sentence below so that it describes a different action applied to the same objects.\n"
action_change_prompt_template += "\n".join(action_change_examples) + '\n"{sentence}":'

Generate mistake detection data from Ego4D:

In [6]:
from travel.data.ego4d import Ego4dFHOMainDataset

EGO4D_ANNOTATION_PATH = "/nfs/turbo/coe-chaijy-unreplicated/datasets/ego4d/v2/annotations/fho_main.json"
# EGO4D_SPLIT_PATH = "/nfs/turbo/coe-chaijy-unreplicated/datasets/ego4d/splits/fho_main_train.json"
EGO4D_SPLIT_PATH = "/nfs/turbo/coe-chaijy/generated-data/ego4d/v2/splits/fho_main_val.json"
EGO4D_VIDEO_PATH = "/nfs/turbo/coe-chaijy-unreplicated/datasets/ego4d/v2/full_scale"

ego4d = Ego4dFHOMainDataset(
    EGO4D_ANNOTATION_PATH,
    EGO4D_SPLIT_PATH,
    EGO4D_VIDEO_PATH
)
print(f"{len(ego4d)} clips loaded")

21634 clips loaded


In [8]:
from pprint import pprint
for clip in ego4d:
    pprint(clip)
    # break

{'aug_index': 0,
 'clip_index': 1,
 'fps': 30.0,
 'future_occurrences': 12,
 'narration_text': '#C C turns the brush in the tin of paint.',
 'pnr_frame': 689,
 'post_frame': 748,
 'pre_15': 536,
 'pre_30': 521,
 'pre_45': 506,
 'pre_frame': 551,
 'previous_occurrences': 0,
 'structured_noun': 'paintbrush',
 'structured_verb': 'turn_(spin,_rotate,_flip,_turn_over)',
 'video': tensor([[[[129, 130, 130,  ..., 236, 236, 236],
          [133, 134, 134,  ..., 236, 236, 236],
          [136, 137, 137,  ..., 236, 236, 236],
          ...,
          [ 28,  28,  28,  ..., 114, 117, 116],
          [ 28,  28,  28,  ..., 114, 117, 116],
          [ 28,  28,  28,  ..., 114, 117, 116]],

         [[154, 150, 143,  ..., 236, 236, 236],
          [156, 150, 147,  ..., 236, 236, 236],
          [158, 155, 152,  ..., 236, 236, 236],
          ...,
          [ 30,  31,  30,  ..., 111, 116, 117],
          [ 28,  28,  28,  ..., 111, 116, 117],
          [ 28,  28,  28,  ..., 111, 116, 117]],

         [[1

KeyboardInterrupt: 

In [None]:
from pprint import pprint
from torchvision.transforms.functional import to_pil_image
from torch.nn.functional import cosine_similarity
from PIL import Image
import matplotlib.pyplot as plt
from travel.data.ego4d import clean_narration_text
import spacy
from tqdm import tqdm

nlp = spacy.load('en_core_web_sm')

clip_idx = 0
positive_examples = []
hard_negative_examples = []
wrong_object_negative_examples = []
wrong_action_negative_examples = []

# ego4d.reset()
SIMILARITY_THRESHOLD = 0.95
for clip in tqdm(ego4d):
    print(clip.keys())
    
    # Convert narration text to imperative form to match the sentence structure of recipes and task instructions    
    instruction_text = clean_narration_text(clip['narration_text']) # Replace symbols in narration text with words
    instruction_text = simple_present_to_imperative(nlp, instruction_text)

    # clip['video'] shape: (C, # frames, H, W)
    precondition_frame_t, effect_frame_t = clip['video'][:,0], clip['video'][:,-1] # (C, H, W)
    precondition_frame, effect_frame = to_pil_image(precondition_frame_t), to_pil_image(effect_frame_t)
       
    # Omit examples where precondition and effect frame are overly similar
    precondition_effect_similarity = cosine_similarity(precondition_frame_t.flatten().float(), effect_frame_t.flatten().float(), dim=0).detach().numpy()
    if precondition_effect_similarity >= SIMILARITY_THRESHOLD:
        continue    
    
    # Generate positive example from effect frame
    positive_examples.append(MistakeDetectionExample(
        task_name="ego4d",
        video_id=clip['video_uid'],
        procedure_id=clip['clip_index'], # procedure_id is just clip_id from Ego4D - can adjust this to be some index of Ego4D structured verbs if needed later
        frames=[effect_frame],
        frame_times=[clip['post_frame'] / clip['fps']],
        procedure_description=instruction_text,
        mistake=False,
    ))
    
    # Generate hard negative example from precondition frame
    if clip['previous_occurrences'] < 2:
        hard_negative_examples.append(MistakeDetectionExample(
            task_name="ego4d",
            video_id=clip['video_uid'],
            procedure_id=clip['clip_index'], # procedure_id is just clip_id from Ego4D - can adjust this to be some index of Ego4D structured verbs if needed later
            frames=[precondition_frame],
            frame_times=[clip['pre_frame'] / clip['fps']],
            procedure_description=instruction_text,
            mistake=True,
            mistake_type="Action Incomplete",
        ))
    # TODO: current solution filters out repeated actions from appearing in ego4d; could consider filtering out actions that occur more than N times in the video, or condense actions
    # Need to think of a better policy for how to filter out/condense repetitive actions
    
    # Generate negative examples by perturbing nouns and verbs in the instruction text
    # TODO: Move this step to later to enable batching
    # TODO: Switch to using GPT-4 later? Might get higher-quality results
    proposed_object_change = lm(object_change_prompt_template.format(sentence=instruction_text),
                                max_new_tokens=16,
                                return_full_text=False)[0]['generated_text'].replace("\n","")
    proposed_object_change = [o.strip() for o in proposed_object_change.split("->")]
    instruction_text_new_object = instruction_text.replace(proposed_object_change[0], proposed_object_change[1])
    print(instruction_text_new_object)
    wrong_object_negative_examples.append(MistakeDetectionExample(
        task_name="ego4d",
        video_id=clip['video_uid'],
        procedure_id=clip['clip_index'],
        frames=[effect_frame],
        frame_times=[clip['post_frame'] / clip['fps']],
        procedure_description=instruction_text_new_object,
        mistake=True,
        mistake_type="Wrong Object",
    ))
    
    proposed_action_change = lm(action_change_prompt_template.format(sentence=instruction_text),
                                max_new_tokens=32,
                                return_full_text=False)[0]['generated_text'].replace("\n","").strip()
    instruction_text_new_action = proposed_action_change
    print(instruction_text_new_action)
    wrong_action_negative_examples.append(MistakeDetectionExample(
        task_name="ego4d",
        video_id=clip['video_uid'],
        procedure_id=clip['clip_index'],
        frames=[effect_frame],
        frame_times=[clip['post_frame'] / clip['fps']],
        procedure_description=instruction_text_new_action,
        mistake=True,
        mistake_type="Wrong Action",
    ))

    fig, axarr = plt.subplots(1, 2, figsize=(10, 4))
    fig.suptitle(instruction_text)
    axarr[0].imshow(precondition_frame)
    axarr[1].imshow(effect_frame)
    plt.show()

    # Stop early for debugging purpose
    clip_idx += 1
    if clip_idx >= 20:
        break

## Debugging RL-VQA-F Pipeline

NOTE: may need to request 2 GPUs/8*10GB CPUs in session for this.

In [4]:
from trl import DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from peft import LoraModel, LoraConfig
from datasets import Dataset

# TODO: before testing DPO, need to do the following steps:
# 1) VQG -> candidate questions for each recipe step; use nonzero temperature to generate at least 2 candidate question sets
# 2) Score question sets using the scorer
# 3) Build scored question sets into DPO dataset to interface with DPOTrainer
# 4) Run DPOTrainer

train_dataset = Dataset.from_dict({
    "prompt": [
        "hello",
        "how are you",
        "What is your name?",
        "What is your name?",
        "Which is the best programming language?",
        "Which is the best programming language?",
        "Which is the best programming language?",
    ],
    "chosen": [
        "hi nice to meet you",
        "I am fine",
        "My name is Mary",
        "My name is Mary",
        "Python",
        "Python",
        "Java",
    ],
    "rejected": [
        "leave me alone",
        "I am not fine",
        "Whats it to you?",
        "I dont have a name",
        "Javascript",
        "C++",
        "C++",
    ],
})

In [10]:
LM_NAME = "meta-llama/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(LM_NAME)
model = AutoModelForCausalLM.from_pretrained(LM_NAME, device_map="auto")
print(model.device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

cuda:0


In [11]:
config = LoraConfig(
    task_type="SEQ_2_SEQ_LM",
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.01,
)
model = LoraModel(model, config, "default")

In [None]:
# TODO: below runs out of CPU memory

training_args = TrainingArguments(output_dir="./dpo_output")

dpo_trainer = DPOTrainer(
    model=model,
    args=training_args,
    beta=0.1,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
    max_length=64,
    max_prompt_length=128,
    max_target_length=64
)

In [None]:
dpo_trainer.train()