# Initial configuration

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch, os
print(torch.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
DATA_CACHE_DIR = "cache_dir"
MODEL_CACHE_DIR = "/scratch/chaijy_root/chaijy0/sstorks/.cache/huggingface"
os.environ['HF_HOME'] = MODEL_CACHE_DIR

1.13.0+cu117
cuda


In [3]:
!nvidia-smi

Thu Feb  1 11:27:12 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.06              Driver Version: 545.23.06    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A40                     On  | 00000000:1D:00.0 Off |                    0 |
|  0%   49C    P0              56W / 300W |      7MiB / 46068MiB |      0%   E. Process |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

# CaptainCook4D Toy Experiments Empirical Experiments

## Data loading

In [4]:
VIDEO_DIR = "/nfs/turbo/coe-chaijy-unreplicated/datasets/captaincook4d/data/captain_cook_4d/hololens/sync/pv" # Directory containing CaptainCook4D mp4s
ANNOTATIONS_DIR = "/nfs/turbo/coe-chaijy-unreplicated/datasets/captaincook4d/annotations"

Boilerplate code to load video frames from video files (from GPT4):

In [5]:
import cv2
import numpy as np

def get_video(video_path):
    # Open the video file
    cap = cv2.VideoCapture(video_path)

    if not cap.isOpened():
        raise IOError("Cannot open video file")
    
    return cap
    # remember to call cap.release() later

def extract_frames(cap, times):
    fps = cap.get(cv2.CAP_PROP_FPS)  # Frames per second
    frames = []

    for t in times:
        frame_number = int(t * fps)
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
        ret, frame = cap.read()

        if ret:
            # Convert to RGB
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)            
            frames.append(frame)
        else:
            print(f"Warning: Frame at time {t} seconds could not be read.")
            frames.append(None)

    return frames

Other utils functions:

In [6]:
def generate_float_series(start, end, step):
    # Ensure step is a positive float
    step = abs(step)

    # Initialize the series with the start value
    series = [start]

    # Generate numbers in the series
    while start + step <= end:
        start += step
        series.append(start)

    # Check if the end value is already in the series
    if series[-1] != end:
        series.append(end)

    return series

Data classes:

In [7]:
from torch.utils.data import Dataset
from PIL import Image
import json
from dataclasses import dataclass
from typing import Optional

ERROR_CATEGORIES = json.load(open(os.path.join(ANNOTATIONS_DIR, "annotation_json/error_category_idx.json"), "r"))

@dataclass
class MistakeDetectionExample:
     video_id: str
     step_id: int
     frames: list[Image]
     frame_times: list[float]
     action_description: str
     mistake: bool
     mistake_type: Optional[str] = None
     mistake_description: Optional[str] = None


Gather data:

In [8]:
import os, json
from pprint import pprint
from tqdm import tqdm
from PIL import Image

# Pick a sample video from CaptainCook4D
all_video_files = os.listdir(VIDEO_DIR)
video_paths = [f for f in all_video_files if f.endswith('.mp4')]
STEP_ANNOTATIONS = json.load(open(os.path.join(ANNOTATIONS_DIR, "annotation_json/complete_step_annotations.json"), "r"))
ERROR_ANNOTATIONS = json.load(open(os.path.join(ANNOTATIONS_DIR, "annotation_json/error_annotations.json"), "r"))
for error_annotation in ERROR_ANNOTATIONS:
    video_id = error_annotation['recording_id']
    STEP_ANNOTATIONS[video_id]["steps_errors"] = error_annotation["step_annotations"]

