In [1]:
# !pip install -q transformers jiwer torchaudio jsonlines datasets accelerate audiomentations # Audio Augmentation
# !pip install -q Cython
# !pip install openai-whisper

In [10]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

from typing import Optional, Dict, Union, List
from dataclasses import dataclass

import numpy as np
import random
import re
import json
import jsonlines
# from tqdm import tqdm

from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift
from jiwer import wer
import whisper

import torch
from torch.utils.data import IterableDataset, DataLoader
import torchaudio
from torchaudio import transforms
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, Seq2SeqTrainingArguments, Seq2SeqTrainer

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping


print("CUDA available:", torch.cuda.is_available())
print("Number of GPUs:", torch.cuda.device_count())
print("GPU Name:", torch.cuda.get_device_name(0))

CUDA available: True
Number of GPUs: 1
GPU Name: Tesla T4


### Defining Directories

In [3]:
cur_dir = os.getcwd()
src_dir = os.path.dirname(cur_dir)
til_dir = os.path.dirname(os.path.dirname(src_dir))
home_dir = os.path.dirname(til_dir)
test_dir = os.path.join(home_dir, 'novice')
audio_dir = os.path.join(test_dir, 'audio')
data_dir = os.path.join(cur_dir, 'data')
model_path = os.path.join(src_dir, "models", "whisper")
metadata_path = os.path.join(test_dir, "asr.jsonl")

# paths for converting datasets to manifest files
train_dir = os.path.join(data_dir, "train")
test_dir = os.path.join(data_dir, "test")
val_dir = os.path.join(data_dir, "val")

metadata_path

'/home/jupyter/novice/asr.jsonl'

### Split

In [5]:
def split_data_indices(data, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, seed=42):
    
    random.seed(seed)

    total_examples = len(data['audio'])
    indices = list(range(total_examples))
    random.shuffle(indices)
    
    train_end = int(train_ratio * total_examples)
    val_end = train_end + int(val_ratio * total_examples)
    
    train_indices = indices[:train_end]
    val_indices = indices[train_end:val_end]
    test_indices = indices[val_end:]
    
    train_data = {'audio': [data['audio'][i] for i in train_indices],
                  'sentence': [data['sentence'][i] for i in train_indices]}
    val_data = {'audio': [data['audio'][i] for i in val_indices],
                'sentence': [data['sentence'][i] for i in val_indices]}
    test_data = {'audio': [data['audio'][i] for i in test_indices],
                 'sentence': [data['sentence'][i] for i in test_indices]}
    
    return train_data, val_data, test_data

In [6]:
## Inspir

def to_pad_to_mel(array):
    """Static function which:
        1. Pads/trims a list of audio arrays to a max length of 30s
        2. Computes log-mel filter coefficients from padded/trimmed audio sequences
        Inputs:
            array: list of audio arrays
        Returns:
            input_ids: torch.tensor of log-mel filter bank coefficients
    """
    padded_input = whisper.pad_or_trim(np.asarray(array, dtype=np.float32))
    input_ids = whisper.log_mel_spectrogram(padded_input)
    return input_ids

@dataclass
class WhisperDataCollatorWithPadding:
    """
    Data collator that dynamically pads the audio inputs received. An EOS token is appended to the labels sequences.
    They are then dynamically padded to max length.
    Args:
        eos_token_id (`int`)
            The end-of-sentence token for the Whisper tokenizer. Ensure to set for sequences to terminate before
            generation max length.
    """

    eos_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        """
        Since Whisper models don't have a HF processor defined (feature extractor + tokenizer), we'll pad by hand...
        """
        # print(features[0])
        # split inputs and labels since they have to be of different lengths
        # and need different padding methods
        input_ids = [feature["input_ids"] for feature in features]
        decoder_input_ids = [feature["decoder_input_ids"] for feature in features]
        labels = [feature["labels"] for feature in features]

        # first, pad the audio inputs to max_len
        input_ids = torch.concat([to_pad_to_mel(input_val)[None, :] for input_val in input_ids])

        # Append the eos token to the sequence of decoder input ids and labels
        decoder_input_ids = [ids + [self.eos_token_id] for ids in decoder_input_ids]
        labels = [lab + [self.eos_token_id] for lab in labels]
        
        # Pad decoder input ids and labels to max length
        decoder_input_lengths = [len(ids) for ids in decoder_input_ids]
        max_decoder_input_len = max(decoder_input_lengths)
        decoder_input_ids = [np.pad(ids, (0, max_decoder_input_len - len(ids)), 'constant', constant_values=-100) for ids in decoder_input_ids]
        
        # finally, pad the target labels to max_len
        label_lengths = [len(lab) for lab in labels]
        max_label_len = max(label_lengths)
        labels = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=-100) for lab, lab_len in zip(labels, label_lengths)]

        batch = {"labels": labels, "decoder_input_ids": decoder_input_ids}
        batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in batch.items()}

        batch["input_ids"] = input_ids

        return batch

