In [1]:
# !pip install -q transformers jiwer torchaudio jsonlines datasets accelerate audiomentations # Audio Augmentation
# !pip install -q Cython
# !pip install openai-whisper

In [1]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

from typing import Optional, Dict, Union, List
from dataclasses import dataclass

import numpy as np
import random
import re
import json
import jsonlines
# from tqdm import tqdm

from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift
from jiwer import wer
import whisper

import torch
from torch.utils.data import IterableDataset, DataLoader
import torchaudio
from torchaudio import transforms
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, Seq2SeqTrainingArguments, Seq2SeqTrainer

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping

from tqdm import tqdm


print("CUDA available:", torch.cuda.is_available())
print("Number of GPUs:", torch.cuda.device_count())
print("GPU Name:", torch.cuda.get_device_name(0))

  from .autonotebook import tqdm as notebook_tqdm


CUDA available: True
Number of GPUs: 1
GPU Name: NVIDIA GeForce RTX 3060 Laptop GPU


### Defining Directories

In [13]:
cur_dir = os.getcwd()
src_dir = os.path.dirname(cur_dir)
til_dir = os.path.dirname(os.path.dirname(src_dir))
home_dir = os.path.dirname(til_dir)
novice_dir = os.path.join(home_dir, 'novice')
audio_dir = os.path.join(novice_dir, 'audio')
data_dir = os.path.join(cur_dir, 'data')
model_path = os.path.join(src_dir, "models", "whisper")
metadata_path = os.path.join(novice_dir, "asr.jsonl")

# paths for converting datasets to manifest files
train_dir = os.path.join(data_dir, "train")
test_dir = os.path.join(data_dir, "test")
val_dir = os.path.join(data_dir, "val")

novice_dir

'/home/rachtrx/workspace/til-ai/novice'

### Split

In [15]:
def split_data(data, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, seed=42):
    
    random.seed(seed)

    total_examples = len(data['audios'])
    indices = list(range(total_examples))
    random.shuffle(indices)
    
    train_end = int(train_ratio * total_examples)
    val_end = train_end + int(val_ratio * total_examples)
    
    train_indices = indices[:train_end]
    val_indices = indices[train_end:val_end]
    test_indices = indices[val_end:]
    
    train_data = {'audios': [data['audios'][i] for i in train_indices], 'sentences': [data['sentences'][i] for i in train_indices]}
    val_data = {'audios': [data['audios'][i] for i in val_indices], 'sentences': [data['sentences'][i] for i in val_indices]}
    test_data = {'audios': [data['audios'][i] for i in test_indices], 'sentences': [data['sentences'][i] for i in test_indices]}
    
    return train_data, val_data, test_data

MAX_FILE_COUNT = None # Set if only want max files

data = {'audios': [], 'sentences': []}
data_path = os.path.join(novice_dir, "asr.jsonl")
with jsonlines.open(metadata_path) as reader:
    for obj in reader:
        if MAX_FILE_COUNT and len(data['audio']) >= MAX_FILE_COUNT:
            break
        data['audios'].append(os.path.join(audio_dir, obj['audio']))
        data['sentences'].append(obj['transcript'])

train_data, val_data, test_data = split_data(data)

