In [1]:
# UNCOMMENT TO ADD PATH TO SAVE/CACHE TRANSFORMER MODELS
# import os
# os.environ['TRANSFORMERS_CACHE'] = "/jagupard28/scr0/xiluo-speech/multi-quantizer-experiments/hf/checkpoints/"

In [4]:
from transformers import AutoModel
import torchaudio
from torchaudio.models.wav2vec2.utils import import_huggingface_model

hf_model = AutoModel.from_pretrained('facebook/hubert-large-ls960-ft').to('cuda')
# hf_model = AutoModel.from_pretrained('facebook/hubert-base-ls960').to('cuda')

assert hf_model.__class__.__name__ in {"Wav2Vec2Model", "HubertModel"}

teacher_model = import_huggingface_model(hf_model).eval().to('cuda')

The model is not an instance of Wav2Vec2ForCTC. "lm_head" module is not imported.


In [5]:
teacher_model

Wav2Vec2Model(
  (feature_extractor): FeatureExtractor(
    (conv_layers): ModuleList(
      (0): ConvLayerBlock(
        (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
      )
      (1): ConvLayerBlock(
        (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
      )
      (2): ConvLayerBlock(
        (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
      )
      (3): ConvLayerBlock(
        (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
      )
      (4): ConvLayerBlock(
        (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
      )
      (5): ConvLayerBlock(
        (lay

In [7]:
from lhotse.recipes import download_librispeech, prepare_librispeech

download_librispeech(dataset_parts="mini_librispeech")
libri = prepare_librispeech(corpus_dir="LibriSpeech", output_dir="data/")

Dataset parts:   0%|                                      | 0/2 [00:00<?, ?it/s]
Distributing tasks: 0it [00:00, ?it/s][A
                                      [A
Processing:   0%|                                      | 0/1089 [00:00<?, ?it/s][A
Processing:  42%|███████████               | 461/1089 [00:00<00:00, 4604.04it/s][A
Dataset parts:  50%|███████████████               | 1/2 [00:00<00:00,  3.15it/s][A
Distributing tasks: 0it [00:00, ?it/s][A
Distributing tasks: 18it [00:00, 179.09it/s][A
                                            [A
Processing:   0%|                                      | 0/1519 [00:00<?, ?it/s][A
Processing:  59%|███████████████▍          | 901/1519 [00:00<00:00, 8999.37it/s][A
Dataset parts: 100%|██████████████████████████████| 2/2 [00:00<00:00,  2.58it/s][A


In [8]:
import torch

from lhotse import CutSet
from lhotse.dataset import BucketingSampler
from lhotse.dataset.input_strategies import AudioSamples
from torch.utils.data import DataLoader

class AudioSamplesDataset(torch.utils.data.Dataset):
    def __init__(self):
      self.collator = AudioSamples()

    def __getitem__(self, cuts: CutSet) -> dict:
        audio_padded, audio_lengths = self.collator(cuts)
        return { "audio_padded": audio_padded, "audio_lengths": audio_lengths }

cuts_train = CutSet.from_manifests(**libri["train-clean-5"])

train_sampler = BucketingSampler(
    cuts_train,
    max_duration=60,
    shuffle=True,
    drop_last=True
)

train_loader = DataLoader(
    AudioSamplesDataset(),
    sampler=train_sampler,
    batch_size=None,
    num_workers=1
)

In [9]:
batch = next(iter(train_loader))

batch

{'audio_padded': tensor([[-9.1553e-05,  9.1553e-05, -9.1553e-05,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 2.9907e-03,  3.1738e-03,  3.5706e-03,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-5.0354e-03, -5.1575e-03, -5.1270e-03,  ..., -3.3264e-03,
          -3.2959e-03, -3.3875e-03]]),
 'audio_lengths': tensor([239520, 242560, 243280], dtype=torch.int32)}

In [10]:
batch['audio_padded'].shape

torch.Size([3, 243280])

In [11]:
batch['audio_lengths']

tensor([239520, 242560, 243280], dtype=torch.int32)

In [12]:
import os
from tqdm import tqdm
from multi_quantization import QuantizerTrainer

In [13]:
# Which wav2vec 2/HuBERT transformer layer?
LAYER_OF_INTEREST=24
layer_index=LAYER_OF_INTEREST - 1

activations = []

# Register hook to trigger whenever forward() of nth layer is called
teacher_model.encoder.transformer.layers[layer_index].register_forward_hook(
    # Append outputs to list
    lambda teacher_model, inputs, outputs: activations.append(outputs.detach())
    # lambda teacher_model, inputs, outputs: print(outputs.shape)
  )


final_activations, feat_lens=teacher_model(batch['audio_padded'].to('cuda'), batch["audio_lengths"].to('cuda'))
int_activations=activations[0]
activations.clear()

trainer = QuantizerTrainer(
    dim=int_activations.shape[-1], bytes_per_frame=16, device=torch.device("cuda") #TODO: change dim to automatically get it from shape of the data
)

In [14]:
activations

[]

In [15]:
pbar = tqdm(total=len(cuts_train))

while not trainer.done():

    batch=next(iter(train_loader))

    with torch.no_grad():
        final_activations, feat_lens=teacher_model(batch['audio_padded'].to('cuda'), batch["audio_lengths"].to('cuda'))

    # Subset 0th item from list and clear list
    int_activations=activations[0]
    activations.clear()

    # Retrieve only non-pad frames
    non_pad_activations = []

    for item, final_frame in zip(int_activations, feat_lens):
        non_pad_activations.append(item[:final_frame])

    # Stack non-pad frames (1 frame=1 item for quantizer training)
    quantizer_train_batch = torch.cat(non_pad_activations, dim=0)

    trainer.step(quantizer_train_batch)
    pbar.update(1)

print("Done!")
pbar.close()

  index=(this_indexes//saved_K).expand(*this_indexes.shape[:-1], dim)) +
20001it [5:18:13,  1.05it/s]                                                    

Done!





In [16]:
import torch
quantizer = trainer.get_quantizer()
torch.save(quantizer.state_dict(), '/jagupard28/scr0/xiluo-speech/multi-quantizer-experiments/quantizers/full-layer24-16N-quantizer.pt')