In [3]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader, random_split
from rich.progress import (
    Progress,
    TextColumn,
    BarColumn,
    TimeRemainingColumn,
    MofNCompleteColumn,
)
from transformers import Wav2Vec2Model
import torch.nn.functional as F
import wandb
from torchmetrics import MeanSquaredError

torch.manual_seed(42)

<torch._C.Generator at 0x2854cb56f70>

In [4]:
# from kaggle_secrets import UserSecretsClient
# user_secrets = UserSecretsClient()
# secret_value_0 = user_secrets.get_secret("WANDB_API_KEY")
# wandb.login(key = secret_value_0)

In [5]:
class AudioGrammarDataset(Dataset):
    def __init__(self, data_dir, metadata_path, is_test=False, max_length=1000000):
        self.df = pd.read_csv(metadata_path)
        self.audio_files = [
            os.path.join(data_dir, file) for file in self.df["filename"]
        ]
        self.is_test = is_test

        if not self.is_test:
            self.labels = self.df["label"]

        self.max_length = (
            max_length  # Max length in samples (16kHz * 60 seconds = 960000)
        )

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]

        waveform, sample_rate = torchaudio.load(audio_path)

        waveform = torch.mean(waveform, dim=0, keepdim=True)

        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(sample_rate, 16000)
            waveform = resampler(waveform)

        waveform = waveform.squeeze(0)

        if waveform.shape[0] > self.max_length:
            waveform = waveform[: self.max_length]
            attention_mask = torch.ones(waveform.shape)
        else:
            padding = torch.zeros(self.max_length - waveform.shape[0])
            attention_mask = torch.ones(waveform.shape)
            waveform = torch.cat([waveform, padding])
            attention_mask = torch.cat([attention_mask, padding])
        if not self.is_test:
            label = self.labels[idx]

            return {
                "raw_waveform": waveform,
                "attention_mask": attention_mask,
                "label": torch.FloatTensor([label]),
            }
        else:
            return {"raw_waveform": waveform, "attention_mask": attention_mask}

In [6]:
class Wav2Vec2GrammarScoring(nn.Module):
    def __init__(self, pretrained_model_name="facebook/wav2vec2-base-960h"):
        super(Wav2Vec2GrammarScoring, self).__init__()

        self.wav2vec2 = Wav2Vec2Model.from_pretrained(pretrained_model_name)

        hidden_size = self.wav2vec2.config.hidden_size

        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, input_values, attention_mask):
        outputs = self.wav2vec2(input_values, attention_mask)

        hidden_states = outputs.last_hidden_state
        pooled_output = torch.mean(hidden_states, dim=1)

        score = 5 * self.classifier(pooled_output)

        return score

    def freeze_feature_encoder(self):
        for param in self.wav2vec2.feature_extractor.parameters():
            param.requires_grad = False

    def unfreeze_feature_encoder(self):
        for param in self.wav2vec2.feature_extractor.parameters():
            param.requires_grad = True

    def freeze_base_model(self):
        for param in self.wav2vec2.parameters():
            param.requires_grad = False

    def unfreeze_transformer_layers(self, num_layers=4):
        self.freeze_base_model()
        # Unfreeze the last num_layers transformer layers
        for i in range(
            len(self.wav2vec2.encoder.layers) - num_layers,
            len(self.wav2vec2.encoder.layers),
        ):
            for param in self.wav2vec2.encoder.layers[i].parameters():
                param.requires_grad = True