In [16]:
class AudioPreprocessor:
    
    def __init__(self, dataset, output_dir, tokenizer, batch_size=4, max_length=30):
        self.dataset = dataset
        self.output_dir = output_dir
        self.batch_size = batch_size
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.augmentations = Compose([
            AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.15),
            TimeStretch(min_rate=0.8, max_rate=1.25, p=0.15),
            PitchShift(min_semitones=-4, max_semitones=4, p=0.15),
        ])
        self.padding_audio = np.zeros((80, 3000))  # Example shape for mel spectrogram
        self.padding_sentence = [self.tokenizer.pad_token_id] * self.max_length

    def preprocess_data(self):
        audios = self.dataset['audios']
        sentences = self.dataset['sentences']
        num_batches = (len(audios) + self.batch_size - 1) // self.batch_size
        
        for batch_idx in tqdm(range(num_batches), desc="Processing Batches"):
            batch_audio = audios[batch_idx * self.batch_size:(batch_idx + 1) * self.batch_size]
            batch_sentences = sentences[batch_idx * self.batch_size:(batch_idx + 1) * self.batch_size]
            batch_data = list(zip(batch_audio, batch_sentences))
            
            input_ids_arr, decoded_input_ids_arr, labels_arr = self.process_batch(batch_data)
            input_ids_arr, decoded_input_ids_arr, labels_arr = self.pad_batch(input_ids_arr, decoded_input_ids_arr, labels_arr)
            self.save_batch(batch_idx, input_ids_arr, decoded_input_ids_arr, labels_arr)
        return num_batches

    def process_batch(self, batch_data):
        
        input_ids_arr = []
        decoded_input_ids_arr = []
        labels_arr = []
        
        for file_path, sentence in batch_data:
        
            sample = self.load_audio(file_path)
            # if self.transform:
            #     input_ids = self.transform(sample['input_ids'])
            # else:
            input_ids = sample['input_ids']

            tokenized_output = self.tokenizer(
                sentence,
                padding='max_length',  # Pad to max_length
                max_length=self.max_length,  # Specify the maximum length
                truncation=True,  # Truncate if longer than max_length
                return_tensors='pt'  # Return PyTorch tensors
            )
            labels = tokenized_output['input_ids'][0].numpy().tolist()  # Convert tensor to list

            # Generate decoder_input_ids from labels by shifting them to the right
            decoder_input_ids = [self.tokenizer.pad_token_id] + labels[:-1]
            
            assert len(decoder_input_ids) == len(labels)
            
            input_ids_arr.append(input_ids)
            decoded_input_ids_arr.append(decoder_input_ids)
            labels_arr.append(labels)
            
        return input_ids_arr, decoded_input_ids_arr, labels_arr

    def load_audio(self, file_path):
        waveform, sample_rate = torchaudio.load(file_path)
        waveform = waveform.numpy().flatten() # waveform is a list
        
        # Apply augmentations if provided
        waveform = self.augmentations(samples=waveform, sample_rate=sample_rate)

        # Resample if needed
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
            waveform = resampler(torch.tensor(waveform)).numpy().flatten()
            
        # Compute log-mel spectrogram
        input_ids = self.to_pad_to_mel(waveform)

        # Extract audio features
        return { 'input_ids': input_ids }
    
    def save_batch(self, batch_idx, input_ids_arr, decoded_input_ids_arr, labels_arr):
        batch_output_dir = os.path.join(self.output_dir, f"batch_{batch_idx}")
        os.makedirs(batch_output_dir, exist_ok=True)

        # Save input_ids
        input_ids_path = os.path.join(batch_output_dir, "input_ids.npy")
        np.save(input_ids_path, input_ids_arr)

        # Save decoder_input_ids
        decoder_input_ids_path = os.path.join(batch_output_dir, "decoder_input_ids.npy")
        np.save(decoder_input_ids_path, decoded_input_ids_arr)

        # Save labels
        labels_path = os.path.join(batch_output_dir, "labels.npy")
        np.save(labels_path, labels_arr)
        
    def pad_batch(self, input_ids_arr, decoded_input_ids_arr, labels_arr):
        while len(input_ids_arr) < self.batch_size:
            input_ids_arr.append(self.padding_audio)
            decoded_input_ids_arr.append(self.padding_sentence)
            labels_arr.append(self.padding_sentence)
        
        return input_ids_arr, decoded_input_ids_arr, labels_arr
    
    ## Referred to https://huggingface.co/sanchit-gandhi/whisper-medium-switchboard-5k/blob/main/run_speech_recognition_whisper.py by sanchit-gandhi
    @staticmethod
    def to_pad_to_mel(array):
        """Static function which:
            1. Pads/trims a list of audio arrays to a max length of 30s
            2. Computes log-mel filter coefficients from padded/trimmed audio sequences
            Inputs:
                array: list of audio arrays
            Returns:
                input_ids: torch.tensor of log-mel filter bank coefficients
        """
        padded_input = whisper.pad_or_trim(np.asarray(array, dtype=np.float32))
        input_ids = whisper.log_mel_spectrogram(padded_input)
        return input_ids

In [17]:
# model_name = "distil-whisper/distil-medium.en"  # You can change this to any model you want to use
# save_directory = "../models/whisper"  # Path to save the model and processor

# model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name, torch_dtype=torch.float16, low_cpu_mem_usage=True)
# model.save_pretrained(save_directory)
# processor = AutoProcessor.from_pretrained(model_name)
# processor.save_pretrained(save_directory)

model_path = "../models/whisper"  # Path where the model and processor are saved
# Load the model
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
model = model.float()
# Load the processor
processor = AutoProcessor.from_pretrained(model_path)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [18]:
train_processor = AudioPreprocessor(train_data, output_dir=train_dir, tokenizer=processor.tokenizer)
val_processor = AudioPreprocessor(val_data, output_dir=val_dir, tokenizer=processor.tokenizer)
test_processor = AudioPreprocessor(test_data,output_dir=test_dir, tokenizer=processor.tokenizer)

In [19]:
train_processor.preprocess_data()
val_processor.preprocess_data()
test_processor.preprocess_data()

Processing Batches: 100%|██████████| 700/700 [06:06<00:00,  1.91it/s]
Processing Batches: 100%|██████████| 88/88 [00:47<00:00,  1.85it/s]
Processing Batches: 100%|██████████| 88/88 [00:48<00:00,  1.82it/s]


