In [2]:
import json
from functools import partial
from torch.utils.data import Dataset
import torch
import torch.nn as nn
from transformers import Wav2Vec2Model
import torch.nn as nn
from torch.utils.data import DataLoader

import re
import os

import random
import torchaudio
import numpy as np
import pandas as pd
from jiwer import wer
from tqdm.auto import tqdm
from IPython.display import Audio
# AdamW is best optimizer
from torch.optim import AdamW
from transformers import get_scheduler
from transformers import Wav2Vec2Model, Wav2Vec2Processor, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer

### Data Making (Do Not Run this Cell)

In [5]:
full_data = pd.read_csv('/home/IITB/ai-at-ieor/23m1508/Shivam_23M1508/Interspeech/data/set1_set2_set3_combined.csv')

In [6]:
full_data.shape

(1398574, 4)

In [None]:
# Define the class column name
class_column = 'class'  # Replace with actual class column name

# Shuffle each class group independently
class_groups = {cls: df.sample(frac=1, random_state=42).reset_index(drop=True)
                for cls, df in full_data.groupby(class_column)}

# Determine how many total rows and chunks we want
total_rows = len(full_data)
chunk_size = 100_000
n_chunks = (total_rows + chunk_size - 1) // chunk_size  # round up

# Prepare empty list to collect rows
balanced_data = []

# Interleave rows in balanced way
for chunk_idx in range(n_chunks):
    chunk = []
    chunk_start = chunk_idx * chunk_size
    chunk_end = min((chunk_idx + 1) * chunk_size, total_rows)

    rows_needed = chunk_end - chunk_start
    per_class_rows = rows_needed // len(class_groups)

    for cls, df in class_groups.items():
        take = min(per_class_rows, len(df))
        chunk.append(df.iloc[:take])
        class_groups[cls] = df.iloc[take:].reset_index(drop=True)

    # Handle remainder rows (to reach 100k)
    remainder = rows_needed - sum(len(c) for c in chunk)
    if remainder > 0:
        # Grab remainder rows randomly from what's left
        leftovers = pd.concat([df for df in class_groups.values() if not df.empty])
        extra = leftovers.sample(n=remainder, random_state=chunk_idx)
        chunk.append(extra)

    balanced_data.append(pd.concat(chunk).sample(frac=1, random_state=chunk_idx).reset_index(drop=True))

# Final combined dataframe
final_data = pd.concat(balanced_data).reset_index(drop=True)


In [None]:
# Check class distribution in the first 100,000 rows
final_data.loc[1300000:, 'class'].value_counts()

class
english    98574
Name: count, dtype: int64

In [20]:
final_data.loc[:100000, 'class'].value_counts()

class
english       9092
korean        9092
thai          9092
spanish       9092
russian       9092
japanese      9091
portuguese    9090
french        9090
german        9090
italian       9090
vietnamese    9090
Name: count, dtype: int64

Now we have equal class distribution on each subsequent 100000 lakh samples of the now we will use these chunks to train the model, last 1 lakh chunk only have english smaples.

In [21]:
final_data.loc[100000:200000, 'class'].value_counts()

class
italian       9092
english       9092
vietnamese    9092
russian       9092
spanish       9092
french        9091
japanese      9090
thai          9090
portuguese    9090
german        9090
korean        9090
Name: count, dtype: int64

In [19]:
final_data.to_csv('/home/IITB/ai-at-ieor/23m1508/Shivam_23M1508/Interspeech/data/balanced_shuffled_data.csv', index=False)

### Config (Only Modify this then run all below this.)

