In [1]:
# This is the training script for the overall model. 
# The model has basically 6 components:
# 1. Encoder module
# 2. Downsampling module
# 3. Frozen Vocabulary module
# 4. Upsampling module 
# 5. Decoder module 
# 6. Frozen neural audio codec module used for generating the ground truth

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
import torchaudio
from torch.utils.data import DataLoader, Dataset
import os
import numpy as np

# import models
from encoder import Encoder, Downsampling
from vocab import FrozenVocabulary, get_closest_vocab, merge_similar_indices
from decoder import Upsampling, Decoder, calculate_params
from codec import Codec

print("All imports are successful")

# params
hidden_dim = 768

# models 
encoder = Encoder() # frozen
downsampling = Downsampling()
vocab = FrozenVocabulary(path="vocab.pth") # frozen
upsampling = Upsampling(inp_dim=768, hidden_dim=hidden_dim)
decoder = Decoder(hidden_dim=hidden_dim, out_dim=1024, num_blocks=5, kernel_size=7)
codec = Codec() # frozen
vocab_embeddings, char_to_idx, idx_to_char = vocab.embeddings, vocab.char_to_idx, vocab.idx_to_char
print(idx_to_char)
print(f"Paraeters of downsampling: {calculate_params(downsampling)}")
print(f"Paraeters of upsampling: {calculate_params(upsampling)}")
print(f"Paraeters of decoder: {calculate_params(decoder)}")

print("Models are initialized")

  from .autonotebook import tqdm as notebook_tqdm


All imports are successful


  return self.fget.__get__(instance, owner)()


{0: '<pad>', 1: ' ', 2: "'", 3: 'A', 4: 'B', 5: 'C', 6: 'D', 7: 'E', 8: 'F', 9: 'G', 10: 'H', 11: 'I', 12: 'J', 13: 'K', 14: 'L', 15: 'M', 16: 'N', 17: 'O', 18: 'P', 19: 'Q', 20: 'R', 21: 'S', 22: 'T', 23: 'U', 24: 'V', 25: 'W', 26: 'X', 27: 'Y', 28: 'Z', 29: '<sil>', 30: '<bos>', 31: '<eos>'}
Paraeters of downsampling: 1.771776
Paraeters of upsampling: 1.771776
Paraeters of decoder: 42.098176
Models are initialized


In [2]:
# Set the models to gpu
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
encoder = encoder.to(device)
downsampling = downsampling.to(device)
vocab_embeddings = vocab_embeddings.to(device)
decoder = decoder.to(device)
upsampling = upsampling.to(device)
codec.model = codec.model.to(device)

In [44]:
import os

input_manifest = "/raid/home/rajivratn/hemant_rajivratn/librispeech/data/manifest/train-clean-100.tsv"

# Read the first line to get the root directory
with open(input_manifest, "r") as infile:
    root_dir = infile.readline().strip()  # First line is the root directory

# Define valid duration range
min_duration = 32000  # 2 seconds
max_duration = 128000  # 8 seconds

# Dictionary to store filtered samples per speaker
filtered_samples_by_speaker = {}

with open(input_manifest, "r") as infile:
    infile.readline()  # Skip header (already read root_dir)
    for line in infile:
        parts = line.strip().split("\t")
        if len(parts) != 2:
            continue
        file_name, duration = parts
        duration = int(duration)

        if min_duration <= duration <= max_duration:
            full_path = os.path.join(root_dir, file_name)
            speaker_id = file_name.split("_")[1]  # Extract speaker ID
            
            if speaker_id not in filtered_samples_by_speaker:
                filtered_samples_by_speaker[speaker_id] = []
            
            filtered_samples_by_speaker[speaker_id].append((full_path, duration))


In [45]:
filtered_samples_by_speaker.keys()

dict_keys(['358', '274117', '135914', '15045', '133695', '121082', '140048', '122819', '127786', '15220', '47824', '172357', '156745', '271888', '152918', '283493', '42010', '48852', '147987', '28452', '25947', '186183', '143879', '124992', '135842', '55211', '152257', '132655', '61334', '145015', '13009', '121914', '139310', '7763', '29405', '121119', '67168', '128982', '123349', '129061', '130551', '11217', '359', '152900', '34600', '126305', '135815', '19397', '61803', '29116', '125237', '62556', '410', '92432', '34669', '18515', '131231', '294887', '145724', '123720', '5143', '130739', '102518', '274346', '5371', '186175', '11691', '261139', '130880', '248638', '41615', '121342', '39621', '126791', '12530', '56168', '39938', '274553', '122615', '59157', '130898', '220959', '142933', '172359', '123857', '219', '86737', '132847', '142371', '130746', '145706', '91187', '283452', '122442', '171115', '123719', '123516', '130697', '102519', '41616', '124404', '21625', '105661', '135887',

In [51]:
filtered_samples_by_speaker['358'][2]

('/raid/home/rajivratn/hemant_rajivratn/librispeech/data/train/audio/train-clean-100_358_730_56.wav',
 66560)

In [55]:


# play audio sample
import IPython.display as ipd

ipd.Audio(filtered_samples_by_speaker['358'][2][0])

In [40]:
for k in filtered_samples_by_speaker:
    t = 0
    for i in filtered_samples_by_speaker[k]:
        t += i[1]
    print(f"Speaker {k} has {len(filtered_samples_by_speaker[k])} samples with total duration in minutes: {t/16000/60}")

Speaker 44.wav has 45 samples with total duration in minutes: 3.788334375
Speaker 7.wav has 66 samples with total duration in minutes: 5.697084375
Speaker 6.wav has 65 samples with total duration in minutes: 5.433083333333333
Speaker 26.wav has 61 samples with total duration in minutes: 5.097498958333333
Speaker 81.wav has 12 samples with total duration in minutes: 1.1793333333333333
Speaker 63.wav has 16 samples with total duration in minutes: 1.3421666666666667
Speaker 35.wav has 62 samples with total duration in minutes: 5.043251041666666
Speaker 30.wav has 63 samples with total duration in minutes: 5.137666666666666
Speaker 77.wav has 18 samples with total duration in minutes: 1.3990833333333332
Speaker 85.wav has 12 samples with total duration in minutes: 0.9304166666666667
Speaker 13.wav has 64 samples with total duration in minutes: 5.32208125
Speaker 22.wav has 80 samples with total duration in minutes: 6.3853333333333335
Speaker 21.wav has 67 samples with total duration in min

In [7]:
class AudioDataset(Dataset):
    def __init__(self):
        input_manifest = "/raid/home/rajivratn/hemant_rajivratn/librispeech/data/manifest/train-clean-100.tsv"

        # Read the first line to get the root directory
        with open(input_manifest, "r") as infile:
            root_dir = infile.readline().strip()  # First line is the root directory

        # Define valid duration range
        min_duration = 32000  # 2 seconds
        max_duration = 128000  # 8 seconds

        filtered_samples = []

        with open(input_manifest, "r") as infile:
            infile.readline()  # Skip header (already read root_dir)
            for line in infile:
                parts = line.strip().split("\t")
                if len(parts) != 2:
                    continue
                file_name, duration = parts
                duration = int(duration)

                if min_duration <= duration <= max_duration:
                    full_path = os.path.join(root_dir, file_name)
                    filtered_samples.append((full_path, duration))

        # filtered_samples = filtered_samples[:320]
        # Sort by duration
        filtered_samples.sort(key=lambda x: x[1])

        self.dataset = filtered_samples
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        waveform, sample_rate = torchaudio.load(self.dataset[idx][0])
        assert sample_rate == 16000, "Sampling rate must be 16000"
        return waveform 

# Create the dataset and dataloader
dataset = AudioDataset()

# create a collate function to truncate the audio files to minimum length
def collate_fn(batch):
    min_len = min([waveform.shape[1] for waveform in batch])
    # random select a start point and extend the audio to 2 seconds 
    # start = np.random.randint(0, min_len-32000)
    # batch = [waveform[:, start:start+32000] for waveform in batch]
    batch = [waveform[:, :min_len] for waveform in batch]
    batch = torch.stack(batch)
    return batch.squeeze(1)

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=6, collate_fn=collate_fn)

In [8]:
# freeze the encoder, and codec
for param in codec.model.parameters():
    param.requires_grad = False   
vocab_embeddings.requires_grad = False

# Training loop
downsampling.train()
decoder.train()
upsampling.train()
codec.model.eval()
1

1

In [9]:
# Set the models to training mode
encoder.train()
for param in encoder.named_parameters():
    # param[1].requires_grad = False
    # continue
    if "model.encoder.layers.10" in param[0] or "model.encoder.layers.11" in param[0]:
        param[1].requires_grad = True
    else:
        param[1].requires_grad = False

In [10]:
optimizer = optim.Adam(
    # list(downsampling.parameters()) + list(decoder.parameters()) + list(upsampling.parameters()),
    list(downsampling.parameters()) + list(decoder.parameters()) + list(upsampling.parameters()) + list(encoder.parameters()),
    lr=0.0005)
    
# start training
num_epochs = 10000
for epoch in range(num_epochs):
    optimizer.zero_grad()
    running_loss = 0.0
    for iteration, waveform in enumerate(dataloader):
        # data
        waveform = waveform.to(device) 
        
        # Forward pass
        encoder_output = encoder(waveform)
        downsampling_output = downsampling(encoder_output) # torch.Size([32, 768, 172])
        
        # Get the closest vocab embeddings
        commitment_loss, vocab_output, indices = get_closest_vocab(downsampling_output, vocab_embeddings)
        
        # Upsampling
        upsampling_output = upsampling(vocab_output)
        # Decoder
        decoder_output = decoder(upsampling_output) # torch.Size([32, 1024, 172])
        
        # Codec
        with torch.no_grad():
            codec_output = codec.encode(waveform).detach()
        
        # Ensure same sequence length for ground truth and output
        min_seq_len = min(codec_output.shape[-1], decoder_output.shape[-1])    
        codec_output = codec_output[:, :, :min_seq_len]
        decoder_output = decoder_output[:, :, :min_seq_len]    

        # Compute the loss
        l2_loss = F.mse_loss(decoder_output, codec_output)
        # commitment_loss *= 100
        loss =  l2_loss + commitment_loss
        
        # Backward pass
        loss.backward()
        # Update weights
        optimizer.step()
        # empty cache
        torch.cuda.empty_cache()  
        optimizer.zero_grad()
        
        running_loss += loss.item()
        
        # print for every 10 iterations
        if iteration % 10 == 0:
            print(f"Indices: {indices[0]}")
            print(f"Epoch: {epoch}, Iteration: {iteration}/{len(dataloader)}, Loss: {running_loss/(iteration+1)}, commit_loss: {commitment_loss.item()}, l2_loss: {l2_loss.item()}")

Indices: tensor([20, 20,  9,  4,  7, 10, 20,  5,  4, 20,  5, 16, 16, 10,  7, 19, 20, 20,
        20, 24, 28,  4,  4, 20, 13,  9,  5, 26,  4, 15,  5,  3, 23,  5,  5,  4,
         9,  5,  4, 22, 19, 20, 20,  9, 22, 22, 20, 20,  5,  5, 20, 20, 20, 20,
         9,  5, 15, 20, 10, 20,  5,  5,  5,  4,  4, 20,  5,  5,  5, 14, 20, 20,
         5,  5], device='cuda:2')
Epoch: 0, Iteration: 0/120, Loss: 12.06057071685791, commit_loss: 1.836810827255249, l2_loss: 10.223759651184082
Indices: tensor([17, 20, 20, 20, 20, 20, 20, 20, 20, 20,  4, 17, 20, 20, 20,  7,  4, 20,
        20, 20, 20,  7,  7,  5,  5,  5,  7,  4, 20, 20, 20, 20,  7,  7,  7,  5,
         5, 20, 20,  5,  5,  5,  5, 20, 20, 20, 20,  7,  7,  7,  7],
       device='cuda:2')
Epoch: 0, Iteration: 10/120, Loss: 7.983944632790306, commit_loss: 0.7960111498832703, l2_loss: 5.751102447509766
Indices: tensor([17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17,  5, 20,  7, 17,
        17,  5,  5,  5, 20, 20, 20, 20,  7,  7,  7,  7,  5,

KeyboardInterrupt: 

In [11]:
indices[0]

tensor([17,  5,  5,  5,  5,  5,  5, 17, 17, 17,  5,  5,  7,  7,  7,  7, 20, 20,
        20, 20, 20,  4,  4,  4, 20,  7,  4,  4,  4, 20, 20, 20,  4,  4,  5,  5,
         5, 20, 20, 20, 20, 20,  4,  4, 20, 20,  5, 20, 20, 20, 20,  5, 20,  4,
        17, 20,  4,  4,  4,  4, 17, 17, 17,  5,  5,  5,  5,  5,  5,  5,  5,  5],
       device='cuda:2')

In [12]:
ind = indices
"".join([idx_to_char[i] for i in merge_similar_indices(ind)[0]]) #.replace("<sil>", " ")

'OCOCERBREBRBCRBRCRCRBORBOC'

In [13]:
# Decode audio signal
y = codec.model.decode(decoder_output[:1,:,:]).cpu().detach().numpy()

# play the numpy array as audio using ipython.display.Audio
import IPython.display as ipd
ipd.Audio(y[0,0,:], rate=16000)  # load a NumPy array

In [14]:
import IPython.display as ipd
ipd.Audio(waveform[0,:].cpu().detach().numpy(), rate=16000)  # load a NumPy array