88

### Setup Custom Classes

In [20]:
# Define your ASR model
class ASRModel(pl.LightningModule):
    def __init__(self, model, processor):
        super().__init__()
        self.model = model
        self.processor = processor
        self.loss_function = torch.nn.CrossEntropyLoss()
        
    def forward(self, input_ids, decoder_input_ids=None):
        return self.model(input_features=input_ids, decoder_input_ids=decoder_input_ids)
    
    def compute_wer(self, logits, labels):
        pred_ids = torch.argmax(logits, dim=-1)
        pred_str = self.processor.batch_decode(pred_ids)
        label_str = self.processor.batch_decode(labels, skip_special_tokens=True)
        return wer(label_str, pred_str), pred_str, label_str

    def training_step(self, batch, batch_idx):
        input_ids, decoder_input_ids, labels = batch
        outputs = self(input_ids, decoder_input_ids=decoder_input_ids)
        logits = outputs.logits
        
        # Reshape logits to (batch_size * sequence_length, vocab_size)
        logits = logits.view(-1, logits.size(-1))
        labels = labels.view(-1)

        loss = self.loss_function(logits, labels)
        self.log('train_loss', loss)
        
        wer_value = self.compute_wer(logits, labels)[0]
        self.log('train_wer', wer_value, prog_bar=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids, decoder_input_ids, labels = batch
        outputs = self(input_ids, decoder_input_ids=decoder_input_ids)
        logits = outputs.logits
        
        # Reshape logits to (batch_size * sequence_length, vocab_size)
        logits = logits.view(-1, logits.size(-1))
        labels = labels.view(-1)

        loss = self.loss_function(logits, labels)
        self.log('val_loss', loss)
        
        wer_value = self.compute_wer(logits, labels)[0]
        self.log('val_wer', wer_value, prog_bar=True)
        
        return loss
    
    def test_step(self, batch):
        input_ids, decoder_input_ids, labels = batch
        self.test_results = []
        with torch.no_grad():
            outputs = self(input_ids, decoder_input_ids=decoder_input_ids)
            logits = outputs.logits
            
            # Reshape logits to (batch_size * sequence_length, vocab_size)
            logits = logits.view(-1, logits.size(-1))
            labels = labels.view(-1)

            test_loss = self.loss_function(logits, labels)
            self.log('test_loss', test_loss)

            wer_value, pred_str, label_str = self.compute_wer(logits, labels)

            # Store results
            for pred, actual in zip(pred_str, label_str):
                self.test_results.append({'predicted': pred, 'actual': actual})
            
            wer_value = self.compute_wer(logits, labels)[0]
            self.log('test_wer', wer_value, prog_bar=True)
            
            return test_loss

    def configure_optimizers(self):
        # Implement your optimizer configuration here
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        return optimizer
    
    @classmethod
    def load_from_checkpoint(cls, checkpoint_path, model, processor):
        # Load the checkpoint
        checkpoint = torch.load(checkpoint_path)
        # Initialize the model
        instance = cls(model, processor)
        # Load the state dict into the model
        instance.load_state_dict(checkpoint['state_dict'])
        return instance
    
class ASRIterableDataset(IterableDataset):
    def __init__(self, data, tokenizer):
        self.type_dir, self.num_batches = data
        self.tokenizer = tokenizer

    def __iter__(self):
        device = torch.device("cuda")  # Define the device as GPU
        for batch_idx in range(self.num_batches):
            batch_output_dir = os.path.join(self.type_dir, f"batch_{batch_idx}")

            # Load input_ids
            input_ids_path = os.path.join(batch_output_dir, "input_ids.npy")
            input_ids_arr = np.load(input_ids_path)

            # Load decoder_input_ids
            decoded_input_ids_path = os.path.join(batch_output_dir, "decoder_input_ids.npy")
            decoded_input_ids_arr = np.load(decoded_input_ids_path)

            # Load labels
            labels_path = os.path.join(batch_output_dir, "labels.npy")
            labels_arr = np.load(labels_path)

            # Convert to tensors, adjust data types, and move to GPU
            input_ids = torch.tensor(input_ids_arr, dtype=torch.float16).to(device)
            decoder_input_ids = torch.tensor(decoded_input_ids_arr, dtype=torch.long).to(device)
            labels = torch.tensor(labels_arr, dtype=torch.long).to(device)

            yield input_ids, decoder_input_ids, labels

class ASRDataModule(pl.LightningDataModule):
    def __init__(self, tokenizer, train_data, val_data, test_data, num_workers=0):
        super().__init__()
        self.tokenizer = tokenizer # can just use the global one?
        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        self.num_workers = num_workers

    def setup(self, stage=None):
        self.train_dataset = ASRIterableDataset(self.train_data, self.tokenizer)
        self.val_dataset = ASRIterableDataset(self.val_data, self.tokenizer,)
        self.test_dataset = ASRIterableDataset(self.test_data, self.tokenizer)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=None, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=None, num_workers=self.num_workers)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=None, num_workers=self.num_workers)

