In [3]:
## Install dependencies
!pip install openai-whisper
!pip install wget
!apt-get install sox libsndfile1 ffmpeg -y
!pip install text-unidecode
!pip install matplotlib>=3.3.2
## Install NeMo
BRANCH = 'main'
!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]

Collecting openai-whisper
  Downloading openai-whisper-20231117.tar.gz (798 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m798.6/798.6 kB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting triton<3,>=2.0.0 (from openai-whisper)
  Downloading triton-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Collecting tiktoken (from openai-whisper)
  Downloading tiktoken-0.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
Downloading triton-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (168.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m168.1/168.1 MB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading tiktoken-0.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2

In [4]:
import os
import glob
import os
import re
import torchaudio as ta
import random
import subprocess
import editdistance
import tarfile
import wget
import librosa
import IPython.display as ipd
import numpy as np
import json
import soundfile as sf
import argparse
from datasets import load_dataset, Dataset
from whisper.normalizers import EnglishTextNormalizer
import torch
from joblib import Parallel, delayed
from tqdm import tqdm
from torch import nn
import json
import pandas as pd
import nemo.collections.asr as nemo_asr
# drive.mount('/content/drive')

In [5]:
# os.mkdir('/kaggle/working/data')

In [6]:
# Run this once
data_dir = '/kaggle/working/data'
os.listdir(data_dir)
if not os.path.exists(data_dir + '/svarah.tar'):
    svarah_url = 'https://indic-asr-public.objectstore.e2enetworks.net/svarah.tar'
    svarah_path = wget.download(svarah_url, data_dir)
    print(f"Dataset downloaded at: {svarah_path}")
    tar = tarfile.open(svarah_path)
    tar.extractall(path=data_dir)
else:
    print('data already downloaded')

data already downloaded


In [7]:
def get_data(split):
    js_data = json.loads(split)
    aud = {}
    aud['path'] = js_data['audio_filepath']
    y, sr = sf.read(aud['path'])
    aud['array'] = y
    aud['sampling_rate'] = sr
    return (aud, js_data['text'])

In [8]:
class audio_dataset(Dataset):
    def __init__(self):
        self.audios = []
        self.sents = []

    def __len__(self):
        return len(self.audios)

    def __getitem__(self, i):
        return {"raw": self.audios[i]['array'], "sampling_rate":self.audios[i]['sampling_rate'],"audio_path" :self.audios[i]['path'] , "reference":self.sents[i]}

    def fill_data(self, aud, sent):
        self.audios.append(aud)
        self.sents.append(sent)


In [9]:
manifest_path = '/kaggle/working/data/svarah/svarah_manifest.json'
train_portion = 0.7
with open(manifest_path, 'r') as f:
    data = f.read()
    splits = data.split('\n')[:-1]
    jsons = [json.loads(split) for split in splits]
    for js in jsons:
        js['audio_filepath'] = '/kaggle/working/data/svarah/'+js['audio_filepath']
    splits = [json.dumps(js) for js in jsons]
    random.seed(0)
    random.shuffle(splits)
    train_last_idx = int(len(splits)*train_portion)
    train_splits = splits[:train_last_idx]
    eval_splits =  splits[train_last_idx:]

In [10]:
da = Parallel(n_jobs=20)(delayed(get_data)(split) for split in tqdm(splits))
norm = EnglishTextNormalizer()
dataset = audio_dataset()
for d in da:
    tr = norm(d[1])
    if not re.search('\d',tr): # remove tags with numbers
        dataset.fill_data(d[0], tr)

100%|██████████| 6656/6656 [00:14<00:00, 464.23it/s]


In [46]:
# load pretrained
quartznet = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name="QuartzNet15x5Base-En",map_location='cpu')

# freeze all encoder layers, enable batchnorm and SqueezeExcite
quartznet.encoder.freeze()

def enable_bn_se(m):
    if type(m) == nn.BatchNorm1d:
        m.train()
        for param in m.parameters():
            param.requires_grad_(True)

    if 'SqueezeExcite' in type(m).__name__:
        m.train()
        for param in m.parameters():
            param.requires_grad_(True)
            
quartznet.encoder.apply(enable_bn_se)
quartznet = quartznet

[NeMo I 2024-09-21 11:41:57 cloud:58] Found existing object /root/.cache/torch/NeMo/NeMo_2.1.0rc0/QuartzNet15x5Base-En/2b066be39e9294d7100fb176ec817722/QuartzNet15x5Base-En.nemo.
[NeMo I 2024-09-21 11:41:57 cloud:64] Re-using file from: /root/.cache/torch/NeMo/NeMo_2.1.0rc0/QuartzNet15x5Base-En/2b066be39e9294d7100fb176ec817722/QuartzNet15x5Base-En.nemo
[NeMo I 2024-09-21 11:41:57 common:826] Instantiating model from pre-trained checkpoint
[NeMo I 2024-09-21 11:41:58 features:305] PADDING: 16
[NeMo I 2024-09-21 11:41:59 save_restore_connector:275] Model EncDecCTCModel was successfully restored from /root/.cache/torch/NeMo/NeMo_2.1.0rc0/QuartzNet15x5Base-En/2b066be39e9294d7100fb176ec817722/QuartzNet15x5Base-En.nemo.


In [49]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self,old_ds,vocab):
        self.old_ds = old_ds
        self.max_len_inputs = max([len(dataset[i]['raw']) for i in range(len(old_ds))])
        self.max_len_transcription = max([len(dataset[i]['reference']) for i in range(len(old_ds))])
        self.vocab = vocab
        self.normalizer = EnglishTextNormalizer()
        

    def _pad_inputs(self,x):
        return torch.nn.functional.pad(x.T,(0,self.max_len_inputs-x.shape[0])).T

    def _pad_transcription(self,x):
        return torch.nn.functional.pad(x,(0,self.max_len_transcription-x.shape[0]))

    def __len__(self):
        return len(self.old_ds)

    def __getitem__(self,idx):
        item = self.old_ds[idx]
        transcription = self.normalizer(item['reference'])
        len_transcription = len(transcription)
        transcription = torch.tensor([self.vocab.index(char) for char in transcription])
        transcription = self._pad_transcription(transcription)
        inputs = torch.tensor(item['raw'])
        len_inputs = inputs.shape[0]
        inputs = self._pad_inputs(inputs)
        return inputs,transcription,len_inputs,len_transcription