In [None]:
class Trainer:
    def __init__(self, config, model = None):
        self.config = config
        self.device = torch.device(config["device"])
        self._prepare_dataloaders()

        self.criterion = torch.nn.MSELoss()

        self.rmse_metric = MeanSquaredError(squared=False).to(self.device)
        
        if model:
            self.model = model.to(device)
        else:
            self._prepare_model()

        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=self.config["learning_rate"],
            weight_decay=self.config["weight_decay"],
        )

        self.scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer, step_size=1, gamma=0.95
        )
        wandb.init(project="wav2vec2-grammar", config=config)
        
        self.history = {
            "train_loss": [],
            "train_rmse": [],
            "val_loss": [],
            "val_rmse": [],
        }
        
        self.best_val_rmse = float("inf")

    def _prepare_dataloaders(self):
        dataset = AudioGrammarDataset(
            data_dir="/kaggle/input/shl-intern-hiring-assessment/dataset/audios_train", 
            metadata_path="/kaggle/input/shl-intern-hiring-assessment/dataset/train.csv"
        )

        train_size = int((1 - self.config["val_size"]) * len(dataset))
        val_size = len(dataset) - train_size
        train_data, val_data = random_split(dataset, [train_size, val_size])

        self.train_loader = DataLoader(
            train_data, batch_size=self.config["batch_size"], shuffle=True
        )
        self.val_loader = DataLoader(
            val_data, batch_size=self.config["batch_size"], shuffle=False
        )

        print("Train and Val dataloaders Prepared !!")

    def _prepare_model(self):
        self.model = Wav2Vec2GrammarScoring(
            self.config["pretrained_model_name"],
        ).to(self.device)
        self.model.freeze_feature_encoder()
        print("Model Initialized !!")

    def train(self):
        step = 0
        running_loss = 0.0

        for epoch in range(self.config["num_epochs"]):
            self.model.train()

            self.rmse_metric.reset()

            with Progress(
                TextColumn("[bold blue]{task.description}"),
                BarColumn(),
                MofNCompleteColumn(),
                TimeRemainingColumn(),
                TextColumn("• Loss: {task.fields[loss]}", justify="right"),
                transient=True,
            ) as progress:
                train_task = progress.add_task(
                    f"Epoch {epoch + 1}/{self.config['num_epochs']} [bold white on blue]TRAIN[/]",
                    total=len(self.train_loader),
                    loss="0.0000"
                )
                for batch in self.train_loader:
                    waveforms = batch["raw_waveform"].to(self.device)
                    attention_mask = batch["attention_mask"].to(self.device)
                    labels = batch["label"].to(self.device).squeeze(1)  # [B, 1]

                    self.optimizer.zero_grad()

                    outputs = self.model(waveforms, attention_mask)  # Shape: [B, 1]
                    loss = self.criterion(outputs.squeeze(1), labels)

                    loss.backward()
                    self.optimizer.step()
                    step += 1
                    running_loss += loss.item()
                    self.rmse_metric.update(outputs.squeeze(1), labels)

                    progress.update(
                        train_task,
                        advance=1,
                        loss=f"{(running_loss) / (step):.4f}"
                    )
                    
                    wandb.log({
                        "train_step_loss": (running_loss) / (step)
                    }, step = step)
                    
                    if step % 100 == 0:
                        train_loss = (running_loss) / (step)
                        train_rmse = self.rmse_metric.compute().item()
            
                        val_loss, val_rmse = self.evaluate(epoch, progress)
            
                        self.history["train_loss"].append(train_loss)
                        self.history["train_rmse"].append(train_rmse)
                        self.history["val_loss"].append(val_loss)
                        self.history["val_rmse"].append(val_rmse)
            
                        wandb.log(
                            {
                                "epoch": epoch + 1,
                                "train_rmse": train_rmse,
                                "val_loss": val_loss,
                                "val_rmse": val_rmse,
                            },
                            step=step,
                        )
            
                        if val_rmse < self.best_val_rmse:
                            self.best_val_rmse = val_rmse
                            model_path = "best_model.pt"
                            torch.save(self.model.state_dict(), model_path)
                            wandb.save(model_path)
                            self._log_model_as_artifact(model_path)
            
                        print(
                            f"[Epoch {epoch + 1}] [Step {step}] Train Loss: {train_loss:.4f}, Train RMSE: {train_rmse:.4f} | Val Loss: {val_loss:.4f}, Val RMSE: {val_rmse:.4f}"
                        )

            self.scheduler.step()

        wandb.finish()
        print(" Training Completed !!")

    def evaluate(self, epoch, progress):
        self.model.eval()
        val_loss = 0.0
        self.rmse_metric.reset()

        with torch.no_grad():
            val_task = progress.add_task(
                f"Epoch {epoch + 1}/{self.config['num_epochs']} [bold white on green]VAL[/]", 
                total=len(self.val_loader), 
                loss="0.0000",
            )
            for i,batch in enumerate(self.val_loader):
                waveforms = batch["raw_waveform"].to(self.device)
                attention_mask = batch["attention_mask"].to(self.device)
                labels = batch["label"].to(self.device).squeeze(1)
    
                outputs = self.model(waveforms, attention_mask)
                loss = self.criterion(outputs.squeeze(1), labels)
                val_loss += loss.item()
                self.rmse_metric.update(outputs.squeeze(1), labels)
    
                progress.update(
                    val_task, advance=1, loss=f"{(val_loss) / (i+1):.4f}"
                )

        return val_loss / len(
            self.val_loader
        ), self.rmse_metric.compute().item()

    def _log_model_as_artifact(self, model_path):
        artifact = wandb.Artifact(f"best_model", type="model")
        artifact.add_file(model_path)
        wandb.log_artifact(artifact)

In [None]:
config = {
    "batch_size": 2,
    "num_epochs": 8,
    "learning_rate": 3e-5,
    "weight_decay": 1e-5,
    "val_size": 0.2,
    "random_state": 42,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "pretrained_model_name": "facebook/wav2vec2-base-960h",
}

In [None]:
trainer = Trainer(config=config)

In [None]:
history, model = trainer.train()

In [None]:
class Inference:
    def __init__(self, config):
        self.config = config
        self.device = torch.device(config["device"] if config["device"] else "cpu")
        self.scores = []
        self.model = Wav2Vec2GrammarScoring(config["pretrained_model_name"]).to(
            self.device
        )

        self.model.load_state_dict(
            torch.load(config["model_path"], map_location=self.device)
        )
        self.model.eval()

        self.sampling_rate = config.get("sampling_rate", 16000)
        print("Inference model loaded and ready!")

        self.preprocess()

    def preprocess(self):
        test_dataset = AudioGrammarDataset(
            data_dir=config["test_audio_dir"],
            metadata_path=config["test_metadata_path"],
            is_test=True,
        )

        self.test_loader = DataLoader(
            test_dataset, batch_size=self.config["batch_size"]
        )

        print("Test_loader initialized !!")

    def predict(self):
        
        for batch in self.test_loader:
            waveform = batch["raw_waveform"].to(self.device)
            attention_mask = batch["attention_mask"].to(self.device)

            with torch.no_grad():
                output = self.model(waveform, attention_mask)
                batch_scores = output.squeeze(1).cpu().numpy()

                batch_scores = np.clip(batch_scores, 0.0, 5.0)

                self.scores.append(batch_scores)

        return np.concatenate(self.scores, axis=0)


In [34]:
config = {
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "pretrained_model_name": "facebook/wav2vec2-base",
    "sampling_rate": 16000,
    "model_path": "best_model.pt",
    "test_metadata_path": "./dataset/test.csv",
    "batch_size": 8,
    "test_audio_dir": "./dataset/audios_test",
}

In [35]:
inference_engine = Inference(config=config)



Inference model loaded and ready!
Test_loader initialized !!


In [None]:
scores = inference_engine.predict()

In [None]:
len(scores)

3

In [None]:
pd.to