In [1]:
# !pip install -q transformers librosa jiwer torchaudio jsonlines datasets accelerate audiomentations # Audio Augmentation
# !pip install -q Cython
# !pip install openai-whisper

In [2]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import jsonlines
import torchaudio
from torchaudio import transforms
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import Dataset
from torch.utils.data import DataLoader, Dataset
import torch
import librosa
import jiwer
import json
import re
# from tqdm import tqdm
from torch.utils.data import IterableDataset, DataLoader
import numpy as np
import random
import whisper
from typing import Optional, Dict, Union, List
from dataclasses import dataclass
import pytorch_lightning as pl

print("CUDA available:", torch.cuda.is_available())
print("Number of GPUs:", torch.cuda.device_count())
print("GPU Name:", torch.cuda.get_device_name(0))

CUDA available: True
Number of GPUs: 1
GPU Name: Tesla T4


### Defining Directories

In [3]:
cur_dir = os.getcwd()
src_dir = os.path.dirname(cur_dir)
til_dir = os.path.dirname(os.path.dirname(src_dir))
home_dir = os.path.dirname(til_dir)
test_dir = os.path.join(home_dir, 'novice')
audio_dir = os.path.join(test_dir, 'audio')
data_dir = os.path.join(cur_dir, 'data')
model_path = os.path.join(src_dir, "models", "whisper")
metadata_path = os.path.join(test_dir, "asr.jsonl")

# paths for converting datasets to manifest files
train_dir = os.path.join(data_dir, "train")
test_dir = os.path.join(data_dir, "test")
val_dir = os.path.join(data_dir, "val")

metadata_path

'/home/jupyter/novice/asr.jsonl'

### Get max length

In [4]:
def split_data_indices(data, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
    total_examples = len(data['audio'])
    indices = list(range(total_examples))
    random.shuffle(indices)

    train_split = int(total_examples * train_ratio)
    val_split = int(total_examples * (train_ratio + val_ratio))

    train_indices = indices[:train_split]
    val_indices = indices[train_split:val_split]
    test_indices = indices[val_split:]

    return train_indices, val_indices, test_indices

def split_data(data):
    train_indices, val_indices, test_indices = split_data_indices(data)

    train_data = ([data['audio'][i] for i in train_indices], [data['sentence'][i] for i in train_indices])
    val_data = ([data['audio'][i] for i in val_indices], [data['sentence'][i] for i in val_indices])
    test_data = ([data['audio'][i] for i in test_indices], [data['sentence'][i] for i in test_indices])
    
    return train_data, val_data, test_data

In [16]:
def to_pad_to_mel(array):
    """Static function which:
        1. Pads/trims a list of audio arrays to a max length of 30s
        2. Computes log-mel filter coefficients from padded/trimmed audio sequences
        Inputs:
            array: list of audio arrays
        Returns:
            input_ids: torch.tensor of log-mel filter bank coefficients
    """
    padded_input = whisper.pad_or_trim(np.asarray(array, dtype=np.float32))
    input_ids = whisper.log_mel_spectrogram(padded_input)
    return input_ids

@dataclass
class WhisperDataCollatorWithPadding:
    """
    Data collator that dynamically pads the audio inputs received. An EOS token is appended to the labels sequences.
    They are then dynamically padded to max length.
    Args:
        eos_token_id (`int`)
            The end-of-sentence token for the Whisper tokenizer. Ensure to set for sequences to terminate before
            generation max length.
    """

    eos_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        """
        Since Whisper models don't have a HF processor defined (feature extractor + tokenizer), we'll pad by hand...
        """
        # print(features[0])
        # split inputs and labels since they have to be of different lengths
        # and need different padding methods
        input_ids = [feature["input_ids"] for feature in features]
        decoder_input_ids = [feature["decoder_input_ids"] for feature in features]
        labels = [feature["labels"] for feature in features]

        # first, pad the audio inputs to max_len
        input_ids = torch.concat([to_pad_to_mel(input_val)[None, :] for input_val in input_ids])

        # Append the eos token to the sequence of decoder input ids and labels
        decoder_input_ids = [ids + [self.eos_token_id] for ids in decoder_input_ids]
        labels = [lab + [self.eos_token_id] for lab in labels]
        
        # Pad decoder input ids and labels to max length
        decoder_input_lengths = [len(ids) for ids in decoder_input_ids]
        max_decoder_input_len = max(decoder_input_lengths)
        decoder_input_ids = [np.pad(ids, (0, max_decoder_input_len - len(ids)), 'constant', constant_values=-100) for ids in decoder_input_ids]
        
        # finally, pad the target labels to max_len
        label_lengths = [len(lab) for lab in labels]
        max_label_len = max(label_lengths)
        labels = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=-100) for lab, lab_len in zip(labels, label_lengths)]

        batch = {"labels": labels, "decoder_input_ids": decoder_input_ids}
        batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in batch.items()}

        batch["input_ids"] = input_ids

        return batch

