In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from snac import SNAC
from src.modules.fsq_vqvae import FSQVAE
from src.modules.vqvae import VQVae
from pathlib import Path
import torchaudio
import yaml
from train_tokenizer import VQVAEModule
from src.dataset import Dataset

model = SNAC.from_pretrained("hubertsiuzdak/snac_32khz").eval().cuda()
# audio = torch.randn(1, 1, 32000).cuda()  # placeholder for actual audio with shape (B, 1, T)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# load audio from .wav file
# Load and preprocess audio
waveform, sample_rate = torchaudio.load("dataset/audio/_a-pB_5eRt0_7.wav")
# Convert to mono if stereo
if waveform.shape[0] > 1:
    waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample to 32kHz if needed
if sample_rate != 24000:
    resampler = torchaudio.transforms.Resample(sample_rate, 24000)
    waveform = resampler(waveform)
# Add batch dimension and move to GPU
audio = waveform.unsqueeze(0).cuda()

with torch.inference_mode():
    codes = model.encode(audio)
    audio_hat = model.decode(codes)

In [17]:
def load_vqvae(model_path: Path) -> VQVae:
    """
    Load and prepare a VQVAE model from a checkpoint file.
    
    Args:
        model_path (Path): Path to the model checkpoint file
        
    Returns:
        VQVae: Prepared VQVAE model loaded on CUDA and in eval mode
    """
    pretrained = torch.load(model_path)
    model_dir = model_path.parent.parent

    if model_path.suffix == '.ckpt':
        config_path = model_dir / 'wandb' / 'latest-run' / 'files' / 'config.yaml'
    else:
        config_path = model_path.parent / 'wandb' / 'latest-run' / 'files' / 'config.yaml'

    config = yaml.safe_load(open(config_path, "r"))
    feats_enabled = config['feats_enabled']['value']

    print([feat for feat in sorted(feats_enabled) if feats_enabled[feat]['enabled']])
    
    if model_path.suffix == '.ckpt':
        vqvae_module = VQVAEModule(vqvae_config=config["vqvae"]['value'], losses_config=config["losses"])
        vqvae_module.load_state_dict(pretrained['state_dict'])
        vqvae = vqvae_module.vqvae
    else:
        vqvae = VQVae(**config["vqvae"]['value'])
        vqvae.load_state_dict(pretrained)

    vqvae.to("cuda")
    vqvae.eval()
    
    return vqvae, feats_enabled

In [139]:
rest_fsq, rest_feats = load_vqvae(Path("outputs/rest_fsq_D4/checkpoints/checkpoint_epoch=1059.ckpt"))

['c_eyes_lst', 'c_lip_lst', 'kp', 't', 'x_s']
FSQ Config:
Levels: [4, 4, 4, 4]
Output Emb Width: 4
Num Quantizers: 1
Using FSQ
Codebook size: 256


RuntimeError: Error(s) in loading state_dict for VQVAEModule:
	size mismatch for vqvae.encoder.branch1.2.0.weight: copying a param with shape torch.Size([512, 512, 2]) from checkpoint, the shape in current model is torch.Size([512, 512, 3]).
	size mismatch for vqvae.encoder.branch1.3.0.weight: copying a param with shape torch.Size([512, 512, 2]) from checkpoint, the shape in current model is torch.Size([512, 512, 3]).

In [13]:
ds = Dataset("dataset", split="eval", compute_stats=False)

Loading precomputed statistics from dataset/stats_all.pkl
Loaded feature-wise statistics successfully
Loaded 4090 eval samples


