# Fine tuning

In [4]:
config = {
    "lr" : 1e-5,
    "epochs" : 2    
    }

## Set up

In [5]:
!nvidia-smi # to see what GPU you have

Mon Apr 10 18:20:16 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01    Driver Version: 515.65.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            On   | 00000000:00:1E.0 Off |                    0 |
| N/A   37C    P8    15W /  70W |      2MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [46]:
# !pip install wandb --quiet
# !pip install torchsummaryX -q
# !pip install mutagen
# !pip install jiwer
# !pip install git+https://github.com/openai/whisper.git 
# # on Ubuntu or Debian
# !sudo apt update && sudo apt install ffmpeg

In [6]:
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# os.environ['CUDA_PATH']= '/usr/local/cuda-11.7'

In [7]:
print(os.environ)

environ({'CONDA_SHLVL': '1', 'LS_COLORS': 'rs=0:di=01;34:ln=01;36:mh=00:pi=40;33:so=01;35:do=01;35:bd=40;33;01:cd=40;33;01:or=40;31;01:mi=00:su=37;41:sg=30;43:ca=30;41:tw=30;42:ow=34;42:st=37;44:ex=01;32:*.tar=01;31:*.tgz=01;31:*.arc=01;31:*.arj=01;31:*.taz=01;31:*.lha=01;31:*.lz4=01;31:*.lzh=01;31:*.lzma=01;31:*.tlz=01;31:*.txz=01;31:*.tzo=01;31:*.t7z=01;31:*.zip=01;31:*.z=01;31:*.Z=01;31:*.dz=01;31:*.gz=01;31:*.lrz=01;31:*.lz=01;31:*.lzo=01;31:*.xz=01;31:*.zst=01;31:*.tzst=01;31:*.bz2=01;31:*.bz=01;31:*.tbz=01;31:*.tbz2=01;31:*.tz=01;31:*.deb=01;31:*.rpm=01;31:*.jar=01;31:*.war=01;31:*.ear=01;31:*.sar=01;31:*.rar=01;31:*.alz=01;31:*.ace=01;31:*.zoo=01;31:*.cpio=01;31:*.7z=01;31:*.rz=01;31:*.cab=01;31:*.wim=01;31:*.swm=01;31:*.dwm=01;31:*.esd=01;31:*.jpg=01;35:*.jpeg=01;35:*.mjpg=01;35:*.mjpeg=01;35:*.gif=01;35:*.bmp=01;35:*.pbm=01;35:*.pgm=01;35:*.ppm=01;35:*.tga=01;35:*.xbm=01;35:*.xpm=01;35:*.tif=01;35:*.tiff=01;35:*.png=01;35:*.svg=01;35:*.svgz=01;35:*.mng=01;35:*.pcx=01;35:*.mov=

In [49]:
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torchsummaryX import summary
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

import torchaudio.transforms as tat

# from sklearn.metrics import accuracy_score
import gc

import zipfile
import pandas as pd
from tqdm import tqdm
import os
import datetime

from mutagen.mp3 import MP3
import jiwer

import wandb

import warnings
warnings.filterwarnings('ignore')

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", DEVICE)

if DEVICE == 'cuda':
    dtype = torch.float16
else:
    dtype = torch.float32

Device:  cuda


## Load the pretrained Whisper model and tokenizer

In [50]:
import whisper

model = whisper.load_model("tiny")
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual, language="en", task="transcribe")

In [51]:
# hard-coded audio hyperparameters
SAMPLE_RATE = 16000
N_FFT = 400
N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE  # 480000 samples in a 30-second chunk
N_FRAMES = whisper.utils.exact_div(N_SAMPLES, HOP_LENGTH)  # 3000 frames in a mel spectrogram input

N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2  # the initial convolutions has stride 2
FRAMES_PER_SECOND = whisper.utils.exact_div(SAMPLE_RATE, HOP_LENGTH)  # 10ms per audio frame
TOKENS_PER_SECOND = whisper.utils.exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN)  # 20ms per audio token

## Dataloader