In [50]:
qnet_ds = Dataset(dataset,vocab=quartznet.decoder.vocabulary)
loader = torch.utils.data.DataLoader(qnet_ds, batch_size=32, shuffle=True)

In [65]:
model.forward()

torch.Size([1, 1488, 29])

## Train on new data

In [None]:
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt

# Define the model, optimizer, and loss function
model = quartznet

vocab = quartznet.decoder.vocabulary
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)
criterion = torch.nn.CTCLoss(blank=len(vocab),zero_infinity=True)

# Create the dataset and dataloader
dataloader = torch.utils.data.DataLoader(qnet_ds, batch_size=1, shuffle=True)

# Training loop
num_epochs = 200
for epoch in range(num_epochs):
    losses = []
    model.train()
    for inputs, transcription, len_inputs, len_transcription in tqdm(dataloader):
        optimizer.zero_grad()

        # Pass the logits through the model
        output = model(input_signal=inputs,input_signal_length=len_inputs)
        logits,length,_ = output
        
        # Reshape the output for CTC loss
        logits = logits.permute(1, 0, 2)  # (T, N, C)

        # Calculate the CTC loss
        loss = criterion(logits, transcription, length, len_transcription)

        # Backpropagate and optimize
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

#     # evaluate wer and cer
#     model.eval()
#     total_pred = ''
#     total_gt = ''
#     total_pre_pred = ''
#     eval_vocab = np.array(vocab + ['eps'])
#     with torch.no_grad():
#         for logits, transcription, len_logits, len_transcription in evalset:
#             output = model(logits[:len_logits].unsqueeze(0).cuda()).squeeze(0).cpu()
#             total_gt += ''.join(eval_vocab[transcription[:len_transcription]])
#             total_pred += transcribe(output.detach().numpy(),eval_vocab)
#             # if epoch == 0:
#             #   total_pre_pred += transcribe(logits[:len_logits].detach().numpy(),eval_vocab)

#     wer,cer = eval_wer_cer(total_gt,total_pred)
#     # if epoch == 0:
#     #   old_wer,old_cer = eval_wer_cer(total_gt,total_pre_pred)
#     print(f"""Epoch {epoch+1}/{num_epochs}, Loss: {np.mean(losses):.3f} WER: {wer:.3f}, CER: {cer:.3f}""")

#     # mat = model.linear.weight.detach().cpu().numpy()
#     # plt.imshow(mat, cmap='viridis')
#     # plt.colorbar()
#     # plt.show()


      with torch.cuda.amp.autocast(enabled=False):
    
  0%|          | 9/5502 [00:21<3:33:54,  2.34s/it]

In [21]:
# !pip install GPUtil

import torch
from GPUtil import showUtilization as gpu_usage
from numba import cuda

def free_gpu_cache():
    print("Initial GPU Usage")
    gpu_usage()                             

    torch.cuda.empty_cache()

    cuda.select_device(0)
    cuda.close()
    cuda.select_device(0)

    print("GPU Usage after emptying the cache")
    gpu_usage()

free_gpu_cache()  

Initial GPU Usage
| ID | GPU | MEM |
------------------
|  0 |  0% | 99% |
GPU Usage after emptying the cache
| ID | GPU | MEM |
------------------
|  0 | 95% |  2% |


In [30]:
from numba import cuda
import torch 
# device = cuda.get_current_device()
device.reset()


In [32]:
torch.cuda.is_available()

True

In [40]:
cuda.get_current_device()

CudaAPIError: [700] Call to cuDevicePrimaryCtxRetain results in UNKNOWN_CUDA_ERROR