In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
import torchaudio
from torch.utils.data import DataLoader, Dataset
import os
import numpy as np

In [2]:
class AudioDataset(Dataset):
    def __init__(self):
        input_manifest = "/raid/home/rajivratn/hemant_rajivratn/librispeech/data/manifest/train-clean-100.tsv"

        # Read the first line to get the root directory
        with open(input_manifest, "r") as infile:
            root_dir = infile.readline().strip()  # First line is the root directory

        # Define valid duration range
        min_duration = 32000  # 2 seconds
        max_duration = 250000  # 15.625 seconds

        # Dictionary to store filtered samples per speaker
        filtered_samples_by_speaker = {}

        with open(input_manifest, "r") as infile:
            infile.readline()  # Skip header (already read root_dir)
            for line in infile:
                parts = line.strip().split("\t")
                if len(parts) != 2:
                    continue
                file_name, duration = parts
                duration = int(duration)

                if min_duration <= duration <= max_duration:
                    full_path = os.path.join(root_dir, file_name)
                    speaker_id = file_name.split("_")[1]  # Extract speaker ID
                    
                    if speaker_id not in filtered_samples_by_speaker:
                        filtered_samples_by_speaker[speaker_id] = []
                    
                    filtered_samples_by_speaker[speaker_id].append((full_path, duration))

        self.diff_speakers = len(filtered_samples_by_speaker)
        print(f"Total speakers: {self.diff_speakers}")
        # a tuple with path, speaker, duration
        filtered_samples = []
        count = 0
        for k in filtered_samples_by_speaker:
            count += 1
            for i in filtered_samples_by_speaker[k]:
                filtered_samples.append((i[0], count, i[1]))
            #     break
            # if len(filtered_samples) == 40: 
            # break
            
        print(f"Total samples: {len(filtered_samples)}, Total speakers: {count}")
        # Sort by duration
        filtered_samples.sort(key=lambda x: x[-1])

        self.dataset = filtered_samples
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        path, speaker, duration = self.dataset[idx]
        waveform, sample_rate = torchaudio.load(path)
        assert sample_rate == 16000, "Sampling rate must be 16000"
        return waveform, speaker, duration

# Create the dataset and dataloader
dataset = AudioDataset()

# create a collate function to truncate the audio files to minimum length
def collate_fn(batches):
    min_dur = min([batch[2] for batch in batches])
    batch = [batch[0][:, :min_dur] for batch in batches]
    batch = torch.stack(batch)
    return batch.squeeze(1), [batch[1] for batch in batches]

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=1, collate_fn=collate_fn)

Total speakers: 585
Total samples: 24473, Total speakers: 585


In [3]:
# import models
from encoder import Encoder, Downsampling
from vocab import FrozenVocabulary, get_closest_vocab, merge_similar_indices
from decoder import Upsampling, Decoder, calculate_params
from codec import Codec

print("All imports are successful")

class Spk_Embed(nn.Module):
    def __init__(self, num_speakers=100, spk_embed_dim=256):
        super(Spk_Embed, self).__init__()
        self.spk_embed = nn.Embedding(num_speakers, spk_embed_dim)
        
    def forward(self, speaker):
        return self.spk_embed(speaker)
    
# params
hidden_dim = 256
spk_embed_dim = 768
num_speakers = dataset.diff_speakers

# models 
spk_embed = Spk_Embed(num_speakers=num_speakers, spk_embed_dim=spk_embed_dim)
encoder = Encoder() # frozen
downsampling = Downsampling()
vocab = FrozenVocabulary(path="vocab.pth") # frozen
upsampling = Upsampling(inp_dim=int(768+spk_embed_dim), hidden_dim=hidden_dim)
decoder = Decoder(hidden_dim=hidden_dim, out_dim=1024, num_blocks=5, kernel_size=11)
codec = Codec() # frozen
vocab_embeddings, char_to_idx, idx_to_char = vocab.embeddings, vocab.char_to_idx, vocab.idx_to_char

print(idx_to_char)
print(f"Paraeters of spk_embed: {calculate_params(spk_embed)}")
print(f"Paraeters of downsampling: {calculate_params(downsampling)}")
print(f"Paraeters of upsampling: {calculate_params(upsampling)}")
print(f"Paraeters of decoder: {calculate_params(decoder)}")

print("Models are initialized")


# Set the models to gpu
device = torch.device("cuda")
encoder = encoder.to(device)
downsampling = downsampling.to(device)
vocab_embeddings = vocab_embeddings.to(device)
decoder = decoder.to(device)
upsampling = upsampling.to(device)
codec.model = codec.model.to(device)
spk_embed = spk_embed.to(device)

# freeze the encoder, and codec
for param in codec.model.parameters():
    param.requires_grad = False   
vocab_embeddings.requires_grad = False

# Training loop
downsampling.train()
decoder.train()
upsampling.train()
codec.model.eval()
spk_embed.train()
1

  from .autonotebook import tqdm as notebook_tqdm


All imports are successful


  return self.fget.__get__(instance, owner)()


