In [1]:
# This mounts your Google Drive to the Colab VM.
from google.colab import drive
drive.mount('/content/drive')

# TODO: Enter the foldername in your Drive where you have saved the unzipped
# assignment folder, e.g. 'cs231/assignment1'
FOLDERNAME = 'Introduction to Speech Processing/'
assert FOLDERNAME is not None, "[!] Enter the folername."

# Now that we've mounted your Drive, this ensures that
# the Python interpreter of the Colab VM can load
# python files from within it
import sys
sys.path.append('/content/drive/MyDrive/{}'.format(FOLDERNAME))

%cd /content/drive/MyDrive/$FOLDERNAME

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/Introduction to Speech Processing


In [2]:
from google.colab import drive

import os

import torch
import random
import torchaudio
import torch.nn as nn
import torch.optim as optim

from librosa import effects
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

!pip install jiwer
from jiwer import wer, cer



In [3]:
class TextTransform:
    """Maps characters to integers and vice versa"""
    def __init__(self):
        char_map_str = """
        ' 0
        <SPACE> 1
        A 2
        B 3
        C 4
        D 5
        E 6
        F 7
        G 8
        H 9
        I 10
        J 11
        K 12
        L 13
        M 14
        N 15
        O 16
        P 17
        Q 18
        R 19
        S 20
        T 21
        U 22
        V 23
        W 24
        X 25
        Y 26
        Z 27
        """
        self.char_map = {}
        self.index_map = {}
        for line in char_map_str.strip().split('\n'):
            ch, index = line.split()
            self.char_map[ch] = int(index)
            self.index_map[int(index)] = ch
        self.index_map[1] = ' '

    def text_to_int(self, text):
        """ Use a character map and convert text to an integer sequence """
        int_sequence = []
        for c in text:
            if c == ' ':
                ch = self.char_map['<SPACE>']
            else:
                ch = self.char_map[c]
            int_sequence.append(ch)
        return int_sequence

    def int_to_text(self, labels):
        """ Use a character map and convert integer labels to an text sequence """
        string = []
        for i in labels:
            string.append(self.index_map[i])
        return ''.join(string).replace('<SPACE>', ' ')

# Initilalize the text transformations class
text_transform = TextTransform()

def Decoder(output, labels, label_lengths, blank_label=28, collapse_repeated=True):
    arg_maxes = torch.argmax(output, dim=2) # (batch, time, n_class)
    decodes = []
    targets = []
    for i, args in enumerate(arg_maxes):
      decode = []
      targets.append(text_transform.int_to_text(labels[i][:label_lengths[i]].tolist()))
      for j, index in enumerate(args):
        if index != blank_label:
          if collapse_repeated and j != 0 and index == args[j -1]:
            continue
          decode.append(index.item())
      decodes.append(text_transform.int_to_text(decode))
    return decodes, targets

In [4]:
class AudioTranscriptDataset(Dataset):
    """ Custom dataset to zip audio files and their matching transcripts """
    def __init__(self, audio_folder, transcript_folder, audio_extension=".wav", transcript_extension=".txt"):
        self.audio_folder = audio_folder
        self.transcript_folder = transcript_folder
        self.audio_extension = audio_extension
        self.transcript_extension = transcript_extension
        self.audio_files = os.listdir(audio_folder)

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        audio_file = os.path.join(self.audio_folder, self.audio_files[idx])
        transcript_file = os.path.join(self.transcript_folder, self.audio_files[idx].replace(self.audio_extension, self.transcript_extension))
        waveform, sample_rate = torchaudio.load(audio_file)
        with open(transcript_file, "r") as f:
            transcript = f.read().strip()

        return waveform, transcript

In [5]:
# Define the CTC model
class CTCModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, dropout=0.3, ckpt_path=None):
        super(CTCModel, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, dropout=dropout, bidirectional=True, batch_first=True)
        self.linear = nn.Linear(hidden_dim * 2, output_dim)  # *2 for bidirectional
        self.softmax = nn.LogSoftmax(dim=-1)

        # Initilize model's weigths when a path is given
        if ckpt_path != None:
          self.load_state_dict(torch.load(ckpt_path))

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        linear_out = self.linear(lstm_out)
        return self.softmax(linear_out)


