In [1]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import json
import torch
import torchaudio
import random
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from tqdm import tqdm
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift
import pytorch_lightning as pl
from omegaconf import OmegaConf
import torchaudio.transforms as transforms

print("CUDA available:", torch.cuda.is_available())
print("Number of GPUs:", torch.cuda.device_count())
print("GPU Name:", torch.cuda.get_device_name(0))

CUDA available: True
Number of GPUs: 1
GPU Name: Tesla T4


### Defining Directories

In [2]:
cur_dir = os.getcwd()
src_dir = os.path.dirname(cur_dir)
til_dir = os.path.dirname(os.path.dirname(src_dir))
home_dir = os.path.dirname(til_dir)
test_dir = os.path.join(home_dir, 'novice')
audio_dir = os.path.join(test_dir, 'audio')
data_dir = os.path.join(cur_dir, 'data')
model_path = os.path.join(src_dir, "models", "whisper")
config_path = os.path.join(cur_dir, "config.yaml")

# paths for converting datasets to manifest files
train_manifest_path = os.path.join(data_dir, 'train.json')
val_manifest_path = os.path.join(data_dir, 'val.json')
test_manifest_path = os.path.join(data_dir, 'test.json')

test_dir

'/home/jupyter/novice'

In [3]:
# import json

# def validate_jsonl(file_path):
#     with open(file_path, 'r') as f:
#         for i, line in enumerate(f, 1):
#             try:
#                 json.loads(line)
#             except json.JSONDecodeError as e:
#                 print(f"Error decoding JSON on line {i}: {e}")
#                 break

# validate_jsonl('./data/train.json')
# validate_jsonl('./data/val.json')
# validate_jsonl('./data/test.json')

### Select subsets (if any) and find max length

In [4]:
def load_manifest_in_chunks(manifest_path, chunk_size=1000):
    with open(manifest_path, 'r') as f:
        chunk = []
        for line in f:
            chunk.append(json.loads(line))
            if len(chunk) >= chunk_size:
                yield chunk
                chunk = []
        if chunk:
            yield chunk

### Define Augmentations

In [5]:
augmentations = Compose([
    AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
    TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5),
    PitchShift(min_semitones=-4, max_semitones=4, p=0.5),
])

## Define Dataset Class and Create Datasets

In [6]:
class AudioDataset(Dataset):
    def __init__(self, manifest_chunk, augmentations=None, max_length=3000):
        self.manifest = manifest_chunk
        self.augmentations = augmentations
        self.max_length = max_length

    def __len__(self):
        return len(self.manifest)

    def __getitem__(self, idx):
        example = self.manifest[idx]

        # Load precomputed Mel spectrogram
        if 'mel_spectrogram' in example:
            mel_spectrogram = torch.tensor(example['mel_spectrogram'])
        else:
            raise Exception("mel_spectrogram not found in the example.")
        
        transcript = example['text']

        # Apply augmentations if any
        if self.augmentations:
            mel_spectrogram = self.augmentations(samples=mel_spectrogram.numpy(), sample_rate=16000)
            mel_spectrogram = torch.tensor(mel_spectrogram)

        # Ensure Mel spectrogram has the correct length
        if mel_spectrogram.shape[1] < self.max_length:
            padding = torch.zeros((mel_spectrogram.shape[0], self.max_length - mel_spectrogram.shape[1]))
            mel_spectrogram = torch.cat((mel_spectrogram, padding), dim=1)
        elif mel_spectrogram.shape[1] > self.max_length:
            mel_spectrogram = mel_spectrogram[:, :self.max_length]

        return {
            'audio_filepath': example['audio_filepath'],
            'duration': example['duration'],
            'text': transcript,
            'mel_spectrogram': mel_spectrogram,
            'labels': torch.tensor(example['labels'])
        }

### Define Collate Function and Create Dataloaders

In [7]:
def custom_collate_fn(batch):
    # Stack Mel spectrograms and labels into batches
    mel_spectrograms = torch.stack([item['mel_spectrogram'] for item in batch])
    labels = [item['labels'] for item in batch]
    max_label_length = max(len(label) for label in labels)
    padded_labels = torch.zeros((len(labels), max_label_length), dtype=torch.long)

    for i, label in enumerate(labels):
        padded_labels[i, :len(label)] = label

    audio_filepaths = [item['audio_filepath'] for item in batch]
    durations = [item['duration'] for item in batch]
    texts = [item['text'] for item in batch]

    return {
        'input_values': mel_spectrograms.to(dtype=torch.float16),  # Convert to float16
        'labels': padded_labels,
        'audio_filepaths': audio_filepaths,
        'durations': durations,
        'texts': texts
    }

## Define Model Class and Load Model

