In [2]:
!pip install dataloader
!pip install fairseq
!pip install pip==24.0
!pip install hydra-core==1.0.7 omegaconf==2.0.6
!git clone https://github.com/pytorch/fairseq
!pip install --editable ./fairseq -v

Collecting dataloader
  Downloading dataloader-2.0.tar.gz (9.1 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: dataloader
  Building wheel for dataloader (setup.py) ... [?25l[?25hdone
  Created wheel for dataloader: filename=dataloader-2.0-py3-none-any.whl size=10085 sha256=028cd8eb6feefca46c69a6e7d2b2d47bdbbe3adb04489d074db5da90c030f658
  Stored in directory: /root/.cache/pip/wheels/60/56/53/2b1c14a2abb6f40f1d59f97461a59e61f326433fac416794de
Successfully built dataloader
Installing collected packages: dataloader
Successfully installed dataloader-2.0
Collecting fairseq
  Downloading fairseq-0.12.2.tar.gz (9.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.6/9.6 MB[0m [31m63.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metada

In [3]:
!pip install jiwer

Collecting jiwer
  Downloading jiwer-3.0.5-py3-none-any.whl.metadata (2.7 kB)
Collecting rapidfuzz<4,>=3 (from jiwer)
  Downloading rapidfuzz-3.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Downloading jiwer-3.0.5-py3-none-any.whl (21 kB)
Downloading rapidfuzz-3.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m32.9 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h[33mDEPRECATION: omegaconf 2.0.6 has a non-standard dependency specifier PyYAML>=5.1.*. pip 24.1 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of omegaconf or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0mInstalling collected packages: rapidfuzz, jiwer
Successfully installed jiwer-3.0.5 rapidfuzz-3.11.0


In [4]:
import jiwer
import fairseq

In [5]:
# Necessary Library Imports
import torch
import torchaudio
import torch.utils.data
import json
from torch import nn
from torch.utils.data import Dataset
from typing import List, Dict

# For Trainer
import os
# import wandb
from pathlib import Path
from glob import glob
from torch import optim, nn, Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
from abc import ABC
from torchmetrics import WordErrorRate

## **Audio Dataset**

In [6]:
class AudioDataset(Dataset):
    def __init__(self, label_manifest_path: str, unlabel_manifest_path: str = None):
        super().__init__()
        self.list_manifest = json.load(open(label_manifest_path, "r"))
        if unlabel_manifest_path:
            self.list_manifest += json.load(open(unlabel_manifest_path, "r"))

    def __len__(self):
        return len(self.list_manifest)

    def __getitem__(self, idx):
        """
        return:
        - for label example: wav, sr, text
        - for unlabel example: wav, sr, None
        """
        # label: { 'audio_filepath': ..., 'text': "hello there" }
        # unlabel: { 'audio_filepath': ...}
        sample = self.list_manifest[idx]
        wav, sr = torchaudio.load(sample.get("audio_filepath"))
        return dict(wav=wav, text=sample.get("text", None))

## **Utils**

In [7]:
test_subset = ["dev-clean", "dev-other", "test-clean", "test-other"]


def calc_length(lengths, padding, kernel_size, stride, ceil_mode, repeat_num=1):
    """Calculates the output length of a Tensor passed through a convolution or max pooling layer"""
    add_pad: float = (padding * 2) - kernel_size
    one: float = 1.0
    for i in range(repeat_num):
        lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one
        if ceil_mode:
            lengths = torch.ceil(lengths)
        else:
            lengths = torch.floor(lengths)
    return lengths.to(dtype=torch.int)


class ConvSubsampling(nn.Module):
    """Convolutional subsampling which supports VGGNet and striding approach introduced in:
    VGGNet Subsampling: https://arxiv.org/pdf/1910.12977.pdf
    Striding Subsampling:
        "Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong et al.
    Args:
        input_dim (int): size of the input features
        feat_out (int): size of the output features
        conv_channels (int): Number of channels for the convolution layers. (encoder dim)
        subsampling_factor (int): The subsampling factor which should be a power of 2
        activation (Module): activation function, default is nn.ReLU()
    """

    def __init__(
        self,
        input_dim: int = 80,
        feat_out: int = -1,
        conv_channels: int = -1,
        subsampling_factor: int = 4,
        activation=nn.ReLU(),
    ):
        super(ConvSubsampling, self).__init__()

        if subsampling_factor % 2 != 0:
            raise ValueError("Sampling factor should be a multiply of 2!")
        self._sampling_num = int(math.log(subsampling_factor, 2))

        in_channels = 1
        layers = []

        self._padding = 1
        self._stride = 2
        self._kernel_size = 3
        self._ceil_mode = False

        for i in range(self._sampling_num):
            layers.append(
                torch.nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=conv_channels,
                    kernel_size=self._kernel_size,
                    stride=self._stride,
                    padding=self._padding,
                )
            )
            layers.append(activation)
            in_channels = conv_channels

        in_length = torch.tensor(input_dim, dtype=torch.float)
        out_length = calc_length(
            in_length,
            padding=self._padding,
            kernel_size=self._kernel_size,
            stride=self._stride,
            ceil_mode=self._ceil_mode,
            repeat_num=self._sampling_num,
        )
        self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out)
        self.conv = torch.nn.Sequential(*layers)

    def forward(self, x, lengths):
        lengths = calc_length(
            lengths,
            padding=self._padding,
            kernel_size=self._kernel_size,
            stride=self._stride,
            ceil_mode=self._ceil_mode,
            repeat_num=self._sampling_num,
        )
        x = x.unsqueeze(1)
        x = self.conv(x)
        b, c, t, f = x.size()
        x = self.out(x.transpose(1, 2).reshape(b, t, -1))
        return x, lengths


class LogMelSpectrogram(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.mel_spec = torchaudio.transforms.MelSpectrogram(
            sample_rate=16000, **kwargs
        )

    def forward(self, inputs):
        return self.mel_spec(inputs)


class ComposeTransform:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, audio_data):
        for t in self.transforms:
            audio_data = t(audio_data)
        return audio_data


def count_params(model):
    if type(model) == nn.DataParallel:
        return model.module.count_params()
    return model.count_params()


def save_state_dict(model):
    if type(model) == nn.DataParallel:
        return model.module.state_dict()
    return model.state_dict()


class EarlyStopping:
    def __init__(self, tolerance=5, min_delta=0):

        self.tolerance = tolerance
        self.min_delta = min_delta
        self.counter = 0
        self.early_stop = False

    def __call__(self, train_loss, validation_loss):
        if (validation_loss - train_loss) > self.min_delta:
            self.counter += 1
            if self.counter >= self.tolerance:
                self.early_stop = True

## **Dataset (LibriLight)**

In [8]:
import torchaudio
#from utils import LogMelSpectrogram
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import os
from pathlib import Path
import torch
from typing import List


class LibriLight(Dataset):
    def __init__(
        self,
        data_path: str = "/kaggle/input/libri-phone/libri-phone",
        subsets: List[str] = ["light", "dev-clean", "dev-other"],
        n_fft: int = 2048,
        n_mels: int = 80,
        win_length: int = 400,
        hop_length: int = 100,
    ):
        super().__init__()
        # assert subset in ['light', 'dev-clean', 'dev-other', 'test-clean', 'test-other'], 'Not found subset'
        df = pd.read_csv(data_path + os.sep + "phones.csv")
        df = df[df.subset.isin(subsets)].drop("subset", axis=1)
        df.label = df.label.map(eval)
        df.path = data_path + os.sep + df.path
        df.path = df.path.apply(lambda x: x.replace("\\", os.sep))
        self.walker = df.to_dict("records")

        self.feature_transform = LogMelSpectrogram(
            n_fft=n_fft, n_mels=n_mels, win_length=win_length, hop_length=hop_length
        )

    def __len__(self):
        return len(self.walker)

    def __getitem__(self, idx):
        item = self.walker[idx]
        label = item["label"]
        wave, sr = torchaudio.load(item["path"])

        specs = self.feature_transform(wave)
        specs = specs.permute(0, 2, 1)
        specs = specs.squeeze()

        return specs, label


class LibriLightLibriSpeechDataset(Dataset):
    def __init__(
        self,
        light_data_path: str = "data/libri-phone",
        subset: str = "train",
        libri_clean_path: str = "data/LibriSpeech",
        n_fft: int = 1024,
        n_mels: int = 128,
        win_length: int = 400,
        hop_length: int = 200,
        **kwargs,
    ):
        super().__init__()
        """
        subset \in ['train', 'val', 'test']
        """
        self.list_url = []
        is_test = True
        if subset == "train":
            self.list_url = [libri_clean_path + "/train-clean-360"]
            is_test = False

        sep = os.sep
        self.libri_walker = []
        for path in self.list_url:
            files_path = f"*{sep}*{sep}*" + ".flac"
            walker = [(str(p.stem), path) for p in Path(path).glob(files_path)]
            self.libri_walker.extend(walker)

        if subset == "train":
            subsets = ["light", "dev-clean", "dev-other", "test-other"]
        else:
            subsets = ["test-clean"]

        df = pd.read_csv(light_data_path + os.sep + "phones.csv")
        df = df[df.subset.isin(subsets)].drop("subset", axis=1)
        df.label = df.label.map(eval)
        df.path = light_data_path + os.sep + df.path
        df.path = df.path.apply(lambda x: x.replace("\\", os.sep))
        self.light_walker = df.to_dict("records")

        self.walker = self.libri_walker + self.light_walker

        sample_rate = 16000
        self.feature_transform = LogMelSpectrogram(
            n_fft=n_fft, n_mels=n_mels, win_length=win_length, hop_length=hop_length
        )

    def __len__(self):
        return len(self.walker)

    def __getitem__(self, idx):
        item = self.walker[idx]
        if type(item) == tuple:
            return self.load_librispeech_item(item)
        return self.load_libri_light_item(item)

    def load_libri_light_item(self, item):
        label = item["label"]
        wave, sr = torchaudio.load(item["path"])

        specs = self.feature_transform(wave)
        specs = specs.permute(0, 2, 1)
        specs = specs.squeeze()

        return specs, label, "labeled"

    def load_librispeech_item(self, item):
        """
        transform audio pack to spectrogram
        """
        fileid, path = item

        speaker_id, chapter_id, utterance_id = fileid.split("-")
        fileid_audio = speaker_id + "-" + chapter_id + "-" + utterance_id
        file_audio = fileid_audio + ".flac"
        file_audio = os.path.join(path, speaker_id, chapter_id, file_audio)

        # Load audio
        waveform, sample_rate = torchaudio.load(file_audio)

        spectrogram = self.feature_transform(waveform)
        spectrogram = spectrogram.squeeze().permute(1, 0)

        return spectrogram, "unlabel"


# class TimitLibriSpeechDataset(Dataset):
#     def __init__(
#         self,
#         timit_data_root: str = 'data/timit',
#         timit_csv_path: str = 'data/timit/timit_pronunciation.csv',
#         libri_clean_path: str = 'data/libri/LibriSpeech',
#         libri_other_path: str = '',
#         subset: str = 'train',
#         n_fft: int = 512,
#         n_mels: int = 80,
#         **kwargs,
#     ):
#         super().__init__()
#         """
#         subset \in ['train', 'val', 'test']
#         """
#         self.list_url = []
#         is_test = True
#         if subset == "train":
#             self.list_url = [libri_clean_path + "train-clean-100"]
#             is_test = False

#         sep = os.sep
#         self.libri_walker = []
#         for path in self.list_url:
#             files_path = f"*{sep}*{sep}*" + '.flac'
#             walker = [(str(p.stem), path) for p in Path(path).glob(files_path)]
#             self.libri_walker.extend(walker)

#         df = pd.read_csv(timit_csv_path, index_col=0)
#         df = df[df.is_test == is_test].drop("is_test", axis=1)
#         df.path = df.path.apply(lambda x: timit_data_root + os.sep + x)
#         df.trans = df.trans.str.split("|")
#         self.is_test = is_test

#         self.timit_walker = df.to_dict("records")

#         self.walker = self.libri_walker + self.timit_walker

#         sample_rate = 16000
#         self.feature_transform = LogMelSpectrogram(n_fft=n_fft, n_mels=n_mels)
#         self.augmentation = ComposeTransform([
#             SpeedPerturbation(sample_rate),
#             RandomBackgroundNoise(sample_rate, max_snr_db=30)
#         ])
#         self.augment_prob = 0.80


#     def __len__(self):
#         return len(self.walker)

#     def __getitem__(self, idx):
#         item = self.walker[idx]
#         if type(item) == tuple:
#             return self.load_librispeech_item(item)
#         return self.load_timit_item(item)

#     def load_timit_item(self, item):
#         trans = item["trans"]
#         wave, sr = torchaudio.load(item["path"])

#         is_augment = np.random.choice(2, p=(1 - self.augment_prob, self.augment_prob))
#         if is_augment and not self.is_test:
#             wave = self.augmentation(wave)

#         specs = self.feature_transform(wave)
#         specs = specs.permute(0, 2, 1)
#         specs = specs.squeeze()
#         return specs, trans, 'labelled'

#     def load_librispeech_item(self, item):
#         """
#         transform audio pack to spectrogram
#         """
#         fileid, path = item

#         speaker_id, chapter_id, utterance_id = fileid.split("-")
#         fileid_audio = speaker_id + "-" + chapter_id + "-" + utterance_id
#         file_audio = fileid_audio + '.flac'
#         file_audio = os.path.join(path, speaker_id, chapter_id, file_audio)

#         # Load audio
#         waveform, sample_rate = torchaudio.load(file_audio)

#         is_augment = np.random.choice(2, p=(1 - self.augment_prob, self.augment_prob))
#         if is_augment:
#             waveform = self.augmentation(waveform)

#         spectrogram = self.feature_transform(waveform)
#         spectrogram = spectrogram.squeeze().permute(1, 0)

#         return spectrogram, 'unlabel'


class TimitLibriSpeechDataset(Dataset):
    def __init__(
        self,
        timit_data_root: str = "/kaggle/input/libri-phone/libri-phone",
        timit_csv_path: str = "/kaggle/input/libri-phone/libri-phone/phones.csv",
        libri_clean_path: str = "data/libri/LibriSpeech",
        libri_other_path: str = "",
        subset: str = "train",
        n_fft: int = 512,
        n_mels: int = 80,
        **kwargs,
    ):
        super().__init__()
        """
        subset \in ['train', 'val', 'test']
        """
        self.list_url = []
        is_test = True
        if subset == "train":
            self.list_url = [libri_clean_path + "train-clean-100"]
            is_test = False

        sep = os.sep
        self.libri_walker = []
        for path in self.list_url:
            files_path = f"*{sep}*{sep}*" + ".flac"
            walker = [(str(p.stem), path) for p in Path(path).glob(files_path)]
            self.libri_walker.extend(walker)

        df = pd.read_csv(timit_csv_path, index_col=0)
        df = df[df.is_test == is_test].drop("is_test", axis=1)
        df.path = df.path.apply(lambda x: timit_data_root + os.sep + x)
        df.trans = df.trans.str.split("|")
        self.is_test = is_test

        self.timit_walker = df.to_dict("records")

        self.walker = self.libri_walker + self.timit_walker

        sample_rate = 16000
        self.feature_transform = LogMelSpectrogram(n_fft=n_fft, n_mels=n_mels)

    def __len__(self):
        return len(self.walker)

    def __getitem__(self, idx):
        item = self.walker[idx]
        if type(item) == tuple:
            return self.load_librispeech_item(item)
        return self.load_timit_item(item)

    def load_timit_item(self, item):
        trans = item["trans"]
        wave, sr = torchaudio.load(item["path"])

        specs = self.feature_transform(wave)
        specs = specs.permute(0, 2, 1)
        specs = specs.squeeze()
        return specs, trans, "labelled"

    def load_librispeech_item(self, item):
        """
        transform audio pack to spectrogram
        """
        fileid, path = item

        speaker_id, chapter_id, utterance_id = fileid.split("-")
        fileid_audio = speaker_id + "-" + chapter_id + "-" + utterance_id
        file_audio = fileid_audio + ".flac"
        file_audio = os.path.join(path, speaker_id, chapter_id, file_audio)

        # Load audio
        waveform, sample_rate = torchaudio.load(file_audio)

        spectrogram = self.feature_transform(waveform)
        spectrogram = spectrogram.squeeze().permute(1, 0)

        return spectrogram, "unlabel"


class TimitDataset(Dataset):
    def __init__(
        self,
        data_root: str = "data/timit",
        csv_path: str = "data/timit/timit_pronunciation.csv",
        n_fft: int = 159,
        n_mels: int = 80,
        is_test: bool = False,
        **kwargs,
    ):
        super().__init__()
        df = pd.read_csv(csv_path, index_col=0)
        df = df[df.is_test == is_test].drop("is_test", axis=1)
        df.path = df.path.apply(lambda x: data_root + os.sep + x)
        df.trans = df.trans.str.split("|")
        self.is_test = is_test
        sample_rate = 16000

        self.walker = df.to_dict("records")
        self.feature_transform = LogMelSpectrogram(n_fft=n_fft, n_mels=n_mels)

    def __len__(self):
        return len(self.walker)

    def __getitem__(self, idx):
        item = self.walker[idx]
        trans = item["trans"]
        wave, sr = torchaudio.load(item["path"])

        specs = self.feature_transform(wave)

        specs = specs.permute(0, 2, 1)
        specs = specs.squeeze()

        return specs, trans

# **DataLoader**

In [9]:
# from dataloader.dataset import AudioDataset

from torch.utils.data import DataLoader
from functools import partial


def collate_fn(batch, audio_transform, text_process):
    transform_fn = lambda x: audio_transform(x).squeeze().permute(1, 0)
    text_transform_fn = lambda x: text_process.text2int(text_process.tokenize(x))

    wavs = [i["wav"] for i in batch]
    feat = list(map(transform_fn, wavs))
    feat_len = torch.LongTensor([i.size(0) for i in feat])
    feat = pad_sequene(feat, batch_first=True)

    trans = [
        (idx, item["text"])
        for idx, text in enumerate(batch)
        if item["text"] is not None
    ]

    if len(trans) == len(batch):
        target = list(map(text_transform_fn, [i[1] for i in trans]))
        target_len = torch.LongTensor([i.size(0) for i in target])
        target = pad_sequence(target, batch_first=True)

        # for all label examples
        return feat, feat_len, target, target_len

    return feat, feat_len, trans


def create_dataloader(dataset_cfg: dict, dataloader_cfg: dict, text_process):

    audio_transform = MelSpectrogram(**dataset_cfg.mel_spectrogram_cfg)

    fused_collate_fn = partial(
        collate_fn, text_process=text_process, audio_transform=audio_transform
    )

    train_loader = DataLoader(
        AudioDataset(
            label_manifest_path=dataset_cfg.manifest_path.train.label,
            unlabel_manifest_path=dataset_cfg.manifest_path.train.unlabel,
        ),
        shuffle=True,
        collate_fn=fused_collate_fn,
        **dataloader_cfg
    )

    valid_loader = DataLoader(
        AudioDataset(
            label_manifest_path=dataset_cfg.manifest_path.train.valid,
        ),
        shuffle=False,
        collate_fn=fused_collate_fn,
        **dataloader_cfg
    )

    test_loader = DataLoader(
        AudioDataset(
            label_manifest_path=dataset_cfg.manifest_path.train.test,
        ),
        shuffle=False,
        collate_fn=fused_collate_fn,
        **dataloader_cfg
    )

    predict_loader = DataLoader(
        AudioDataset(
            label_manifest_path=dataset_cfg.manifest_path.train.predict,
        ),
        shuffle=False,
        collate_fn=fused_collate_fn,
        **dataloader_cfg
    )

    return train_loader, valid_loader, test_loader, predict_loader

 

# **Training Model**

In [10]:
import os
# import wandb
from pathlib import Path
from glob import glob
from torch import optim, nn, Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
from abc import ABC
from torchmetrics import WordErrorRate


class Trainer(ABC):
    def __init__(
        self,
        max_epochs: int,
        experiment_path: str,
        wandb_conf: dict,
        optim_conf: dict,
        scheduler_conf: dict,
        text_process,
        device: str = "cpu",
    ):
        super().__init__()
        self.device = device
        self.optim_conf = optim_conf
        self.scheduler_conf = scheduler_conf
        self.max_epochs = max_epochs
        self.wandb_conf = wandb_conf
        self.experiment_path = experiment_path

        if os.path.exists(experiment_path):
            os.mkdir(experiment_path)

        if wandb_conf.is_log:
            wandb.init(**wandb_conf.config)

    def get_optimizer_and_scheduler(self, model: nn.Module, dataloader: DataLoader):
        optimizer = getattr(optim, self.optim_conf.optim_name)(
            model.paramters(), **self.optim_conf
        )
        if self.scheduler_conf.sched_name == "OneCycleLR":
            # for training only
            self.scheduler_conf.update(
                {"total_steps": len(dataloader) * self.max_epochs}
            )
        scheduler = getattr(optim.lr_scheduler, self.scheduler_conf.sched_name)(
            optimizer, **self.scheduler_conf
        )

        if self.optim_ckpt:
            optimizer.load_state_dict(self.optim_ckpt.get("optimizer_state_dict"))
            scheduler.load_state_dict(self.optim_ckpt.get("scheduler_state_dict"))

        return optimizer, scheduler

    def save_ckpt(
        self,
        model: nn.Module,
        optimizer: Optimizer,
        scheduler: _LRScheduler,
        epoch: int = -1,
        step: int = -1,
    ) -> None:
        trainer = dict(
            optim=dict(
                optimizer_state_dict=optimizer.state_dict(),
                scheduler_state_dict=scheduler.state_dict(),
            ),
            hyperparams=self.__dict__,
        )

        model = dict(model_state_dict=model.state_dict(), hyperparams=model.__dict__)

        version = len(glob(str(Path(self.experiment_path) / "version_*")))
        version_path = os.path.join(self.experiment_path, f"version_{version}")
        os.mkdir(version_path)

        trainer_name = f"{self.__class__.__name__}.epoch={epoch}.step={step}.pt"
        model_name = f"{model.__class__.__name__}.epoch={epoch}.step={step}.pt"

        trainer_path = os.path.join(version_path, trainer_name)
        model_path = os.path.join(version_path, model_name)

        torch.save(trainer, trainer_path)
        torch.save(model, model_path)

    def load_from_ckpt(self, ckpt_path: dict) -> None:
        trainer_ckpt_path = ckpt_path.get("trainer")

        if trainer_ckpt_path:
            trainer = torch.load(trainer_ckpt_path)
            for k, v in trainer.get("hyperparams"):
                setattr(self, k, v)

        print("<Restore Trainer checkpoint successfully>")

    def train(self, model: nn.Module, dataloader: DataLoader):
        optimizer, scheduler = self.get_optimizer_and_scheduler(model, dataloader)

        for epoch in range(1, self.max_epochs + 1):
            self.train_epoch(model, dataloader, optimizer, scheduler, epoch)
            self.test_epoch(model, dataloader, epoch, "valid")

            if self.scheduler_conf.interval == "epoch":
                scheduler.step()

    def test(self, model: nn.Module, dataloader: DataLoader):
        self.test_epoch(model, dataloader, 0, "test")

    def predict(self, model: nn.Module, dataloader: DataLoader, outcome_path: str):
        self.test_epoch(model, dataloader, 0, outcome_path)

    def train_epoch(
        self,
        model: nn.Module,
        dataloader: DataLoader,
        optimizer: Optimizer,
        scheduler: _LRScheduler,
        epoch: int,
    ):
        pass

    def test_epoch(
        self,
        model: nn.Module,
        dataloader: DataLoader,
        epoch: int,
        task: str = "test",
        outcome_name: str = None,
    ):
        pass


class TeacherTrainer(Trainer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def train_epoch(
        self,
        model: nn.Module,
        dataloader: DataLoader,
        optimizer: Optimizer,
        scheduler: _LRScheduler,
        epoch: int,
    ):
        size = len(dataloader)
        pbar = tqdm(dataloader, total=size)

        for batch_idx, batch in enumerate(tqdm, start=1):
            feat, feat_len, target, target_len = list(
                map(lambda x: x.to(self.device), batch)
            )

            optimizer.zero_grad()

            # for training only
            out, out_len, loss = model(feat, feat_len, target, target_len)

            loss.backward()

            optimizer.step()

            if self.scheduler_conf.interval == "step":
                scheduler.step()

            # if self.wandb_conf.is_log:
            #     wandb.log({"train/loss": loss.item()})

            #     sched_name = scheduler.__class__.__name__
            #     last_lr = scheduler.get_last_lr()[0]
            #     wandb.log({f"lr-{sched_name}": last_lr})

            pbar.set_description(f"[Epoch: {epoch}] Loss: {loss.item():.2f}")

            self.save_ckpt(model, optimizer, scheduler, epoch, epoch * batch_idx)

    def test_epoch(
        self,
        model: nn.Module,
        dataloader: DataLoader,
        epoch: int,
        task: str = "test",
        outcome_name: str = None,
    ):
        size = len(dataloader)
        pbar = tqdm(dataloader, total=size)
        cal_wer = WordErrorRate()

        outcome_path = os.path.join(self.experiment_path, outcome_name)

        with open(outcome_path, "a") as f:
            f.write("=" * 10 + f"{task} | Epoch: {epoch}" + "=" * 10)
            f.write("\n")

        with torch.inference_mode():
            for batch_idx, batch in enumerate(tqdm, start=1):
                feat, feat_len, target, target_len = list(
                    map(lambda x: x.to(self.device), batch)
                )

                # for training only
                out, out_len, loss = model(
                    feat, feat_len, target, target_len, predict=True
                )

                predict = model.recognize(inputs, input_lengths)
                actual = list(map(self.text_process.int2text, targets))
                list_wer = [
                    cal_wer(hypot, truth).item()
                    for hypot, truth in zip(predict, actual)
                ]
                mean_wer = cal_wer(predict, actual).item()

                # with open(outcome_path, "a") as f:
                #     for pred, act, wer in zip(predict, actual, list_wer):
                #         f.write(f"PER    : {wer}\n")
                #         f.write(f"Actual : {act}\n")
                #         f.write(f"Predict: {pred}\n")
                #         f.write("=" * 20 + "\n")

                # if self.wandb_conf.is_log:
                #     wandb.log({f"{task}/loss": loss.item()})
                #     wandb.log({f"{task}/wer": mean_wer})

                pbar.set_description(
                    f"[Epoch: {epoch}] Loss: {loss.item():.2f} | WER: {mean_wer:.2f}%"
                )


class StudentTrainer(Trainer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def train_epoch(
        self,
        teacher_model: nn.Module,
        student_model: nn.Module,
        dataloader: DataLoader,
        optimizer: Optimizer,
        scheduler: _LRScheduler,
        epoch: int,
    ):
        size = len(dataloader)
        pbar = tqdm(dataloader, total=size)

        for batch_idx, batch in enumerate(tqdm, start=1):
            feat, feat_len, trans = batch
            feat, feat_len = feat.to(self.device), feat_len.to(self.device)

            # teacher generate pseudo-label for student learning
            predicted = teacher_model.recognize(inputs, input_lengths)

            # replace the origin transcript of timit dataset
            for origin_trans, idx in trans:
                predicted[idx] = origin_trans

            for i in range(len(predicted)):
                if type(predicted[i]) == str:
                    predicted[i] = self.text_process.tokenize(predicted[i])

            predicted = [self.text_process.text2int(s) for s in predicted]

            target_len = torch.IntTensor([s.size(0) for s in predicted]).to(self.device)
            target = pad_sequence(predicted, batch_first=True).to(
                self.device, torch.int
            )

            optimizer.zero_grad()

            # for training only
            out, out_len, loss = student_model(feat, feat_len, target, target_len)

            loss.backward()

            optimizer.step()

            if self.scheduler_conf.interval == "step":
                scheduler.step()

            # if self.wandb_conf.is_log:
            #     wandb.log({"train/loss": loss.item()})

            #     sched_name = scheduler.__class__.__name__
            #     last_lr = scheduler.get_last_lr()[0]
            #     wandb.log({f"lr-{sched_name}": last_lr})

            pbar.set_description(f"[Epoch: {epoch}] Loss: {loss.item():.2f}")

            self.save_ckpt(
                student_model, optimizer, scheduler, epoch, epoch * batch_idx
            )

    def test_epoch(
        self,
        student_model: nn.Module,
        dataloader: DataLoader,
        epoch: int,
        outcome_name: str = None,
    ):
        size = len(dataloader)
        pbar = tqdm(dataloader, total=size)
        cal_wer = WordErrorRate()

        outcome_path = os.path.join(self.experiment_path, outcome_name)

        with open(outcome_path, "a") as f:
            f.write("=" * 10 + f"{task} | Epoch: {epoch}" + "=" * 10)
            f.write("\n")

        with torch.inference_mode():
            for batch_idx, batch in enumerate(tqdm, start=1):
                feat, feat_len, target, target_len = list(
                    map(lambda x: x.to(self.device), batch)
                )

                # for training only
                out, out_len, loss = student_model(
                    feat, feat_len, target, target_len, predict=True
                )

                predict = student_model.recognize(inputs, input_lengths)
                actual = list(map(self.text_process.int2text, targets))
                list_wer = [
                    cal_wer(hypot, truth).item()
                    for hypot, truth in zip(predict, actual)
                ]
                mean_wer = cal_wer(predict, actual).item()

                # with open(outcome_path, "a") as f:
                #     for pred, act, wer in zip(predict, actual, list_wer):
                #         f.write(f"PER    : {wer}\n")
                #         f.write(f"Actual : {act}\n")
                #         f.write(f"Predict: {pred}\n")
                #         f.write("=" * 20 + "\n")

                # if self.wandb_conf.is_log:
                #     wandb.log({f"{task}/loss": loss.item()})
                #     wandb.log({f"{task}/wer": mean_wer})

                pbar.set_description(
                    f"[Epoch: {epoch}] Loss: {loss.item():.2f} | WER: {mean_wer:.2f}%"
                )

## **Convolutional Subsampling**

In [13]:
class ConvSubsampling(nn.Module):
    def __init__(
        self,
        input_dim: int = 80,
        feat_out: int = -1,
        conv_channels: int = -1,
        subsampling_factor: int = 4,
        activation=nn.ReLU(),
    ):
        super(ConvSubsampling, self).__init__()

        if subsampling_factor % 2 != 0:
            raise ValueError("Sampling factor should be a multiply of 2!")
        self._sampling_num = int(math.log(subsampling_factor, 2))

        in_channels = 1
        layers = []

        self._padding = 1
        self._stride = 2
        self._kernel_size = 3
        self._ceil_mode = False

        for i in range(self._sampling_num):
            layers.append(
                torch.nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=conv_channels,
                    kernel_size=self._kernel_size,
                    stride=self._stride,
                    padding=self._padding,
                )
            )
            layers.append(activation)
            in_channels = conv_channels

        in_length = torch.tensor(input_dim, dtype=torch.float)
        out_length = calc_length(
            in_length,
            padding=self._padding,
            kernel_size=self._kernel_size,
            stride=self._stride,
            ceil_mode=self._ceil_mode,
            repeat_num=self._sampling_num,
        )
        self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out)
        self.conv = torch.nn.Sequential(*layers)

    def calc_length(
        self, lengths, padding, kernel_size, stride, ceil_mode, repeat_num=1
    ):
        """Calculates the output length of a Tensor passed through a convolution or max pooling layer"""
        add_pad: float = (padding * 2) - kernel_size
        one: float = 1.0
        for i in range(repeat_num):
            lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one
            if ceil_mode:
                lengths = torch.ceil(lengths)
            else:
                lengths = torch.floor(lengths)
        return lengths.to(dtype=torch.int)

    def forward(self, x, lengths):
        lengths = self.calc_length(
            lengths,
            padding=self._padding,
            kernel_size=self._kernel_size,
            stride=self._stride,
            ceil_mode=self._ceil_mode,
            repeat_num=self._sampling_num,
        )
        x = x.unsqueeze(1)
        x = self.conv(x)
        b, c, t, f = x.size()
        x = self.out(x.transpose(1, 2).reshape(b, t, -1))
        return x, lengths

# **Audio Augmentation**
  ###  (To Add noise to Audio)

In [14]:
import torch
import random
import torchaudio
from torch import nn
import math
import os
import pathlib
import numpy as np


class SpeedPerturbation(nn.Module):
    def __init__(self, sample_rate):
        super().__init__()
        self.sample_rate = sample_rate

    @torch.inference_mode()
    def forward(self, audio_data):
        #         speed_factor = random.choice([0.9, 1.0, 1.1])
        speed_factor = np.random.uniform(0.9, 1.1)
        if speed_factor == 1.0:  # no change
            return audio_data

        # change speed and resample to original rate:
        sox_effects = [
            ["speed", str(speed_factor)],
            ["rate", str(self.sample_rate)],
        ]
        transformed_audio, _ = torchaudio.sox_effects.apply_effects_tensor(
            audio_data, self.sample_rate, sox_effects
        )
        return transformed_audio


class RandomBackgroundNoise:
    def __init__(
        self,
        noise_dir: str,
        sample_rate: int,
        min_snr_db: int = 0,
        max_snr_db: int = 15,
    ):
        self.sample_rate = sample_rate
        self.min_snr_db = min_snr_db
        self.max_snr_db = max_snr_db

        if not os.path.exists(noise_dir):
            raise IOError(f"Noise directory `{noise_dir}` does not exist")
        # find all WAV files including in sub-folders:
        noise = list(pathlib.Path(noise_dir).glob("noise/**/*.wav"))
        self.noise_files_list = noise
        if len(self.noise_files_list) == 0:
            raise IOError(f"No .wav file found in the noise directory `{noise_dir}`")

    def __call__(self, audio_data):
        random_noise_file = random.choice(self.noise_files_list)
        effects = [
            ["remix", "1"],  # convert to mono
            ["rate", str(self.sample_rate)],  # resample
        ]
        noise, _ = torchaudio.sox_effects.apply_effects_file(
            random_noise_file, effects, normalize=True
        )
        audio_length = audio_data.shape[-1]
        noise_length = noise.shape[-1]
        if noise_length > audio_length:
            offset = random.randint(0, noise_length - audio_length)
            noise = noise[..., offset : offset + audio_length]
        elif noise_length < audio_length:
            noise = torch.cat(
                [noise, torch.zeros((noise.shape[0], audio_length - noise_length))],
                dim=-1,
            )

        snr_db = random.randint(self.min_snr_db, self.max_snr_db)
        snr = math.exp(snr_db / 10)
        audio_power = audio_data.norm(p=2)
        noise_power = noise.norm(p=2)
        scale = snr * noise_power / audio_power

        return (scale * audio_data + noise) / 2


class SpecAugment(nn.Module):
    def __init__(self, freq_masks=2, time_masks=10, freq_width=27, time_width=0.05):
        super().__init__()
        self._rng = random.Random()
        self.freq_masks = freq_masks
        self.time_masks = time_masks
        self.freq_width = freq_width
        self.time_width = time_width
        self.mask_value = 0

    @torch.inference_mode()
    def forward(self, input_spec, length):
        sh = input_spec.shape
        for idx in range(sh[0]):
            for i in range(self.freq_masks):
                x_left = self._rng.randint(0, sh[2] - self.freq_width)
                w = self._rng.randint(0, self.freq_width)
                input_spec[idx, :, x_left : x_left + w] = self.mask_value

            for i in range(self.time_masks):
                time_width = max(1, int(length[idx] * self.time_width))
                y_left = self._rng.randint(0, max(1, length[idx] - time_width))
                w = self._rng.randint(0, time_width)
                input_spec[idx, y_left : y_left + w, :] = self.mask_value
        return input_spec, length


class AdaptiveSpecAugment(nn.Module):
    def __init__(
        self,
        freq_masks=2,
        time_masks=0.05,
        freq_width=27,
        time_width=0.05,
        max_time_masks=10,
    ):
        super().__init__()
        self._rng = random.Random()
        self.freq_masks = freq_masks
        self.time_masks = time_masks
        self.freq_width = freq_width
        self.time_width = time_width
        self.max_time_masks = max_time_masks
        self.mask_value = 0

    @torch.inference_mode()
    def forward(self, input_spec, length):
        sh = input_spec.shape
        for idx in range(sh[0]):
            for i in range(self.freq_masks):
                x_left = self._rng.randint(0, sh[2] - self.freq_width)
                w = self._rng.randint(0, self.freq_width)
                input_spec[idx, :, x_left : x_left + w] = self.mask_value

            time_masks = min(self.max_time_masks, int(length[idx] * self.time_masks))
            for i in range(time_masks):
                time_width = max(1, int(length[idx] * self.time_width))
                y_left = self._rng.randint(0, max(1, length[idx] - time_width))
                w = self._rng.randint(0, time_width)
                input_spec[idx, y_left : y_left + w, :] = self.mask_value
        return input_spec, length

## **Main Conformer Model**

In [15]:

class ConformerModel(nn.Module):
    def __init__(
        self,
        pretrained_conformer,
        input_dim: int = 80,
        vocab_size: int = 41,
        feat_extract_dim: int = 512,
        conv_channels: int = 256,
        conformer_dim: int = 1024,
        dropout_inp_proj: int = 0.1,
        dropout_outp_proj: int = 0.1,
        freq_masks=2,
        time_masks=0.05,
        freq_width=27,
        time_width=0.05,
        max_time_masks=10,
    ):
        super().__init__()
        self.augmentation = AdaptiveSpecAugment(
            freq_masks, time_masks, freq_width, time_width, max_time_masks
        )
        self.conv_subsampling = ConvSubsampling(
            input_dim=input_dim, feat_out=feat_extract_dim, conv_channels=conv_channels
        )
        self.input_projection = nn.Sequential(
            nn.Linear(feat_extract_dim, conformer_dim), nn.Dropout(dropout_inp_proj)
        )
        self.conformer = pretrained_conformer
        self.output_projection = nn.Sequential(
            nn.Linear(conformer_dim, conformer_dim),
            nn.SiLU(),
            nn.Dropout(dropout_outp_proj),
            nn.Linear(conformer_dim, vocab_size),
        )

    def freeze_conformer_blocks(self, n_block: int = 0):
        for l in range(n_block):
            for p in self.conformer.layers[l].parameters():
                p.requires_grad = False

    def forward(
        self,
        input_values: Tensor,
        length: Tensor,
        attention_mask: Tensor = None,
        predict: bool = False,
    ):
        if not predict:
            input_values, length = self.augmentation(input_values, length)

        out, length = self.conv_subsampling(input_values, length)
        hidden_states = self.input_projection(out)
        out = self.conformer(hidden_states, attention_mask).last_hidden_state
        out = self.output_projection(out)

        return out, length

    def count_params(self):
        return sum(p.numel() for p in self.parameters())

    def get_state_dict(self, path_to_ckpt: str):
        ckpt = torch.load(path_to_ckpt)
        state = self.load_state_dict(ckpt)
        print("Conformer:", state)
print("Model Defined Successfully!!")

Model Defined Successfully!!


## **CTC Model**

In [16]:
class CTCModel(nn.Module):
    def __init__(self, conformer_model: nn.Module, text_process):
        super().__init__()
        self.model = conformer_model
        self.ctc_loss = nn.CTCLoss()
        self.text_process = text_process

    def forward(self, feat, feat_len, target, target_len):
        out, out_len = self.model(feat, feat_len)
        if target and target_len:
            loss = self.criterion(out, target, out_len, target_len)
            return out, out_len, loss
        return out, out_len

    def criterion(
        self,
        logits: Tensor,
        targets: Tensor,
        input_lengths: Tensor,
        target_lengths: Tensor,
    ):
        log_prob = nn.functional.log_softmax(logits, dim=-1)
        return self.ctc_loss(log_prob, targets, input_lengths, target_lengths)

    def decode(encoder_output: Tensor):
        argmax = encoder_output.squeeze(0).argmax(-1)
        return self.text_process.decode(argmax)

    def recognize(inputs: Tensor, input_lengths: Tensor):
        outputs = list()

        encoder_outputs, _ = self(inputs, input_lengths)

        for encoder_output in encoder_outputs:
            predict = decode(encoder_output)
            outputs.append(predict)

        return outputs

## **Text Process**

In [17]:
import json
class TextProcess:
    def __init__(self, dataset="libri", **kwargs):
        assert dataset in ["libri", "timit"]
        """label for timit"""
        self.base_vocabs = [
            "<p>",
            "<s>",
            "<e>",
        ]
        if dataset == "libri":
            vocab = list(
                json.load(open("/kaggle/input/asr-model/phones_mapping.json", "r", encoding="utf-8")).keys()
            )
        else:
            vocab = list(json.load(open("/kaggle/input/asr-model/timit_vocab.json", "r")))
        self.vocabs = self.base_vocabs + vocab

        self.n_class = len(self.vocabs)
        self.label_vocabs = dict(zip(self.vocabs, range(self.n_class)))

        self.sos_id = 1
        self.eos_id = 2
        self.blank_id = 0

    def tokenize(self, data):
        return data

    def text2int(self, s: str) -> torch.Tensor:
        return torch.Tensor([self.label_vocabs[i] for i in s])

    def int2text(self, s: torch.Tensor) -> str:
        text = ""
        for i in s:
            if i in [self.sos_id, self.blank_id]:
                continue
            if i == self.eos_id:
                break
            text += " " + self.vocabs[i]
        return text

    def decode(self, argmax: torch.Tensor):
        """
        decode greedy with collapsed repeat
        """
        decode = []
        for i, index in enumerate(argmax):
            if index != self.blank_id:
                if i != 0 and index == argmax[i - 1]:
                    continue
                decode.append(index.item())
        return self.int2text(decode)

## **Teacher Model Training**

In [21]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn, Tensor, optim
import fairseq
import transformers
from torchsummary import summary
import torchaudio
from transformers import (
    AutoFeatureExtractor,
    Wav2Vec2FeatureExtractor,
    Trainer,
    Wav2Vec2ConformerForPreTraining,
)
import os
from pathlib import Path
from dataclasses import dataclass, field#Check--OK
import random
from copy import deepcopy
import matplotlib.pyplot as plt
from typing import List, Tuple
import pandas as pd
import jiwer
# import wandb
import time
# from audio_augmentation import (
#     SpeedPerturbation,
#     AdaptiveSpecAugment,
#     RandomBackgroundNoise,
# )
#from dataset import LibriLight # Done
#from utils import * # Done
#from text_process import TextProcess# Check This
import numpy as np
from transformers import Adafactor

num_hidden_layers = 2

# wandb.init(
#     project="speech_verification",
#     name=f"conformer_{num_hidden_layers}_hidden_gen_1_libri_subsampling_swish",
# )

device = "cuda:0" if torch.cuda.is_available() else "cpu"
# device = "cuda:0"
print("Device:", device)
# device = 'cpu'

"""# Self-training"""

batch_size = 32

text_process = TextProcess()

n_fft = 1024
win_length = 400
hop_length = 200
n_mels = 80

# lr = 0.0005 * batch_size ** (1 / 2)
lr = 0.001
max_epochs = 100
log_idx = 2

print("upto before Model it is cleared--->!!!!")
class ConformerModel(nn.Module):
    def __init__(
        self,
        pretrained_conformer,
        input_dim,
        vocab_size,
        freq_masks=2,
        time_masks=0.05,
        freq_width=27,
        time_width=0.05,
    ):
        super().__init__()
        # input_dim: 80 (n_mels)
        feat_extract_dim = 512
        conv_channels = 256
        conformer_dim = 1024
        self.spec_augment = AdaptiveSpecAugment(
            freq_masks, time_masks, freq_width, time_width
        )
        self.conv_subsampling = ConvSubsampling(
            input_dim=input_dim, feat_out=feat_extract_dim, conv_channels=conv_channels
        )
        self.input_projection = nn.Sequential(
            nn.Linear(feat_extract_dim, conformer_dim), nn.Dropout(0.1)
        )
        self.conformer = pretrained_conformer
        self.output_projection = nn.Sequential(
            # nn.Linear(conformer_dim, conformer_dim),
            # nn.Dropout(0.1),
            nn.SiLU(),
            nn.Linear(conformer_dim, vocab_size),
        )
        self.log_softmax = nn.LogSoftmax(-1)
    print("Model Defined Successfully!!!--->!!!!")
    def freeze_conformer_blocks(self):
        for p in self.conformer.layers[0].parameters():
            p.requires_grad = False
    print("Model Forwarding Started!!!--->!!!!")
    def forward(self, input_values, length, attention_mask=None):
        out, length = self.spec_augment(input_values, length)
        out, length = self.conv_subsampling(out, length)
        hidden_states = self.input_projection(out)
        out = self.conformer(hidden_states, attention_mask).last_hidden_state
        out = self.output_projection(out)
        out = self.log_softmax(out)
        return out, length

    def count_params(self):
        return sum(p.numel() for p in self.parameters())


def decode(encoder_output: Tensor) -> str:
    argmax = encoder_output.squeeze(0).argmax(-1)
    return text_process.decode(argmax)


def recognize(inputs: Tensor, input_lengths: Tensor, model: nn.Module) -> List[str]:
    outputs = list()

    encoder_outputs, _ = model(inputs, input_lengths)

    for encoder_output in encoder_outputs:
        predict = decode(encoder_output)
        outputs.append(predict)

    return outputs


def train_epoch(model, dataloader, optimizer, scheduler, criterion, epoch):
    size = len(dataloader)
    running_loss = 0
    for batch_idx, batch in enumerate(dataloader):
        inputs, input_lengths, targets, target_lengths = batch
        inputs, input_lengths = inputs.to(device), input_lengths.to(device)
        targets, target_lengths = targets.to(device), target_lengths.to(device)

        outputs, output_lengths = model(inputs, input_lengths)

        loss = criterion(
            outputs.permute(1, 0, 2), targets, output_lengths, target_lengths
        )

        if torch.isnan(loss).item() == True:
            break

        running_loss += loss.item()

        optimizer.zero_grad()

        loss.backward()

        optimizer.step()
        if scheduler:
            scheduler.step()

        # wandb.log({"train/epoch": epoch})
        # wandb.log({"train/loss": loss.item()})
        # if scheduler:
        #     wandb.log({"train/lr": scheduler.get_last_lr()[0]})
        
    return running_loss / len(dataloader)


def eval_epoch(model, dataloader, criterion, epoch, run_type="eval"):
    size = len(dataloader)
    start_time = time.perf_counter()
    running_loss = 0
    running_wer = 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            inputs, input_lengths, targets, target_lengths = batch
            inputs, input_lengths = inputs.to(device), input_lengths.to(device)
            targets, target_lengths = targets.to(device), target_lengths.to(device)

            outputs, output_lengths = model(inputs, input_lengths)

            loss = criterion(
                outputs.permute(1, 0, 2), targets, output_lengths, target_lengths
            )

            predict_sequences = recognize(inputs, input_lengths, model)
            label_sequences = list(map(text_process.int2text, targets))
            wer = torch.Tensor(
                [
                    jiwer.wer(truth, hypot)
                    for truth, hypot in zip(label_sequences, predict_sequences)
                ]
            )
            wer = torch.mean(wer).item()
            running_loss += loss.item()
            running_wer += wer

            # wandb.log({f"{run_type}/loss": loss.item()})
            # wandb.log({f"{run_type}/wer": wer})

    return running_loss / size, running_wer / size


train_dataset = LibriLight(
    n_fft=n_fft,
    n_mels=n_mels,
    win_length=win_length,
    hop_length=hop_length,
    subsets=["light", 'dev-clean', 'dev-other'],
)

test_dataset = LibriLight(
    n_fft=n_fft,
    n_mels=n_mels,
    win_length=win_length,
    hop_length=hop_length,
    subsets=['test-clean'],
)

# test_dataset = {}
# for subset in test_subset:
#     test_dataset[subset] = LibriLight(
#         n_fft=n_fft,
#         n_mels=n_mels,
#         win_length=win_length,
#         hop_length=hop_length,
#         subset=subset,
#     )


def collate_fn(batch):
    """
    Take feature and input, transform and then padding it
    """

    specs = [i[0] for i in batch]
    input_lengths = torch.IntTensor([i.size(0) for i in specs])
    trans = [i[1] for i in batch]
    
    bs = len(specs)

    # batch, time, feature
    specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True)

    trans = [text_process.text2int(s) for s in trans]
    target_lengths = torch.IntTensor([s.size(0) for s in trans])
    trans = torch.nn.utils.rnn.pad_sequence(trans, batch_first=True).to(dtype=torch.int)

    return specs, input_lengths, trans, target_lengths


train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    collate_fn=collate_fn,
    shuffle=True,
    pin_memory=True,
    num_workers=2,
    drop_last=False,
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    collate_fn=collate_fn,
    shuffle=False,
    pin_memory=True,
    num_workers=2,
)


# load pretrained big
wav2vec2_model = Wav2Vec2ConformerForPreTraining.from_pretrained(
    "facebook/wav2vec2-conformer-rel-pos-large"
)
wav2vec2_conformer = wav2vec2_model.wav2vec2_conformer.encoder
wav2vec2_conformer.layers = nn.ModuleList(
    [wav2vec2_conformer.layers[i] for i in range(num_hidden_layers)]
)


def count_params(model):
    if type(model) == nn.DataParallel:
        return model.module.count_params()
    return model.count_params()


vocab_size = text_process.n_class

conformer = ConformerModel(wav2vec2_conformer, input_dim=n_mels, vocab_size=vocab_size)
# if torch.cuda.device_count() > 1:
#     conformer = nn.DataParallel(conformer, device_ids=[0, 1])
conformer = conformer.to(device)
#print(conformer)
# print(
#     summary(conformer, [(300, n_mels), (1,)], dtypes=[torch.float, torch.long]) 
# )
ckpt = torch.load('/kaggle/input/nst-pretrained-model/teacher_2_hidden_libri_subsampling_swish_final.pt') # Need to Change This

conformer_state_dict = ckpt['conformer_state_dict']
conformer.load_state_dict(conformer_state_dict, strict=False)
print("Load Done!!!")
# conformer.load_state_dict(ckpt)
print("Model Successfully Loaded State Dictionary and Weights!!!!")
total_steps = len(train_dataloader) * max_epochs

print("Total steps:", total_steps)

criterion = nn.CTCLoss().to(device)
optimizer = optim.AdamW(conformer.parameters(), lr=lr, betas=(0.9, 0.9999))

scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=lr, pct_start=0.3, total_steps=total_steps
)
# scheduler = None

config = {
    "learning_rate": lr,
    "max_epochs": max_epochs,
    "batch_size": batch_size,
    "n_fft": n_fft,
    "n_mels": n_mels,
    "num_hidden_layers": num_hidden_layers,
    "dataset": "libri-light",
    "no_params": count_params(conformer),
    "augmentation": f"SpecAugment",
}
# wandb.config = config
early_stopping = EarlyStopping()

eval_loss, eval_wer = eval_epoch(
    conformer, test_dataloader, criterion, 0, 'val'
)

print("Eval wer:", eval_wer)


Device: cuda:0
upto before Model it is cleared--->!!!!
Model Defined Successfully!!!--->!!!!
Model Forwarding Started!!!--->!!!!


Some weights of the model checkpoint at facebook/wav2vec2-conformer-rel-pos-large were not used when initializing Wav2Vec2ConformerForPreTraining: ['wav2vec2_conformer.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2_conformer.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ConformerForPreTraining from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ConformerForPreTraining from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ConformerForPreTraining were not initialized from the model checkpoint at facebook/wav2vec2-conformer-rel-pos-large and are newly initialized: ['wav2vec2_conformer.encoder.pos_conv_embed.conv.parametrizations.weig

Load Done!!!
Model Successfully Loaded State Dictionary and Weights!!!!
Total steps: 26100


  self.pid = os.fork()


Eval wer: 6.769102445462855