success_examples = []
error_examples = []
for sample_video_path in tqdm(video_paths):
    sample_video_id = "_".join(sample_video_path.split('_')[:2])
    sample_video_path = os.path.join(VIDEO_DIR, sample_video_path)
    try:
        sample_video = get_video(sample_video_path)
    except:
        print(f"Warning: could not open video file: {sample_video_path}")
        continue

    # Load step annotations for it and display precondition/effect frames
    for step in STEP_ANNOTATIONS[sample_video_id]["steps_errors"]:
        # Extract some keyframes for the action
        step_duration = step['end_time'] - step['start_time']
        step_id = int(step['step_id'])
        
        # Some steps are skipped
        if step_duration < 0.1:
            continue

        adjusted_start = step['start_time'] + min(step_duration * 0.05, 0.5) # Adjust the start time to be later by a maximum of 0.5 seconds
        adjusted_end = step['end_time'] - min(step_duration * 0.3, 3) # Adjust the end time to be earlier by a maximum of 3 seconds
        SAMPLE_FREQUENCY = 4.0
        times = generate_float_series(adjusted_start, adjusted_end, SAMPLE_FREQUENCY) # ultimately, we'll want to look at every image frame in some regular interval to determine if there's a mistake
        frames = extract_frames(sample_video, times)
        frames = [Image.fromarray(frame) for frame in frames]

        verb, action_description = step['description'].split("-")[0], "-".join(step['description'].split("-")[1:])
        
        if "errors" in step and len(step["errors"]) > 0:               
            mistake_type = step['errors'][0]["tag"]
            mistake_description = step['errors'][0]['description']
            # altered_action_description = step['modified_description'] # NOTE: can use this later if needed
            
            # Start with only errors specific to a single step, not related to quantities
            # Preparation error involves the wrong object(s)
            # Technique error involves action being performed the wrong way
            if mistake_type not in ["Preparation Error", "Technique Error"]:
                continue
            
            if len(step['errors']) > 1:
                print("Warning: Some error information discarded from only using the first annotated error.")            
            
            error_examples.append(
                MistakeDetectionExample(
                    sample_video_id,
                    step_id,
                    frames,
                    [time - min(times) for time in times],
                    action_description,
                    True,
                    mistake_type,
                    mistake_description
                )
            )
            # pprint(error_examples[-1])
        else:
            success_examples.append(
                MistakeDetectionExample(
                    sample_video_id,
                    step_id,
                    frames,
                    [time - min(times) for time in times],
                    action_description,
                    False
                )
            )        
            # pprint(success_examples[-1])

    if len(error_examples) >= 20 and len(success_examples) >= 20:
        print("Collected at least 20 positive and negative examples!")
        break
    else:
        print("Error examples:", len(error_examples))
        print("Success examples:", len(success_examples))

    sample_video.release()

  0%|          | 1/335 [00:00<04:40,  1.19it/s]

Error examples: 1
Success examples: 6


  1%|          | 2/335 [00:01<03:44,  1.48it/s]

Error examples: 3
Success examples: 11


  1%|          | 3/335 [00:02<04:00,  1.38it/s]

Error examples: 8
Success examples: 13


  1%|          | 4/335 [00:04<06:36,  1.20s/it]

Error examples: 8
Success examples: 24


  1%|▏         | 5/335 [00:04<04:48,  1.14it/s]

Error examples: 11
Success examples: 24


  2%|▏         | 6/335 [00:04<04:09,  1.32it/s]

Error examples: 12
Success examples: 31


  2%|▏         | 7/335 [00:05<04:16,  1.28it/s]

Error examples: 17
Success examples: 35


  2%|▏         | 8/335 [00:06<04:22,  1.25it/s]

Error examples: 17
Success examples: 41


  2%|▏         | 8/335 [00:07<05:09,  1.06it/s]

Collected at least 20 positive and negative examples!





## Model setup

## Step 1: VQG with LLaMA for Recipe Steps

Load model:

In [15]:
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM

LM_NAME = "meta-llama/Llama-2-7b-hf"
model = pipeline("text-generation", 
                 model=LM_NAME, 
                 token="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", 
                 model_kwargs={"load_in_8bit": True})

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [16]:
model.model.generation_config.top_p = None
model.model.generation_config.temperature = None
model.model.generation_config.do_sample = False
model.tokenizer.pad_token_id = model.model.config.eos_token_id
model.tokenizer.padding_side = "left"

Load recipe steps:

In [17]:
from pprint import pprint
import json