### Setup Custom Classes

In [7]:
# Define your ASR model
class ASRModel(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.loss_function = torch.nn.CrossEntropyLoss()
        
    def forward(self, input_ids, decoder_input_ids=None):
        return self.model(input_features=input_ids, decoder_input_ids=decoder_input_ids)
    
    def compute_wer(self, logits, labels):
        pred_ids = torch.argmax(logits, dim=-1)
        pred_str = processor.batch_decode(pred_ids)
        label_str = processor.batch_decode(labels, skip_special_tokens=True)
        return wer(label_str, pred_str)

    def training_step(self, batch, batch_idx):
        outputs = self(batch['input_ids'], decoder_input_ids=batch['decoder_input_ids'])
        logits = outputs.logits
        labels = batch['labels']
        
        # Reshape logits to (batch_size * sequence_length, vocab_size)
        logits = logits.view(-1, logits.size(-1))
        labels = labels.view(-1)

        loss = self.loss_function(logits, labels)
        self.log('train_loss', loss)
        
        wer_value = self.compute_wer(logits, labels)
        self.log('train_wer', wer_value, prog_bar=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self(batch['input_ids'], decoder_input_ids=batch['decoder_input_ids'])
        logits = outputs.logits
        labels = batch['labels']
        
        # Reshape logits to (batch_size * sequence_length, vocab_size)
        logits = logits.view(-1, logits.size(-1))
        labels = labels.view(-1)

        val_loss = self.loss_function(logits, labels)
        self.log('val_loss', val_loss)

        wer_value = self.compute_wer(logits, labels)
        self.log('val_wer', wer_value, prog_bar=True)
        
        return val_loss
        
    def test_step(self, batch, batch_idx):
        outputs = self(batch['input_ids'], decoder_input_ids=batch['decoder_input_ids'])
        logits = outputs.logits
        labels = batch['labels']
        
        # Reshape logits to (batch_size * sequence_length, vocab_size)
        logits = logits.view(-1, logits.size(-1))
        labels = labels.view(-1)

        test_loss = self.loss_function(logits, labels)
        self.log('test_loss', test_loss)
        
        wer_value = self.compute_wer(logits, labels)
        self.log('test_wer', wer_value, prog_bar=True)
        
        return test_loss

    def configure_optimizers(self):
        # Implement your optimizer configuration here
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        return optimizer
    
class ASRIterableDataset(IterableDataset):
    def __init__(self, data, tokenizer, augmentations=None, shuffle=False, transform=None):
        self.file_paths, self.sentences = data
        self.tokenizer = tokenizer
        self.shuffle = shuffle
        self.augmentations = augmentations
        self.transform = transform

    def __iter__(self):
        if self.shuffle:
            combined = list(zip(self.file_paths, self.sentences))
            random.shuffle(combined)
            self.file_paths, self.sentences = zip(*combined)
            
        batch = []
        
        for file_path, transcript in zip(self.file_paths, self.sentences):
            sample = self.load_audio(file_path)
            if self.transform:
                sample['input_ids'] = self.transform(audio_sample)
            
            tokenized_output = self.tokenizer(transcript)
            sample['labels'] = tokenized_output.input_ids

            # Generate decoder_input_ids from labels by shifting them to the right
            decoder_input_ids = [self.tokenizer.pad_token_id] + tokenized_output.input_ids[:-1]
            sample['decoder_input_ids'] = decoder_input_ids
            
            yield sample
#             batch.append(sample)
            
#             if len(batch) == self.batch_size:
#                 yield batch
#                 batch = []

        # Yield the remaining samples if they do not make up a full batch
        # if batch:
        #     yield batch

    def load_audio(self, file_path):
        waveform, sample_rate = torchaudio.load(file_path)
        waveform = waveform.numpy().flatten() # waveform is a list
        
        # Apply augmentations if provided
        if self.augmentations:
            waveform = self.augmentations(samples=waveform, sample_rate=sample_rate)

        # Resample if needed
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
            waveform = resampler(torch.tensor(waveform)).numpy().flatten()
            
        # use regex on string if needed

        # Extract audio features
        return { 'input_ids': waveform }
        # to return sample_rate? in audio? 'audio': {'array': waveform.tolist(), 'sampling_rate': 16000},
        # 'input_lengths': len(waveform), # seems like not needed since we are not filtering for inputs within the acceptable duration

class ASRDataModule(pl.LightningDataModule):
    def __init__(self, tokenizer, train_data, val_data, test_data, augmentations=None, collate_fn=None, num_workers=0, transform=None):
        super().__init__()
        self.tokenizer = tokenizer # can just use the global one?
        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        self.augmentations = augmentations
        self.collate_fn = collate_fn
        self.num_workers = num_workers
        self.transform = transform

    def setup(self, stage=None):
        self.train_dataset = ASRIterableDataset(self.train_data, self.tokenizer, self.augmentations, shuffle=True, transform=self.transform)
        self.val_dataset = ASRIterableDataset(self.val_data, self.tokenizer, self.augmentations, transform=self.transform)
        self.test_dataset = ASRIterableDataset(self.test_data, self.tokenizer, self.augmentations, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, collate_fn=self.collate_fn, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, collate_fn=self.collate_fn, num_workers=self.num_workers)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, collate_fn=self.collate_fn, num_workers=self.num_workers)

## Load Data and Model

In [8]:
MAX_FILE_COUNT = None # Set if only want max files

data = {'audio': [], 'sentence': []}
data_path = os.path.join(test_dir, "asr.jsonl")
with jsonlines.open(metadata_path) as reader:
    for obj in reader:
        if MAX_FILE_COUNT and len(data['audio']) >= MAX_FILE_COUNT:
            break
        data['audio'].append(os.path.join(audio_dir, obj['audio']))
        data['sentence'].append(obj['transcript'])

train_data, val_data, test_data = split_data(data)

torch_dtype = torch.float32
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
processor = AutoProcessor.from_pretrained(model_path)

model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.eos_token_id = processor.tokenizer.eos_token_id
model.config.decoder_start_token_id = processor.tokenizer.eos_token_id

# Instantiate your ASR model
model = ASRModel(model)
model.to(device)

cuda


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


### Define Augmentations

In [None]:
augmentations = Compose([
    AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
    TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5),
    PitchShift(min_semitones=-4, max_semitones=4, p=0.5),
])