## Set Configs and Run

In [21]:
early_stopping_callback = EarlyStopping(
    monitor='val_loss',  # metric to monitor
    patience=3,          # no of epochs with no improvement to wait before stopping
    verbose=True,        # logging
    mode='min'           # minimize or maximize the monitored metric
)

# Initialize Trainer with model checkpointing
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    dirpath='model_checkpoints',
    filename='asr_model-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min',
)

trainer = pl.Trainer(
    max_steps=700*100,  # Maximum number of steps (batches) to train for
    callbacks=[checkpoint_callback, early_stopping_callback],
    val_check_interval=700,
    limit_val_batches=88,  # Limit the number of validation batches
)

torch.set_float32_matmul_precision('medium')

# model_path = "../models/whisper"  # Path where the model and processor are saved
# # Load the model
# model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
# # Load the processor
# processor = AutoProcessor.from_pretrained(model_path)

data_module = ASRDataModule(
    tokenizer=processor.tokenizer,
    train_data=(train_dir, 700),
    val_data=(val_dir, 88),
    test_data=(test_dir, 88),
    num_workers=4,
    # batch_size=1, # Removed param as setting to 2 causes errors, probably due to IterableDataset? Perhaps need to manually handle using arrays in Dataset class and update collate function.
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [22]:
asr_model = ASRModel(model, processor)
asr_model.to('cuda')

# Train the model
trainer.fit(asr_model, data_module) # pl.LightningDataModule can be 2nd parameter

# Test the model
trainer.test(asr_model, data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                            | Params
------------------------------------------------------------------
0 | model         | WhisperForConditionalGeneration | 394 M 
1 | loss_function | CrossEntropyLoss                | 0     
------------------------------------------------------------------
392 M     Trainable params
1.5 M     Non-trainable params
394 M     Total params
1,577.501 Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/rachtrx/mambaforge/envs/til-ai/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/home/rachtrx/mambaforge/envs/til-ai/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 41, in fetch
    data = next(self.dataset_iter)
  File "/tmp/ipykernel_2998/585947386.py", line 115, in __iter__
    input_ids = torch.tensor(input_ids_arr, dtype=torch.float16).to(device)
  File "/home/rachtrx/mambaforge/envs/til-ai/lib/python3.9/site-packages/torch/cuda/__init__.py", line 279, in _lazy_init
    raise RuntimeError(
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method


#### Old notes
Maximum length for padding: 219847
<br>
Use Max Length = 220000, which is around 13.75s for a video at 16000 samples/s 

In [10]:
# max_length = calculate_max_length(dataset, audio_dir)
# print(f"Maximum length for padding: {max_length}")

In [15]:
checkpoint_path = 'model_checkpoints/asr_model-epoch=04-val_loss=0.61.ckpt'

# Load the model from the checkpoint
asr_model = ASRModel.load_from_checkpoint(checkpoint_path, model=model, processor=processor)
asr_model.to(device)

# Initialize Trainer for testing (no callbacks needed)
trainer = pl.Trainer()

# Test the model
trainer.test(asr_model, data_module)

# Print the results
for result in asr_model.test_results:
    print(f"Predicted: {result['predicted']}, Actual: {result['actual']}")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

Predicted: <|startoftranscript|>, Actual: <|startoftranscript|>
Predicted: <|notimestamps|>, Actual: <|notimestamps|>
Predicted: Head, Actual: Head
Predicted: ing, Actual: ing
Predicted:  is, Actual:  is
Predicted:  one, Actual:  three
Predicted:  one, Actual:  one
Predicted:  five, Actual:  five
Predicted: ,, Actual: ,
Predicted:  target, Actual:  target
Predicted:  is, Actual:  is
Predicted:  red, Actual:  black
Predicted:  and, Actual: ,
Predicted:  black, Actual:  green
Predicted: ,, Actual: ,
Predicted:  and, Actual:  and
Predicted:  green, Actual:  grey
Predicted:  fighter, Actual:  drone
Predicted: ,, Actual: ,
Predicted:  tool, Actual:  tool
Predicted:  to, Actual:  to
Predicted:  deploy, Actual:  deploy
Predicted:  is, Actual:  is
Predicted:  surface, Actual:  drone
Predicted:  catcher, Actual:  catcher
Predicted: ., Actual: .
Predicted: <|endoftext|>, Actual: <|endoftext|>
Predicted: <|endoftext|>, Actual: <|endoftext|>
