## Systems Prep

In [2]:
import os
import json
import copy
from typing import Optional, List, Dict
from dataclasses import dataclass, field

import psutil
from torch.utils.data import Dataset
import transformers
from transformers.trainer_pt_utils import LabelSmoother
import torch
from PIL import Image

from conversation import get_conv_template

IGNORE_TOKEN_ID = LabelSmoother.ignore_index

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# print ram size

ram_gb = psutil.virtual_memory().total / 1e9
print(f"RAM: {ram_gb} GB")

Using device: cuda
RAM: 31.538241536 GB


In [4]:
RAW_VIDEO_DATA_PATH = '/home/saberwu2002/CS229-Project/local_data/MMTrail_processed/test/metas_video_convs.json'
RAW_AUDIO_DATA_PATH = '/home/saberwu2002/CS229-Project/local_data/MMTrail_processed/test/metas_audio_convs.json'

In [5]:
raw_video_data : List[dict] = json.load(open(RAW_VIDEO_DATA_PATH, 'r'))
print(f"Loaded {len(raw_video_data)} video data")
print(f"Example video data: {raw_video_data[0]}")

Loaded 80 video data
Example video data: {'id': '-2x2NMwBDzE', 'video_path': 'group_1/-2x2NMwBDzE.mp4', 'images_folder': 'images/-2x2NMwBDzE', 'audio_file': 'audios/-2x2NMwBDzE.wav', 'conversations': [{'from': 'human', 'value': "What is the video's main focus area?"}, {'from': 'assistant', 'value': 'A stunning woman dressed in traditional Indian attire stands against a colorful background, her neutral expression captivating the viewer as she poses gracefully, showcasing her beauty and poise.'}]}


In [6]:
raw_audio_data : List[dict] = json.load(open(RAW_AUDIO_DATA_PATH, 'r'))
print(f"Loaded {len(raw_audio_data)} audio data")
print(f"Example audio data: {raw_audio_data[0]}")

Loaded 80 audio data
Example audio data: {'id': '-2x2NMwBDzE', 'video_path': 'group_1/-2x2NMwBDzE.mp4', 'images_folder': 'images/-2x2NMwBDzE', 'audio_file': 'audios/-2x2NMwBDzE.wav', 'conversations': [{'from': 'human', 'value': 'What kind of imagery does the audio evoke?'}, {'from': 'assistant', 'value': 'The low quality recording features a wide harmonizing vocals singing over playback that consists of an acoustic rhythm guitar, groovy bass, shimmering shakers, punchy snare and kick hits. It sounds happy, joyful and fun - like something kids would listen to.'}]}


## Arguments

In [7]:
@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
    
@dataclass
class DataArguments:
    data_path: str = field(default=None,
                           metadata={"help": "Path to the training data."})

@dataclass
class TrainingArguments(transformers.TrainingArguments):
    pass

