In [4]:
# @title 1. Install & Setup
!apt-get update && apt-get install -y libsndfile1 ffmpeg

# We use transformers for the stable MMS model loader
# We use sentencepiece for the Tokenizer
!pip install --upgrade transformers sentencepiece torch torchaudio soundfile jiwer

import torch
import torch.nn as nn
import torchaudio
import soundfile as sf
import sentencepiece as spm
import numpy as np
import json
import os
import math
import random
from pathlib import Path
from tqdm import tqdm
from transformers import Wav2Vec2Model

# Setup Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

0% [Working]            Hit:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
0% [Connecting to archive.ubuntu.com (91.189.91.81)] [Connecting to security.ub                                                                               Hit:2 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease
Hit:3 https://cli.github.com/packages stable InRelease
Hit:4 http://security.ubuntu.com/ubuntu jammy-security InRelease
Hit:5 http://archive.ubuntu.com/ubuntu jammy InRelease
Hit:6 http://archive.ubuntu.com/ubuntu jammy-updates InRelease
Hit:7 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Hit:8 https://r2u.stat.illinois.edu/ubuntu jammy InRelease
Hit:9 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:10 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Hit:11 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Reading package lists... Done
W: Skipping acq

In [None]:
# @title 2. Data Preparation & Tokenizer Training (JSON/CSV Support) - Fixed
import os
import shutil
import glob
import pandas as pd
import sentencepiece as spm
import random
import json
from google.colab import files
from pathlib import Path

# --- CONFIGURATION ---
VOCAB_SIZE = 1000  # Keep small for small datasets
DATA_DIR = "/content/my_custom_data"
# ---------------------

# 1. Clean previous runs
if os.path.exists(DATA_DIR):
    shutil.rmtree(DATA_DIR)
os.makedirs(DATA_DIR, exist_ok=True)

# 2. Upload Zip File
print("Please upload your ZIP file containing .wav files and the metadata (JSON or CSV)...")
uploaded = files.upload()

if not uploaded:
    raise ValueError("No file uploaded. Please run the cell again and upload a zip.")

zip_filename = list(uploaded.keys())[0]
print(f"Extracting {zip_filename}...")

# 3. Extract Zip
import zipfile
with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
    zip_ref.extractall(DATA_DIR)

# 4. Find Metadata File (JSON or CSV)
metadata_files = list(Path(DATA_DIR).rglob("*.json")) + \
                 list(Path(DATA_DIR).rglob("*.jsonl")) + \
                 list(Path(DATA_DIR).rglob("*.csv"))

if not metadata_files:
    raise FileNotFoundError("Could not find a .json, .jsonl, or .csv file in the uploaded zip.")

meta_path = metadata_files[0]
print(f"Found Metadata File: {meta_path}")

# 5. Load Metadata into DataFrame
df = None

if meta_path.suffix == '.csv':
    df = pd.read_csv(meta_path)
elif meta_path.suffix == '.jsonl':
    df = pd.read_json(meta_path, lines=True)
elif meta_path.suffix == '.json':
    try:
        # Try standard list-of-dicts
        df = pd.read_json(meta_path)
    except ValueError:
        # If structure is complex (e.g. dict of dicts), try loading manually
        with open(meta_path, 'r') as f:
            data = json.load(f)
        # If it's a dict containing a list (e.g. {'data': [...]})
        if isinstance(data, dict):
            # Look for the first list value
            for k, v in data.items():
                if isinstance(v, list):
                    df = pd.DataFrame(v)
                    break
        if df is None:
             # Fallback: dict of entries
            df = pd.DataFrame.from_dict(data, orient='index')

if df is None:
    raise ValueError("Could not parse the JSON file structure.")

print("Metadata Head (First 5 rows):")
print(df.head())

# --- AUTO-DETECT COLUMNS ---
cols = df.columns.tolist()
filename_col = None
text_col = None

# Heuristic: find column names resembling 'path' or 'text'
for c in cols:
    c_lower = str(c).lower()
    if any(x in c_lower for x in ['path', 'file', 'audio', 'id', 'wav']):
        filename_col = c
    if any(x in c_lower for x in ['text', 'transcript', 'sentence', 'caption']):
        text_col = c

# Fallback defaults
if filename_col is None: filename_col = cols[0]
if text_col is None: text_col = cols[1]

print(f"Using Column '{filename_col}' for Audio Paths.")
print(f"Using Column '{text_col}' for Transcriptions.")

# 6. Build Manifest & Corpus
manifest = []
corpus_file = "corpus.txt"

# Map all wav files found in the extracted folder
wav_map = {p.name: p for p in Path(DATA_DIR).rglob("*.wav")}

# --- FIXED LINE BELOW ---
print(f"Found {len(wav_map)} .wav files in extracted folder.")

missing_count = 0
with open(corpus_file, "w") as txt_out:
    for idx, row in df.iterrows():
        # Get filename and clean it
        fname = str(row[filename_col]).strip()
        # Add extension if missing (common in some JSON datasets)
        if not fname.lower().endswith('.wav'):
            fname += '.wav'

        # Get text
        text = str(row[text_col]).strip().lower()

        # Match audio file
        fname_base = os.path.basename(fname)

        if fname_base in wav_map:
            full_path = str(wav_map[fname_base])

            # Write for Tokenizer
            txt_out.write(text + "\n")

            # Add to manifest
            manifest.append({
                "audio": full_path,
                "text": text
            })
        else:
            missing_count += 1

print(f"Successfully matched {len(manifest)} audio-text pairs.")
if missing_count > 0:
    print(f"Warning: {missing_count} entries in metadata did not match any WAV file.")

# 7. Train Tokenizer
print(f"Training BPE Tokenizer...")

# Dynamic vocab size for small data
estimated_vocab = VOCAB_SIZE
if len(manifest) < 50:
    estimated_vocab = 100
    print(f"Small dataset detected. Using vocab size: {estimated_vocab}")

spm.SentencePieceTrainer.train(
    input=corpus_file,
    model_prefix='spm_bpe',
    vocab_size=estimated_vocab,
    model_type='bpe',
    pad_id=0, unk_id=1, bos_id=2, eos_id=3,
    pad_piece='<pad>', unk_piece='<unk>', bos_piece='<s>', eos_piece='</s>',
    user_defined_symbols=[],
    hard_vocab_limit=False
)

# Load Tokenizer
tokenizer = spm.SentencePieceProcessor(model_file='spm_bpe.model')
print("Tokenizer ready.")

# 8. Split Data
if len(manifest) < 2:
    print("Warning: Only 1 sample found. Using it for both Train and Valid.")
    train_data = manifest
    valid_data = manifest
else:
    random.shuffle(manifest)
    split = int(len(manifest) * 0.8)
    if split == 0: split = 1
    train_data = manifest[:split]
    valid_data = manifest[split:]
    if not valid_data: valid_data = train_data

print(f"Train samples: {len(train_data)}, Valid samples: {len(valid_data)}")

Please upload your ZIP file containing .wav files and the metadata (JSON or CSV)...


In [6]:
# @title 3. Build Generative ASR Model
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class GenerativeASR(nn.Module):
    def __init__(self, encoder_name, vocab_size, embed_dim=512, num_decoder_layers=4):
        super().__init__()

        # 1. Encoder: Load MMS-300M from Hugging Face
        print(f"Loading Encoder: {encoder_name}...")
        self.encoder = Wav2Vec2Model.from_pretrained(encoder_name)

        # Freeze encoder to save memory (Optional: set to True to fine-tune encoder too)
        self.encoder.feature_extractor._requires_grad = False

        # 2. Adapter: Project Encoder Dim (1024) -> Decoder Dim (512)
        enc_dim = self.encoder.config.hidden_size
        self.enc_to_dec_proj = nn.Linear(enc_dim, embed_dim)

        # 3. Decoder: Standard Transformer
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.pos_encoder = PositionalEncoding(embed_dim)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim, nhead=8, dim_feedforward=2048, dropout=0.1, batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        # 4. Output Head
        self.output_proj = nn.Linear(embed_dim, vocab_size)

    def forward(self, audio_values, audio_attn_mask, text_input_ids, text_pad_mask):
        """
        audio_values: [Batch, Time]
        audio_attn_mask: [Batch, Time] (1=Valid, 0=Pad) - Standard HF format
        text_input_ids: [Batch, Seq]
        text_pad_mask: [Batch, Seq] (True=Pad) - Standard PyTorch format
        """

        # --- Encode Audio ---
        # HF Wav2Vec2 forward
        enc_out = self.encoder(audio_values, attention_mask=audio_attn_mask)
        enc_hidden = enc_out.last_hidden_state # [Batch, AudioFrames, 1024]

        # Project to Decoder Size
        memory = self.enc_to_dec_proj(enc_hidden) # [Batch, AudioFrames, 512]

        # --- Decode Text ---
        # Embed and Add Position Info
        tgt = self.embed(text_input_ids) # [Batch, Seq, 512]
        tgt = self.pos_encoder(tgt)

        # Generate Causal Mask (Prevent looking ahead)
        seq_len = tgt.size(1)
        causal_mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1).to(device)

        # Run Decoder
        # Note: We skip memory_key_padding_mask for simplicity here (model learns to ignore padded audio features)
        output = self.decoder(
            tgt=tgt,
            memory=memory,
            tgt_mask=causal_mask,
            tgt_key_padding_mask=text_pad_mask
        )

        logits = self.output_proj(output) # [Batch, Seq, Vocab]
        return logits