{0: '<pad>', 1: ' ', 2: "'", 3: 'A', 4: 'B', 5: 'C', 6: 'D', 7: 'E', 8: 'F', 9: 'G', 10: 'H', 11: 'I', 12: 'J', 13: 'K', 14: 'L', 15: 'M', 16: 'N', 17: 'O', 18: 'P', 19: 'Q', 20: 'R', 21: 'S', 22: 'T', 23: 'U', 24: 'V', 25: 'W', 26: 'X', 27: 'Y', 28: 'Z', 29: '<sil>', 30: '<bos>', 31: '<eos>'}
Paraeters of spk_embed: 0.44928
Paraeters of downsampling: 1.771008
Paraeters of upsampling: 1.179904
Paraeters of decoder: 7.478272
Models are initialized


1

In [4]:
# Set the models to training mode
encoder.train()
for param in encoder.named_parameters():
    param[1].requires_grad = False
    continue
    if "model.encoder.layers.8" in param[0] or "model.encoder.layers.11" in param[0]:
        param[1].requires_grad = True
    else:
        param[1].requires_grad = False
        
optimizer = optim.Adam(
    list(downsampling.parameters()) + list(decoder.parameters()) + list(upsampling.parameters()) + list(spk_embed.parameters()),
    # list(downsampling.parameters()) + list(decoder.parameters()) + list(upsampling.parameters()) + list(encoder.parameters()) + list(spk_embed.parameters()),
    lr=0.0005)

In [10]:
encoder.model