In [None]:
parser = transformers.HfArgumentParser(
    (ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses(
    args = [
        "--model_name_or_path", "/home/saberwu2002/CS229-Project/hf_ckp/vicuna-7b-v1.5",
        "--data_path", "/home/saberwu2002/CS229-Project/local_data/MMTrail_processed/test/metas_video_convs.json",
        "--output_dir", "/home/saberwu2002/CS229-Project/output/",
    ]
)

In [9]:
print(f"Model arguments: {model_args}\n")
print(f"Data arguments: {data_args}\n")
print(f"Training arguments: {training_args}")

Model arguments: ModelArguments(model_name_or_path='lmsys/vicuna-13b-v1.5')

Data arguments: DataArguments(data_path='/home/saberwu2002/CS229-Project/local_data/MMTrail_processed/test/metas_video_convs.json')

Training arguments: TrainingArguments(
_n_gpu=1,
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
batch_eval_metrics=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
dataloader_prefetch_factor=None,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=F

## Vicuna Tokenizer

In [10]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# Specify your desired path
model_path = "/home/saberwu2002/CS229-Project/hf_ckp/vicuna-7b-v1.5"

In [11]:
# load vicuna tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)

In [12]:
# model = AutoModelForCausalLM.from_pretrained(model_path)

## Dataset

In [13]:
IDX = 0
print(f"Example video data: {raw_video_data[IDX]}\n")

source = copy.deepcopy(raw_video_data[IDX]['conversations'])
print(f"Sources: {source}")

Example video data: {'id': '-2x2NMwBDzE', 'video_path': 'group_1/-2x2NMwBDzE.mp4', 'images_folder': 'images/-2x2NMwBDzE', 'audio_file': 'audios/-2x2NMwBDzE.wav', 'conversations': [{'from': 'human', 'value': "What is the video's main focus area?"}, {'from': 'assistant', 'value': 'A stunning woman dressed in traditional Indian attire stands against a colorful background, her neutral expression captivating the viewer as she poses gracefully, showcasing her beauty and poise.'}]}

Sources: [{'from': 'human', 'value': "What is the video's main focus area?"}, {'from': 'assistant', 'value': 'A stunning woman dressed in traditional Indian attire stands against a colorful background, her neutral expression captivating the viewer as she poses gracefully, showcasing her beauty and poise.'}]


In [14]:
conv = get_conv_template("vicuna_v1.1")
print(f"Conversation template: {conv}")

Conversation template: Conversation(name='vicuna_v1.1', system_template='{system_message}', system_message="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.", system_message_vision='', roles=('USER', 'ASSISTANT'), messages=[], offset=0, sep_style=<SeparatorStyle.ADD_COLON_TWO: 2>, sep=' ', sep2='</s>', stop_str=None, stop_token_ids=None, max_image_size_mb=None)


In [15]:
roles = {"human": conv.roles[0], "assistant": conv.roles[1]}
# Apply prompt templates
conversations = []
if roles[source[0]["from"]] != conv.roles[0]:
    # Skip the first one if it's not from human
    source = source[1:]
    
conv.messages = []
for j, sentence in enumerate(source):
    role = roles[sentence["from"]]
    assert role == conv.roles[j % 2], f"Role mismatch"
    conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
        
print(f"Conversations: {conversations}")

Conversations: ["A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: What is the video's main focus area? ASSISTANT: A stunning woman dressed in traditional Indian attire stands against a colorful background, her neutral expression captivating the viewer as she poses gracefully, showcasing her beauty and poise.</s>"]


In [16]:
# Tokenize conversations
input_ids = tokenizer(
    conversations,
    return_tensors="pt",
    padding="max_length",
    max_length=tokenizer.model_max_length,
    truncation=True
).input_ids
targets = input_ids.clone()

print(f"Input IDs: {input_ids}")
print(f"Targets: {targets}")

Input IDs: tensor([[    1,   319, 13563,  ...,     0,     0,     0]])
Targets: tensor([[    1,   319, 13563,  ...,     0,     0,     0]])


In [19]:
# Mask targets. Only compute loss on the assistant outputs.
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversations, targets):
    total_len = int(target.ne(tokenizer.pad_token_id).sum())

    turns = conversation.split(conv.sep2)
    cur_len = 1
    target[:cur_len] = IGNORE_TOKEN_ID
    for i, turn in enumerate(turns):
        if turn == "":
            break
        turn_len = len(tokenizer(turn).input_ids)

        parts = turn.split(sep)
        if len(parts) != 2:
            break
        parts[0] += sep
        # "-2" is hardcoded for the Llama tokenizer to make the offset correct.
        instruction_len = len(tokenizer(parts[0]).input_ids) - 2

        if i != 0 and not tokenizer.legacy:
            # The legacy and non-legacy modes handle special tokens differently
            instruction_len -= 1

        # Ignore the user instructions
        target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
        cur_len += turn_len

        if i != 0 and not tokenizer.legacy:
            # The legacy and non-legacy modes handle special tokens differently
            cur_len -= 1

    target[cur_len:] = IGNORE_TOKEN_ID

    if True:  # Inspect and check the correctness of masking
        z = target.clone()
        z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
        print(tokenizer.decode(z))
        # exit()

    if cur_len < tokenizer.model_max_length:
        if cur_len != total_len:
            target[:] = IGNORE_TOKEN_ID
            print(
                f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
                f" #turn = {len(turns) - 1}. (ignored)"
            )

<unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk>