In [1]:
import json
import csv
import re
import os
import torch
import random
import torchaudio
import torch.nn as nn
import numpy as np
import pandas as pd
from jiwer import wer
from tqdm.auto import tqdm
import torch.nn.functional as F
from IPython.display import Audio
from functools import partial
from torch.utils.data import Dataset

from transformers import Wav2Vec2Model
import torch.nn as nn
from torch.utils.data import DataLoader
# AdamW is best optimizer
from torch.optim import AdamW
from transformers import get_scheduler
from transformers import Wav2Vec2Model, Wav2Vec2Processor, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer

In [2]:
config = {"VOCAB_PATH" : "/home/IITB/ai-at-ieor/23m1508/Shivam_23M1508/Interspeech/code/model_code/model_files/multilingual_vocab.json",
          "DEVICE" : 0,
          "BASE_MODEL_ID" : "facebook/wav2vec2-xls-r-2b",
          
          "infer_checkpoint_dir" : "/scratch/IITB/ai-at-ieor/23m1508/23m1508_backup/checkpoint_points_2/weights_401000_406000/multilingual_asr_model_401000_406000.pt"}

In [3]:
device = torch.device(f"cuda:{config["DEVICE"]}" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

### Model Codes

In [4]:
# Building the tokenizer
tokenizer = Wav2Vec2CTCTokenizer(config['VOCAB_PATH'],
                                 bos_token = "<s>",
                                 eos_token = "</s>",
                                 unk_token = "<unk>", 
                                 pad_token = "<pad>", 
                                 word_delimiter_token = "|")


feature_extractor = Wav2Vec2FeatureExtractor(feature_size = 1, 
                                             sampling_rate = 16000, 
                                             padding_value = 0.0, 
                                             do_normalize = True, 
                                             return_attention_mask = True)


processor = Wav2Vec2Processor(feature_extractor = feature_extractor, 
                              tokenizer = tokenizer)

In [5]:
# Loaidng the pretrained model
model = Wav2Vec2Model.from_pretrained(config['BASE_MODEL_ID'])

In [6]:
class Projector(nn.Module):
    def __init__(self, model, projection_dim = 5000):
        super().__init__()
        self.wav2vec2 = model
        self.projection = nn.Linear(1920, projection_dim)

    def forward(self, input_values, attention_mask = None):
        outputs = self.wav2vec2(input_values, attention_mask = attention_mask)
        hidden_states = outputs.last_hidden_state  # [batch, time, hidden]
        projected = self.projection(hidden_states)  # [batch, time, 5000]
        return projected


# Custom CTC model
class CustomWav2Vec2CTC(nn.Module):
    def __init__(self, model, vocab_size, projection_dim = 5000):
        super().__init__()

        self.projector = Projector(model, projection_dim = projection_dim)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(projection_dim, vocab_size)

    def forward(self, input_values, attention_mask = None):
        hidden_states = self.projector(input_values, attention_mask)
        hidden_states = self.dropout(hidden_states)
        logits = self.classifier(hidden_states)
        return logits

In [7]:
# Path to your JSON file
file_path = config['VOCAB_PATH']

# Open and load JSON data
with open(file_path, "r", encoding="utf-8") as f:
    vocab = json.load(f)

vocab_size = len(vocab)  # your vocab size here

In [8]:
vocab_size

3147

In [9]:
# This is complete model
multilingual_asr_model = CustomWav2Vec2CTC(model, vocab_size = vocab_size)

In [10]:
curr_best_loss = 1000

In [11]:
checkpoint_path = config['infer_checkpoint_dir']

map_location = device

checkpoint = torch.load(checkpoint_path, map_location = 'cpu', weights_only = False)
multilingual_asr_model.load_state_dict(checkpoint['model_state_dict'])
optimizer_state = checkpoint['optimizer_state_dict']
scheduler_state = checkpoint['scheduler_state_dict']
start_epoch = checkpoint['epoch']
curr_best_loss = checkpoint.get('best_loss', curr_best_loss)

print(f"Loaded checkpoint from {checkpoint_path}")
print(f"Resuming at epoch {start_epoch}, best loss: {curr_best_loss:.4f}")
print('Loaded form the previous model: ')
print(f"Loaded checkpoint from {checkpoint_path}")
print(f"Trained till to epoch {start_epoch} with best loss {curr_best_loss:.4f}")

Loaded checkpoint from /scratch/IITB/ai-at-ieor/23m1508/23m1508_backup/checkpoint_points_2/weights_401000_406000/multilingual_asr_model_401000_406000.pt
Resuming at epoch 254, best loss: 1.0442
Loaded form the previous model: 
Loaded checkpoint from /scratch/IITB/ai-at-ieor/23m1508/23m1508_backup/checkpoint_points_2/weights_401000_406000/multilingual_asr_model_401000_406000.pt
Trained till to epoch 254 with best loss 1.0442


In [12]:
multilingual_asr_model.to(device)

CustomWav2Vec2CTC(
  (projector): Projector(
    (wav2vec2): Wav2Vec2Model(
      (feature_extractor): Wav2Vec2FeatureEncoder(
        (conv_layers): ModuleList(
          (0): Wav2Vec2LayerNormConvLayer(
            (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
            (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (activation): GELUActivation()
          )
          (1-4): 4 x Wav2Vec2LayerNormConvLayer(
            (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
            (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (activation): GELUActivation()
          )
          (5-6): 2 x Wav2Vec2LayerNormConvLayer(
            (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,))
            (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (activation): GELUActivation()
          )
        )
      )
      (feature_projection): Wav2Vec2FeatureProjection(
        (layer_n

In [13]:
next(multilingual_asr_model.parameters()).device

device(type='cuda', index=0)

In [14]:
def preprocess_audio(file_path, target_sr = 16000):
    waveform, sample_rate = torchaudio.load(file_path)

    if sample_rate != target_sr:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)
        waveform = resampler(waveform)

    waveform = waveform.squeeze()

    inputs = processor.feature_extractor(
                                            waveform,
                                            sampling_rate=target_sr,
                                            return_tensors="pt"
                                        )
    
    inputs = inputs["input_values"]

    return inputs

### Final Inference Funtion

In [15]:
def infer(model, file_path):
    
    input_values = preprocess_audio(file_path).to(device)  # (1, time)
    # print('the shape of the input_values is', input_values.shape)

    with torch.no_grad():
        logits = model(input_values)
        log_probs = F.log_softmax(logits, dim = -1)

        predicted_ids = torch.argmax(log_probs, dim = -1)  # (batch, time)

        transcription = processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens = True)
        
        return transcription[0]

In [16]:
# file_path = "/scratch/IITB/ai-at-ieor/23m1508/23m1508_backup/Evaluation_Set_1/All_Evaluation_audio/Thai-0245_003_phone-O2-128859-128985.wav"

In [17]:
# infer(multilingual_asr_model, file_path)

In [18]:
evaluation_path = "/scratch/IITB/ai-at-ieor/23m1508/23m1508_backup/Evaluation_Set_1/evaluation_paths.csv"

In [19]:
evaluation_data = pd.read_csv(evaluation_path)
evaluation_data

Unnamed: 0,ID,Text,Path
0,French-0185_004_phone-O2-000779-000914,/scratch/IITB/ai-at-ieor/23m1508/23m1508_backu...,/scratch/IITB/ai-at-ieor/23m1508/23m1508_backu...
1,French-0185_004_phone-O2-001194-001283,/scratch/IITB/ai-at-ieor/23m1508/23m1508_backu...,/scratch/IITB/ai-at-ieor/23m1508/23m1508_backu...
2,French-0185_004_phone-O1-001324-001450,/scratch/IITB/ai-at-ieor/23m1508/23m1508_backu...,/scratch/IITB/ai-at-ieor/23m1508/23m1508_backu...
3,French-0185_004_phone-O2-001467-001585,/scratch/IITB/ai-at-ieor/23m1508/23m1508_backu...,/scratch/IITB/ai-at-ieor/23m1508/23m1508_backu...
4,French-0185_004_phone-O1-001651-001737,/scratch/IITB/ai-at-ieor/23m1508/23m1508_backu...,/scratch/IITB/ai-at-ieor/23m1508/23m1508_backu...
...,...,...,...
29294,Thai-0245_003_phone-O2-128309-128491,/scratch/IITB/ai-at-ieor/23m1508/23m1508_backu...,/scratch/IITB/ai-at-ieor/23m1508/23m1508_backu...
29295,Thai-0245_003_phone-O1-128502-128660,/scratch/IITB/ai-at-ieor/23m1508/23m1508_backu...,/scratch/IITB/ai-at-ieor/23m1508/23m1508_backu...
29296,Thai-0245_003_phone-O2-128661-128699,/scratch/IITB/ai-at-ieor/23m1508/23m1508_backu...,/scratch/IITB/ai-at-ieor/23m1508/23m1508_backu...
29297,Thai-0245_003_phone-O1-128700-128843,/scratch/IITB/ai-at-ieor/23m1508/23m1508_backu...,/scratch/IITB/ai-at-ieor/23m1508/23m1508_backu...


In [20]:
import pandas as pd

def run_inference_on_dataframe(df, model, output_path, device):
    
    with open(output_path, "w", encoding="utf-8") as f:
        # f.write("UtteranceID Hypothesis\n")  # header
        
        for i, row in tqdm(df.iterrows(), total = len(df), desc = "Inference Progress"):
            
            utt_id = row["ID"]
            audio_path = row["Path"]

            try:
                transcription = infer(model, audio_path)
            except Exception as e:
                transcription = f"[ERROR: {str(e)}]"
                print(transcription)
            # print(transcription)
            f.write(f"{utt_id} {transcription.strip()}\n")
            f.flush()  # <-- Force immediate write to disk
            
            # Explicit GPU memory cleanup
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()

In [22]:
run_inference_on_dataframe(df = evaluation_data, 
                           model = multilingual_asr_model, 
                           output_path = "/home/IITB/ai-at-ieor/23m1508/Shivam_23M1508/Interspeech/code/submission_code/submission_csv/submission_csv_ap1.txt", 
                           device = device)

Inference Progress:   0%|          | 0/29299 [00:00<?, ?it/s]

### Making the (text_space) with no extension

In [23]:
import re
from tqdm.auto import tqdm
from pathlib import Path

def add_space_between_chars(text):
    pattern = re.compile(
        r"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF\u3000-\u303F\uff01-\uff60\u0E00-\u0E7F])"
    )  # CJKT + Thai characters
    chars = pattern.split(text)
    chars = [ch for ch in chars if ch.strip()]
    text = " ".join(w for w in chars)
    text = re.sub(r"\s+", " ", text)
    return text

In [24]:
input_path = Path("/home/IITB/ai-at-ieor/23m1508/Shivam_23M1508/Interspeech/code/submission_code/submission_csv/submission_csv_ap1.txt")  # your input file path

output_path = Path("/home/IITB/ai-at-ieor/23m1508/Shivam_23M1508/Interspeech/code/submission_code/submission_csv/text_space")     # output file name without extension

In [None]:
with open(input_path, "r", encoding="utf-8") as fin, open(output_path, "w", encoding="utf-8") as fout:
    
    for line in tqdm(fin):
        
        line = line.strip()
        
        if not line:
            continue
        
        utt = line.split()[0]
        text = ' '.join(line.split()[1:])
        
        if 'Japanese' in utt or 'Korean' in utt or 'Thai' in utt:
            text = add_space_between_chars(text)
            
        fout.write(f"{utt} {text}\n")

0it [00:00, ?it/s]