In [3]:
config = {
            "BASE_MODEL_ID" : "facebook/wav2vec2-xls-r-2b",
            
            "start_data" : 200000,
            "end_data" : 300000,
            'device' : 1,
            'BATCH_SIZE' : 16,
            'EPOCHS' : 30,
            'LR' : 1e-5,
            
            'Number_of_first_layers_freeze_transofrmer' : 48,
            
            'vocab_json_path' : "/home/IITB/ai-at-ieor/23m1508/Shivam_23M1508/Interspeech/code/model_code/model_files/multilingual_vocab.json",
            
            'num_warmup_steps' : 1000,
            
            'loss_csv_saving_path'  : "/home/IITB/ai-at-ieor/23m1508/Shivam_23M1508/Interspeech/code/model_code/model_loss_files/loss_2_3_lakh",
            
            'loss_csv_name' : "loss_2_3_lakh.csv",
             
            'prev_checkpoint_dir' : "/home/IITB/ai-at-ieor/23m1508/Shivam_23M1508/Interspeech/code/model_code/model_weights/weights_0_2_lakh/multilingual_asr_model_0_2_lakh.pt",
            
            'new_checkpoint_dir' : "/home/IITB/ai-at-ieor/23m1508/Shivam_23M1508/Interspeech/code/model_code/model_weights/weights_2_3_lakh",
            
            'model_name' : "multilingual_asr_model_2_3_lakh",
            
            'load_from_prev' : True,
            'resume_training' : False
          
          
          }

In [4]:
device = torch.device(f"cuda:{config['device']}" if torch.cuda.is_available() else "cpu")

In [5]:
device

device(type='cuda', index=1)

### Data Preprocessing (Do Run From this Cell)

In [2]:
import pandas as pd

In [4]:
final_data = pd.read_csv('/home/IITB/ai-at-ieor/23m1508/Shivam_23M1508/Interspeech/data/scratch_balanced_shuffled_data.csv')

Now we will take 2 lakh  to 4 lakh sample for training

lower case only the english text along the class

In [7]:
final_data = final_data.loc[config['start_data'] : config['end_data'], :]

In [8]:
def lowercase_english_letters(text):
    return ''.join([c.lower() if c.isascii() and c.isalpha() else c for c in text])

final_data['Text'] = final_data['Text'].apply(lowercase_english_letters)

In [9]:
def clean_text(text):
    characters_to_remove = r'[,\?\.\!\-\;\:\%\'\`\{}()@#$%^&*\+\[\]\_｡]+'
    cleaned_text = re.sub(characters_to_remove, '', text)
    # cleaned_text = ''.join([c.lower() if c.isascii() else c for c in cleaned_text]) + ' '
    return cleaned_text

final_data['Text'] = final_data['Text'].apply(clean_text)

In [10]:
final_data['class'].value_counts()

class
english       9094
korean        9091
vietnamese    9091
thai          9091
spanish       9091
russian       9091
french        9091
japanese      9091
german        9090
italian       9090
portuguese    9090
Name: count, dtype: int64

In [11]:
final_data = final_data.reset_index(drop = True)

In [12]:
# final_data

### Building Tokenizer

In [13]:
tokenizer = Wav2Vec2CTCTokenizer(r"/home/IITB/ai-at-ieor/23m1508/Shivam_23M1508/Interspeech/code/model_code/model_files/multilingual_vocab.json", bos_token = "<s>",
                                 eos_token = "</s>",
                                 unk_token = "<unk>", 
                                 pad_token = "<pad>", 
                                 word_delimiter_token = "|")

In [14]:
feature_extractor = Wav2Vec2FeatureExtractor(feature_size = 1, 
                                             sampling_rate = 16000, 
                                             padding_value = 0.0, 
                                             do_normalize = True, 
                                             return_attention_mask = True)

In [15]:
processor = Wav2Vec2Processor(feature_extractor = feature_extractor, 
                              tokenizer = tokenizer)

In [16]:
# speech = speech.squeeze()

# input_values = processor.feature_extractor(speech, sampling_rate = 16000, return_tensors="pt")

# labels = processor.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)

In [17]:
# import torch

# model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**2)  # Convert to MB
# print(f"Model size: {model_size:.2f} MB")

### Data Class \& Data Loader Making

In [18]:
# from sklearn.model_selection import train_test_split