In [6]:
# Define the dataset and data loader
def load_an4_dataset(mode='train'):
    # Load the AN4 dataset and adjust the path accordingly
    audio_folder = f"./an4/{mode}/an4/wav/"
    transcript_folder = f"./an4/{mode}/an4/txt/"
    dataset = AudioTranscriptDataset(audio_folder, transcript_folder)

    return dataset

# MFCC feature extraction
# copied from fairseq/examples/hubert/simple_kmeans/dump_mfcc_feature.py on github
def get_feats(waveform):
  with torch.no_grad():
    x = waveform.float()
    x = x.view(1, -1)

    mfccs = torchaudio.compliance.kaldi.mfcc(
        waveform=x,
        sample_frequency=16000,
        use_energy=False,
    )  # (time, freq)
    mfccs = mfccs.transpose(0, 1)  # (freq, time)
    deltas = torchaudio.functional.compute_deltas(mfccs)
    ddeltas = torchaudio.functional.compute_deltas(deltas)
    concat = torch.cat([mfccs, deltas, ddeltas], dim=0)
    concat = concat.transpose(0, 1).contiguous()  # (freq, time)
    return concat

# Define the preprocessing function for MFCC feature extraction and target labels extraction
def data_processing(data, mode='train'):
    mfccs = []
    labels = []
    input_lengths = []
    label_lengths = []
    for (waveform, transcript) in data:
      # Augment the waveform with probability
      if mode == 'train':
        waveform = augment_waveform(waveform)

      # Use 39 MFCC coefficients
      mfcc = get_feats(waveform)
      mfccs.append(mfcc)
      label = torch.Tensor(text_transform.text_to_int(transcript))
      labels.append(label)
      input_lengths.append(mfcc.shape[0])
      label_lengths.append(len(label))

    mfccs = nn.utils.rnn.pad_sequence(mfccs, batch_first=True)
    labels = nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=1)
    return mfccs, labels, torch.tensor([mfccs.shape[1]]*mfccs.shape[0]), torch.tensor([labels.shape[1]]*labels.shape[0]) # input_lengths, label_lengths

def augment_waveform(waveform, augmentation_probability=0.1):
    """
    Augment a waveform by adding random noise with the given probability.
    """
    noise_level = 0.01
    minimum_stretch_rate = 0.9
    stretch_factor = 0.2
    sample_rate = 16000
    num_semitones = 2
    min_factor = 1.0
    max_factor = 1.1

    waveform = waveform.numpy()

    add_white_noise = lambda waveform, noise_level: waveform + torch.randn_like(torch.tensor(waveform)).numpy() * noise_level
    time_stretch = lambda waveform, stretch_rate: effects.time_stretch(waveform, rate=stretch_rate)
    pitch_scale = lambda waveform, sr, num_semitones: effects.pitch_shift(waveform, sr=sr, n_steps=num_semitones)
    random_gain = lambda waveform, min_factor, max_factor: waveform * random.uniform(min_factor, max_factor)
    invert_polarity = lambda waveform: waveform * -1
    if torch.rand(1) < augmentation_probability:
        waveform = add_white_noise(waveform, noise_level)

    if torch.rand(1) < augmentation_probability:
        stretch_rate = minimum_stretch_rate + stretch_factor * torch.rand(1).item()
        waveform = time_stretch(waveform, stretch_rate)

    if torch.rand(1) < augmentation_probability:
        waveform = pitch_scale(waveform, sample_rate, num_semitones)

    if torch.rand(1) < augmentation_probability:
        waveform = random_gain(waveform, min_factor, max_factor)

    if torch.rand(1) < augmentation_probability:
        waveform = invert_polarity(waveform)

    return torch.tensor(waveform)


In [7]:
# Training function
def train(model, batch_size, print_interval, save_ckpt_interval, criterion, optimizer, device):
    # Load the AN4 train dataset
    train_dataset = load_an4_dataset(mode='train')

    # Preprocess the dataset and create the data loader
    data_loader = DataLoader(dataset=train_dataset,
                                batch_size=batch_size,
                                shuffle=True,
                                collate_fn=lambda x: data_processing(x))
    model.train()
    total_loss = 0.0

    for batch_idx, _data in enumerate(data_loader):
        batch_inputs, targets, input_lengths, target_lengths = _data
        batch_inputs, targets = batch_inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        # Get model predictions
        outputs = model(batch_inputs).transpose(0, 1)

        # Calculate the CTC loss
        loss = criterion(outputs, targets, input_lengths, target_lengths)

        # Backpropagation and optimization step
        loss.backward()
        optimizer.step()

        # Print current batch loss
        if batch_idx % print_interval == 0:
          print(f"Batch {batch_idx}, Loss: {loss:.4f}")

        # Save the trained model
        # if batch_idx % save_ckpt_interval == 1:
        #   torch.save(model.state_dict(), "ctc_model.pt")

        total_loss += loss.item()

    return total_loss / len(data_loader)