## Set Configs and Run

In [11]:
collator = WhisperDataCollatorWithPadding(
    eos_token_id=processor.tokenizer.eos_token_id
)

data_module = ASRDataModule(
    tokenizer=processor.tokenizer,
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    augmentations=augmentations,
    collate_fn=collator,
    num_workers=2,
    transform=None,
    # batch_size=1, # Removed param as setting to 2 causes errors, probably due to IterableDataset? Perhaps need to manually handle using arrays in Dataset class and update collate function.
)

early_stopping_callback = EarlyStopping(
    monitor='val_loss',  # metric to monitor
    patience=3,          # no of epochs with no improvement to wait before stopping
    verbose=True,        # logging
    mode='min'           # minimize or maximize the monitored metric
)

# Initialize Trainer with model checkpointing
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    dirpath='my_model_checkpoints',
    filename='asr_model-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min',
)

trainer = pl.Trainer(
    max_epochs=2,
    callbacks=[checkpoint_callback, early_stopping_callback],  # Add early stopping callback here
    val_check_interval=0.5,  # Check validation twice every epoch
    check_val_every_n_epoch=1  # Ensure validation runs every epoch
)


# Train the model
trainer.fit(model, data_module) # pl.LightningDataModule can be 2nd parameter

# Test the model
trainer.test(model, data_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                            | Params
------------------------------------------------------------------
0 | model         | WhisperForConditionalGeneration | 394 M 
1 | loss_function | CrossEntropyLoss                | 0     
------------------------------------------------------------------
392 M     Trainable params
1.5 M     Non-trainable params
394 M     Total params
1,577.501 Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

#### Old notes
Maximum length for padding: 219847
<br>
Use Max Length = 220000, which is around 13.75s for a video at 16000 samples/s 

In [None]:
# max_length = calculate_max_length(dataset, audio_dir)
# print(f"Maximum length for padding: {max_length}")