In [14]:
def prepare_features(sample, feats_enabled, only_lips, device="cuda"):
    """
    Prepare features from a sample by processing enabled features and concatenating them.
    
    Args:
        sample (dict): Dictionary containing the sample data
        feats_enabled (dict): Dictionary of enabled features and their metadata
        device (str): Device to move tensors to (default: "cuda")
    
    Returns:
        tuple: (features tensor, dimensions dictionary)
    """
    frames = sample['kp'].shape[0]
    fps = sample['metadata']['output_fps']
    seq_len = min(sample['kp'].shape[0], 300)

    # Initialize an empty tensor list to collect features
    feature_tensors = []
    dims = {}

    for feat, metadata in feats_enabled.items():
        is_enabled = metadata['enabled']
        if is_enabled:
            print(f"Using {feat}")
            if feat in ["exp", "exp_velocity"]:
                if only_lips:
                    feature = sample[feat][:seq_len, :, 15:, :].reshape(1, seq_len, -1)
                else:
                    feature = sample[feat][:seq_len, :, :15, :].reshape(1, seq_len, -1)
            else:
                feature = sample[feat][:seq_len, ...].reshape(1, seq_len, -1)
            
            dims[feat] = feature.shape[-1]
            feature_tensors.append(feature)

    # Concatenate all enabled features
    if feature_tensors:
        features = torch.concat(feature_tensors, dim=2)
    else:
        # Create an empty tensor if no features are enabled
        features = torch.empty((1, seq_len, 0))

    features = features.to(device)
    print("dims: ", dims)
    print("Total dims: ", features.shape[-1])
    
    return features, dims

In [15]:
sample = ds[3364] 

pickle_path = sample['metadata']['pickle_path']
vid_id = pickle_path.split("/")[-1].split(".")[0]
vid_path = f"dataset/train/{vid_id}.mp4"

rest_features, rest_dims = prepare_features(sample, rest_feats, only_lips=False)

Using c_eyes_lst
Using c_lip_lst
Using kp
Using t
Using x_s
dims:  {'c_eyes_lst': 2, 'c_lip_lst': 1, 'kp': 63, 't': 3, 'x_s': 63}
Total dims:  132


In [45]:
z_hat, indices = rest_fsq.encode(rest_features)

In [51]:
codes = rest_fsq.quantizer.indices_to_codes(indices)

In [98]:
config = yaml.safe_load(open('configs/fsq_config.yaml'))

In [133]:
test = VQVAEModule(vqvae_config=config["vqvae"], losses_config=config["losses"])

FSQ Config:
Levels: [4, 4, 4, 4, 4, 4, 4, 4]
Output Emb Width: 8
Num Quantizers: 1
Using FSQ
Codebook size: 65536


In [134]:
fsq = test.vqvae.to("cuda")

In [135]:
indices = fsq.encode(rest_features)

In [136]:
codes = fsq.quantizer.indices_to_codes(indices)

In [137]:
codes = fsq.preprocess(codes)

In [138]:
fsq.decoder(codes).shape

torch.Size([1, 132, 100])

In [None]:
ds = Dataset("dataset", split="eval", compute_stats=False)

Loading precomputed statistics from dataset/stats_all.pkl
Loaded feature-wise statistics successfully
Loaded 4090 eval samples


In [None]:
def prepare_features(sample, feats_enabled, only_lips, device="cuda"):
    """
    Prepare features from a sample by processing enabled features and concatenating them.
    
    Args:
        sample (dict): Dictionary containing the sample data
        feats_enabled (dict): Dictionary of enabled features and their metadata
        device (str): Device to move tensors to (default: "cuda")
    
    Returns:
        tuple: (features tensor, dimensions dictionary)
    """
    frames = sample['kp'].shape[0]
    fps = sample['metadata']['output_fps']
    seq_len = min(sample['kp'].shape[0], 300)

    # Initialize an empty tensor list to collect features
    feature_tensors = []
    dims = {}

    for feat, metadata in feats_enabled.items():
        is_enabled = metadata['enabled']
        if is_enabled:
            print(f"Using {feat}")
            if feat in ["exp", "exp_velocity"]:
                if only_lips:
                    feature = sample[feat][:seq_len, :, 15:, :].reshape(1, seq_len, -1)
                else:
                    feature = sample[feat][:seq_len, :, :15, :].reshape(1, seq_len, -1)
            else:
                feature = sample[feat][:seq_len, ...].reshape(1, seq_len, -1)
            
            dims[feat] = feature.shape[-1]
            feature_tensors.append(feature)

    # Concatenate all enabled features
    if feature_tensors:
        features = torch.concat(feature_tensors, dim=2)
    else:
        # Create an empty tensor if no features are enabled
        features = torch.empty((1, seq_len, 0))

    features = features.to(device)
    print("dims: ", dims)
    print("Total dims: ", features.shape[-1])
    
    return features, dims