# # Assume df is your DataFrame and 'class' is the column with class labels
# train_df, test_df = train_test_split(
#                                         full_data,
#                                         test_size = 0.5,
#                                         stratify = full_data['class'],      # Ensures equal class distribution
#                                         random_state = 42            # For reproducibility
#                                     )

Currently i am using only 50 percent of data for training to see wether the wer decresing or not.

In [19]:
# train_df = train_df.reset_index(drop = True)
# test_df = test_df.reset_index(drop = True)

In [20]:
# train_df.shape, test_df.shape

In [21]:
# train_df['class'].value_counts(), test_df['class'].value_counts()

In [22]:
# # Check if index is not in sequential order
# mismatched_indices = train_df.index != range(len(train_df))

# # Show the rows where index is not in order
# train_df[mismatched_indices]

In [23]:
class SpeechDataset(Dataset):
    def __init__(self, df, processor, transforms=None):
        self.df = df.reset_index(drop=True)  # ensure index is clean
        self.processor = processor
        self.transforms = transforms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        waveform, sample_rate = torchaudio.load(row['Path'])
        text = row['Text']

        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
            waveform = resampler(waveform)

        waveform = waveform.squeeze()

        # Optional audio transforms
        if self.transforms:
            for transform in self.transforms:
                waveform = transform(waveform)

        # Process audio
        input_values = self.processor.feature_extractor(
                                                        waveform, sampling_rate=16000, return_tensors="pt"
                                                    )["input_values"].squeeze(0)

        # Process labels (text)
        labels = self.processor.tokenizer(
                                            text, return_tensors="pt", truncation=True
                                        )["input_ids"].squeeze(0)

        return {
            "input_values": input_values,
            "labels": labels
        }

In [24]:
speech_dataset = SpeechDataset(final_data, processor, transforms = None)

In [25]:
len(speech_dataset)

100001

In [26]:
# train_df.loc[957005, 'Text']

In [27]:
# sample = speech_dataset[45]
# audio = sample["input_values"]
# transcription = sample["labels"]

# print(audio)
# print(audio.shape)
# print(transcription)
# print(transcription.shape)

In [28]:
def collate_function(batch, processor, 
                     padding = True, 
                     max_length = None, 
                     max_length_labels = None, 
                     pad_to_multiple_of = None, 
                     pad_to_multiple_of_labels = None):
  
    # Extract input values and labels from each sample in the batch
    b_X = [{"input_values": sample["input_values"]} for sample in batch]
    b_Y = [{"input_ids": sample["labels"]} for sample in batch]

    # Pad the audio inputs the same length
    features = processor.feature_extractor.pad(
                                                b_X,
                                                padding = padding,
                                                max_length = max_length,
                                                pad_to_multiple_of = pad_to_multiple_of,
                                                return_tensors = "pt"
                                              )

    # Pad the labels
    batchY = processor.tokenizer.pad(
                            b_Y,
                            padding = padding,
                            max_length = max_length_labels,
                            pad_to_multiple_of = pad_to_multiple_of_labels,
                            return_tensors = "pt"
                          )

    # Replace padding tokens in labels with -100, so they are ignored during loss calculation
    labels = batchY["input_ids"].masked_fill(batchY.attention_mask.ne(1), -100)

    # Add the padded labels back into the features dictionary
    features["labels"] = labels

    # Return the features, which now include both input values and labels
    return features

In [29]:
collate_fn = partial(collate_function, 
                     processor = processor, 
                     padding = True, 
                     max_length = None, 
                     max_length_labels = None, 
                     pad_to_multiple_of = None, 
                     pad_to_multiple_of_labels = None)

In [30]:
batch_size = config['BATCH_SIZE']
epochs = config['EPOCHS']
lr = config['LR']

In [31]:
batch_size, epochs, lr

(16, 30, 1e-05)

In [32]:
train_dataloader = DataLoader(speech_dataset, 
                              batch_size = batch_size, 
                              shuffle = True, 
                              collate_fn = collate_fn)

In [33]:
# test_speech_data = SpeechDataset(test_df, processor, transforms = None)