HubertModel(
  (feature_extractor): HubertFeatureEncoder(
    (conv_layers): ModuleList(
      (0): HubertGroupNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (activation): GELUActivation()
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
      )
      (1-4): 4 x HubertNoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (5-6): 2 x HubertNoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
    )
  )
  (feature_projection): HubertFeatureProjection(
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (projection): Linear(in_features=512, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): HubertEncoder(
    (pos_conv_embed): HubertPositionalConvEmbedding(
      (conv): Para

In [21]:
encoder.model.feature_extractor(waveform)

tensor([[[ 1.1097e-03,  1.0716e-03,  8.7021e-04,  ...,  7.2552e-04,
           2.0577e-03,  1.5775e-03],
         [ 1.2447e-03,  1.2546e-03,  1.0578e-03,  ...,  7.4357e-04,
           2.1386e-03,  1.7433e-03],
         [ 1.2679e-03,  1.2649e-03,  1.0902e-03,  ...,  7.2007e-04,
           2.1590e-03,  1.7740e-03],
         ...,
         [ 3.4858e-02,  2.0660e-02,  2.3154e-02,  ...,  2.3416e-02,
           3.1290e-02,  7.9780e-03],
         [ 1.3558e-03,  1.3255e-03,  1.2078e-03,  ...,  7.9715e-04,
           2.1935e-03,  1.8330e-03],
         [ 1.1462e-03,  1.2029e-03,  8.8880e-04,  ...,  5.9375e-04,
           2.0590e-03,  1.6928e-03]],

        [[ 1.0257e-03,  1.1311e-03,  7.5927e-04,  ...,  6.2689e-04,
           1.6387e-03,  2.7738e-03],
         [ 1.2818e-03,  1.2915e-03,  1.0284e-03,  ...,  8.5428e-04,
           1.8988e-03,  2.9862e-03],
         [ 1.2790e-03,  1.3763e-03,  1.0400e-03,  ...,  8.1125e-04,
           1.8305e-03,  2.9647e-03],
         ...,
         [ 8.0596e-03,  2

In [22]:
encoder.model.feature_projection(encoder.model.feature_extractor(waveform).transpose(1,2))

tensor([[[ 0.1857,  0.0769,  0.3977,  ..., -0.8687,  0.2623,  0.0000],
         [ 0.0000,  0.2378,  0.5511,  ..., -1.8676,  0.6503,  0.0643],
         [ 0.0884,  0.0252,  0.4305,  ..., -1.1239,  0.4569,  0.0000],
         ...,
         [-0.6440, -0.0602,  0.4956,  ...,  1.6554, -0.2846, -0.0311],
         [ 0.2597,  0.1765,  0.0596,  ..., -0.3161, -0.5654,  0.2011],
         [-0.0552, -0.1849, -0.3153,  ...,  0.3303, -0.0757,  0.6304]],

        [[-0.1254, -0.0303,  0.0689,  ..., -0.4109,  0.1515,  0.3123],
         [-0.4545, -0.1595,  0.2852,  ..., -0.8969,  0.7662, -0.0000],
         [-1.3589, -0.2886,  0.3215,  ..., -0.1405, -0.0090, -0.2662],
         ...,
         [ 0.7510, -0.0685, -0.2636,  ..., -0.2571,  0.0000, -0.6380],
         [ 1.1083,  0.0932,  0.0091,  ..., -0.1345,  0.7107, -1.4082],
         [ 1.6174,  0.2708, -0.0517,  ..., -0.9167,  1.6382, -0.0056]],

        [[ 0.3804, -0.0479,  0.0997,  ..., -1.2887, -0.2636,  0.3358],
         [ 0.2789, -0.1658, -0.0662,  ..., -0

In [20]:
encoder.model.encoder(encoder.model.feature_projection(encoder.model.feature_extractor(waveform).transpose(1,2)))['last_hidden_state']

tensor([[[ 3.2652e-02,  1.5366e-01,  2.5070e-01,  ..., -2.7075e-01,
           5.5647e-03, -5.6789e-01],
         [ 2.3165e-01,  4.2435e-02,  1.5181e-01,  ..., -8.0031e-02,
           2.1582e-02, -8.7625e-01],
         [ 4.6178e-02,  3.1645e-01,  8.7990e-02,  ..., -3.1208e-01,
           5.1273e-02, -1.1649e+00],
         ...,
         [ 4.5507e-03,  1.5237e-01, -1.6684e-01,  ..., -2.7230e-01,
          -7.5107e-02, -1.0434e+00],
         [ 4.0861e-03,  3.1525e-01,  9.1576e-02,  ..., -2.1161e-01,
          -8.1003e-02,  6.7370e-01],
         [ 3.0743e-01,  2.6195e-01,  6.8893e-03,  ...,  1.9722e-04,
          -7.1389e-02,  6.9221e-01]],

        [[ 9.6288e-03,  1.2530e-01,  2.9277e-01,  ...,  7.2687e-02,
          -7.4341e-02, -4.0891e-02],
         [ 4.4819e-02, -3.3666e-02,  3.1662e-01,  ..., -2.2837e-02,
           7.7811e-02, -5.3007e-01],
         [-1.0689e-01, -1.7001e-01,  5.6261e-01,  ...,  1.8448e-01,
          -2.3359e-01, -2.9243e-01],
         ...,
         [-1.6722e-01,  1

In [11]:
def merge_tensors(t1, t2):
    t2 = t2.unsqueeze(-1)  # Reshape to (batch, features, 1)
    return torch.cat([t1, t2.expand(-1, -1, t1.shape[-1])], dim=1)

# start training
num_epochs = 10000
for epoch in range(num_epochs):
    optimizer.zero_grad()
    running_loss = 0.0
    for iteration, data in enumerate(dataloader):
        # data
        waveform, speaker = data
        waveform = waveform.to(device) 
        speaker = torch.tensor(speaker).to(device)
        
        # Forward pass
        with torch.no_grad():
            encoder_output = encoder(waveform)
        downsampling_output = downsampling(encoder_output) # torch.Size([32, 768, 172])
        # Get the closest vocab embeddings
        commitment_loss, vocab_output, indices = get_closest_vocab(downsampling_output, vocab_embeddings)
        
        # add speaker embeddings
        speaker = spk_embed(speaker)
        vocab_output = merge_tensors(vocab_output, speaker)

        
        # Upsampling
        upsampling_output = upsampling(vocab_output)
        # Decoder
        decoder_output = decoder(upsampling_output).contiguous() # torch.Size([32, 1024, 172])
        
        # Codec
        with torch.no_grad():
            codec_output = codec.encode(waveform).detach().contiguous()
        
        # Ensure same sequence length for ground truth and output
        min_seq_len = min(codec_output.shape[-1], decoder_output.shape[-1])    
        codec_output = codec_output[:, :, :min_seq_len]
        decoder_output = decoder_output[:, :, :min_seq_len]    

        # Compute the loss
        l2_loss = F.mse_loss(decoder_output, codec_output)
        commitment_loss *= 10
        loss =  l2_loss + commitment_loss
        
        # Backward pass
        loss.backward()
        # Update weights
        optimizer.step()
        # empty cache
        torch.cuda.empty_cache()  
        optimizer.zero_grad()
        
        running_loss += loss.item()
        
        # print for every 10 iterations
        if iteration % 10 == 0:
            print(f"Indices: {indices[0]}")
            print(f"Epoch: {epoch}, Iteration: {iteration}/{len(dataloader)}, Loss: {running_loss/(iteration+1)}, commit_loss: {commitment_loss.item()}, l2_loss: {l2_loss.item()}")

Indices: tensor([20, 20, 20, 20, 20,  5, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20,  4,  5,  4, 20, 20,  5, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,  5, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20,  5,  5, 20], device='cuda:0')
Epoch: 0, Iteration: 0/765, Loss: 17.163265228271484, commit_loss: 9.549200057983398, l2_loss: 7.614066123962402


KeyboardInterrupt: 

In [None]:
ind = indices
"".join([idx_to_char[i] for i in merge_similar_indices(ind)[0]]) #.replace("<sil>", " ")

In [None]:
# Decode audio signal
y = codec.model.decode(decoder_output[1:2,:,:]).cpu().detach().numpy()

# play the numpy array as audio using ipython.display.Audio
import IPython.display as ipd
ipd.Audio(y[0,0,:], rate=16000)  # load a NumPy array

In [None]:
import IPython.display as ipd
ipd.Audio(waveform[1,:].cpu().detach().numpy(), rate=16000)  # load a NumPy array