In [52]:
class AudioDataset(torch.utils.data.Dataset):

    # For this homework, we give you full flexibility to design your data set class.
    
    def __init__(self, partition, root, tokenizer): 
        '''
        Initializes the dataset.

        INPUTS: What inputs do you need here?
        '''
        df = pd.read_csv('common_voice_'+partition+'.csv')
        # self.length = len(df)
        
        self.mfccs = []
        self.transcripts = []
        self.probs = []
        # self.content_frames = []
        for index, entry in df.iterrows():
            if index == 100:
                print("100 done")
                break
            filepath = root + entry['path']
            audio = MP3(filepath)
#             if True:
            if audio.info.length < 30:
                transcript = tokenizer.encode(entry['sentence'])
                prob = entry['probs']
                mfcc = whisper.audio.log_mel_spectrogram(filepath,padding=N_SAMPLES)
                mfcc = whisper.audio.pad_or_trim(mfcc, N_FRAMES).to(DEVICE).to(dtype)
                # content_frame = mfcc.shape[-1] - N_FRAMES
                self.mfccs.append(mfcc)
                self.transcripts.append(transcript)
                self.probs.append(prob)
                # self.content_frames.append(content_frame)
            else:
                print("too long")
        self.length = len(self.mfccs)

    def __len__(self):
        
        return self.length

    def __getitem__(self, ind):
        '''
        TODO: RETURN THE MFCC COEFFICIENTS AND ITS CORRESPONDING LABELS

        If you didn't do the loading and processing of the data in __init__,
        do that here.

        Once done, return a tuple of features and labels.
        '''
        
        
#         mfcc = torch.FloatTensor(self.mfccs[ind])
        mfcc = self.mfccs[ind]
        transcript = torch.LongTensor(self.transcripts[ind])
        prob = self.probs[ind]
#         prob = torch.FloatTensor(self.probs[ind])
        # content_frame = torch.LongTensor(self.content_frames[ind])
        # return mfcc, transcript, content_frame, prob
        return mfcc, transcript, prob


    def collate_fn(self,batch):
        '''
        TODO:
        1.  Extract the features and labels from 'batch'
        2.  We will additionally need to pad both features and labels,
            look at pytorch's docs for pad_sequence
        3.  This is a good place to perform transforms, if you so wish. 
            Performing them on batches will speed the process up a bit.
        4.  Return batch of features, labels, lenghts of features, 
            and lengths of labels.
        '''
        # batch of input mfcc coefficients
        batch_mfcc = [] # TODO
        # batch of output phonemes
        batch_transcript = [] # TODO
        # batch_content_frame = []
        batch_prob = []
        lengths_mfcc = []
        lengths_transcript = []

        # for mfcc,transcript,content_frame, prob in batch:
        for mfcc,transcript, prob in batch:
        
          batch_mfcc.append(mfcc)
          batch_transcript.append(transcript)
          # batch_content_frame.append(content_frame)
          batch_prob.append(prob)
          lengths_mfcc.append(len(mfcc))
          lengths_transcript.append(len(transcript))
          
        # HINT: CHECK OUT -> pad_sequence (imported above)
        # Also be sure to check the input format (batch_first)
        batch_mfcc_pad = pad_sequence(batch_mfcc, batch_first= True) # TODO
        # lengths_mfcc = batch_mfcc_pad.shape[1] # TODO 

        batch_transcript_pad = pad_sequence(batch_transcript, batch_first= True) # TODO
        # lengths_transcript = batch_transcript_pad.shape[1] # TODO
        
        # You may apply some transformation, Time and Frequency masking, here in the collate function;
        # Food for thought -> Why are we applying the transformation here and not in the __getitem__?
        #                  -> Would we apply transformation on the validation set as well?
        #                  -> Is the order of axes / dimensions as expected for the transform functions?
        # Return the following values: padded features, padded labels, actual length of features, actual length of the labels
        # return batch_mfcc_pad, batch_transcript_pad, batch_content_frame,  batch_prob, torch.tensor(lengths_mfcc), torch.tensor(lengths_transcript)
        return batch_mfcc_pad, batch_transcript_pad, torch.tensor(batch_prob), torch.tensor(lengths_mfcc), torch.tensor(lengths_transcript)


In [53]:
BATCH_SIZE = 64 # Increase if your device can handle it