# Initialize
# We use 'facebook/mms-300m' which is robust.
model = GenerativeASR("facebook/mms-300m", vocab_size=VOCAB_SIZE).to(device)
print("Model initialized successfully.")

Loading Encoder: facebook/mms-300m...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

Model initialized successfully.


In [9]:
# @title 4. Training Loop
# Hyperparams
BATCH_SIZE = 4
LR = 1e-4
EPOCHS = 100

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss(ignore_index=0) # Ignore Padding

# Collate Function (Batches variable length data)
def collate_fn(batch):
    audio_list, text_in_list, text_out_list = [], [], []

    for item in batch:
        # Load Audio
        wav, sr = sf.read(item['audio'])
        # Ensure 16k
        if sr != 16000:
            # In real code, use torchaudio.transforms.Resample here
            pass
        audio_list.append(torch.tensor(wav, dtype=torch.float32))

        # Tokenize Text
        # Input: <s> + text
        # Target: text + </s>
        ids = tokenizer.encode(item['text'], out_type=int)
        text_in_list.append(torch.tensor([tokenizer.bos_id()] + ids, dtype=torch.long))
        text_out_list.append(torch.tensor(ids + [tokenizer.eos_id()], dtype=torch.long))

    # Pad Audio
    audio_padded = torch.nn.utils.rnn.pad_sequence(audio_list, batch_first=True).to(device)

    # Create Audio Attention Mask (1 for Valid, 0 for Pad)
    # We first create lengths, then a mask
    audio_lens = [len(x) for x in audio_list]
    max_len = max(audio_lens)
    audio_mask = torch.zeros(len(audio_list), max_len).to(device)
    for i, l in enumerate(audio_lens):
        audio_mask[i, :l] = 1

    # Pad Text (Padding Value = 0)
    text_in_padded = torch.nn.utils.rnn.pad_sequence(text_in_list, batch_first=True, padding_value=0).to(device)
    text_out_padded = torch.nn.utils.rnn.pad_sequence(text_out_list, batch_first=True, padding_value=0).to(device)

    # Create Text Padding Mask (True where Pad exists)
    text_pad_mask = (text_in_padded == 0)

    return audio_padded, audio_mask, text_in_padded, text_out_padded, text_pad_mask

