In [1]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import torch

from collections import Counter
from tqdm.auto import tqdm

# Import your data loading utilities and model
from mGPT.data.build_data import build_data
from mGPT.models.build_model import build_model
from mGPT.archs.mgpt_vq import VQVae
from mGPT.config import get_module_config
from omegaconf import OmegaConf

### Load configs

In [2]:
# Load config files in the same way as parse_args()
OmegaConf.register_new_resolver("eval", eval)

In [3]:
cfg_assets = OmegaConf.load('./configs/assets.yaml')
cfg_base = OmegaConf.load(os.path.join(cfg_assets.CONFIG_FOLDER, 'default.yaml'))
cfg_exp = OmegaConf.merge(cfg_base, OmegaConf.load('configs/codebook_experiments/config_h3d_stage1.yaml'))

# Load module configs if not full config
if not cfg_exp.FULL_CONFIG:
    print("Loading full config...")
    cfg_exp = get_module_config(cfg_exp, cfg_assets.CONFIG_FOLDER)

# Merge with assets config which contains the dataset paths
cfg = OmegaConf.merge(cfg_exp, cfg_assets)

# Override some config values for testing
cfg.TRAIN.BATCH_SIZE = 32
cfg.TRAIN.NUM_WORKERS = 2
cfg.DEBUG = False
cfg.DEVICE = [0]

Loading full config...


In [4]:
# Load from checkpoint
cfg.TRAIN.PRETRAINED = 'experiments/mgpt/Codebook_VQVAE_Usage/checkpoints/epoch=9.ckpt'

In [5]:

# Initialize data and model
datamodule = build_data(cfg)
datamodule.setup('fit')  # Prepare the data

mGPT.data.HumanML3D HumanML3DDataModule


Output()

Pointer Pointing at 0


Output()

Pointer Pointing at 0


Pointer Pointing at 0


In [6]:
vqvae = VQVae(
    nfeats=263,
    code_num=512,
    code_dim=512,
    output_emb_width=512,
    down_t=2,
    stride_t=2,
    width=512,
    depth=3,
    dilation_growth_rate=3,
    activation='relu'
)

# Load pretrained weights with proper key matching
# Load pretrained weights with proper key matching
if cfg.TRAIN.PRETRAINED:
    print("Loading pretrained weights from: ", cfg.TRAIN.PRETRAINED)
    state_dict = torch.load(cfg.TRAIN.PRETRAINED, map_location='cpu')['state_dict']
    
    # Debug: Print original keys
    print("\nOriginal keys containing 'codebook':")
    codebook_keys = [k for k in state_dict.keys() if 'codebook' in k]
    print(codebook_keys)
    
    # Create a new state dict with correct keys
    new_state_dict = {}
    for k, v in state_dict.items():
        # Handle both 'motion_vae.' and 'vae.' prefixes
        if k.startswith('motion_vae.'):
            new_key = k.replace('motion_vae.', '')
            new_state_dict[new_key] = v
        elif k.startswith('vae.'):
            new_key = k.replace('vae.', '')
            new_state_dict[new_key] = v
    
    # Debug: Print new keys
    print("\nMapped keys containing 'codebook':")
    new_codebook_keys = [k for k in new_state_dict.keys() if 'codebook' in k]
    print(new_codebook_keys)
    
    # Try loading with strict=False first to see what matches
    incompatible_keys = vqvae.load_state_dict(new_state_dict, strict=False)
    vqvae.to('cuda')
    vqvae.training = False
    vqvae.quantizer.training = False

Loading pretrained weights from:  experiments/mgpt/Codebook_VQVAE_Usage/checkpoints/epoch=9.ckpt

Original keys containing 'codebook':
['vae.quantizer.codebook']

Mapped keys containing 'codebook':
['quantizer.codebook']


In [41]:

# Get a small subset of data
train_loader = datamodule.train_dataloader()
eval_batch = next(iter(train_loader))  # Get just one batch
eval_batch['motion'] = eval_batch['motion'].to('cuda')

In [43]:
vqvae(eval_batch['motion'])

(tensor([[[-9.8387e-03, -3.3771e-01, -5.2731e+00,  ...,  4.2882e+00,
            5.2595e+00,  4.5748e+00],
          [ 4.9087e-02, -1.7745e-01, -9.3925e+00,  ...,  7.8541e+00,
            8.6929e+00,  7.9850e+00],
          [-9.1842e-02, -7.2390e-02, -1.1847e+01,  ...,  9.7935e+00,
            1.0588e+01,  9.6801e+00],
          ...,
          [-6.2316e-02, -6.0175e-02, -1.1515e+01,  ...,  9.8404e+00,
            1.0355e+01,  9.3999e+00],
          [ 2.6104e-02,  6.5609e-02, -9.0484e+00,  ...,  7.9223e+00,
            8.3383e+00,  7.6678e+00],
          [-9.7985e-02,  2.8476e-01, -5.1183e+00,  ...,  4.7841e+00,
            4.5934e+00,  4.3042e+00]],
 
         [[ 1.5146e-02, -3.5860e-01, -5.0380e+00,  ...,  4.2354e+00,
            5.4761e+00,  4.7863e+00],
          [ 6.6775e-02, -2.4182e-01, -9.0430e+00,  ...,  7.7067e+00,
            9.1282e+00,  8.3848e+00],
          [-5.6464e-02, -1.7559e-01, -1.1417e+01,  ...,  9.6067e+00,
            1.1170e+01,  1.0240e+01],
          ...,
    

In [8]:
seen_tokens = set()

In [37]:
with torch.no_grad():
    batches = 0
    for batch in train_loader:
        batches += 1
        codes, _ = vqvae.encode(batch['motion'].to('cuda'))
        seen_tokens.update(codes.flatten().cpu().numpy())
        # codes = vqvae(batch['motion'].to('cuda'))
print(batches)
print(len(seen_tokens)/512*100)

628
63.671875


In [38]:
# 
with torch.no_grad():
    for batch in train_loader:
        x_r, loss, perplexity = vqvae(batch['motion'].to('cuda'))

In [39]:
vqvae.quantizer.get_token_usage_stats()

{'val/unique_tokens': 339,
 'val/total_tokens': 1296704,
 'val/codebook_usage_percent': 66.2109375}

In [None]:
print(len(seen_tokens))

In [14]:
stats = vqvae.quantizer.get_token_usage_stats()


In [None]:
freqs = stats['val/token_frequencies'].cpu().numpy()

threshold = 600
mean_freq = np.mean(freqs)
freqs = freqs[freqs > threshold]

print("mean freq: ", mean_freq)
print(f"Usage percentage: {(len(freqs) / 512)*100: .2f}%")

plt.figure(figsize=(15, 5))
plt.bar(range(len(freqs)), freqs[freqs > threshold])
plt.title('Codebook Usage Distribution')
plt.xlabel('Codebook Index (sorted by usage)')
plt.ylabel('Usage Count')
plt.show()


In [None]:
# Function to analyze codebook usage

with torch.no_grad():
    # Forward pass
    print("Input shape: ", eval_batch['motion'].shape)
    x_out, loss, perplexity = vqvae(eval_batch['motion'])


In [None]:
perplexity