In [3]:
# Create objects for the dataset class
root = '/data/home/ubuntu/cv-corpus-13.0-2023-03-09/en/clips/'
train_data = AudioDataset('toy',root, tokenizer) #TODO
# val_data = AudioDataset('validation',root, tokenizer) # TODO : You can either use the same class with some modifications or make a new one :)

# Do NOT forget to pass in the collate function as parameter while creating the dataloader
train_loader = torch.utils.data.DataLoader(
    dataset     = train_data, 
    num_workers = 0,
    batch_size  = BATCH_SIZE, 
#     pin_memory  = True,
    shuffle     = True,
    collate_fn  = train_data.collate_fn
)
# val_loader = torch.utils.data.DataLoader(
#     dataset     = val_data, 
#     num_workers = 2,
#     batch_size  = BATCH_SIZE, 
#     pin_memory  = True,
#     shuffle     = False,
#     collate_fn  = val_data.collate_fn
# )

print("Batch size: ", BATCH_SIZE)
print("Train dataset samples = {}, batches = {}".format(train_data.__len__(), len(train_loader)))
# print("Val dataset samples = {}, batches = {}".format(val_data.__len__(), len(val_loader)))

NameError: name 'AudioDataset' is not defined

In [55]:
# sanity check
for data in train_loader:
    x, y, p, lx, ly = data
    print(x.shape)
    print(y.shape)
    print(lx.shape) 
    print(ly.shape) 
    print(p.shape)
    break 

torch.Size([64, 80, 3000])
torch.Size([64, 34])
torch.Size([64])
torch.Size([64])
torch.Size([64])


## Training Model

In [56]:
class ASRModel(torch.nn.Module):

    def __init__(self, model):
        super().__init__()
        self.options = whisper.decoding.DecodingOptions(fp16 = False,temperature=0.0)
        self.asr = model
        
    
    def forward(self, x, lengths_x):
        out = self.asr.decode(x,self.options)

        return out, lengths_x

## Criterion, optimizer, scheduler

In [57]:
criterion = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True) # Define CTC loss as the criterion. How would the losses be reduced?

optimizer =  torch.optim.SGD(model.parameters(),lr = config["lr"]) # What goes in here?

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3) #TODO

# Mixed Precision, if you need it
scaler = torch.cuda.amp.GradScaler()

## Train and validate functions

In [58]:
from tqdm import tqdm

def train_model(model, train_loader, optimizer):
    
    model.train()
    batch_bar = tqdm(total=len(train_loader), dynamic_ncols=True, leave=False, position=0, desc='Train') 

    total_loss = 0

    for i, data in enumerate(train_loader):
        optimizer.zero_grad()

        x, y, lx, ly, p = data
        x, y = x.to(DEVICE), y.to(DEVICE)

        with torch.cuda.amp.autocast():     
            h, lh = model(x, lx)
            # h = torch.permute(h, (1, 0, 2))
            loss = criterion(h, y, lh, ly)
            loss = loss/p

        total_loss += loss.item()

        batch_bar.set_postfix(
            loss="{:.04f}".format(float(total_loss / (i + 1))),
            lr="{:.06f}".format(float(optimizer.param_groups[0]['lr'])))

        batch_bar.update() # Update tqdm bar

        # Another couple things you need for FP16. 
        scaler.scale(loss).backward() # This is a replacement for loss.backward()
        scaler.step(optimizer) # This is a replacement for optimizer.step()
        scaler.update() # This is something added just for FP16

        del x, y, lx, ly, h, lh, loss 
        torch.cuda.empty_cache()

    batch_bar.close() # You need this to close the tqdm bar
    
    return total_loss / len(train_loader)


def validate_model(model, val_loader):

    model.eval()
    batch_bar = tqdm(total=len(val_loader), dynamic_ncols=True, position=0, leave=False, desc='Val')

    total_loss = 0
    val_wil = 0

    for i, data in enumerate(val_loader):

        x, y, lx, ly = data
        x, y = x.to(DEVICE), y.to(DEVICE)

        with torch.inference_mode():
            h, lh = model(x, lx)
            # h = torch.permute(h, (1, 0, 2))
            loss = criterion(h, y, lh, ly)

        total_loss += float(loss)
        val_wil += wil(h,y)

        batch_bar.set_postfix(loss="{:.04f}".format(float(total_loss / (i + 1))), dist="{:.04f}".format(float(val_wil / (i + 1))))

        batch_bar.update()
    
        del x, y, lx, ly, h, lh, loss
        torch.cuda.empty_cache()
        
    batch_bar.close()
    total_loss = total_loss/len(val_loader)
    avg_wil = val_wil/len(val_loader)
    return total_loss, avg_wil