# Train
model.train()
print("Starting training...")

for epoch in range(EPOCHS):
    random.shuffle(train_data)
    total_loss = 0

    # Batch Iterator
    for i in range(0, len(train_data), BATCH_SIZE):
        batch = train_data[i : i+BATCH_SIZE]
        if not batch: break

        # Prepare Inputs
        aud, aud_mask, txt_in, txt_out, txt_mask = collate_fn(batch)

        # Forward
        logits = model(aud, aud_mask, txt_in, txt_mask)

        # Loss (Flatten [Batch*Seq, Vocab])
        loss = criterion(logits.reshape(-1, VOCAB_SIZE), txt_out.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {total_loss / (len(train_data)/BATCH_SIZE):.4f}")

print("Training Complete.")

Starting training...
Epoch 1/100 | Loss: 11.1617
Epoch 2/100 | Loss: 10.4625
Epoch 3/100 | Loss: 9.7826
Epoch 4/100 | Loss: 9.2090
Epoch 5/100 | Loss: 8.6998
Epoch 6/100 | Loss: 8.1564
Epoch 7/100 | Loss: 7.6395
Epoch 8/100 | Loss: 7.1127
Epoch 9/100 | Loss: 6.6067
Epoch 10/100 | Loss: 6.0991
Epoch 11/100 | Loss: 5.5642
Epoch 12/100 | Loss: 5.0738
Epoch 13/100 | Loss: 4.6488
Epoch 14/100 | Loss: 4.2573
Epoch 15/100 | Loss: 3.8260
Epoch 16/100 | Loss: 3.3988
Epoch 17/100 | Loss: 3.0443
Epoch 18/100 | Loss: 2.6928
Epoch 19/100 | Loss: 2.4160
Epoch 20/100 | Loss: 2.1036
Epoch 21/100 | Loss: 1.8639
Epoch 22/100 | Loss: 1.6350
Epoch 23/100 | Loss: 1.4354
Epoch 24/100 | Loss: 1.2428
Epoch 25/100 | Loss: 1.1121
Epoch 26/100 | Loss: 0.9440
Epoch 27/100 | Loss: 0.8675
Epoch 28/100 | Loss: 0.7462
Epoch 29/100 | Loss: 0.6390
Epoch 30/100 | Loss: 0.5534
Epoch 31/100 | Loss: 0.4942
Epoch 32/100 | Loss: 0.4370
Epoch 33/100 | Loss: 0.3888
Epoch 34/100 | Loss: 0.3402
Epoch 35/100 | Loss: 0.3126
Epoch 

In [10]:
# @title 5. Inference / Testing
def transcribe(audio_path, model, tokenizer, max_tokens=50):
    model.eval()

    # 1. Load and Prepare Audio
    wav, sr = sf.read(audio_path)
    audio_tensor = torch.tensor(wav, dtype=torch.float32).unsqueeze(0).to(device) # [1, Time]
    audio_mask = torch.ones_like(audio_tensor).to(device) # All valid

    with torch.no_grad():
        # 2. Encode Audio (Once)
        enc_out = model.encoder(audio_tensor, attention_mask=audio_mask)
        enc_hidden = enc_out.last_hidden_state
        memory = model.enc_to_dec_proj(enc_hidden) # [1, Seq, 512]

        # 3. Autoregressive Loop
        # Start with <s> (BOS)
        curr_tokens = [tokenizer.bos_id()]

        for _ in range(max_tokens):
            # Prepare text input tensor
            tgt_in = torch.tensor([curr_tokens], dtype=torch.long).to(device)

            # Embed
            tgt_emb = model.embed(tgt_in)
            tgt_emb = model.pos_encoder(tgt_emb)

            # Causal Mask
            seq_len = tgt_in.size(1)
            causal_mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1).to(device)

            # Decode
            # Note: We pass the WHOLE sequence so far.
            # (Optimized inference would cache Key/Values, but this is safer for simple scripts)
            out = model.decoder(tgt=tgt_emb, memory=memory, tgt_mask=causal_mask)

            # Get logits of the LAST token only
            last_token_logits = model.output_proj(out[:, -1, :])

            # Greedy Choice
            next_token_id = torch.argmax(last_token_logits, dim=-1).item()

            # Check EOS
            if next_token_id == tokenizer.eos_id():
                break

            curr_tokens.append(next_token_id)

    # Decode to string
    # Skip the first BOS token
    return tokenizer.decode(curr_tokens[1:])

# Run on a validation sample
if len(valid_data) > 0:
    sample = valid_data[0]
else:
    sample = train_data[0] # Fallback if valid empty

print(f"--- Inference Test ---")
print(f"File: {sample['audio']}")
print(f"Ground Truth: {sample['text']}")
predicted = transcribe(sample['audio'], model, tokenizer)
print(f"Prediction:   {predicted}")

--- Inference Test ---
File: /content/my_custom_data/dataset/audio_files/588001011271948_chL_seg002.wav
Ground Truth: <hesitation> bu tahminiy bo'lsa ham qachondan qachongacha bo'ldi masalan bir yilmi yo uch oymi osha muddatini bilsam bo'ladimi qachon boshlangan.
Prediction:   volte deganda xurmatli abonent bu xizmat aynan tort ji tarmog'i asosida aloqa va internet xizmatlar
