In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
import torchaudio
from torch.utils.data import DataLoader, Dataset
import os
import numpy as np

import warnings
warnings.simplefilter("ignore")

In [13]:
class AudioDataset(Dataset):
    def __init__(self):
        input_manifest = "/raid/home/rajivratn/hemant_rajivratn/librispeech/data/manifest/train-clean-100.tsv"

        # Read the first line to get the root directory
        with open(input_manifest, "r") as infile:
            root_dir = infile.readline().strip()  # First line is the root directory

        # Define valid duration range
        min_duration = 32000  # 2 seconds
        max_duration = 250000  # 15.625 seconds

        # Dictionary to store filtered samples per speaker
        filtered_samples_by_speaker = {}

        with open(input_manifest, "r") as infile:
            infile.readline()  # Skip header (already read root_dir)
            for line in infile:
                parts = line.strip().split("\t")
                if len(parts) != 2:
                    continue
                file_name, duration = parts
                duration = int(duration)

                if min_duration <= duration <= max_duration:
                    full_path = os.path.join(root_dir, file_name)
                    speaker_id = file_name.split("_")[1]  # Extract speaker ID
                    
                    if speaker_id not in filtered_samples_by_speaker:
                        filtered_samples_by_speaker[speaker_id] = []
                    
                    filtered_samples_by_speaker[speaker_id].append((full_path, duration))

        self.diff_speakers = len(filtered_samples_by_speaker)
        print(f"Total speakers: {self.diff_speakers}")
        # a tuple with path, speaker, duration
        filtered_samples = []
        count = 0
        for k in filtered_samples_by_speaker:
            count += 1
            for i in filtered_samples_by_speaker[k]:
                filtered_samples.append((i[0], count, i[1]))
            #     break
            # if len(filtered_samples) == 40: 
            # break
            
        print(f"Total samples: {len(filtered_samples)}, Total speakers: {count}")
        # Sort by duration
        filtered_samples.sort(key=lambda x: x[-1])

        self.dataset = filtered_samples
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        path, speaker, duration = self.dataset[idx]
        waveform, sample_rate = torchaudio.load(path)
        assert sample_rate == 16000, "Sampling rate must be 16000"
        return waveform.squeeze(0), speaker, duration

# Create the dataset and dataloader
dataset = AudioDataset()

# create a collate function to truncate the audio files to minimum length
def collate_fn(batches):
    min_dur = min([batch[2] for batch in batches])
    waveforms = []
    speakers = []
    for batch in batches:
        waveforms.append(batch[0][:min_dur])
        speakers.append(batch[1])
    return torch.stack(waveforms), torch.tensor(speakers).unsqueeze(1) # bsz, seq_len and bsz, 1

dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=1, collate_fn=collate_fn)

Total speakers: 585
Total samples: 24473, Total speakers: 585


In [None]:
# import models
from encoder import Encoder, Downsampling
from vocab import FrozenVocabulary, get_closest_vocab, merge_similar_indices
from decoder import Upsampling, Decoder, calculate_params
from codec import Codec

print("All imports are successful")

class Spk_Embed(nn.Module):
    def __init__(self, num_speakers=100, spk_embed_dim=256):
        super(Spk_Embed, self).__init__()
        self.spk_embed = nn.Embedding(num_speakers, spk_embed_dim)
        
    def forward(self, speaker):
        return self.spk_embed(speaker)
    
# params
hidden_dim = 256
spk_embed_dim = 256
num_speakers = dataset.diff_speakers

# models 
spk_embed = Spk_Embed(num_speakers=num_speakers, spk_embed_dim=spk_embed_dim)
encoder = Encoder() # frozen
downsampling = Downsampling()
vocab = FrozenVocabulary(path="vocab.pth") # frozen
upsampling = Upsampling(inp_dim=int(768+spk_embed_dim), hidden_dim=hidden_dim)
decoder = Decoder(hidden_dim=hidden_dim, out_dim=1024, num_blocks=5, kernel_size=11)
codec = Codec() # frozen
vocab_embeddings, char_to_idx, idx_to_char = vocab.embeddings, vocab.char_to_idx, vocab.idx_to_char

print(idx_to_char)
print(f"Paraeters of spk_embed: {calculate_params(spk_embed)}")
print(f"Paraeters of downsampling: {calculate_params(downsampling)}")
print(f"Paraeters of upsampling: {calculate_params(upsampling)}")
print(f"Paraeters of decoder: {calculate_params(decoder)}")

print("Models are initialized")


# Set the models to gpu
device = torch.device("cuda")
encoder = encoder.to(device)
downsampling = downsampling.to(device)
vocab_embeddings = vocab_embeddings.to(device)
decoder = decoder.to(device)
upsampling = upsampling.to(device)
codec.model = codec.model.to(device)
spk_embed = spk_embed.to(device)

# freeze the encoder, and codec
for param in codec.model.parameters():
    param.requires_grad = False   
vocab_embeddings.requires_grad = False

# Training loop
downsampling.train()
decoder.train()
upsampling.train()
codec.model.eval()
spk_embed.train()
1

  from .autonotebook import tqdm as notebook_tqdm


All imports are successful


  return self.fget.__get__(instance, owner)()


{0: '<pad>', 1: ' ', 2: "'", 3: 'A', 4: 'B', 5: 'C', 6: 'D', 7: 'E', 8: 'F', 9: 'G', 10: 'H', 11: 'I', 12: 'J', 13: 'K', 14: 'L', 15: 'M', 16: 'N', 17: 'O', 18: 'P', 19: 'Q', 20: 'R', 21: 'S', 22: 'T', 23: 'U', 24: 'V', 25: 'W', 26: 'X', 27: 'Y', 28: 'Z', 29: '<sil>', 30: '<bos>', 31: '<eos>'}
Paraeters of spk_embed: 0.14976
Paraeters of downsampling: 4.130304
Paraeters of upsampling: 0.786688
Paraeters of decoder: 7.478272
Models are initialized


1

In [None]:
# Set the models to training mode
encoder.train()
for param in encoder.named_parameters():
    param[1].requires_grad = False
    continue
    if "model.encoder.layers.8" in param[0] or "model.encoder.layers.11" in param[0]:
        param[1].requires_grad = True
    else:
        param[1].requires_grad = False
        
optimizer = optim.Adam(
    list(downsampling.parameters()) + list(decoder.parameters()) + list(upsampling.parameters()) + list(spk_embed.parameters()),
    # list(downsampling.parameters()) + list(decoder.parameters()) + list(upsampling.parameters()) + list(encoder.parameters()) + list(spk_embed.parameters()),
    lr=0.0005)

In [None]:
def merge_tensors(t1, t2):
    t2 = t2.unsqueeze(-1)  # Reshape to (batch, features, 1)
    return torch.cat([t1, t2.expand(-1, -1, t1.shape[-1])], dim=1)

# start training
num_epochs = 10000
for epoch in range(num_epochs):
    optimizer.zero_grad()
    running_loss = 0.0
    for iteration, data in enumerate(dataloader):
        # data
        waveform, speaker = data
        waveform = waveform.to(device) 
        speaker = torch.tensor(speaker).to(device)
        
        # Forward pass
        with torch.no_grad():
            encoder_output = encoder(waveform)
        downsampling_output = downsampling(encoder_output) # torch.Size([32, 768, 172])
        # Get the closest vocab embeddings
        commitment_loss, vocab_output, indices = get_closest_vocab(downsampling_output, vocab_embeddings)
        
        # add speaker embeddings
        speaker = spk_embed(speaker)
        vocab_output = merge_tensors(vocab_output, speaker)

        
        # Upsampling
        upsampling_output = upsampling(vocab_output)
        # Decoder
        decoder_output = decoder(upsampling_output).contiguous() # torch.Size([32, 1024, 172])
        
        # Codec
        with torch.no_grad():
            codec_output = codec.encode(waveform).detach().contiguous()
        
        # Ensure same sequence length for ground truth and output
        min_seq_len = min(codec_output.shape[-1], decoder_output.shape[-1])    
        codec_output = codec_output[:, :, :min_seq_len]
        decoder_output = decoder_output[:, :, :min_seq_len]    

        # Compute the loss
        l2_loss = F.mse_loss(decoder_output, codec_output)
        # commitment_loss *= 10
        loss =  l2_loss + commitment_loss
        
        # Backward pass
        loss.backward()
        # Update weights
        optimizer.step()
        # empty cache
        torch.cuda.empty_cache()  
        optimizer.zero_grad()
        
        running_loss += loss.item()
        
        # print for every 10 iterations
        if iteration % 10 == 0:
            print(f"Indices: {indices[0]}")
            print(f"Epoch: {epoch}, Iteration: {iteration}/{len(dataloader)}, Loss: {running_loss/(iteration+1)}, commit_loss: {commitment_loss.item()}, l2_loss: {l2_loss.item()}")

RuntimeError: shape '[32, 66]' is invalid for input of size 6336

In [None]:
downsampling_output.shape

(torch.Size([32, 768, 66]), torch.Size([29, 256]))

In [None]:
ind = indices
"".join([idx_to_char[i] for i in merge_similar_indices(ind)[0]]) #.replace("<sil>", " ")

In [None]:
# Decode audio signal
y = codec.model.decode(decoder_output[1:2,:,:]).cpu().detach().numpy()

# play the numpy array as audio using ipython.display.Audio
import IPython.display as ipd
ipd.Audio(y[0,0,:], rate=16000)  # load a NumPy array

In [None]:
import IPython.display as ipd
ipd.Audio(waveform[1,:].cpu().detach().numpy(), rate=16000)  # load a NumPy array