## Training setup

In [59]:
def save_model(model, optimizer, scheduler, metric, epoch, path):
    torch.save(
        {'model_state_dict'         : model.state_dict(),
         'optimizer_state_dict'     : optimizer.state_dict(),
         'scheduler_state_dict'     : scheduler.state_dict(),
         metric[0]                  : metric[1], 
         'epoch'                    : epoch}, 
         path
    )

def load_model(path, model, metric= 'valid_acc', optimizer= None, scheduler= None):

    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])

    if optimizer != None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if scheduler != None:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
    epoch   = checkpoint['epoch']
    metric  = checkpoint[metric]

    return [model, optimizer, scheduler, epoch, metric]

In [60]:
# This is for checkpointing, if you're doing it over multiple sessions

last_epoch_completed = 0
start = last_epoch_completed
end = config["epochs"]
best_wil = 1 # if you're restarting from some checkpoint, use what you saw there.
epoch_model_path = "/content/epoch_model.checkpoint"#TODO set the model path( Optional, you can just store best one. Make sure to make the changes below )
best_model_path = "/content/best_model.checkpoint"#TODO set best model path 

## wandb

In [None]:
wandb.login(key="a5b7420abbe354e6d0b2f5554b97ee11f327fc92") #API Key is in your wandb account, under settings (wandb.ai/settings)

In [None]:
# Create your wandb run
run = wandb.init(
    name = "first-attempt", ## Wandb creates random run names if you skip this field
    reinit = True, ### Allows reinitalizing runs when you re-run this cell
    # run_id = 'oigknwdr', #Insert specific run id here if you want to resume a previous run
    # resume = "must", ### You need this to resume previous runs, but comment out reinit = True when using this
    project = "whisper-finetune", ### Project should be created in your wandb account 
    config = config ### Wandb Config for your run
)

## Train

In [61]:
asrmodel = ASRModel(model)

In [62]:
torch.cuda.empty_cache()
gc.collect()

#TODO: Please complete the training loop

for epoch in range(0, config['epochs']):

    print("\nEpoch: {}/{}".format(epoch+1, config['epochs']))
    
    curr_lr = float(optimizer.param_groups[0]['lr'])

    train_loss              = train_model(asrmodel, train_loader, optimizer)
    valid_loss, val_wil  = validate_model(asrmodel, val_loader)
    scheduler.step(valid_loss)

    print("\tTrain Loss {:.04f}\t Learning Rate {:.07f}".format(train_loss, curr_lr))
    print("\tVal Dist {:.04f}\t Val Loss {:.04f}".format(val_wil, valid_loss))


#     wandb.log({
#         'train_loss': train_loss,  
#         'val_wil': val_wil, 
#         'valid_loss': valid_loss, 
#         'lr'        : curr_lr
#     })
    
    save_model(arsmodel, optimizer, scheduler, ['val_wil', val_wil], epoch, epoch_model_path)
#     wandb.save('epoch_model.checkpoint')
    print("Saved epoch model")

    if val_wil <= best_wil:
        best_wil = val_wil
        save_model(asrmodel, optimizer, scheduler, ['val_wil', val_wil], epoch, best_model_path)
#         wandb.save('best_model.checkpoint')
        print("Saved best model")
      # You may find it interesting to exlplore Wandb Artifcats to version your models
# run.finish()


Epoch: 1/2


Train:   0%|                                              | 0/2 [00:00<?, ?it/s]

RuntimeError: GET was unable to find an engine to execute this computation

# Playing with the model

In [13]:
# tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual, language="en", task="transcribe")
# t = tokenizer.encode("Please call Stella. Ask her to bring these things with her from the store: Six spoons of fresh snow peas, five thick slabs of blue cheese, and maybe a snack for her brother Bob.  We also need a small plastic snake and a big toy frog for the kids. She can scoop these things into three red bags, and we will go meet her Wednesday at the train station.")
# print(len(t),t)

