In [1]:
!pip install --upgrade pip
!pip install transformers
!pip install datasets>=1.18.3
!pip install librosa
!pip install jiwer
!pip install evaluate>=0.30
!pip install wandb


Collecting pip
  Downloading pip-24.0-py3-none-any.whl.metadata (3.6 kB)
Downloading pip-24.0-py3-none-any.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.3.2
    Uninstalling pip-23.3.2:
      Successfully uninstalled pip-23.3.2
Successfully installed pip-24.0
Collecting jiwer
  Downloading jiwer-3.0.3-py3-none-any.whl.metadata (2.6 kB)
Downloading jiwer-3.0.3-py3-none-any.whl (21 kB)
Installing collected packages: jiwer
Successfully installed jiwer-3.0.3


In [2]:
!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git@main
!pip install huggingface_hub
!pip install -U datasets huggingface-hub

Collecting datasets
  Downloading datasets-2.17.0-py3-none-any.whl.metadata (20 kB)
Collecting pyarrow>=12.0.0 (from datasets)
  Downloading pyarrow-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.0 kB)
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl.metadata (3.6 kB)
Collecting fsspec<=2023.10.0,>=2023.1.0 (from fsspec[http]<=2023.10.0,>=2023.1.0->datasets)
  Downloading fsspec-2023.10.0-py3-none-any.whl.metadata (6.8 kB)