In [34]:
# test_dataloader = DataLoader(test_speech_data, 
#                               batch_size = batch_size, 
#                               shuffle = True, 
#                               collate_fn = collate_fn)

In [35]:
for i in train_dataloader:
    print(i['input_values'].shape)
    print(i['attention_mask'].shape)
    print(i['labels'].shape)
    break

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


torch.Size([16, 128000])
torch.Size([16, 128000])
torch.Size([16, 131])


In [36]:
# for i in test_dataloader:
#     print(i['input_values'].shape)
#     print(i['attention_mask'].shape)
#     print(i['labels'].shape)
#     break

### Loading the Pretrained Model

In [37]:
model = Wav2Vec2Model.from_pretrained(config['BASE_MODEL_ID'])
# model

In [38]:
# methods = [method for method in dir(model) if callable(getattr(model, method))]
# print(methods)

In [39]:
# for name, param in model.named_parameters():
#     print(f"{name}: requires_grad = {param.requires_grad}")

In [40]:
model.encoder.layers

ModuleList(
  (0-47): 48 x Wav2Vec2EncoderLayerStableLayerNorm(
    (attention): Wav2Vec2SdpaAttention(
      (k_proj): Linear(in_features=1920, out_features=1920, bias=True)
      (v_proj): Linear(in_features=1920, out_features=1920, bias=True)
      (q_proj): Linear(in_features=1920, out_features=1920, bias=True)
      (out_proj): Linear(in_features=1920, out_features=1920, bias=True)
    )
    (dropout): Dropout(p=0.1, inplace=False)
    (layer_norm): LayerNorm((1920,), eps=1e-05, elementwise_affine=True)
    (feed_forward): Wav2Vec2FeedForward(
      (intermediate_dropout): Dropout(p=0.0, inplace=False)
      (intermediate_dense): Linear(in_features=1920, out_features=7680, bias=True)
      (intermediate_act_fn): GELUActivation()
      (output_dense): Linear(in_features=7680, out_features=1920, bias=True)
      (output_dropout): Dropout(p=0.1, inplace=False)
    )
    (final_layer_norm): LayerNorm((1920,), eps=1e-05, elementwise_affine=True)
  )
)

In [41]:
# There are 48 transformer layers in the model

### Currently i am training only My custom CTC Head

In [42]:
# Freeze feature extractor (same as model.freeze_feature_encoder())
for param in model.feature_extractor.parameters():
    param.requires_grad = False

# Freeze feature projection (optional)
for param in model.feature_projection.parameters():
    param.requires_grad = False

# Freeze all transformer encoder layers
# for param in model.encoder.parameters():
#     param.requires_grad = False
    
    
# Freeze first 30 transformer layers (out of 48 for base model)
# for i in range(config['Number_of_first_layers_freeze_transofrmer']):
#     for param in model.encoder.layers[i].parameters():
#         param.requires_grad = False
        
        
for param in model.parameters():
    param.requires_grad = False
    
# Optional: Print which layers are frozen
# for name, param in model.named_parameters():
#     print(f"{name}: requires_grad = {param.requires_grad}")

In [43]:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable_params} / Total: {total_params}")

Trainable: 0 / Total: 2159259648


### Building the custom CTC Head

In [44]:
class Projector(nn.Module):
    def __init__(self, model, projection_dim = 5000):
        super().__init__()
        self.wav2vec2 = model
        self.projection = nn.Linear(1920, projection_dim)

    def forward(self, input_values, attention_mask = None):
        outputs = self.wav2vec2(input_values, attention_mask = attention_mask)
        hidden_states = outputs.last_hidden_state  # [batch, time, hidden]
        projected = self.projection(hidden_states)  # [batch, time, 5000]
        return projected


# Custom CTC model
class CustomWav2Vec2CTC(nn.Module):
    def __init__(self, model, vocab_size, projection_dim = 5000):
        super().__init__()

        self.projector = Projector(model, projection_dim = projection_dim)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(projection_dim, vocab_size)

    def forward(self, input_values, attention_mask = None):
        hidden_states = self.projector(input_values, attention_mask)
        hidden_states = self.dropout(hidden_states)
        logits = self.classifier(hidden_states)
        return logits