# import torchaudio
# def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor:
#     waveform, sr = torchaudio.load(wave_path, normalize=True)
#     if sample_rate != sr:
#         waveform = tat.Resample(sr, sample_rate)(waveform)
#     return waveform
# audio = load_wave('/content/data/recordings/recordings/english1.mp3', sample_rate=16000)
# audio = whisper.pad_or_trim(audio.flatten())
# mel = whisper.log_mel_spectrogram(audio)
# mel = whisper.audio.log_mel_spectrogram('/content/data/recordings/recordings/english10.mp3',padding=N_SAMPLES)
# mel = whisper.audio.pad_or_trim(mel, N_FRAMES).to('cpu').to(torch.float32)
# options = whisper.decoding.DecodingOptions(fp16 = False,temperature=0.0)
# m = model.decode(mel,options)
# print(len(m.tokens),m.tokens)
# d = predict(model, mel)
# print(d)

79 [16216, 818, 45073, 13, 12320, 720, 281, 1565, 613, 721, 365, 720, 490, 264, 3531, 25, 11678, 36316, 295, 4451, 5756, 24494, 11, 1732, 5060, 1061, 17243, 295, 3344, 5399, 11, 293, 1310, 257, 13288, 337, 720, 3708, 6085, 13, 220, 492, 611, 643, 257, 1359, 5900, 12650, 293, 257, 955, 12058, 17259, 337, 264, 2301, 13, 1240, 393, 19555, 613, 721, 666, 1045, 2182, 10405, 11, 293, 321, 486, 352, 1677, 720, 10579, 412, 264, 3847, 5214, 13]
[50364, 2555, 818, 45073, 11, 1029, 720, 281, 1565, 613, 721, 365, 720, 490, 264, 3531, 13, 50589, 50589, 11678, 36316, 295, 4451, 5756, 24494, 11, 1732, 5060, 1061, 17243, 295, 3344, 5399, 11, 293, 1310, 257, 13288, 337, 720, 3708, 12, 65, 404, 13, 50914, 50914, 492, 611, 643, 257, 1359, 5900, 12650, 293, 257, 955, 12058, 17259, 337, 264, 2301, 13, 51114, 51114, 1240, 486, 19555, 613, 721, 493, 666, 1045, 2182, 10405, 11, 293, 321, 486, 352, 1677, 720, 10579, 412, 264, 3097, 5214, 13, 51414] 89


In [11]:
tokenizer.decode([50364, 2555, 818, 45073, 13, 12320, 720, 281, 1565, 613, 721, 365, 720, 490, 264, 3531, 13, 11678, 50608, 50608, 36316, 295, 4451, 5756, 24494, 11, 1732, 5060, 1061, 17243, 295, 3344, 5399, 11, 293, 1310, 257, 13288, 337, 720, 50880, 50880, 3708, 6085, 13, 492, 611, 643, 257, 1359, 5900, 12650, 293, 257, 955, 12058, 17259, 337, 264, 2301, 13, 51120, 51120, 1240, 393, 19555, 613, 721, 666, 1045, 2182, 10405, 293, 321, 486, 352, 1677, 720, 10579, 51390, 51390, 412, 264, 3847, 5214, 13, 51440])

' Please call Stella. Ask her to bring these things with her from the store. Six spoons of fresh snow peas, five thick slabs of blue cheese, and maybe a snack for her brother Bob. We also need a small plastic snake and a big toy frog for the kids. She can scoop these things into three red bags and we will go meet her Wednesday at the train station.'

In [12]:
tokenizer.decode([16216, 818, 45073, 13, 220, 12320, 720, 281, 1565, 613, 721, 365, 720, 490, 264, 3531, 25, 220, 11678, 36316, 295, 4451, 5756, 24494, 11, 1732, 5060, 1061, 17243, 295, 3344, 5399, 11, 293, 1310, 257, 13288, 337, 720, 3708, 6085, 13, 220, 492, 611, 643, 257, 1359, 5900, 12650, 293, 257, 955, 12058, 17259, 337, 264, 2301, 13, 220, 1240, 393, 19555, 613, 721, 666, 1045, 2182, 10405, 11, 293, 321, 486, 352, 1677, 720, 10579, 412, 264, 3847, 5214, 13])