RECIPE_STEPS = json.load(open(os.path.join(ANNOTATIONS_DIR, "annotation_json/step_idx_description.json"), "r"))
RECIPE_STEPS = {int(k): "-".join(v.split("-")[1:]).strip() for k, v in RECIPE_STEPS.items()}

pprint(RECIPE_STEPS)

{1: 'Pour 1 egg into the ramekin cup',
 2: 'Place the egg from the cup over the lettuce',
 3: 'Coat a 6-oz. ramekin cup with cooking spray',
 4: 'Microwave the ramekin cup uncovered on high for 30 seconds',
 5: 'sprinkle 1 tablespoon of cheese on cup',
 6: 'Top cup with 1 tablespoon of salsa',
 7: 'replace the top of the English muffin',
 8: 'Continue to Microwave for 15-30 more seconds or until the egg is almost '
    'set',
 9: 'Line the bottom piece of the English muffin with lettuce',
 10: 'Microwave just until cheese melts, about 10 seconds',
 11: 'stir the ramekin cup',
 12: 'Cut the English muffin into two pieces with a knife',
 14: 'Peel 1 garlic clove',
 15: 'Pour the sauces over the meatballs',
 16: 'Cut 1/8 garlic clove',
 17: 'Peel one medium onion',
 18: 'Stir the contents in the microwave with a spoon',
 19: 'Slice 1/8 medium onion',
 20: 'Microwave the plate, covered, on high for 1.5 minutes',
 21: 'Place 5 meatballs in a Microwave-safe plate',
 22: 'Cut 1/4 medium carro

Generate success verification questions:

In [24]:
import torch
from transformers.pipelines.pt_utils import KeyDataset
import os
from dataclasses import dataclass, field
from dataclasses_json import dataclass_json
from enum import Enum
import pickle

class VQAResponse(Enum):
    No = 0
    Yes = 1

# TODO: add json serialization later with dataclass_json
@dataclass_json
@dataclass
class VQGOutputs:
    """Dataclass to hold all LM outputs from visual question generation (VQG)."""
    step_id: int
    step: str
    target_object: str
    state_description: str
    questions: list[str]
    answers_str: list[str]
    answers: list[VQAResponse] = field(default_factory=list)
    
    def __post_init__(self):
        """Validation steps to ensure every QA-pair is valid and every question has an answer."""
        for answer in self.answers_str:
            try: 
                self.answers.append(VQAResponse[answer])
            except:
                raise ValueError(f"Unrecognized VQA answer could not be accepted by VQAResponse class: {answer}")
            
        assert len(self.questions) == len(self.answers), "VQGOutputs received mismatched number of questions and answers."
        
        for question in self.questions:
            if not question.strip().endswith("?"):
                print(f"Warning: Question '{question}' doesn't appear to be a question.")
                      
            
USE_VQG_CACHE = False

if not USE_VQG_CACHE or not os.path.exists(os.path.join(DATA_CACHE_DIR, "vqg_outputs.json")):

    # TODO: source recipe steps from elsewhere; find a comprehensive set that covers variety of state changes?
    example1 = 'The recipe step is "Spoon the mixture from the bowl onto the bread". To visually verify that this step is complete, what are 2 questions we could ask about an image of a target object and their expected answers?\n' \
               'Target object: bread\n' \
               'Expected state: After this step, all of the bread should be covered in mixture.\n' \
               '1. Is there mixture on the bread? Yes\n' \
               '2. Is there any bread without mixture on top of it? No' \

    example2 = 'The recipe step is "Roll the tortilla into a thin, log shape about 1 inch thick. Make sure no filling leaks out.". To visually verify that this step is complete, what are 2 questions we could ask about an image of a target object and their expected answers?\n' \
               'Target object: tortilla\n' \
               'Expected state: After this step, the tortilla should be rolled tightly into a thin log with no filling leaking out.\n' \
               '1. Is the tortilla in a thin log shape? Yes\n' \
               '2. Is there any filling leaking out of the tortilla? No'

    example3 = 'The recipe step is "Fold the coffee filter into quarters". To visually verify that this step is complete, what are 2 questions we could ask about an image of a target object and their expected answers?\n' \
               'Target object: coffee filter\n' \
               'Expected state: After this step, the coffee filter should be folded into a quarter circle.\n' \
               '1. Is the coffee filter in a quarter circle? Yes\n' \
               '2. Is the coffee filter folded? Yes' \

    prompts = []
    with torch.no_grad():
        for step_id, step in RECIPE_STEPS.items():
            test = f'The recipe step is "{step}". To visually verify that this step is complete, what are 2 questions we could ask about an image of a target object and their expected answers?\n'
            prompt = "\n\n".join([example1, example2, example3, test])
            prompts.append({"step_id": step_id, "step": step, "prompt": prompt})

    vqg_outputs = {}
    prompt_idx = 0
    for out in tqdm(model(KeyDataset(prompts, "prompt"), 
                     batch_size=16, 
                     max_new_tokens=128, 
                     return_full_text=False, 
                     truncation="do_not_truncate"),
                   total=len(prompts)):
        inp = prompts[prompt_idx]

        step_id = int(inp['step_id'])
        step = inp['step']

        text = out[0]['generated_text']
        
        # print("===========================================================================")
        # print(text)
        text_fixed = text.replace("Љ", "").strip() # Hack: sometimes output from LLaMA 2 starts with Љ and whitespace characters
        
        # Parse reported target object and questions and answers
        try:
            target_object = text_fixed.split("\n")[0].split("Target object: ")[0].strip()
            state_description = text_fixed.split("\n")[1].replace("Expected state: After this step, ", "").strip()
            questions_answers = [(q_a.split("?")[0].strip() + "?", q_a.split("?")[1].strip()) for q_a in text_fixed.split("\n")[2:4]] # NOTE: only extract k=2 questions and answers; can adjust this as needed later
            questions = [q for q, _ in questions_answers]          
            answers = [a for _, a in questions_answers]
            output = VQGOutputs(step_id,
                                step,
                                target_object,
                                state_description,
                                questions,
                                answers)
        except:
            print(text_fixed)
            raise

        vqg_outputs[step_id] = output
        prompt_idx += 1

        # Early stopping for debugging
        if prompt_idx >= 20:
            break
else:
    with open(os.path.join(DATA_CACHE_DIR, "vqg_outputs.json"), "r") as f:
        vqg_outputs_json = f.read()
        vqg_outputs = VQGOutputs.schema().loads(vqg_outputs_json, many=True)
                      
for step_id, output in vqg_outputs.items():
    print(RECIPE_STEPS[step_id])
    pprint(output)
    print('===================')

  5%|▌         | 19/350 [00:47<13:56,  2.53s/it] 

Coat a 6-oz. ramekin cup with cooking spray
VQGOutputs(step_id=3,
           step='Coat a 6-oz. ramekin cup with cooking spray',
           target_object='',
           state_description='the ramekin cup should be coated with cooking '
                             'spray.',
           questions=['1. Is the ramekin cup coated with cooking spray?',
                      '2. Is the ramekin cup coated with cooking spray?'],
           answers_str=['Yes', 'Yes'],
           answers=[<VQAResponse.Yes: 1>, <VQAResponse.Yes: 1>])
Pour 1 egg into the ramekin cup
VQGOutputs(step_id=1,
           step='Pour 1 egg into the ramekin cup',
           target_object='',
           state_description='the ramekin cup should have one egg in it.',
           questions=['1. Is there an egg in the ramekin cup?',
                      '2. Is there more than one egg in the ramekin cup?'],
           answers_str=['Yes', 'No'],
           answers=[<VQAResponse.Yes: 1>, <VQAResponse.No: 0>])
Microwave the ramekin




In [11]:
# Sort the dicts
vqg_keys = sorted(vqg_outputs.keys())

vqg_outputs_new = {}
for key in vqg_keys:
    vqg_outputs_new[key] = vqg_outputs[key]
vqg_outputs = vqg_outputs_new

with open(os.path.join(DATA_CACHE_DIR, "vqg_outputs.json"), "w") as f:
    vqg_outputs_json = VQGOutputs.schema().dumps(list(vqg_outputs.values()), many=True)
    f.write(vqg_outputs_json)

## Step 2: VQA with LLaVA

Load model:

In [12]:
# Setup code grabbed from docs: https://huggingface.co/docs/transformers/model_doc/llava#transformers.LlavaForConditionalGeneration
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration

MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir=DATA_CACHE_DIR, load_in_8bit=True)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Ask success verification questions per frame:

(doesn't try to focus VLM on target object yet)

In [13]:
import random
from pprint import pprint
import torch

examples = success_examples[:20] + error_examples[:20]
# examples = examples[:1] # Just for debug purposes

prompt_template = "USER: <image>\n{question} (yes/no) ASSISTANT: "
NO_ID = processor.tokenizer("No", add_special_tokens=False)['input_ids'][0]
YES_ID = processor.tokenizer("Yes", add_special_tokens=False)['input_ids'][0]

vqa_outputs = []
with torch.no_grad():
    for example in tqdm(examples):
        this_vqa_outputs = []
        
        step_id = example.step_id
        
        questions = vqg_outputs[step_id].questions
        prompts = [prompt_template.format(question=question) for question in questions]
        expected_answers = vqg_outputs[step_id].answers
                           
        # TODO: make more efficient for full evaluation; will need to mess around with padding, ensure padding token is on correct side
        for frame in example.frames:
            for prompt, expected_answer in zip(prompts, expected_answers):
                inputs = processor(text=prompt, images=frame, return_tensors="pt").to(device)

                # Generate
                logits = model(**inputs).logits[0] # (seq length, vocab size)
                no_logit = logits[-1, NO_ID]
                yes_logit = logits[-1, YES_ID]
                probs = torch.softmax(torch.stack((no_logit, yes_logit), dim=0), dim=0).detach().cpu()
                
                # TODO: save confidences in a new VQAOutputs dataclass?
                if probs[0] < 0.5:
                    pred = "Yes"
                else:
                    pred = "No" 
                pred = VQAResponse[pred]
                
                this_vqa_outputs.append((frame, prompt, probs, pred, expected_answer))
                
        vqa_outputs.append(this_vqa_outputs)

100%|██████████| 40/40 [06:59<00:00, 10.49s/it]


## Step 3: Evaluate VQA Outputs

In [14]:
from collections import Counter
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

mistake_predictions = []
mistake_labels = []
for example, outputs in zip(examples, vqa_outputs):
    this_mistake_predictions = []
    for frame, prompt, probs, pred, expected_answer in outputs:
        if pred != expected_answer:
            predicted_mistake = True
        else:
            predicted_mistake = False
        this_mistake_predictions.append(predicted_mistake)
    
    mistake_labels.append(example.mistake)
    
    mistake_predictions.append(this_mistake_predictions)
    
# Heuristic: for last 10% of frames, take majority prediction of mistake/success
# In the future, can prompt LLaMA again for this information?
agg_preds = []
for mistake_pred, mistake_label in zip(mistake_predictions, mistake_labels):
    last_n = max(int(len(mistake_pred) * 0.1), 1) # Round up to 1
    mistake_pred = Counter(mistake_pred[-last_n:])
    mistake_pred, _ = mistake_pred.most_common()[0]
    
    agg_preds.append(mistake_pred)
    
metrics = {}
metrics['accuracy'] = accuracy_score(mistake_labels, agg_preds)
metrics['precision'] = precision_score(mistake_labels, agg_preds)
metrics['recall'] = recall_score(mistake_labels, agg_preds)
metrics['f1'] = f1_score(mistake_labels, agg_preds)
    
print("Metrics:")
pprint(metrics)

Metrics:
{'accuracy': 0.575,
 'f1': 0.6046511627906976,
 'precision': 0.5652173913043478,
 'recall': 0.65}
