In [1]:
# Libs
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, Wav2Vec2Processor, Wav2Vec2Model, AutoModelForSeq2SeqLM
from tqdm import tqdm
import os
from datasets import load_dataset
from IPython.display import Audio, display
import json

2025-06-22 12:57:24.618949: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750597044.796914      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750597044.849587      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
# Configs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
!git clone https://github.com/VarunGumma/IndicTransToolkit
!pwd
%cd IndicTransToolkit

!python3 -m pip install --editable ./

Cloning into 'IndicTransToolkit'...
remote: Enumerating objects: 245, done.[K
remote: Counting objects: 100% (150/150), done.[K
remote: Compressing objects: 100% (89/89), done.[K
remote: Total 245 (delta 74), reused 108 (delta 49), pack-reused 95 (from 1)[K
Receiving objects: 100% (245/245), 4.45 MiB | 23.24 MiB/s, done.
Resolving deltas: 100% (102/102), done.
/kaggle/working
/kaggle/working/IndicTransToolkit
Obtaining file:///kaggle/working/IndicTransToolkit
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting sacremoses (from IndicTransToolkit==1.0.4)
  Downloading sacremoses-0.1.1-py3-none-any.whl.metadata (8.3 kB)
Collecting sacrebleu (from IndicTransToolkit==1.0.4)
  Downloading sacrebleu-2.5.1-py3-none-any.whl.metadata (51 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━

In [4]:
!pwd

/kaggle/working/IndicTransToolkit


In [5]:
from IndicTransToolkit.processor import IndicProcessor
%cd ..

/kaggle/working


In [6]:
# Complete multimodal punctuation training script

PUNCT_LABELS = {"O": 0, ",": 1, ".": 2, "?": 3, ";": 4}  # No punct, comma, period, question, semicolon

# ---------------------- Dataset ---------------------- #

class FleursPunctuationDataset(Dataset):
    def __init__(self, alignment_jsonl, tokenizer_name, indictrans_name, wav2vec_name):
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
        # self.wav2vec_processor = Wav2Vec2Processor.from_pretrained(wav2vec_name, trust_remote_code=True)
        # self.wav2vec_model = Wav2Vec2Model.from_pretrained(wav2vec_name, trust_remote_code=True).to(device).eval()
        w2v_temp = torch.load('/kaggle/input/w2v_feats/w2v_train_output.pt')
        self.w2v_features = {item["id"]: item["feature"] for item in w2v_temp}
        # self.indic_encoder = AutoModelForSeq2SeqLM.from_pretrained(indictrans_name, trust_remote_code=True, torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device).eval()
        indic_temp = torch.load('/kaggle/input/indic_feats/indictrans_enc_unpunct_fleurs_train.pt')
        self.indic_features = {item["id"]: item["encoded_vector"] for item in indic_temp}

        with open(alignment_jsonl, 'r') as f:
            alignment_lines = [json.loads(line) for line in f]
            self.ids = [line['id'] for line in alignment_lines]
            self.alignments = {line['id']: line['words'] for line in alignment_lines}

        from datasets import load_dataset
        fleurs = load_dataset("google/fleurs", "en_us", split="train", trust_remote_code=True)
        self.data = {}
        for example in fleurs:
            if example['id'] in self.ids:
                self.data[example['id']] = {
                    'id': example['id'],
                    # 'audio': example['audio']['array'],
                    'w2v_feat': self.w2v_features[example['id']],
                    'indic_feat': self.indic_features[example['id']],
                    'text': example['transcription'].strip(),
                    'raw_text': example['raw_transcription'].strip()
                }

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        audio_id = self.ids[idx]
        entry = self.data[audio_id]
        # waveform = torch.tensor(entry['audio']).unsqueeze(0)
        # sr = 16000  # [1, T]

        # # Extract acoustic features using wav2vec2
        # with torch.no_grad():
        #     input_values = self.wav2vec_processor(waveform.squeeze(0).numpy(), sampling_rate=sr, return_tensors="pt", padding=True).input_values.to(device)
        #     acoustic_hidden = self.wav2vec_model(input_values).last_hidden_state[0]  # [T', 768]
        # acoustic_feats = acoustic_hidden.cpu()  # [T', D]

        # Extract lexical features using indictrans2
        with torch.no_grad():
            text_input = self.tokenizer(entry['text'], return_tensors="pt").to(device)
            text_feats = self.indic_encoder(**text_input).last_hidden_state[0]  # [L, D]

        input_ids = text_input['input_ids'][0]
        tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
        punct_targets = self.get_punctuation_labels(entry['raw_text'], tokens)

        word_times = [(w['start'], w['end']) for w in self.alignments[audio_id]]

        return {
            'id': audio_id,
            'text_tokens': input_ids,
            'punct_labels': punct_targets,
            'text_feats': text_feats.cpu(),
            'audio_feats': entry['w2v_feat'].cpu(),
            'word_times': word_times,
        }

    def get_punctuation_labels(self, raw_text, tokens):
        labels = []
        punct_map = {".": ".", ",": ",", "?": "?", ";": ";"}
        text = raw_text.replace("\u2019", "'")
        words = text.split()
        i = 0
        for tok in tokens:
            if tok.startswith("▁"):
                label = "O"
                if i < len(words) and words[i][-1] in punct_map:
                    label = punct_map[words[i][-1]]
                    words[i] = words[i][:-1]
                i += 1
            else:
                label = "O"
            labels.append(PUNCT_LABELS[label])
        return torch.tensor(labels, dtype=torch.long)


# ------------------- Acoustic Encoder ------------------- #

class AcousticEncoder(nn.Module):
    def __init__(self, input_dim=768, conv_out_dim=1024, lstm_hidden=1024):
        super().__init__()
        self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=conv_out_dim, kernel_size=5, padding=2)
        self.lstm = nn.LSTM(input_size=conv_out_dim, hidden_size=lstm_hidden, batch_first=True)

    def forward(self, features):
        features = features.transpose(1, 2)  # [B, D, T]
        conv_out = self.conv1d(features)     # [B, C, T]
        conv_out = conv_out.transpose(1, 2)  # [B, T, C]
        lstm_out, _ = self.lstm(conv_out)    # [B, T, H]
        return lstm_out

# ------------------- Punctuation Model ------------------- #

class PunctuationModel(nn.Module):
    def __init__(self, text_dim, audio_dim, hidden_dim=1024, num_classes=5):
        super().__init__()
        self.fuse = nn.Linear(text_dim + audio_dim, num_classes)

    def forward(self, text_feats, audio_feats):
        fused = torch.cat([text_feats, audio_feats], dim=-1)
        logits = self.fuse(fused)
        return logits


# ---------------------- Training Loop ---------------------- #

def train(model, dataloader, optimizer, num_epochs):
    model.train()
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in tqdm(dataloader):
            input_ids = batch['text_tokens']
            text_feats = torch.stack(batch['text_feats']).to(device)
            audio_feats = align_audio_to_tokens(batch, text_feats.size(1)).to(device)
            labels = torch.stack(batch['punct_labels']).to(device)

            audio_encoded = model[0](audio_feats.unsqueeze(0)).squeeze(0)
            logits = model[1](text_feats, audio_encoded)
            loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1} Loss: {total_loss/len(dataloader):.4f}")


# ---------------------- Audio Alignment ---------------------- #

def align_audio_to_tokens(batch, L):
    aligned_batch = []
    for b in range(len(batch['audio_feats'])):
        frame_feats = batch['audio_feats'][b]  # [T', D]
        word_times = batch['word_times'][b]
        word_feats = []
        for start, end in word_times:
            start_idx = int(start * 50)
            end_idx = int(end * 50)
            pooled = frame_feats[end_idx - 1] if end_idx > start_idx else frame_feats[start_idx]
            word_feats.append(pooled)

        aligned_feats = []
        word_ptr = 0
        tokens = batch['text_tokens'][b]
        for tok in tokens:
            if tok.item() >= 0 and tokenizer.convert_ids_to_tokens([tok])[0].startswith("▁"):
                aligned_feats.append(word_feats[word_ptr])
                word_ptr += 1
            else:
                aligned_feats.append(word_feats[word_ptr-1])
        aligned_tensor = torch.stack(aligned_feats, dim=0)
        aligned_batch.append(aligned_tensor)
    return torch.stack(aligned_batch, dim=0)

In [7]:
# ---------------------- Example Usage ---------------------- #

ALIGNMENT_JSONL = '/kaggle/input/fleurs-alignment/alignment_output_cleaned_deduplicated.jsonl'
TOKENIZER = 'ai4bharat/indictrans2-indic-en-1B'
INDIC_MODEL = 'ai4bharat/indictrans2-indic-en-1B'
W2V_MODEL = 'facebook/wav2vec2-base-960h'

dataset = FleursPunctuationDataset(ALIGNMENT_JSONL, TOKENIZER, INDIC_MODEL, W2V_MODEL)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=lambda x: {k: [d[k] for d in x] for k in x[0]})

tokenizer_config.json:   0%|          | 0.00/1.10k [00:00<?, ?B/s]

tokenization_indictrans.py:   0%|          | 0.00/8.04k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/ai4bharat/indictrans2-indic-en-1B:
- tokenization_indictrans.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


dict.SRC.json:   0%|          | 0.00/3.39M [00:00<?, ?B/s]

dict.TGT.json:   0%|          | 0.00/645k [00:00<?, ?B/s]

model.SRC:   0%|          | 0.00/3.26M [00:00<?, ?B/s]

model.TGT:   0%|          | 0.00/759k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/96.0 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/159 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/163 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.60k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/291 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/378M [00:00<?, ?B/s]

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


config.json:   0%|          | 0.00/1.37k [00:00<?, ?B/s]

configuration_indictrans.py:   0%|          | 0.00/14.2k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/ai4bharat/indictrans2-indic-en-1B:
- configuration_indictrans.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_indictrans.py:   0%|          | 0.00/79.8k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/ai4bharat/indictrans2-indic-en-1B:
- modeling_indictrans.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/4.09G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/163 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/13.3k [00:00<?, ?B/s]

fleurs.py:   0%|          | 0.00/12.5k [00:00<?, ?B/s]

train.tar.gz:   0%|          | 0.00/1.38G [00:00<?, ?B/s]

dev.tar.gz:   0%|          | 0.00/171M [00:00<?, ?B/s]

test.tar.gz:   0%|          | 0.00/290M [00:00<?, ?B/s]

train.tsv:   0%|          | 0.00/1.41M [00:00<?, ?B/s]

dev.tsv:   0%|          | 0.00/213k [00:00<?, ?B/s]

test.tsv:   0%|          | 0.00/368k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [None]:
acoustic_encoder = AcousticEncoder().to(device)
model = nn.Sequential(
    acoustic_encoder,
    PunctuationModel(text_dim=1024, audio_dim=1024)
).to(device)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-4)

for name, param in model.named_parameters():
    if 'proj_audio' not in name and 'fuse' not in name and 'context' not in name and 'classifier' not in name:
        param.requires_grad = False

train(model, dataloader, optimizer, num_epochs=5)