'Please call Stella.  Ask her to bring these things with her from the store:  Six spoons of fresh snow peas, five thick slabs of blue cheese, and maybe a snack for her brother Bob.  We also need a small plastic snake and a big toy frog for the kids.  She can scoop these things into three red bags, and we will go meet her Wednesday at the train station.'

In [None]:
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual, language="en", task="transcribe")

In [None]:
def predict(model, tokenizer, mel):
  decode_options = dict()
  decode_options["fp16"] = (DEVICE == 'cuda')
  if decode_options["fp16"]:
    dtype = torch.float16
  else:
    dtype = torch.float32
  logprob_threshold = -1.0
  no_speech_threshold = 0.6
  decode_options["language"] = "en"


  content_frames = mel.shape[-1] - N_FRAMES

  def decode_with_fallback(segment: torch.Tensor) -> whisper.decoding.DecodingResult:
      kwargs = {**decode_options}
      options = whisper.decoding.DecodingOptions(**kwargs, temperature=0.0)
      decode_result = model.decode(segment, options)
      return decode_result

  seek = 0
  input_stride = whisper.utils.exact_div(
      N_FRAMES, model.dims.n_audio_ctx
  )  # mel frames per output token: 2
  time_precision = (
      input_stride * HOP_LENGTH / SAMPLE_RATE
  )  # time per output token: 0.02 (seconds)
  all_tokens = []
  all_segments = []
  prompt_reset_since = 0

  initial_prompt_tokens = []

  def new_segment(
      *, start: float, end: float, tokens: torch.Tensor, result: whisper.decoding.DecodingResult
  ):
      tokens = tokens.tolist()
      text_tokens = [token for token in tokens if token < tokenizer.eot]
      return {
          "seek": seek,
          "start": start,
          "end": end,
          "text": tokenizer.decode(text_tokens),
          "tokens": tokens,
          "temperature": result.temperature,
          "avg_logprob": result.avg_logprob,
          "compression_ratio": result.compression_ratio,
          "no_speech_prob": result.no_speech_prob,
      }

# show the progress bar when verbose is False (if True, transcribed text will be printed)

  while seek < content_frames:
      time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
      mel_segment = mel[:, seek : seek + N_FRAMES]
      segment_size = min(N_FRAMES, content_frames - seek)
      segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
      mel_segment = whisper.audio.pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)

      decode_options["prompt"] = all_tokens[prompt_reset_since:]
      result: whisper.decoding.DecodingResult = decode_with_fallback(mel_segment)
      tokens = torch.tensor(result.tokens)

      if no_speech_threshold is not None:
          # no voice activity check
          should_skip = result.no_speech_prob > no_speech_threshold
          if (
              logprob_threshold is not None
              and result.avg_logprob > logprob_threshold
          ):
              # don't skip if the logprob is high enough, despite the no_speech_prob
              should_skip = False

          if should_skip:
              seek += segment_size  # fast-forward to the next segment boundary
              continue

      previous_seek = seek
      current_segments = []

      timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
      single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]

      consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
      consecutive.add_(1)
      if len(consecutive) > 0:
          # if the output contains two consecutive timestamp tokens
          slices = consecutive.tolist()
          if single_timestamp_ending:
              slices.append(len(tokens))

          last_slice = 0
          for current_slice in slices:
              sliced_tokens = tokens[last_slice:current_slice]
              start_timestamp_pos = (
                  sliced_tokens[0].item() - tokenizer.timestamp_begin
              )
              end_timestamp_pos = (
                  sliced_tokens[-1].item() - tokenizer.timestamp_begin
              )
              current_segments.append(
                  new_segment(
                      start=time_offset + start_timestamp_pos * time_precision,
                      end=time_offset + end_timestamp_pos * time_precision,
                      tokens=sliced_tokens,
                      result=result,
                  )
              )
              last_slice = current_slice

          if single_timestamp_ending:
              # single timestamp at the end means no speech after the last timestamp.
              seek += segment_size
          else:
              # otherwise, ignore the unfinished segment and seek to the last timestamp
              last_timestamp_pos = (
                  tokens[last_slice - 1].item() - tokenizer.timestamp_begin
              )
              seek += last_timestamp_pos * input_stride
      else:
          duration = segment_duration
          timestamps = tokens[timestamp_tokens.nonzero().flatten()]
          if (
              len(timestamps) > 0
              and timestamps[-1].item() != tokenizer.timestamp_begin
          ):
              # no consecutive timestamps but it has a timestamp; use the last one.
              last_timestamp_pos = (
                  timestamps[-1].item() - tokenizer.timestamp_begin
              )
              duration = last_timestamp_pos * time_precision

          current_segments.append(
              new_segment(
                  start=time_offset,
                  end=time_offset + duration,
                  tokens=tokens,
                  result=result,
              )
          )
          seek += segment_size

      if result.temperature > 0.5:
          # do not feed the prompt tokens if a high temperature was used
          prompt_reset_since = len(all_tokens)

      # if a segment is instantaneous or does not contain text, clear it
      for i, segment in enumerate(current_segments):
          if segment["start"] == segment["end"] or segment["text"].strip() == "":
              segment["text"] = ""
              segment["tokens"] = []
              segment["words"] = []

      all_segments.extend(
          [
              {"id": i, **segment}
              for i, segment in enumerate(
                  current_segments, start=len(all_segments)
              )
          ]
      )
      all_tokens.extend(
          [token for segment in current_segments for token in segment["tokens"]]
      )


  return tokenizer.decode(all_tokens)