In [8]:
def test(model, device, batch_size, criterion, mode='test'):
    # Load the AN4 test dataset
    test_dataset = load_an4_dataset(mode=mode)

    # Preprocess the dataset and create the data loader
    test_loader = DataLoader(dataset=test_dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                collate_fn=lambda x: data_processing(x, mode='test'))

    print('\nevaluating...')
    model.eval()
    test_loss = 0
    test_cer, test_wer = [], []
    with torch.no_grad():
        for i, _data in enumerate(test_loader):
            mfccs, labels, input_lengths, label_lengths = _data
            mfccs, labels = mfccs.to(device), labels.to(device)

            output = model(mfccs)  # (batch, time, n_class)
            output = output.transpose(0, 1) # (time, batch, n_class)

            loss = criterion(output, labels, input_lengths, label_lengths)
            test_loss += loss.item() / len(test_loader)

            decoded_preds, decoded_targets = Decoder(output.transpose(0, 1), labels, label_lengths)
            for j in range(len(decoded_preds)):
                print("Target: " + decoded_targets[j])
                print("Predicted: " + decoded_preds[j])
                test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
                test_wer.append(wer(decoded_targets[j], decoded_preds[j]))


    avg_cer = sum(test_cer)/len(test_cer)
    avg_wer = sum(test_wer)/len(test_wer)

    print('Test set: Average loss: {:.4f}, Average CER: {:4f}% Average WER: {:.4f}%\n'.format(test_loss, avg_cer*100, avg_wer*100))


In [9]:
# CTC loss function
class CTCLoss(nn.Module):
    def __init__(self, device):
        super(CTCLoss, self).__init__()
        self.ctc_loss = nn.CTCLoss(blank=28).to(device)

    def forward(self, log_probs, targets, input_lengths, target_lengths):
        return self.ctc_loss(log_probs, targets, input_lengths, target_lengths)


In [10]:
# Hyperparameters
input_dim = 39
hidden_dim = 256
output_dim = 29  # Number of characters in AN4 dataset, including blank (28)

batch_size = 32
num_epochs = 200
print_interval = 15
save_ckpt_interval = 45
learning_rate = 0.0005
regularization = 0.0001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize the CTC model
model = CTCModel(input_dim, hidden_dim, output_dim).to(device)

# Loss function and optimizer
criterion = CTCLoss(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=regularization)

In [11]:
# Training loop
for epoch in range(num_epochs):
    avg_loss = train(model, batch_size, print_interval, save_ckpt_interval, criterion, optimizer, device)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

# save final weights
torch.save(model.state_dict(), "./aug_ctc_lstm_model.pt")

Batch 0, Loss: 29.0631
Batch 15, Loss: 4.1947
Epoch 1/200, Average Loss: 10.7399
Batch 0, Loss: 2.3308
Batch 15, Loss: 1.7132
Epoch 2/200, Average Loss: 1.7404
Batch 0, Loss: 1.5165
Batch 15, Loss: 1.6819
Epoch 3/200, Average Loss: 1.6346
Batch 0, Loss: 1.6212
Batch 15, Loss: 1.6822
Epoch 4/200, Average Loss: 1.5007
Batch 0, Loss: 1.3278
Batch 15, Loss: 1.4395
Epoch 5/200, Average Loss: 1.4224
Batch 0, Loss: 1.6086
Batch 15, Loss: 1.5189
Epoch 6/200, Average Loss: 1.4447
Batch 0, Loss: 1.7496
Batch 15, Loss: 1.3180
Epoch 7/200, Average Loss: 1.4677
Batch 0, Loss: 1.5744
Batch 15, Loss: 1.7741
Epoch 8/200, Average Loss: 1.4301
Batch 0, Loss: 1.2087
Batch 15, Loss: 1.8061
Epoch 9/200, Average Loss: 1.3803
Batch 0, Loss: 1.3059
Batch 15, Loss: 1.2139
Epoch 10/200, Average Loss: 1.4299
Batch 0, Loss: 1.4087
Batch 15, Loss: 1.2082
Epoch 11/200, Average Loss: 1.3854
Batch 0, Loss: 1.2006
Batch 15, Loss: 1.4916
Epoch 12/200, Average Loss: 1.3935
Batch 0, Loss: 1.2413
Batch 15, Loss: 1.5164
Ep