In [45]:
import json

# Path to your JSON file
file_path = config['vocab_json_path']

# Open and load JSON data
with open(file_path, "r", encoding="utf-8") as f:
    vocab = json.load(f)

len(vocab)

3147

In [46]:
vocab_size = 3147  # your vocab size here

# This is complete model

multilingual_asr_model = CustomWav2Vec2CTC(model, vocab_size = vocab_size)

In [47]:
trainable_params = sum(p.numel() for p in multilingual_asr_model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in multilingual_asr_model.parameters())
print(f"Trainable: {trainable_params} / Total: {total_params}")

Trainable: 25343147 / Total: 2184602795


### Training Loop

In [48]:
optimizer = AdamW(multilingual_asr_model.parameters(), lr = lr)

In [49]:
# total number of training steps the model will take during the entire training process.
num_training_steps = epochs * len(train_dataloader)

# https://huggingface.co/docs/transformers/main_classes/optimizer_schedules#transformers.get_scheduler
# SchedulerType, please select one of ['linear', 'cosine', 'cosine_with_restarts', 'polynomial', 'constant', 'constant_with_warmup', 'inverse_sqrt', 'reduce_lr_on_plateau', 'cosine_with_min_lr', 'warmup_stable_decay']

lr_scheduler = get_scheduler(
                            "linear",
                            optimizer = optimizer,
                            num_warmup_steps = config['num_warmup_steps'],
                            num_training_steps = num_training_steps
                            )

In [50]:
num_training_steps

187530

In [51]:
# For mixed training of torch.float32, torch.float64
# https://pytorch.org/docs/stable/amp.html

# scaler = torch.cuda.amp.GradScaler()

scaler = torch.amp.GradScaler(device)

multilingual_asr_model = multilingual_asr_model.train()

In [52]:
import torch.nn.functional as F

ctc_loss_fn = nn.CTCLoss(blank = processor.tokenizer.pad_token_id, zero_infinity = True, reduction = 'mean')

### Start Training

In [53]:
def compute_metrics(labels, preds):
    
    # preds = torch.argmax(preds, axis=-1)

    labels[labels == -100] = processor.tokenizer.pad_token_id

    # print('The shape of the preds is', preds)
    # print('The shape of the label is', labels)
    pred_str = processor.batch_decode(preds)
    
    # print('The pred str is', pred_str)
    label_str = processor.batch_decode(labels, group_tokens = False)
    # print('The label str is', label_str)
    return wer(label_str, pred_str)

In [54]:
import os
import csv

# Define CSV directory and file
csv_dir = config['loss_csv_saving_path']
csv_file_name = config['loss_csv_name']
csv_file_path = os.path.join(csv_dir, csv_file_name)

# Ensure the directory exists
os.makedirs(csv_dir, exist_ok=True)

# If the CSV file doesn't exist, create it with headers
if not os.path.exists(csv_file_path):
    with open(csv_file_path, mode='w', newline='') as file:
        writer = csv.DictWriter(file, fieldnames=["step", "epoch", "loss", "wer"])
        writer.writeheader()

### Checkpoint saving variables

In [55]:
import os
import torch

new_checkpoint_dir = config['new_checkpoint_dir']

os.makedirs(new_checkpoint_dir, exist_ok=True)

start_epoch = 0
curr_best_loss = 1e9

resume_training = config['resume_training']  # set this to False to start fresh


model_name = config['model_name']

new_checkpoint_path = os.path.join(new_checkpoint_dir, model_name + ".pt")