In [18]:
text_transformation = jiwer.Compose([
    jiwer.ToLowerCase(),
    jiwer.RemoveWhiteSpace(replace_by_space=True),
    jiwer.RemoveMultipleSpaces(),
    jiwer.RemovePunctuation(),
    jiwer.Strip(),
    jiwer.ReduceToListOfListOfWords()
]) 

def wer(target, output):
  return jiwer.wer(
    target, 
    output, 
    truth_transform=text_transformation, 
    hypothesis_transform=text_transformation)
  
def wil(target, output):
  return jiwer.wil(
    target, 
    output, 
    truth_transform=text_transformation, 
    hypothesis_transform=text_transformation)


In [None]:
from tqdm import tqdm

def train_model(model, train_loader, criterion, optimizer):
    
    model.train()
    batch_bar = tqdm(total=len(train_loader), dynamic_ncols=True, leave=False, position=0, desc='Train') 

    total_loss = 0

    for i, data in enumerate(train_loader):
        optimizer.zero_grad()

        x, y, cfx, p, lx, ly = data
        x, y = x.to(device), y.to(device)

        with torch.cuda.amp.autocast():     
            h, lh = model(x, lx)
            h = torch.permute(h, (1, 0, 2))
            loss = criterion(h, y, lh, ly)

        total_loss += loss.item()

        batch_bar.set_postfix(
            loss="{:.04f}".format(float(total_loss / (i + 1))),
            lr="{:.06f}".format(float(optimizer.param_groups[0]['lr'])))

        batch_bar.update() # Update tqdm bar

        # Another couple things you need for FP16. 
        scaler.scale(loss).backward() # This is a replacement for loss.backward()
        scaler.step(optimizer) # This is a replacement for optimizer.step()
        scaler.update() # This is something added just for FP16

        del x, y, lx, ly, h, lh, loss 
        torch.cuda.empty_cache()

    batch_bar.close() # You need this to close the tqdm bar
    
    return total_loss / len(train_loader)


def validate_model(model, val_loader, decoder, phoneme_map= LABELS):

    model.eval()
    batch_bar = tqdm(total=len(val_loader), dynamic_ncols=True, position=0, leave=False, desc='Val')

    total_loss = 0
    vdist = 0

    for i, data in enumerate(val_loader):

        x, y, lx, ly = data
        x, y = x.to(device), y.to(device)

        with torch.inference_mode():
            h, lh = model(x, lx)
            h = torch.permute(h, (1, 0, 2))
            loss = criterion(h, y, lh, ly)

        total_loss += float(loss)
        vdist += calculate_levenshtein(torch.permute(h, (1, 0, 2)), y, lh, ly, decoder, phoneme_map)

        batch_bar.set_postfix(loss="{:.04f}".format(float(total_loss / (i + 1))), dist="{:.04f}".format(float(vdist / (i + 1))))

        batch_bar.update()
    
        del x, y, lx, ly, h, lh, loss
        torch.cuda.empty_cache()
        
    batch_bar.close()
    total_loss = total_loss/len(val_loader)
    val_dist = vdist/len(val_loader)
    return total_loss, val_dist