Downloading datasets-2.17.0-py3-none-any.whl (536 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m536.6/536.6 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading fsspec-2023.10.0-py3-none-any.whl (166 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m166.4/166.4 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyarrow-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (38.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m38.

In [3]:
from huggingface_hub import notebook_login
notebook_login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [1]:

import os
import yaml
import logging
import random
import torch
import numpy as np


COLORS = {
    "yellow": "\x1b[33m",
    "blue": "\x1b[94m",
    "green": "\x1b[32m",
    "end": "\033[0m"
}


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def progress_bar(progress=0, status="", bar_len=20):
    status = status.ljust(30)
    if progress == 1:
        status = "{}".format(status.ljust(30))

    block = int(round(bar_len * progress))
    text = "\rProgress: [{}] {:.2f}% {}".format(
        COLORS['green'] + "="*(block-1) + ">" + COLORS['end'] + '-'*(bar_len-block), round(progress*100, 2), status
    )
    print(text, end="")

class AverageMeter:
    ''' Keeps track of metric statistics '''

    def __init__(self):
        self.reset()

    def reset(self):
        self.metrics = {}
        self.count = 0

    def add(self, metrics):
        if len(self.metrics) == 0:
            self.metrics = {key: value for key, value in metrics.items()}
            self.count += 1
        else:
            for key, value in metrics.items():
                if key in self.metrics.keys():
                    self.metrics[key] = (self.metrics[key] * self.count + value)/(self.count + 1)
                else:
                    raise KeyError(f'Metric key "{key}" not found')
            self.count += 1

    def return_metrics(self):
        metrics = {key: value for key, value in self.metrics.items()}
        return metrics

    def return_msg(self):
        metrics = self.return_metrics()
        msg = "".join(["[{}] {:.4f} ".format(key, value) for key, value in metrics.items()])
        return msg


class Logger:
    ''' For logging and sending messages to terminal '''

    def __init__(self, output_dir):
        # Reset logger and setup output file
        [logging.root.removeHandler(handler) for handler in logging.root.handlers[:]]
        logging.basicConfig(
            level = logging.INFO,
            format = "%(message)s",
            handlers = [logging.FileHandler(os.path.join(output_dir, "trainlogs.txt"))]
        )

    def show(self, msg, mode=""):
        if mode == 'info':
            print(f"{COLORS['yellow']}[INFO] {msg}{COLORS['end']}")
        elif mode == 'train':
            print(f"\n[TRAIN] {msg}")
        elif mode == 'val':
            print(f"\n{COLORS['blue']}[VALID] {msg}{COLORS['end']}")
        elif mode == 'test':
            print(f"\n{COLORS['green']}[TEST] {msg}{COLORS['end']}")
        else:
            print(f"{msg}")

    def write(self, msg, mode=''):
        if mode == "info":
            msg = f"[INFO] {msg}"
        elif mode == "train":
            msg = f"[TRAIN] {msg}"
        elif mode == "val":
            msg = f"[VALID] {msg}"
        elif mode == "test":
            msg = f"[TEST] {msg}"
        else:
            msg = f"{msg}"
        logging.info(msg)

    def record(self, msg, mode=''):
        self.show(msg, mode)
        self.write(msg, mode)


def open_config(config_dict):
    ''' Opens a configuration file '''

    # config = yaml.safe_load(open(file, 'r'))
    return config_dict


def init_experiment(args, seed=420):
    ''' Instantiates output file, loggers and random seeds '''

    # Set seeds
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # Some other stuff
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # open config
    config = open_config(args["config"])

    # Setup logging directory
    output_dir = os.path.join("./outputs", args["dataset"], args["task"], args["output"])
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    logger = Logger(output_dir)

    logger.show("Logging at {}".format(output_dir), mode="info")
    logger.show("-" * 50)
    logger.show("{:>25}".format("Configuration"))
    logger.show("-" * 50)
    logger.show(yaml.dump(config))
    logger.show("-" * 50)

    # write hyper params to seperate file
    with open(os.path.join(output_dir, "hyperparameters.txt"), "w") as logs:
        logs.write(yaml.dump(config))

    # setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        logger.show(f"Found device {torch.cuda.get_device_name(0)}", mode="info")

    return config, output_dir, logger, device


def print_network(model, name=""):
    """
    Pretty prints the model.
    """
    print(name.rjust(35))
    print("-" * 70)
    print("{:>25} {:>27} {:>15}".format("Layer.Parameter", "Shape", "Param"))
    print("-" * 70)

    for param in model.state_dict():
        p_name = param.split(".")[-2] + "." + param.split(".")[-1]
        if p_name[:2] != "BN" and p_name[:2] != "bn":  # Not printing batch norm layers
            print(
                "{:>25} {:>27} {:>15}".format(
                    p_name,
                    str(list(model.state_dict()[param].squeeze().size())),
                    "{0:,}".format(np.product(list(model.state_dict()[param].size()))),
                )
            )
    print("-" * 70 + "\n")

In [2]:

import re
import json
import torch
import string
import librosa
import datasets
import soundfile
import transformers
import numpy as np


def remove_punctuation_and_lower(texts):
    punctuation = re.sub(r"\'", r"", string.punctuation)
    for i in range(len(texts)):
        texts[i] = texts[i].translate(str.maketrans("", "", punctuation)).upper()
    return texts

def create_vocabulary_file(texts):
    vocab_list = list(set(" ".join(texts)))
    vocab_dict = {v: k for k, v in enumerate(vocab_list)}
    vocab_dict["|"] = vocab_dict.pop(" ")
    vocab_dict["[UNK]"] = len(vocab_dict)
    vocab_dict["[PAD]"] = len(vocab_dict)
    with open("./med_asr_vocab.json", "w") as f:
        json.dump(vocab_dict, f)

def process_med_asr_dataset(read_limit=2500):
    med_asr_test = datasets.load_dataset("yashtiwari/PaulMooney-Medical-ASR-Data", split = "train[:10%]")
    med_asr_train = datasets.load_dataset("yashtiwari/PaulMooney-Medical-ASR-Data", split = "train[:60%]")
    med_asr_test = med_asr_test.remove_columns(["prompt", "speaker_id", "id"])
    med_asr_train = med_asr_train.remove_columns(["prompt", "speaker_id", "id"])
    train_files, train_text = med_asr_train["path"][:read_limit], med_asr_train["sentence"][:read_limit]
    test_files, test_text = med_asr_test["path"][:read_limit], med_asr_test["sentence"][:read_limit]
    train_text = remove_punctuation_and_lower(train_text)
    test_text = remove_punctuation_and_lower(test_text)
    create_vocabulary_file(train_text + test_text)
    return {"file": train_files, "text": train_text}, {"file": test_files, "text": test_text}


class MedASRDataloader:

    def __init__(self, data, batch_size):
        self.processor = transformers.Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
        self.files, self.text = data["file"], data["text"]
        self.batch_size = batch_size
        self.ptr = 0

    def __len__(self):
        return len(self.files) // self.batch_size

    def flow(self):
        speech, text = [], []
        for _ in range(self.batch_size):
            signal, sr = np.array(self.files[self.ptr]["array"]), self.files[self.ptr]["sampling_rate"]
            speech.append(signal)
            text.append(self.text[self.ptr])
            self.ptr += 1
            if self.ptr >= len(self.files):
                self.ptr = 0

        inputs = self.processor(speech, sampling_rate=16000, padding=True, return_attention_mask=True, return_tensors="pt")
        input_data, input_attention = inputs["input_values"], inputs["attention_mask"]
        with self.processor.as_target_processor():
            labels = self.processor(text, padding=True, return_tensors="pt")
            targets, attention_mask = labels["input_ids"], labels["attention_mask"]
            targets = targets.masked_fill(attention_mask.ne(1), -100)
        return input_data, input_attention, targets

    def generate_from_file(self, file_path):
        signal, sr = soundfile.read(file_path, dtype="float32")
        signal = librosa.resample(np.mean(signal, axis=1), orig_sr=sr, target_sr=16000)
        inputs = self.processor(signal, sampling_rate=16000, return_attention_mask=True, return_tensors="pt")
        input_data, input_attention = inputs["input_values"], inputs["attention_mask"]
        return input_data, input_attention


def get_dataloaders(batch_size, read_limit=2500):
    train_data, test_data = process_med_asr_dataset(read_limit=read_limit)
    train_loader = MedASRDataloader(train_data, batch_size)
    test_loader = MedASRDataloader(test_data, batch_size)
    return train_loader, test_loader



if __name__ == "__main__":

    train_data, test_data = process_med_asr_dataset()
    train_loader = MedASRDataloader(train_data, batch_size=4)

    inputs, input_attention, targets = train_loader.flow()
    print(inputs)
    print(input_attention)
    print(targets)

2024-02-18 18:25:35.463086: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-18 18:25:35.463218: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-18 18:25:35.633434: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


tensor([[ 0.1572,  0.1572,  0.1587,  ...,  0.0000,  0.0000,  0.0000],
        [-0.7458, -0.7419, -0.7723,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.0691,  1.0712,  1.0937,  ...,  0.0000,  0.0000,  0.0000],
        [-0.2316, -0.2057, -0.1885,  ...,  0.3391,  0.3897,  0.4607]])
tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1]], dtype=torch.int32)
tensor([[  10,    4,   11,    7,   25,    5,    4,   12,   11,    8,   16,   15,
           14,    5,   13,    4,   23,    7,   10,    9,    4,   18,   11,    5,
            9,    4,   10,    4,    6,   13,   22,    4,    6,    8,    4,   19,
            7,   13,   13,   22,    4,   17,   22,    4,   21,   13,    8,   19,
            5,   13,   10,    5,   12, -100, -100, -100, -100, -100],
        [  10,    4,   11,    7,   25,    5,    4,   23,    7,   10,    9,    4,
           15,   10,   26,    5,    4,    9,    5,    5,   14,   15,    5,   12,
       



In [3]:

import os
import torch
import wandb
import datasets
import numpy as np
import transformers
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


class SpeechRecognitionModel(nn.Module):

    def __init__(self, processor):
        super().__init__()
        self.model = transformers.Wav2Vec2ForCTC.from_pretrained(
            "facebook/wav2vec2-base-960h",
            gradient_checkpointing = True,
            ctc_loss_reduction = "mean",
            pad_token_id = processor.tokenizer.pad_token_id
        )
        self.model.freeze_feature_extractor()

    def forward(self, inputs, input_attention, targets):
        output = self.model(input_values=inputs, attention_mask=input_attention, labels=targets)
        return output.loss, output.logits


class Trainer:

    def __init__(self, args):
        self.args = args
        self.config, self.output_dir, self.logger, _ = init_experiment(args)
        self.device = torch.device(self.args["device"]) if args["device"] in ["cpu", "cuda"] else torch.device("cpu")
        self.train_loader, self.val_loader = get_dataloaders(
            batch_size = self.config["data"]["batch_size"], read_limit = self.config["data"]["read_limit"])

        self.model = SpeechRecognitionModel(processor=self.train_loader.processor).to(self.device)
        if self.args["task"] == "train":
            if type(self.device) == "cpu":
                self.logger.show("\nTraining model on CPU! I sure hope you know what you are doing", mode='info')
            self.optim = optim.SGD(self.model.parameters(), lr=self.config["model"]["optim_lr"], weight_decay=0.005, momentum=0.9)
            self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optim, T_max=self.config["epochs"]-self.config["warmup_epochs"], eta_min=0.0, last_epoch=-1)
            self.warmup_epochs = self.config.get("warmup_epochs", 0)

            if self.warmup_epochs > 0:
                self.warmup_rate = (self.config["model"]["optim_lr"] - 1e-12) / self.warmup_epochs

            self.done_epochs = 1
            self.metric_best = np.inf
            run = wandb.init(project="assignment-asr-prototype")
            self.logger.write(f"Wandb run: {run.get_url()}", mode='info')
        else:
            if args["load"] is not None:
                self.load_model(args["load"])

    def compute_word_error_rate(self, loader):
        wer_values, preds, trgs = [], [], []
        metric = datasets.load_metric("wer")
        for idx in range(len(loader)):
            inputs, input_mask, targets = loader.flow()
            inputs, input_mask, targets = inputs.to(self.device), input_mask.to(self.device), targets.to(self.device)
            with torch.no_grad():
                loss, logits = self.model(inputs, input_mask, targets)

            predictions = F.softmax(logits, dim=-1).argmax(dim=-1).detach().cpu().numpy()
            targets = targets.detach().cpu().numpy()
            targets[targets == -100] = self.train_loader.processor.tokenizer.pad_token_id
            pred_str = self.train_loader.processor.batch_decode(predictions)
            target_str = self.train_loader.processor.batch_decode(targets, group_tokens=False)
            wer_values.append(metric.compute(predictions=pred_str, references=target_str))
            preds.extend(pred_str), trgs.extend(target_str)
            progress_bar(status="", progress=(idx+1)/len(loader))

        progress_bar(status="[WER] {:.4f}".format(np.mean(wer_values)), progress=1.0)
        return np.mean(wer_values), preds, trgs

    def train_on_batch(self, batch):
        self.model.train()
        inputs, input_mask, targets = batch
        inputs, input_mask, targets = inputs.to(self.device), input_mask.to(self.device), targets.to(self.device)
        loss, logits = self.model(inputs, input_mask, targets)
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
        return {"CTC loss": loss.item()}

    def infer_on_batch(self, batch):
        self.model.eval()
        inputs, input_mask, targets = batch
        inputs, input_mask, targets = inputs.to(self.device), input_mask.to(self.device), targets.to(self.device)
        with torch.no_grad():
            loss, logits = self.model(inputs, input_mask, targets)
        return {"CTC loss": loss.item()}

    def save_model(self, epoch, metric):
        state = {
            "model": self.model.state_dict(),
            "optim": self.optim.state_dict(),
            "scheduler": self.scheduler.state_dict(),
            "metric": metric,
            "epoch": epoch
        }
        torch.save(state, os.path.join(self.output_dir, "best_model.pt"))

    def load_model(self, path):
        if not os.path.exists(os.path.join(self.args["load"], "best_model.pt")):
            raise NotImplementedError(f"Could not find saved model 'best_model.pt' at {self.args['load']}")
        else:
            state = torch.load(os.path.join(self.args["load"], "best_model.pt"), map_location=self.device)
            self.model.load_state_dict(state["model"])
            self.logger.show(f"Successfully loaded model from {path}", mode='info')

    def adjust_learning_rate(self, epoch):
        if epoch < self.warmup_epochs:
            for group in self.optim.param_groups:
                group["lr"] = 1e-12 + epoch * self.warmup_rate
        else:
            self.scheduler.step()

    def get_test_performance(self):
        test_meter = AverageMeter()
        for idx in range(len(self.val_loader)):
            batch = self.val_loader.flow()
            test_metrics = self.infer_on_batch(batch)
            test_meter.add(test_metrics)
            progress_bar(status=test_meter.return_msg(), progress=(idx+1)/len(self.val_loader))

        progress_bar(status=test_meter.return_msg(), progress=1.0)
        self.logger.record("Computing WER", mode='test')
        test_wer, preds, trgs = self.compute_word_error_rate(self.val_loader)
        self.logger.record(test_meter.return_msg() + " [WER] {:.4f}".format(test_wer), mode="test")
        print("\n\nSample predictions")
        print("============================================================")
        for i in np.random.choice(np.arange(len(preds)), size=10, replace=False):
            print("Target     : {}".format(trgs[i]))
            print("Prediction : {}".format(preds[i]))
            print("--------------------------------------------------------")

    def predict_for_file(self, file_path):
        inputs, input_mask = self.train_loader.generate_from_file(file_path)
        inputs, input_mask = inputs.to(self.device), input_mask.to(self.device)
        with torch.no_grad():
            logits = self.model.model(inputs, attention_mask=input_mask).logits
            predictions = F.softmax(logits, dim=-1).argmax(dim=-1).detach().cpu().numpy()
        pred_str = self.train_loader.processor.batch_decode(predictions)
        print("\nPrediction: {}".format(pred_str))
        return pred_str

    def train(self):
        print()
        for epoch in range(max(1, self.done_epochs), self.config["epochs"]+1):
            self.logger.record(f"Epoch {epoch}/{self.config['epochs']}", mode="train")
            train_meter = AverageMeter()

            for idx in range(len(self.train_loader)):
                batch = self.train_loader.flow()
                train_metrics = self.train_on_batch(batch)
                train_meter.add(train_metrics)
                wandb.log({"Train CTC loss": train_metrics["CTC loss"]})
                progress_bar(status=train_meter.return_msg(), progress=(idx+1)/len(self.train_loader))

            progress_bar(status=train_meter.return_msg(), progress=1.0)
            self.logger.record(f"Epoch {epoch}/{self.config['epochs']} Computing WER", mode='train')
            train_wer, _, _ = self.compute_word_error_rate(self.train_loader)
            wandb.log({"Train WER": train_wer, "Epoch": epoch})
            self.logger.write(train_meter.return_msg() + f" [WER] {round(train_wer, 4)}", mode="train")
            self.adjust_learning_rate(epoch)

            if epoch % self.config["eval_every"] == 0:
                self.logger.record(f"Epoch {epoch}/{self.config['epochs']}", mode='val')
                val_meter = AverageMeter()

                for idx in range(len(self.val_loader)):
                    batch = self.val_loader.flow()
                    val_metrics = self.infer_on_batch(batch)
                    val_meter.add(val_metrics)
                    progress_bar(status=val_meter.return_msg(), progress=(idx+1)/len(self.val_loader))

                progress_bar(status=val_meter.return_msg(), progress=1.0)
                self.logger.record(f"Epoch {epoch}/{self.config['epochs']} Computing WER", mode='val')
                val_wer, _, _ = self.compute_word_error_rate(self.val_loader)
                wandb.log({"Val CTC loss": val_meter.return_metrics()["CTC loss"], "Val WER": val_wer, "Epoch": epoch})
                self.logger.write(val_meter.return_msg() + f" [WER] {round(val_wer, 4)}", mode='val')

                if val_wer < self.metric_best:
                    self.metric_best = val_wer
                    self.save_model(epoch, val_wer)

        print()
        self.logger.record("Training complete! Generating test predictions...", mode='info')


In [4]:

import os
import time
import argparse
from datetime import datetime as dt


if __name__ == "__main__":

    config = {
        "epochs": 10,
        "warmup_epochs": 2,
        "eval_every": 1,
        "data": {
            "batch_size": 8,
          "read_limit": 5000
        },
        "model":{
          "optim_lr": 0.0001
        }
    }

    # ap = argparse.ArgumentParser()
    # ap.add_argument("-c", "--config", required=True, help="Path to configuration file")
    # ap.add_argument("-d", "--device", default="cpu", type=str, help="Whether to perform task on CPU ('cpu') or GPU ('cuda')")
    # ap.add_argument("-a", "--dataset", default="med_asr", type=str, help="Name of dataset finetuned on")
    # ap.add_argument("-t", "--task", default="train", type=str, help="Task to perform. Choose between ['train', 'test']")
    # ap.add_argument("-o", "--output", default=dt.now().strftime("%d-%m-%Y-%H-%M"), type=str, help="Output directory path")
    # ap.add_argument("-l", "--load", default=None, type=str, help="Path to directory containing checkpoint as best_model.pt")
    # ap.add_argument("-f", "--file", default="test.wav", type=str, help="Path to single testing file")
    # args = vars(ap.parse_args())
    args = {"config": config, "device": "cuda", "dataset": "med_asr", "test": "train", "output": dt.now().strftime("%d-%m-%Y-%H-%M"), "load": None, "file": "test.wav"}
    trainer = Trainer(args)

    if args["task"] == "train":
        trainer.train()

    elif args["task"] == "test":
        assert args["load"] is not None, "Please provide a checkpoint to load using --load to check test performance"
        trainer.get_test_performance()

    elif args["task"] == "single_test":
        assert os.path.exists(args["file"]), "No wav file found at path provided"
        assert args["load"] is not None, "Please provide a checkpoint to load using --load to check test performance"
        trainer.predict_for_file(args["file"])

    else:
        raise ValueError(f"Unrecognized argument passed to --task: {args['task']}")

[33m[INFO] Logging at ./outputs/med_asr/train/18-02-2024-18-25[0m
--------------------------------------------------
            Configuration
--------------------------------------------------
data:
  batch_size: 8
  read_limit: 5000
epochs: 10
eval_every: 1
model:
  optim_lr: 0.0001
warmup_epochs: 2

--------------------------------------------------
[33m[INFO] Found device Tesla T4[0m


model.safetensors:   0%|          | 0.00/378M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You sho

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc




[TRAIN] Epoch 1/10




[TRAIN] Epoch 1/10 Computing WER


  metric = datasets.load_metric("wer")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/1.90k [00:00<?, ?B/s]



[94m[VALID] Epoch 1/10[0m
[94m[VALID] Epoch 1/10 Computing WER[0m


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


[TRAIN] Epoch 2/10




[TRAIN] Epoch 2/10 Computing WER


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


[94m[VALID] Epoch 2/10[0m
[94m[VALID] Epoch 2/10 Computing WER[0m


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


[TRAIN] Epoch 3/10




[TRAIN] Epoch 3/10 Computing WER


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


[94m[VALID] Epoch 3/10[0m
[94m[VALID] Epoch 3/10 Computing WER[0m


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


[TRAIN] Epoch 4/10




[TRAIN] Epoch 4/10 Computing WER


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


[94m[VALID] Epoch 4/10[0m
[94m[VALID] Epoch 4/10 Computing WER[0m


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


[TRAIN] Epoch 5/10




[TRAIN] Epoch 5/10 Computing WER


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


[94m[VALID] Epoch 5/10[0m
[94m[VALID] Epoch 5/10 Computing WER[0m


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


[TRAIN] Epoch 6/10




[TRAIN] Epoch 6/10 Computing WER


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


[94m[VALID] Epoch 6/10[0m
[94m[VALID] Epoch 6/10 Computing WER[0m


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


[TRAIN] Epoch 7/10




[TRAIN] Epoch 7/10 Computing WER


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


[94m[VALID] Epoch 7/10[0m
[94m[VALID] Epoch 7/10 Computing WER[0m


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


[TRAIN] Epoch 8/10




[TRAIN] Epoch 8/10 Computing WER


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


[94m[VALID] Epoch 8/10[0m
[94m[VALID] Epoch 8/10 Computing WER[0m


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


[TRAIN] Epoch 9/10




[TRAIN] Epoch 9/10 Computing WER


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


[94m[VALID] Epoch 9/10[0m
[94m[VALID] Epoch 9/10 Computing WER[0m


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


[TRAIN] Epoch 10/10




[TRAIN] Epoch 10/10 Computing WER


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


[94m[VALID] Epoch 10/10[0m
[94m[VALID] Epoch 10/10 Computing WER[0m


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


[33m[INFO] Training complete! Generating test predictions...[0m


In [8]:
args = {"config": config, "device": "cuda", "dataset": "med_asr", "task": "test", "output": dt.now().strftime("%d-%m-%Y-%H-%M"), "load": "/kaggle/working/outputs/med_asr/train/18-02-2024-18-25/", "file": "test.wav"}
trainer = Trainer(args)

if args["task"] == "train":
    trainer.train()

elif args["task"] == "test":
    assert args["load"] is not None, "Please provide a checkpoint to load using --load to check test performance"
    trainer.get_test_performance()

elif args["task"] == "single_test":
    assert os.path.exists(args["file"]), "No wav file found at path provided"
    assert args["load"] is not None, "Please provide a checkpoint to load using --load to check test performance"
    trainer.predict_for_file(args["file"])

else:
    raise ValueError(f"Unrecognized argument passed to --task: {args['task']}")

[33m[INFO] Logging at ./outputs/med_asr/test/18-02-2024-19-09[0m
--------------------------------------------------
            Configuration
--------------------------------------------------
data:
  batch_size: 8
  read_limit: 5000
epochs: 10
eval_every: 1
model:
  optim_lr: 0.0001
warmup_epochs: 2

--------------------------------------------------
[33m[INFO] Found device Tesla T4[0m


Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You sho

[33m[INFO] Successfully loaded model from /kaggle/working/outputs/med_asr/train/18-02-2024-18-25/[0m
[32m[TEST] Computing WER[0m


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


[32m[TEST] [CTC loss] 2.7950  [WER] 0.9948[0m


Sample predictions
Target     : I AM ALWAYS COLD EVEN WHEN I AM WEARING LAYERS
Prediction : 
--------------------------------------------------------
Target     : IT IS LIKE I HAVE A NEEDLE PUSHING THROUGH MY HEART
Prediction : 
--------------------------------------------------------
Target     : THERE IS SO MUCH PAIN WHEN I MOVE MY ARM
Prediction : ATEES
--------------------------------------------------------
Target     : I HAVE A GREAT STOMACH ACHE AND I CAN'T EAT ANY THING
Prediction : G
--------------------------------------------------------
Target     : IN THE MORNING MY RESPIRATION IS LOUD
Prediction : 
--------------------------------------------------------
Target     : I HAVE A SHARP PAIN IN MY ABDOMEN
Prediction : 
--------------------------------------------------------
Target     : I CAN'T REALLY JUMP ON MY LEFT FOOT BECAUSE MY TRIPLE FRACTURE OF THE ANKLE LEFT ME WITH NEVERENDING PAINS
Prediction : AAAAAAEAAAAAAA
-------