In [8]:
class WhisperASRModel(pl.LightningModule):
    def __init__(self, model, processor, lr):
        super(WhisperASRModel, self).__init__()
        self.model = model
        self.processor = processor
        self.lr = lr
        self.test_predictions = []
        self.test_references = []

    def forward(self, input_values):
        outputs = self.model(input_values)
        return outputs

    def training_step(self, batch, batch_idx):
        input_values = batch['input_values'].to(dtype=torch.float16)
        labels = batch['labels']
        outputs = self.model(input_values, labels=labels)
        loss = outputs.loss
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        input_values = batch['input_values'].to(dtype=torch.float16)
        labels = batch['labels']
        outputs = self.model(input_values, labels=labels)
        val_loss = outputs.loss
        self.log('val_loss', val_loss)
        return val_loss

    def test_step(self, batch, batch_idx):
        input_values = batch['input_values'].to(dtype=torch.float16)
        labels = batch['labels']
        outputs = self.model.generate(input_values)  # Generate predictions
        predicted_texts = self.processor.batch_decode(outputs, skip_special_tokens=True)

        # Decode the actual labels
        actual_texts = self.processor.batch_decode(labels, skip_special_tokens=True)
        
        # Collect predictions and actual texts
        self.test_predictions.extend(predicted_texts)
        self.test_references.extend(actual_texts)

        # Log the test loss
        outputs = self.model(input_values, labels=labels)
        test_loss = outputs.loss
        self.log('test_loss', test_loss)
        return test_loss

    def test_epoch_end(self, outputs):
        # Print predictions and actual texts
        for pred, actual in zip(self.test_predictions, self.test_references):
            print(f"PRED: {pred}")
            print(f"ACTUAL: {actual}")
            print('-' * 40)  # Separator for readability

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        return optimizer


### Setup and Train Model

In [16]:
# def get_dataloader(manifest_chunk, batch_size, shuffle, num_workers):
#     dataset = AudioDataset(manifest_chunk)
#     return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=custom_collate_fn)

# def train_and_evaluate(train_manifest_path, val_manifest_path, test_manifest_path, processor, model, config, chunk_size=50):
#     whisper_asr_model = WhisperASRModel(model, processor, config.optim.lr)
#     trainer = pl.Trainer(**config.trainer)

#     for epoch in range(config.trainer.max_epochs):
#         print(f"Epoch {epoch + 1}/{config.trainer.max_epochs}")

#         # Load training data in chunks
#         for train_chunk in load_manifest_in_chunks(train_manifest_path, chunk_size):
#             train_loader = get_dataloader(train_chunk, config.model.train_ds.batch_size, config.model.train_ds.shuffle, config.model.train_ds.num_workers)
#             trainer.fit(whisper_asr_model, train_dataloaders=train_loader)

#         # Validate
#         for val_chunk in load_manifest_in_chunks(val_manifest_path, chunk_size):
#             val_loader = get_dataloader(val_chunk, config.model.validation_ds.batch_size, config.model.validation_ds.shuffle, config.model.validation_ds.num_workers)
#             trainer.validate(whisper_asr_model, val_dataloaders=val_loader)

#     # Test
#     for test_chunk in load_manifest_in_chunks(test_manifest_path, chunk_size):
#         test_loader = get_dataloader(test_chunk, config.model.test_ds.batch_size, config.model.test_ds.shuffle, config.model.test_ds.num_workers)
#         trainer.test(whisper_asr_model, test_dataloaders=test_loader)

In [24]:
# from torch.utils.data import IterableDataset

class ChunkedDataset(IterableDataset):
    def __init__(self, manifest_path, chunk_size):
        self.manifest_path = manifest_path
        self.chunk_size = chunk_size

    def __iter__(self):
        for chunk in load_manifest_in_chunks(self.manifest_path, self.chunk_size):
            dataset = AudioDataset(chunk, augmentations)
            for data in dataset:
                yield data
                
def get_dataloader(manifest_path, processor, batch_size, num_workers, chunk_size):
    dataset = ChunkedDataset(manifest_path, chunk_size)
    return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=custom_collate_fn)

def train_and_evaluate(train_manifest_path, val_manifest_path, test_manifest_path, processor, model, config, chunk_size=50):
    whisper_asr_model = WhisperASRModel(model, processor, config.optim.lr)
    trainer = pl.Trainer(**config.trainer)

    # Prepare the train DataLoader
    train_loader = get_dataloader(train_manifest_path, processor, config.model.train_ds.batch_size, config.model.train_ds.num_workers, chunk_size)

    # Prepare the validation DataLoader
    val_loader = get_dataloader(val_manifest_path, processor, config.model.validation_ds.batch_size, config.model.validation_ds.num_workers, chunk_size)

    # Train the model
    trainer.fit(whisper_asr_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

    # Prepare the test DataLoader
    test_loader = get_dataloader(test_manifest_path, processor, config.model.test_ds.batch_size, config.model.test_ds.num_workers, chunk_size)

    # Test the model
    trainer.test(whisper_asr_model, test_dataloaders=test_loader)

    # Print the test results
    whisper_asr_model.test_epoch_end(None)

In [21]:
# Load model

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
model.to(device)
processor = AutoProcessor.from_pretrained(model_path)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [27]:
### Update config
config = OmegaConf.load(config_path)

config.model.train_ds.manifest_filepath = train_manifest_path
config.model.validation_ds.manifest_filepath = val_manifest_path
config.model.test_ds.manifest_filepath = test_manifest_path

train_and_evaluate(train_manifest_path, val_manifest_path, test_manifest_path, processor, model, config)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                            | Params
----------------------------------------------------------
0 | model | WhisperForConditionalGeneration | 394 M 
----------------------------------------------------------
392 M     Trainable params
1.5 M     Non-trainable params
394 M     Total params
1,577.501 Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

TypeError: Trainer.test() got an unexpected keyword argument 'test_dataloaders'