In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from snac import SNAC
from src.modules.fsq_vqvae import FSQVAE
from src.modules.vqvae import VQVae
from pathlib import Path
import yaml
from train_tokenizer import VQVAEModule
from src.dataset import Dataset

model = SNAC.from_pretrained("hubertsiuzdak/snac_32khz").eval().cuda()
# audio = torch.randn(1, 1, 32000).cuda()  # placeholder for actual audio with shape (B, 1, T)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# load audio from .wav file
import torchaudio

# Load and preprocess audio
waveform, sample_rate = torchaudio.load("dataset/audio/_a-pB_5eRt0_7.wav")
# Convert to mono if stereo
if waveform.shape[0] > 1:
    waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample to 32kHz if needed
if sample_rate != 24000:
    resampler = torchaudio.transforms.Resample(sample_rate, 24000)
    waveform = resampler(waveform)
# Add batch dimension and move to GPU
audio = waveform.unsqueeze(0).cuda()

with torch.inference_mode():
    codes = model.encode(audio)
    audio_hat = model.decode(codes)

In [4]:
def load_vqvae(model_path: Path) -> VQVae:
    """
    Load and prepare a VQVAE model from a checkpoint file.
    
    Args:
        model_path (Path): Path to the model checkpoint file
        
    Returns:
        VQVae: Prepared VQVAE model loaded on CUDA and in eval mode
    """
    pretrained = torch.load(model_path)
    model_dir = model_path.parent.parent

    if model_path.suffix == '.ckpt':
        config_path = model_dir / 'wandb' / 'latest-run' / 'files' / 'config.yaml'
    else:
        config_path = model_path.parent / 'wandb' / 'latest-run' / 'files' / 'config.yaml'

    config = yaml.safe_load(open(config_path, "r"))
    feats_enabled = config['feats_enabled']['value']

    print([feat for feat in sorted(feats_enabled) if feats_enabled[feat]['enabled']])
    
    if model_path.suffix == '.ckpt':
        vqvae_module = VQVAEModule(vqvae_config=config["vqvae"]['value'], losses_config=config["losses"])
        vqvae_module.load_state_dict(pretrained['state_dict'])
        vqvae = vqvae_module.vqvae
    else:
        vqvae = VQVae(**config["vqvae"]['value'])
        vqvae.load_state_dict(pretrained)

    vqvae.to("cuda")
    vqvae.eval()
    
    return vqvae, feats_enabled

In [5]:
rest_fsq, rest_feats = load_vqvae(Path("outputs/rest_fsq_D4/checkpoints/checkpoint_epoch=1059.ckpt"))
rest_fsq: FSQVAE = rest_fsq

['c_eyes_lst', 'c_lip_lst', 'kp', 't', 'x_s']
FSQ Config:
Levels: [4, 4, 4, 4]
Output Emb Width: 4
Num Quantizers: 1
Using FSQ
Codebook size: 256


In [6]:
ds = Dataset("dataset", split="eval", compute_stats=False)

Loading precomputed statistics from dataset/stats_all.pkl
Loaded feature-wise statistics successfully
Loaded 4090 eval samples


In [7]:
def prepare_features(sample, feats_enabled, only_lips, device="cuda"):
    """
    Prepare features from a sample by processing enabled features and concatenating them.
    
    Args:
        sample (dict): Dictionary containing the sample data
        feats_enabled (dict): Dictionary of enabled features and their metadata
        device (str): Device to move tensors to (default: "cuda")
    
    Returns:
        tuple: (features tensor, dimensions dictionary)
    """
    frames = sample['kp'].shape[0]
    fps = sample['metadata']['output_fps']
    seq_len = min(sample['kp'].shape[0], 300)

    # Initialize an empty tensor list to collect features
    feature_tensors = []
    dims = {}

    for feat, metadata in feats_enabled.items():
        is_enabled = metadata['enabled']
        if is_enabled:
            print(f"Using {feat}")
            if feat in ["exp", "exp_velocity"]:
                if only_lips:
                    feature = sample[feat][:seq_len, :, 15:, :].reshape(1, seq_len, -1)
                else:
                    feature = sample[feat][:seq_len, :, :15, :].reshape(1, seq_len, -1)
            else:
                feature = sample[feat][:seq_len, ...].reshape(1, seq_len, -1)
            
            dims[feat] = feature.shape[-1]
            feature_tensors.append(feature)

    # Concatenate all enabled features
    if feature_tensors:
        features = torch.concat(feature_tensors, dim=2)
    else:
        # Create an empty tensor if no features are enabled
        features = torch.empty((1, seq_len, 0))

    features = features.to(device)
    print("dims: ", dims)
    print("Total dims: ", features.shape[-1])
    
    return features, dims

In [8]:
sample = ds[3364] 

pickle_path = sample['metadata']['pickle_path']
vid_id = pickle_path.split("/")[-1].split(".")[0]
vid_path = f"dataset/train/{vid_id}.mp4"

rest_features, rest_dims = prepare_features(sample, rest_feats, only_lips=False)

Using c_eyes_lst
Using c_lip_lst
Using kp
Using t
Using x_s
dims:  {'c_eyes_lst': 2, 'c_lip_lst': 1, 'kp': 63, 't': 3, 'x_s': 63}
Total dims:  132


In [9]:
sample['metadata']

{'pickle_path': 'dataset/pickles/VTKRMo7HYkY_0.pkl',
 'n_frames': 100,
 'output_fps': 23.98}

In [10]:
indices = rest_fsq.encode(rest_features)

Indices shape: torch.Size([1, 98, 1])


In [12]:
indices

tensor([[[ 10],
         [ 91],
         [182],
         [119],
         [162],
         [ 81],
         [ 15],
         [255],
         [241],
         [243],
         [227],
         [171],
         [207],
         [207],
         [250],
         [242],
         [163],
         [231],
         [158],
         [ 95],
         [251],
         [242],
         [167],
         [163],
         [ 86],
         [ 79],
         [155],
         [183],
         [167],
         [146],
         [ 86],
         [ 91],
         [183],
         [150],
         [211],
         [178],
         [ 27],
         [155],
         [182],
         [227],
         [162],
         [ 98],
         [ 95],
         [187],
         [166],
         [215],
         [ 97],
         [ 38],
         [ 75],
         [ 91],
         [115],
         [167],
         [101],
         [ 23],
         [ 75],
         [103],
         [119],
         [150],
         [161],
         [102],
         [ 87],
         [103],
        

In [11]:
rest_fsq.quantizer.indices_to_codes(indices).shape

torch.Size([1, 4, 98, 1])