In [12]:
model = CTCModel(input_dim, hidden_dim, output_dim, ckpt_path="./aug_ctc_lstm_model.pt").to(device)
test(model, device, batch_size, criterion, 'train')


evaluating...
Target: HELP                                            
Predicted: HELP
Target: FIVE TWO SEVEN                                  
Predicted: FIVE TWO SEVEN
Target: O A K D A L E D R I V E                         
Predicted: O A K D A L E D R I V E
Target: ONE TWO FOUR ONE                                
Predicted: ONE TWO FOUR ONE
Target: NO                                              
Predicted: NO
Target: ONE FIVE TWO ZERO SEVEN                         
Predicted: ONE FIVE TWO ZERO SEVEN
Target: P I T T S B U R G H                             
Predicted: P I T T S B U R G H
Target: FIVE SIX THREE ONE                              
Predicted: FIVE SIX THREE ONE
Target: P I T T S B U R G H                             
Predicted: P I T T S B U R G H
Target: L O O F B O U R R O W                           
Predicted: L O O F B O U R R O W
Target: STOP                                            
Predicted: STOP
Target: ENTER FORTY THREE FORTY FIVE                    
Predic

In [13]:
test(model, device, batch_size, criterion, 'val')


evaluating...
Target: S P R I N G H O U S E L A N E                    
Predicted: S P R I N G  H  S C L A N E
Target: ENTER FIVE                                       
Predicted: ENTER FIVE
Target: STOP                                             
Predicted: STOP
Target: FOUR ONE TWO FOUR TWO TWO NINE EIGHT TWO EIGHT   
Predicted: FOUR ONE TWO FOUR TWO TERO NINE EIGHT TWO EIGHT
Target: V L C Z TWENTY SIX                               
Predicted: V L C Z TWENTY SIX
Target: J S P S Z NINE SIX NINE                          
Predicted: J S P S CNINE SIX NIND
Target: Z EIGHT OH TWO                                   
Predicted: Z A O TWO
Target: ONE SIXTEEN FORTY EIGHT                          
Predicted: ONE SIXTEEN FORTY EIGHT
Target: SEPTEMBER FIRST NINETEEN SIXTY NINE              
Predicted: SETOBR SIRTH NINETEEN SIXTY NINE
Target: P I T T S B U R G H                              
Predicted: P I T T S B U R G H
Target: YES                                              
Predicted: YES
T

In [14]:
test(model, device, batch_size, criterion)


evaluating...
Target: NO                                  
Predicted: NO
Target: ERASE K M H N I SIX OH FIVE         
Predicted: ERASE K M H N I SIX OH FIVE
Target: ONE FIVE TWO TWO SEVEN              
Predicted: OE FIVE TWO TWO SEVEN
Target: MAY SECOND NINETEEN SIXTY FIVE      
Predicted: ERAE SECOND NINETEEN SIXTY FIVE
Target: M Y E R S                           
Predicted: M Y E R S
Target: ENTER ONE SEVENTY SIX               
Predicted: ENTER NE SEVENTY SIX
Target: ONE FIVE TWO ONE THREE              
Predicted: ONE FIVE TWO ONE THREE
Target: ENTER EIGHT THIRTEEN                
Predicted: ENTER EIGHT FIRTEEN
Target: P H I N N E Y                       
Predicted: P H I N N EU Y
Target: ENTER SEVEN                         
Predicted: ENTER SEVEN
Target: P I T T S B U R G H                 
Predicted: P I T T S B U R G H
Target: TWO TWO SIX                         
Predicted: TWO TWO SIX
Target: RUBOUT G M E F THREE NINE           
Predicted: RUBOT G N E N TWOEE NINE
Target: R O C 