In [None]:
sample = ds[3364] 

pickle_path = sample['metadata']['pickle_path']
vid_id = pickle_path.split("/")[-1].split(".")[0]
vid_path = f"dataset/train/{vid_id}.mp4"

rest_features, rest_dims = prepare_features(sample, rest_feats, only_lips=False)

Using c_eyes_lst
Using c_lip_lst
Using kp
Using t
Using x_s
dims:  {'c_eyes_lst': 2, 'c_lip_lst': 1, 'kp': 63, 't': 3, 'x_s': 63}
Total dims:  132


In [None]:
fsq_index_ranges = {
    'rest': (4, 4),
    'rot_scale': (4, 4),
    'exp': (4, 10),
    'lip': (4, 10)
}

def calculate_fsq_ranges(fsq_configs):
    """
    Calculate codebook size and value ranges for each FSQ model.
    
    Args:
        fsq_configs: Dict with FSQ names as keys and (L, D) tuples as values
                    where L is levels and D is dimensions
    
    Returns:
        Dict with FSQ names as keys and (start_idx, end_idx) tuples as values
    """
    ranges = {}
    current_start = 0
    
    for fsq_name, (L, D) in fsq_configs.items():
        codebook_size = L ** D
        end_idx = current_start + codebook_size
        ranges[fsq_name] = (current_start, end_idx)
        current_start = end_idx
    
    return ranges

def map_code_to_fsq(code_index, fsq_ranges):
    """
    Map a global code index to the correct FSQ and shift it to its local range.
    
    Args:
        code_index: Global code index
        fsq_ranges: Dict with FSQ names as keys and (start_idx, end_idx) tuples as values
    
    Returns:
        Tuple of (fsq_name, local_code_index)
    """
    for fsq_name, (start_idx, end_idx) in fsq_ranges.items():
        if start_idx <= code_index < end_idx:
            local_code_index = code_index - start_idx
            return fsq_name, local_code_index
    
    raise ValueError(f"Code index {code_index} is out of range for all FSQ models")

# Calculate the actual ranges
fsq_ranges = calculate_fsq_ranges(fsq_index_ranges)
print("FSQ Ranges:")
for name, (start, end) in fsq_ranges.items():
    codebook_size = end - start
    print(f"{name}: {start} - {end} (size: {codebook_size})")

# Test the mapping function
print("\nTesting code mapping:")
test_codes = [0, 255, 256, 511, 512, 1000000, 1500000]
for code in test_codes:
    try:
        fsq_name, local_code = map_code_to_fsq(code, fsq_ranges)
        print(f"Code {code} -> {fsq_name} (local: {local_code})")
    except ValueError as e:
        print(f"Code {code} -> {e}")


FSQ Ranges:
rest: 0 - 256 (size: 256)
rot_scale: 256 - 512 (size: 256)
exp: 512 - 1049088 (size: 1048576)
lip: 1049088 - 2097664 (size: 1048576)

Testing code mapping:
Code 0 -> rest (local: 0)
Code 255 -> rest (local: 255)
Code 256 -> rot_scale (local: 0)
Code 511 -> rot_scale (local: 255)
Code 512 -> exp (local: 0)
Code 1000000 -> exp (local: 999488)
Code 1500000 -> lip (local: 450912)