# Baseline

## Fetching the Dataset

In [4]:
!pip install --upgrade --force-reinstall --no-deps kaggle==1.5.8
!mkdir /root/.kaggle

# with open("/root/.kaggle/kaggle.json", "w+") as f:
#     f.write('{"username":"UserName","key":"Key"}') 
#     # Put your kaggle username & key here
with open("/root/.kaggle/kaggle.json", "w+") as f:
    f.write('{"username":"sma2023","key":"6c819f763f537a6b8bbb60cb11520dbf"}') 
    # Put your kaggle username & key here

!chmod 600 /root/.kaggle/kaggle.json

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting kaggle==1.5.8
  Downloading kaggle-1.5.8.tar.gz (59 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/59.2 KB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.2/59.2 KB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: kaggle
  Building wheel for kaggle (setup.py) ... [?25l[?25hdone
  Created wheel for kaggle: filename=kaggle-1.5.8-py3-none-any.whl size=73272 sha256=2761abf8cabcf8d60f78714a24ae0b4b7d40604aa99ffa664c0530e61002775e
  Stored in directory: /root/.cache/pip/wheels/d4/02/ef/3f8c8d86b8d5388a1d3155876837f1a1a3143ab3fc2ff1ffad
Successfully built kaggle
Installing collected packages: kaggle
  Attempting uninstall: kaggle
    Found existing installation: kaggle 1.5.13
    Uninstalling kaggle-1.5.13:
   

In [5]:
!kaggle datasets download -d rtatman/speech-accent-archive
!unzip -qo 'speech-accent-archive.zip' -d '/content/data'

Downloading speech-accent-archive.zip to /content
 99% 860M/865M [00:08<00:00, 65.5MB/s]
100% 865M/865M [00:08<00:00, 106MB/s] 


## Define the two metrics

In [None]:
!pip install jiwer

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting jiwer
  Downloading jiwer-3.0.0-py3-none-any.whl (21 kB)
Collecting rapidfuzz==2.13.7
  Downloading rapidfuzz-2.13.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, jiwer
Successfully installed jiwer-3.0.0 rapidfuzz-2.13.7


In [None]:
import jiwer

In [None]:
jiwer.wer("the cat", "cat the")

1.0

In [None]:
f = open("/content/data/reading-passage.txt")
target = ""
for line in f:
  target += line

text_transformation = jiwer.Compose([
    jiwer.ToLowerCase(),
    jiwer.RemoveWhiteSpace(replace_by_space=True),
    jiwer.RemoveMultipleSpaces(),
    jiwer.RemovePunctuation(),
    # jiwer.ReduceToSingleSentence(),
    jiwer.Strip(),
    jiwer.SubstituteRegexes({r"6": r"six", r"5": r"five", r"3": r"three"}),
    jiwer.ReduceToListOfListOfWords()
]) 

def wer(output):
  return jiwer.wer(
    target, 
    output, 
    truth_transform=text_transformation, 
    hypothesis_transform=text_transformation)
  
def wil(output):
  return jiwer.wil(
    target, 
    output, 
    truth_transform=text_transformation, 
    hypothesis_transform=text_transformation)


## Transcribe and recording the data

In [None]:
# open the speaker_all.csv
import pandas as pd
from tqdm.auto import tqdm

In [None]:
speakers = pd.read_csv('/content/data/speakers_all.csv')
speakers['wer'] = 1.0
speakers['wil'] = 1.0
cnt = 0
batch_bar   = tqdm(total=len(speakers), dynamic_ncols=True, leave=False, position=0)
for index, row in speakers.iterrows():
  if row['file_missing?']==False:
    file_name = row['filename']
    transcription = whisper.transcribe(model = model, audio = '/content/data/recordings/recordings/'+file_name+'.mp3', fp16=False)['text']
    speakers.at[index,'wer'] = wer(transcription)
    speakers.at[index,'wil'] = wil(transcription)
  batch_bar.update()
batch_bar.close()
speakers.to_csv('results.csv')


  0%|          | 0/2172 [00:00<?, ?it/s]