### Preprocess Data

In [43]:
# Define your ASR model
class ASRModel(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.loss_function = torch.nn.CrossEntropyLoss()
        
    def forward(self, input_features, decoder_input_ids=None):
        return self.model(input_features=input_features, decoder_input_ids=decoder_input_ids)

    def training_step(self, batch, batch_idx):
        outputs = self(batch['input_ids'], decoder_input_ids=batch['decoder_input_ids'])
        logits = outputs.logits
        labels = batch['labels']
        
        # Reshape logits to (batch_size * sequence_length, vocab_size)
        logits = logits.view(-1, logits.size(-1))
        labels = labels.view(-1)

        loss = self.loss_function(logits, labels)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self(batch['input_ids'], decoder_input_ids=batch['decoder_input_ids'])
        logits = outputs.logits
        labels = batch['labels']
        
        # Reshape logits to (batch_size * sequence_length, vocab_size)
        logits = logits.view(-1, logits.size(-1))
        labels = labels.view(-1)

        val_loss = self.loss_function(logits, labels)
        self.log('val_loss', val_loss)

#     def test_step(self, batch, batch_idx):
#         outputs = self(batch['input_ids'], decoder_input_ids=batch['decoder_input_ids'])
#         test_loss = self.loss_function(outputs.logits, batch['labels'])
#         self.log('test_loss', test_loss)
#         return test_loss

    def configure_optimizers(self):
        # Implement your optimizer configuration here
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        return optimizer
    
class ASRIterableDataset(IterableDataset):
    def __init__(self, data, tokenizer, batch_size, transform=None, shuffle=False):
        self.file_paths, self.sentences = data
        self.transform = transform
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        if self.shuffle:
            combined = list(zip(self.file_paths, self.sentences))
            random.shuffle(combined)
            self.file_paths, self.sentences = zip(*combined)
            
        batch = []
        
        for file_path, transcript in zip(self.file_paths, self.sentences):
            sample = self.load_audio(file_path)
            if self.transform:
                sample['input_ids'] = self.transform(audio_sample)
            
            tokenized_output = self.tokenizer(transcript)
            sample['labels'] = tokenized_output.input_ids

            # Generate decoder_input_ids from labels by shifting them to the right
            decoder_input_ids = [self.tokenizer.pad_token_id] + tokenized_output.input_ids[:-1]
            sample['decoder_input_ids'] = decoder_input_ids
            
            yield sample
#             batch.append(sample)
            
#             if len(batch) == self.batch_size:
#                 yield batch
#                 batch = []

        # Yield the remaining samples if they do not make up a full batch
        # if batch:
        #     yield batch

    def load_audio(self, file_path):
        waveform, sample_rate = torchaudio.load(file_path)
        waveform = waveform.numpy().flatten() # waveform is a list

        # Resample if needed
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
            waveform = resampler(torch.tensor(waveform)).numpy().flatten()
            
        # use regex on string if needed

        # Extract audio features
        return { 'input_ids': waveform }
        # to return sample_rate? in audio? 'audio': {'array': waveform.tolist(), 'sampling_rate': 16000},
        # 'input_lengths': len(waveform), # seems like not needed since we are not filtering for inputs within the acceptable duration

class ASRDataModule(pl.LightningDataModule):
    def __init__(self, tokenizer, train_data, val_data, test_data, batch_size=1, transform=None, collate_fn=None):
        super().__init__()
        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        self.batch_size = batch_size
        self.transform = transform
        self.tokenizer = tokenizer # can just use the global one?
        self.collate_fn = collate_fn

    def setup(self, stage=None):
        self.train_dataset = ASRIterableDataset(self.train_data, self.tokenizer, self.batch_size, self.transform)
        self.val_dataset = ASRIterableDataset(self.val_data, self.tokenizer, self.batch_size, self.transform)
        self.test_dataset = ASRIterableDataset(self.test_data, self.tokenizer, self.batch_size, self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn)

In [44]:
MAX_FILE_COUNT = None # Set if only want max files

data = {'audio': [], 'sentence': []}
data_path = os.path.join(test_dir, "asr.jsonl")
with jsonlines.open(metadata_path) as reader:
    for obj in reader:
        if MAX_FILE_COUNT and len(data['audio']) >= MAX_FILE_COUNT:
            break
        data['audio'].append(os.path.join(audio_dir, obj['audio']))
        data['sentence'].append(obj['transcript'])

train_data, val_data, test_data = split_data(data)

torch_dtype = torch.float32
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
processor = AutoProcessor.from_pretrained(model_path)

model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.eos_token_id = processor.tokenizer.eos_token_id
model.config.decoder_start_token_id = processor.tokenizer.eos_token_id

# Instantiate your ASR model
model = ASRModel(model)
model.to(device)

# Initialize Trainer with model checkpointing
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    dirpath='my_model_checkpoints',
    filename='asr_model-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min',
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
collator = WhisperDataCollatorWithPadding(
    eos_token_id=processor.tokenizer.eos_token_id
)

data_module = ASRDataModule(
    tokenizer=processor.tokenizer,
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    batch_size=1, # Removed param as setting to 2 causes errors, probably due to IterableDataset? Perhaps need to manually handle using arrays in Dataset class and update collate function.
    transform=None,
    collate_fn=collator
)

trainer = pl.Trainer(max_epochs=2, callbacks=[checkpoint_callback])

# Train the model
trainer.fit(model, data_module) # pl.LightningDataModule can be 2nd parameter

# Test the model
trainer.test(model, data_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                            | Params
------------------------------------------------------------------
0 | model         | WhisperForConditionalGeneration | 394 M 
1 | loss_function | CrossEntropyLoss                | 0     
------------------------------------------------------------------
392 M     Trainable params
1.5 M     Non-trainable params
394 M     Total params
1,577.501 Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

In [None]:
# def preprocess_audio(audio_path, transcript):
#     try:
#         waveform, sample_rate = torchaudio.load(os.path.join(audio_dir, audio_path))
#         waveform = waveform.numpy().flatten()

#         # Resample if needed
#         if sample_rate != 16000:
#             resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
#             waveform = resampler(torch.tensor(waveform)).numpy().flatten()

#         return {
#             'audio': {'array': waveform.tolist(), 'sampling_rate': 16000},
#             'sentence': transcript,
#         }
#     except Exception as e:
#         print(f"Error processing audio: {e}")
#         return None

# def process_and_save_dataset(data, output_dir):
#     dataset_dict = {'audio': [], 'sentence': []}
#     for audio_path, transcript in tqdm(zip(data['audio'], data['sentence']), total=len(data['audio']), desc="Processing audio"):
#         processed_example = preprocess_audio(audio_path, transcript)
#         if processed_example is not None:
#             dataset_dict['audio'].append(processed_example['audio'])
#             dataset_dict['sentence'].append(processed_example['sentence'])
            
#     # Convert dictionary to Dataset
#     dataset = Dataset.from_dict(dataset_dict)

#     # Save dataset to disk
#     dataset.save_to_disk(output_dir)

## Load Model

In [None]:
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
# torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# # model_id = "distil-whisper/distil-medium.en"

# # model = AutoModelForSpeechSeq2Seq.from_pretrained(
# #     model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
# # )
# # model.to(device)
# # processor = AutoProcessor.from_pretrained(model_id)

# # model.save_pretrained(model_path)
# # processor.save_pretrained(model_path)

# model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
# model.to(device)
# processor = AutoProcessor.from_pretrained(model_path)

In [None]:
# vocab = processor.tokenizer.get_vocab()
# print(vocab)

## Load and preprocess data - Ran once to create manifest files

In [None]:
# def split_data_indices(data, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
#     total_examples = len(data['audio'])
#     indices = list(range(total_examples))
#     random.shuffle(indices)

#     train_split = int(total_examples * train_ratio)
#     val_split = int(total_examples * (train_ratio + val_ratio))

#     train_indices = indices[:train_split]
#     val_indices = indices[train_split:val_split]
#     test_indices = indices[val_split:]

#     return train_indices, val_indices, test_indices

# def preprocess_and_save(data):
#     train_indices, val_indices, test_indices = split_data_indices(data)

#     train_set = {'audio': [data['audio'][i] for i in train_indices], 'sentence': [data['sentence'][i] for i in train_indices]}
#     val_set = {'audio': [data['audio'][i] for i in val_indices], 'sentence': [data['sentence'][i] for i in val_indices]}
#     test_set = {'audio': [data['audio'][i] for i in test_indices], 'sentence': [data['sentence'][i] for i in test_indices]}

#     process_and_save_dataset(train_set, train_dir)
#     process_and_save_dataset(val_set, val_dir)
#     process_and_save_dataset(test_set, test_dir)

In [None]:
# import random

# MAX_FILE_COUNT = None # Set if only want max files

# data = {'audio': [], 'sentence': []}
# data_path = os.path.join(test_dir, "asr.jsonl")
# with jsonlines.open(metadata_path) as reader:
#     for obj in reader:
#         if MAX_FILE_COUNT and len(data['audio']) >= MAX_FILE_COUNT:
#             break
#         data['audio'].append(obj['audio'])
#         data['sentence'].append(obj['transcript'])
        
# preprocess_and_save(data)

Reason for Max Length = 220000

In [None]:
# max_length = calculate_max_length(dataset, audio_dir)
# print(f"Maximum length for padding: {max_length}")

Maximum length for padding: 219847
<br>
Use Max Length = 220000, which is around 13.75s for a video at 16000 samples/s 