In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import torch
import torchaudio

from snac import SNAC
from transformers import AutoTokenizer

from src import MotionDataset
from src import TokenizerModule
from src.full_dataset import SNACMotionTextDataset

from IPython.display import Audio

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
audio_tokenizer = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().cuda()
motion_tokenizer = TokenizerModule.from_pretrained("InternalCan/tokenizer_module")
text_tokenizer = AutoTokenizer.from_pretrained("canopylabs/orpheus-3b-0.1-pretrained")

In [None]:
full_ds = SNACMotionTextDataset("dataset", split="val", val_split=0.1, seed=2, compute_stats=False, device='cuda')

In [5]:
sample = full_ds[0]

In [16]:
sample = full_ds[134]

audio = sample['audio']
motion = full_ds.resample_item(sample['motion'], 24)

with torch.no_grad():
    motion_features = motion_tokenizer.sample_to_features(motion)
    motion_tokens = motion_tokenizer.features_to_codes(motion_features)
    audio_tokens = audio_tokenizer.encode(audio.unsqueeze(0).to('cuda'))
    text_tokens = text_tokenizer.encode(sample['text'], add_special_tokens=True)

In [24]:
# Test the functions
frame_tokens = merge_tokens(audio_tokens, motion_tokens)

example = {
    "speech_motion_tokens": frame_tokens,
    "text_tokens": text_tokens
}

In [21]:
tokeniser_length = 128256
start_of_text = 128000
end_of_text = 128009

start_of_speech = tokeniser_length + 1
end_of_speech = tokeniser_length + 2

start_of_human = tokeniser_length + 3
end_of_human = tokeniser_length + 4

start_of_ai = tokeniser_length + 5
end_of_ai =  tokeniser_length + 6
pad_token = tokeniser_length + 7

In [9]:
def create_input_ids(example):
    text_tokens = example['text_tokens']
    speech_motion_tokens = example['speech_motion_tokens'].flatten().tolist()

    text_tokens.append(end_of_text)

    input_ids = (
        [start_of_human]
        + text_tokens
        + [end_of_human]
        + [start_of_ai]
        + [start_of_speech]
        + speech_motion_tokens
        + [end_of_speech]
        + [end_of_ai]
    )

    output = {
        "input_ids": input_ids,
        "labels": input_ids,
        "attention_mask": [1] * len(input_ids)
    }
    
    return output


def parse_input_ids(input_ids):
    """
    Parse input_ids back into text tokens and speech_motion tokens.
    
    Args:
        input_ids: List of token IDs created by create_input_ids
        
    Returns:
        dict: Dictionary containing 'text_tokens' and 'speech_motion_tokens'
    """
    # Find positions of special tokens
    try:
        start_human_idx = input_ids.index(start_of_human)
        end_human_idx = input_ids.index(end_of_human)
        start_ai_idx = input_ids.index(start_of_ai)
        start_speech_idx = input_ids.index(start_of_speech)
        end_speech_idx = input_ids.index(end_of_speech)
        end_ai_idx = input_ids.index(end_of_ai)
    except ValueError as e:
        raise ValueError(f"Missing expected special tokens in input_ids: {e}")
    
    # Extract text tokens (between start_of_human and end_of_human, excluding end_of_text)
    text_tokens = input_ids[start_human_idx + 1:end_human_idx]
    if text_tokens and text_tokens[-1] == end_of_text:
        text_tokens = text_tokens[:-1]  # Remove end_of_text token
    
    # Extract speech motion tokens (between start_of_speech and end_of_speech)
    speech_motion_tokens = input_ids[start_speech_idx + 1:end_speech_idx]
    
    return {
        'text_tokens': text_tokens,
        'speech_motion_tokens': speech_motion_tokens
    }

In [10]:
input_ids = create_input_ids(example)

NameError: name 'example' is not defined

In [12]:
from datasets import Dataset
import torch
from pathlib import Path
import pickle
from tqdm import tqdm

# Load all .pt files from the tokens directory
token_dir = Path("dataset/tokens")
pt_files = list(token_dir.glob("*.pt"))

# Collect all input_ids from the .pt files
all_input_ids = []
all_metadata = []

for pt_file in tqdm(pt_files):
    try:
        data = torch.load(pt_file)
        if 'input_ids' in data:
            all_input_ids.append(data['input_ids'])
            # Store metadata if available
            metadata = data.get('metadata', {})
            metadata['file_path'] = str(pt_file)
            all_metadata.append(metadata)
    except Exception as e:
        print(f"Error loading {pt_file}: {e}")

# Create HuggingFace dataset
dataset_dict = {
    'input_ids': all_input_ids,
    'metadata': all_metadata
}

hf_dataset = Dataset.from_dict(dataset_dict)
print(f"Created dataset with {len(hf_dataset)} samples")

# Save dataset locally
hf_dataset.save_to_disk("dataset/hf_dataset")

# Push to HuggingFace Hub (uncomment and set your repo name)

100%|██████████| 19988/19988 [01:22<00:00, 241.35it/s]


Created dataset with 19988 samples


Saving the dataset (1/1 shards): 100%|██████████| 19988/19988 [00:00<00:00, 69379.89 examples/s] 


In [13]:
hf_dataset.push_to_hub("InternalCan/snac_motion_text_dataset")

Creating parquet from Arrow format: 100%|██████████| 20/20 [00:00<00:00, 25.87ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:03<00:00,  3.11s/it]


CommitInfo(commit_url='https://huggingface.co/datasets/InternalCan/snac_motion_text_dataset/commit/3aa604478f1daa2165c0e15fc6d6612831b0acfd', commit_message='Upload dataset', commit_description='', oid='3aa604478f1daa2165c0e15fc6d6612831b0acfd', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/InternalCan/snac_motion_text_dataset', endpoint='https://huggingface.co', repo_type='dataset', repo_id='InternalCan/snac_motion_text_dataset'), pr_revision=None, pr_num=None)