if config['load_from_prev']:
    
    checkpoint_path = config['prev_checkpoint_dir']
    
    checkpoint = torch.load(checkpoint_path, map_location = 'cpu', weights_only = False)
    
    multilingual_asr_model.load_state_dict(checkpoint['model_state_dict'])
    
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    start_epoch = checkpoint['epoch']
    
    curr_best_loss = checkpoint.get('best_loss', curr_best_loss)
    
    print('Loaded form the previous model')
    print(f"Resumed training from epoch {start_epoch} with best loss {curr_best_loss:.4f}")
    

# Resume logic and save in new check dir since it is new chunk of the data
if resume_training and os.path.exists(new_checkpoint_path):
    
    checkpoint = torch.load(checkpoint_path)
    
    multilingual_asr_model.load_state_dict(checkpoint['model_state_dict'])
    
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    start_epoch = checkpoint['epoch']
    
    curr_best_loss = checkpoint.get('best_loss', curr_best_loss)
    
    print(f"Resumed training from epoch {start_epoch} with best loss {curr_best_loss:.4f}")

Loaded form the previous model
Resumed training from epoch 10 with best loss 1000000000.0000


In [56]:
start_epoch

10

In [57]:
multilingual_asr_model = multilingual_asr_model.to(device)

In [58]:
# Move optimizer state to the correct device
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

In [None]:
import warnings
warnings.filterwarnings("ignore")

curr_best_loss = 1e9

for n in tqdm(range(start_epoch, epochs)):
    
    losses = []
    wers = []

    total_number_batch = len(train_dataloader)
    
    for step, batch in enumerate(tqdm(train_dataloader)):

        optimizer.zero_grad()
        batch = {k: v.to(device) for k, v in batch.items()}


        with torch.autocast("cuda"):
            logits = multilingual_asr_model(batch["input_values"], attention_mask = batch.get("attention_mask"))
            
        log_probs = F.log_softmax(logits, dim = -1).transpose(0, 1)

        input_lengths = torch.full(
                                    size = (log_probs.size(1),),  # batch size
                                    fill_value = log_probs.size(0),  # time dimension
                                    dtype = torch.long).to(device)


        labels = batch["labels"]
        
        target_lengths = (labels != -100).sum(dim = 1)

        flattened_targets = labels[labels != -100]

        with torch.cuda.amp.autocast(enabled = False):
            loss = ctc_loss_fn(
                                log_probs.float(),         # ensure float32
                                flattened_targets,
                                input_lengths,
                                target_lengths
                                )
        
        losses.append(loss.item())

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        lr_scheduler.step()
        scaler.update()

        # WER computation
        preds = torch.argmax(log_probs, dim = -1).transpose(0, 1)  # back to (batch, time)
        
        labels_for_metrics = labels.clone()
        
        metrics = compute_metrics(labels_for_metrics, preds)
        wers.append(metrics)
        
        
        with open(csv_file_path, mode = 'a', newline = '') as file:
            writer = csv.DictWriter(file, fieldnames=["step", "epoch", "loss", "wer"])
            
            writer.writerow({
                                "step": step,
                                "epoch": n + 1,
                                "loss": loss.item(),
                                "wer": metrics
                            })


    result = {"loss": np.mean(losses), "wer": np.mean(wers)}    
    
    print("EPOCH: ", n + 1)
    print(result)
    print('=' * 100)
    
        
    if result["loss"] < curr_best_loss:
        
        print(f"New best model found at epoch {n + 1} with loss {result['loss']:.4f}, saving...")
        
        curr_best_loss = result["loss"]
    
    
    # Save latest checkpoint
    checkpoint = {
                        'epoch': n + 1,
                        'model_state_dict': multilingual_asr_model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': lr_scheduler.state_dict(),
                        'loss': result['loss'],
                        'best_loss': curr_best_loss
                    }
    
    
    torch.save(checkpoint, new_checkpoint_path)

print('Training completed.')

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/6251 [00:00<?, ?it/s]

EPOCH:  11
{'loss': np.float64(3.815136089990128), 'wer': np.float64(1.0008730460704365)}
New best model found at epoch 11 with loss 3.8151, saving...


RuntimeError: [enforce fail at inline_container.cc:626] . unexpected pos 6878457472 vs 6878457368