In [None]:
import sys, os, random, math, time
import math
import logging, warnings, csv, base64, gzip
import json, datetime, numpy as np, torch, torch.nn as nn
import torch.optim as optim, torch.nn.functional as F, torchaudio
import torchaudio.transforms as transforms, torch.utils.checkpoint as checkpoint
import torch.utils.tensorboard as tensorboard, torch.optim.lr_scheduler as lr_scheduler
import transformers, neologdn, evaluate, MeCab, deepl, logging, datasets, tqdm, whisper
import transformers.utils.logging
from datasets import load_from_disk, load_dataset
from contextlib import contextmanager
from dataclasses import dataclass
from torch.utils.data import Subset, IterableDataset
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from tqdm import tqdm
from torch.profiler import profile, ProfilerActivity, record_function
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
from torch import amp, Tensor
from torch.optim import Adamax
import logging
from safetensors import safe_open
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.trainer_pt_utils import IterableDatasetShard
from transformers.trainer_utils import is_main_process
from transformers.trainer_pt_utils import find_batch_size, get_parameter_names
from transformers import (
    TrainerState,
    TrainerControl,
    logging,
    Trainer,
    TrainingArguments,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    PretrainedConfig,
    GenerationConfig,
    WhisperFeatureExtractor,
    WhisperProcessor,
    WhisperForConditionalGeneration,
    WhisperTokenizerFast,
    WhisperTokenizer,
    WhisperModel,
    WhisperConfig,
    Adafactor,
    TrainerCallback,
    logging
)
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.amp import GradScaler, autocast
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
warnings.filterwarnings(action="ignore")
warnings.warn = lambda *args, **kwargs: None

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
checkpointing_args = {"reentrant": False}

try:
    from torch.nn.functional import scaled_dot_product_attention
    SDPA_AVAILABLE = True
except (ImportError, RuntimeError, OSError):
    scaled_dot_product_attention = None
    SDPA_AVAILABLE = False

from whisper.decoding import decode as decode_function
from whisper.decoding import detect_language as detect_language_function
from whisper.transcribe import transcribe as transcribe_function

from torch.utils.checkpoint import checkpoint

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mecab = MeCab.Tagger("-Owakati")

transformers.utils.logging.set_verbosity_error()


In [None]:
class BiasedCrossAttention(nn.Module):
    def __init__(self, n_state, n_head, dropout_rate=0.1):
        super().__init__()
        self.n_head = n_head
        self.n_state = n_state
        self.head_dim = n_state // n_head

        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)

        self.bias = nn.Parameter(torch.zeros(n_head, 1, self.head_dim))
        self.dropout = nn.Dropout(dropout_rate)
        self.norm = LayerNorm(n_state)
        
    def forward(self, q, k, v, mask=None):
        batch_size, seq_length, _ = q.size()

        q = self.query(q).view(batch_size, seq_length, self.n_head, self.head_dim)
        k = self.key(k).view(batch_size, seq_length, self.n_head, self.head_dim)
        v = self.value(v).view(batch_size, seq_length, self.n_head, self.head_dim)

        qk = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) + self.bias
        if mask is not None:
            qk = qk.masked_fill(mask == 0, float('-inf'))

        w = F.softmax(qk, dim=-1)
        w = self.dropout(w)

        out = (w @ v).transpose(1, 2).contiguous().view(batch_size, seq_length, -1)
        out = self.norm(self.out(out) + q.view(batch_size, seq_length, -1))
        return out

class Conv1d(nn.Conv1d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')
        if self.bias is not None:
            nn.init.zeros_(self.bias)

    def _conv_forward(self, x, weight, bias) -> Tensor:
        weight = self.weight.to(x.dtype)
        bias = None if self.bias is None else self.bias.to(x.dtype)
        return super()._conv_forward(x, weight, bias)
    
class LearnedSinusoidalEmbeddings(nn.Module):
    def __init__(self, n_ctx, n_state, checkpointing=False):
        super().__init__()
        self.n_ctx = n_ctx
        self.n_state = n_state
        self.checkpointing = checkpointing

        position = torch.arange(0, n_ctx, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, n_state, 2).float() * -(math.log(10000.0) / n_state))
        features = torch.zeros(n_ctx, n_state)
        features[:, 0::2] = torch.sin(position * div_term)
        features[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('sinusoidal_features', features)

        self.positional_embeddings = nn.Parameter(self.sinusoidal_features.clone())

    def forward(self, positions):
        if self.checkpointing:
            position_embeddings = checkpoint(lambda x: self.positional_embeddings[x], positions)
        else:
            position_embeddings = self.positional_embeddings[positions]

        position_embeddings = torch.nn.functional.normalize(position_embeddings, p=2, dim=-1)  
        return position_embeddings

class DynamicConvAttention(nn.Module):
    def __init__(self, n_state, n_head, kernel_size=3, dropout_rate=0.1):
        super().__init__()
        self.n_state = n_state
        self.n_head = n_head
        self.kernel_size = kernel_size

        self.conv = nn.Conv1d(n_state, n_state, kernel_size, padding=kernel_size // 2, groups=n_head)
        self.dropout = nn.Dropout(dropout_rate)

        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.out_proj = nn.Linear(n_state, n_state)

        self.norm = LayerNorm(n_state)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()
        if embed_dim != self.n_state:
            raise ValueError(f"Expected embed_dim of {self.n_state}, but got {embed_dim}")

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        x = x.permute(0, 2, 1)
        conv_out = self.conv(x)
        conv_out = conv_out.permute(0, 2, 1)
        conv_out = self.norm(conv_out)
        conv_out = self.dropout(conv_out)

        attention_out = F.softmax(torch.matmul(q, k.transpose(-2, -1)) / (self.n_state ** 0.5), dim=-1)
        attention_out = torch.matmul(attention_out, v)
        
        combined_out = conv_out + attention_out
        combined_out = self.norm(combined_out)
        
        return self.out_proj(self.dropout(combined_out)) + x.permute(0, 2, 1)

class HybridAttention(nn.Module):
    def __init__(self, n_state, n_head, window_size=1, dropout_rate=0.1):
        super().__init__()
        self.local_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)
        self.global_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)
        self.ln_local = LayerNorm(n_state)
        self.ln_global = LayerNorm(n_state)

        self.dropout = nn.Dropout(dropout_rate)
        self.window_size = window_size

    def forward(self, x: torch.Tensor):
        x_local = self.ln_local(x)
        x_global = self.ln_global(x)
        x_local = x_local.permute(1, 0, 2)
        x_global = x_global.permute(1, 0, 2)
        local_out = self.sliding_window_attention(x_local)
        global_out, _ = self.global_attn(x_global, x_global, x_global)
        combined_out = local_out + global_out
        combined_out = combined_out.permute(1, 0, 2)
        return self.dropout(combined_out)

    def sliding_window_attention(self, x):
        seq_len, batch_size, n_state = x.size()
        window_size = min(self.window_size, max(1, seq_len // 4))
        output = torch.zeros_like(x, device=x.device, dtype=x.dtype)

        for i in range(0, seq_len, window_size):
            end = min(i + window_size, seq_len)
            query = x[i:end, :, :]
            start = max(0, i - window_size)
            key = x[start:end, :, :]
            value = x[start:end, :, :]
            attn_output, _ = self.local_attn(query, key, value)
            output[i:end, :, :] = attn_output[:end - i, :, :]

        return output

class LayerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        x = (x - mean) / (std + self.eps)
        return self.gamma * x + self.beta

class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, dropout_rate: float = 0.01, use_batchnorm: bool = True, activation: str = 'relu'):
        super(Linear, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.dropout = nn.Dropout(dropout_rate)
        self.use_batchnorm = use_batchnorm
        self.activation = activation

        if self.use_batchnorm:
            self.batchnorm = nn.BatchNorm1d(out_features)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.linear.weight, nonlinearity=self.activation)
        if self.linear.bias is not None:
            nn.init.zeros_(self.linear.bias)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        x = x.view(-1, x.size(-1))  
        x = self.linear(x)

        if self.use_batchnorm:
            x = self.batchnorm(x)

        x = self.apply_activation(x)
        x = self.dropout(x)
        x = x.view(batch_size, seq_len, -1)  
        
        return x

    def apply_activation(self, x):
        if self.activation == 'relu':
            return F.relu(x)
        elif self.activation == 'tanh':
            return torch.tanh(x)
        elif self.activation == 'sigmoid':
            return torch.sigmoid(x)
        else:
            raise ValueError(f'Unsupported activation function: {self.activation}')

def sinusoids(length, channels, max_timescale=10000):
    """Returns sinusoids for positional embedding"""
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)

class MultiHeadAttention(nn.Module):
    use_sdpa = True

    def __init__(self, n_state: int, n_head: int, base: int = 10000, checkpointing=False):
        super().__init__()
        self.n_head = n_head
        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)
        self.h_dim = n_state // n_head
        self.checkpointing=checkpointing

        self.rotation_matrix = nn.Parameter(torch.empty(self.h_dim, self.h_dim))
        nn.init.orthogonal_(self.rotation_matrix)
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))
        self.register_buffer('inv_freq', inv_freq)

    def rotate_queries_or_keys(self, x):
        sinusoid_inp = torch.einsum('i , j -> i j', torch.arange(x.shape[1], device=x.device), self.inv_freq)
        sin = sinusoid_inp.sin()[None, :, None, :]
        cos = sinusoid_inp.cos()[None, :, None, :]
        x1, x2 = x[..., ::2], x[..., 1::2]
        x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
        return x

    def forward(self, x: torch.Tensor, xa: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, kv_cache: Optional[dict] = None):
        q = self.query(x)

        if kv_cache is None or xa is None or self.key not in kv_cache:
            k = self.key(x if xa is None else xa)
            v = self.value(x if xa is None else xa)
        else:
            k = kv_cache[self.key]
            v = kv_cache[self.value]

        q = q.view(q.shape[0], q.shape[1], self.n_head, -1)
        k = k.view(k.shape[0], k.shape[1], self.n_head, -1)

        q = self.rotate_queries_or_keys(q)
        k = self.rotate_queries_or_keys(k)
        q = torch.matmul(q, self.rotation_matrix)
        k = torch.matmul(k, self.rotation_matrix)

        q = q.view(q.shape[0], q.shape[1], -1)
        k = k.view(k.shape[0], k.shape[1], -1)

        wv, qk = self.qkv_attention(q, k, v, mask)
        return self.out(wv), qk

    def qkv_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

        if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
            a = scaled_dot_product_attention(q, k, v, is_causal=mask is not None and n_ctx > 1)
            out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
            qk = None
        else:
            qk = (q * scale) @ (k * scale).transpose(-1, -2)
            if mask is not None:
                qk = qk + mask[:n_ctx, :n_ctx]
            qk = qk.float()

            w = F.softmax(qk, dim=-1).to(q.dtype)
            out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
            qk = qk.detach()

        return out, qk

class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, checkpointing=False):
        super().__init__()

        self.attn = MultiHeadAttention(n_state, n_head)
        self.attn_ln = LayerNorm(n_state)
        self.checkpointing=checkpointing
        
        self.cross_attn = (
            MultiHeadAttention(n_state, n_head) if cross_attention else None
        )
        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None

        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
        )
        self.mlp_ln = LayerNorm(n_state)

    def forward(self, x: torch.Tensor, xa: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, kv_cache: Optional[dict] = None):
        residual = x
        x = self.attn_ln(x)
        x = residual + self.attn(x, mask=mask, kv_cache=kv_cache)[0]

        if self.cross_attn:
            residual = x
            x = self.cross_attn_ln(x)
            x = residual + self.cross_attn(x, xa, kv_cache=kv_cache)[0]

        residual = x
        x = self.mlp_ln(x)
        x = residual + self.mlp(x)

        return x

class RotaryEmbeddingWithRotation(nn.Module):
    def __init__(self, n_state, n_head, base=10000):
        super().__init__()
        self.n_state = n_state
        self.n_head = n_head
        self.h_dim = n_state // n_head
        
        self.rotation_matrix = nn.Parameter(torch.eye(self.h_dim))
        inv_freq = 1.0 / (base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))
        self.register_buffer('inv_freq', inv_freq)

    def reset_parameters(self):
        nn.init.orthogonal_(self.rotation_matrix)

    def forward(self, x):
        if x.dim() == 3:
            batch_size, seq_len, n_state = x.size()
        elif x.dim() == 4:
            batch_size, seq_len, n_head, h_dim = x.size()
            n_state = n_head * h_dim
            x = x.view(batch_size, seq_len, n_state)
        else:
            raise ValueError(f"Expected input tensor to be 3D or 4D, but got {x.dim()}D")

        if n_state != self.n_state:
            raise ValueError(f"Expected n_state of {self.n_state}, but got {n_state}")

        x = x.reshape(batch_size, seq_len, self.n_head, self.h_dim)

        x = x.reshape(-1, self.h_dim)
        rotated_x = torch.matmul(x, self.rotation_matrix)
        rotated_x = rotated_x.reshape(batch_size, seq_len, self.n_head, self.h_dim)

        sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(seq_len, device=x.device), self.inv_freq)
        sin = sinusoid_inp.sin()[None, :, None, :]
        cos = sinusoid_inp.cos()[None, :, None, :]
        x1, x2 = rotated_x[..., ::2], rotated_x[..., 1::2]
        rotated_x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
        
        rotated_x = rotated_x.reshape(batch_size, seq_len, self.n_state)
        return rotated_x

class LearnedSinusoidalEmbeddings(nn.Module):
    def __init__(self, n_ctx, n_state, checkpointing=False):
        super().__init__()
        self.n_ctx = n_ctx
        self.n_state = n_state
        self.checkpointing = checkpointing

        position = torch.arange(0, n_ctx, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, n_state, 2).float() * -(math.log(10000.0) / n_state))
        features = torch.zeros(n_ctx, n_state)
        features[:, 0::2] = torch.sin(position * div_term)
        features[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('sinusoidal_features', features)

        self.positional_embeddings = nn.Parameter(self.sinusoidal_features.clone())

    def forward(self, positions):
        if self.checkpointing:
            position_embeddings = checkpoint(lambda x: self.positional_embeddings[x], positions)
        else:
            position_embeddings = self.positional_embeddings[positions]

        position_embeddings = torch.nn.functional.normalize(position_embeddings, p=2, dim=-1)  
        return position_embeddings

class AudioEncoder(nn.Module):
    def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, checkpointing=False):
        super().__init__()
        self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
        self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, n_state, checkpointing=checkpointing)
        self.rotor_layer = RotaryEmbeddingWithRotation(n_state, n_head)
        self.checkpointing = checkpointing

        self.blocks = nn.ModuleList(
            [ResidualAttentionBlock(n_state, n_head, checkpointing=checkpointing) for _ in range(n_layer)]
        )
        self.ln_post = LayerNorm(n_state)

    def forward(self, x: torch.Tensor):
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)
        x = self.rotor_layer(x)
        pos_emb = self.positional_embedding(torch.arange(x.size(1), device=x.device)).unsqueeze(0)
        x = x + pos_emb

        for block in self.blocks:
            if self.checkpointing:
                x = checkpoint(block, x)
            else:
                x = block(x)
        x = self.ln_post(x)
        return x

class TextDecoder(nn.Module):
    def __init__(self, n_vocab, n_ctx, n_state, n_head, n_layer, checkpointing=False):
        super().__init__()
        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, n_state, checkpointing=checkpointing)
        self.rotor_layer = RotaryEmbeddingWithRotation(n_state, n_head)
        self.checkpointing = checkpointing
        self.n_head = n_head

        self.blocks = nn.ModuleList([
            ResidualAttentionBlock(n_state, n_head, cross_attention=True, checkpointing=checkpointing)
            for _ in range(n_layer)
        ])
        self.ln = LayerNorm(n_state)
        mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
        self.register_buffer("mask", mask, persistent=False)

    def forward(self, x: torch.Tensor, xa: torch.Tensor, kv_cache: Optional[dict] = None):
        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
        positions = torch.arange(x.shape[1], device=x.device) + offset
        pos_emb = self.positional_embedding(positions).unsqueeze(0)

        x = self.token_embedding(x) + pos_emb
        x = x.to(xa.dtype)

        batch_size, seq_length, embedding_dim = x.shape
        num_heads = self.n_head
        head_dim = embedding_dim // num_heads
        x = x.view(batch_size, seq_length, num_heads, head_dim)

        x = self.rotor_layer(x)
        x = x.view(batch_size, seq_length, embedding_dim)

        for block in self.blocks:
            if self.checkpointing:
                x = checkpoint(block, x, xa, self.mask, kv_cache)
            else:
                x = block(x, xa, self.mask, kv_cache)

        x = self.ln(x)
        logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()

        return logits

class Whisper(nn.Module):
    def __init__(self, config: WhisperConfig):
        super().__init__()
        self.config = config
        self.encoder = AudioEncoder(
            self.config.n_mels,
            self.config.n_audio_ctx,
            self.config.n_audio_state,
            self.config.n_audio_head,
            self.config.n_audio_layer,
            self.config.checkpointing,
        )
        self.decoder = TextDecoder(
            self.config.n_vocab,
            self.config.n_text_ctx,
            self.config.n_text_state,
            self.config.n_text_head,
            self.config.n_text_layer,
            self.config.checkpointing,
        )

        all_heads = torch.zeros(
            self.config.n_text_layer, self.config.n_text_head, dtype=torch.bool
        )
        all_heads[self.config.n_text_layer // 2 :] = True
        self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)

    def set_alignment_heads(self, dump: bytes):
        array = np.frombuffer(
            gzip.decompress(base64.b85decode(dump)), dtype=bool
        ).copy()
        mask = torch.from_numpy(array).reshape(
            self.config.n_text_layer, self.config.n_text_head
        )
        self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)

    def embed_audio(self, mel: torch.Tensor):
        return self.encoder(mel)

    def logits(self, tokens: torch.Tensor, input_features: torch.Tensor):
        return self.decoder(tokens, input_features)
    
    @staticmethod
    def shift_tokens_right(input_ids: torch.Tensor, pad_token_id, decoder_start_token_id) -> torch.Tensor:
        shifted_input_ids = input_ids.new_zeros(input_ids.shape, dtype=torch.long)
        shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
        shifted_input_ids[:, 0] = decoder_start_token_id

        if pad_token_id is None:
            raise ValueError("pad_token_id has to be defined.")
        shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
        return shifted_input_ids

    def forward(self, input_features: torch.Tensor, labels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        input_features = input_features.float()
        if torch.isnan(input_features).any():
            print("NaNs detected in input features")

        encoded_features = self.encoder(input_features)
        if torch.isnan(encoded_features).any():
            print("NaNs detected in encoded features")

        if labels is not None:
            labels = labels.long()
            decoder_input_ids = self.shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
            if torch.isnan(decoder_input_ids).any():
                print("NaNs detected in decoder input IDs")
        else:
            decoder_input_ids = None

        logits = self.decoder(decoder_input_ids, encoded_features)
        if torch.isnan(logits).any():
            print("NaNs detected in logits")

        loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
            logits = logits.view(-1, self.config.n_vocab)
            labels = labels.view(-1).long()
            loss = loss_fct(logits, labels)
            if torch.isnan(loss).any():
                print("NaNs detected in loss- do the NaN dance!")

        return {"loss": loss, "logits": logits, "input_features": encoded_features}

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def is_multilingual(self):
        return self.config.n_vocab >= 51865

    @property
    def num_languages(self):
        return self.config.n_vocab - 51765 - int(self.is_multilingual)

    def install_kv_cache_hooks(self, cache: Optional[dict] = None):
        cache = {**cache} if cache is not None else {}
        hooks = []

        def save_to_cache(module, _, output):
            if module not in cache or output.shape[1] > self.config.n_text_ctx:
                cache[module] = output
            else:
                cache[module] = torch.cat([cache[module], output], dim=1).detach()
            return cache[module]

        def install_hooks(layer: nn.Module):
            if isinstance(layer, MultiHeadAttention):
                hooks.append(layer.key.register_forward_hook(save_to_cache))
                hooks.append(layer.value.register_forward_hook(save_to_cache))

        self.decoder.apply(install_hooks)
        return cache, hooks

    def set_input_embeddings(self, new_embeddings: torch.nn.Embedding):
        self.decoder.token_embedding = new_embeddings

    def get_input_embeddings(self):
        return self.decoder.token_embedding

    def resize_token_embeddings(self, new_num_tokens: int):
        old_embeddings = self.get_input_embeddings()
        old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
        new_embeddings = torch.nn.Embedding(new_num_tokens, old_embedding_dim)
    
        new_embeddings.weight.data[:old_num_tokens, :] = old_embeddings.weight.data
        self.set_input_embeddings(new_embeddings)
        self.config.n_vocab = new_num_tokens

    detect_language = detect_language_function
    transcribe = transcribe_function
    decode = decode_function

    def save_pretrained(self, save_directory, safetensor=False):
        self.config.save_pretrained(save_directory)
        if safetensor:
            safetensor_path = os.path.join(save_directory, "model.safetensors")
            with safe_open(safetensor_path, framework="pt", mode="w") as f:
                for key, value in self.state_dict().items():
                    f.set_tensor(key, value)
        else:
            torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, safetensor=False, *model_args, **kwargs):
        config = WhisperConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
        model = cls(config, *model_args, **kwargs)
        if safetensor:
            safetensor_path = f"{pretrained_model_name_or_path}/model.safetensors"
            with safe_open(safetensor_path, framework="pt", device="cpu") as f:
                state_dict = {key: torch.tensor(f.get_tensor(key)) for key in f.keys()}
        else:
            state_dict = torch.load(os.path.join(pretrained_model_name_or_path, "pytorch_model.bin"), map_location="cpu")
        model.load_state_dict(state_dict, strict=False)
        return model


    def get_encoder(self):
        return self.encoder

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {'input_features': input_ids}

    def _prepare_decoder_input_ids_for_generation(self, batch_size, decoder_start_token_id=None, bos_token_id=None):
        return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * self.config.decoder_start_token_id

    def can_generate(self):
        return True
    
    def generate(self, inputs, **kwargs):
        encoder_outputs = self.encoder(inputs)
        decoder_input_ids = torch.zeros((inputs.size(0), 1), dtype=torch.long, device=inputs.device)
        outputs = self.decoder(decoder_input_ids, encoder_outputs)
        return outputs.argmax(dim=-1)
    
    def generate_beam_search(self, inputs, **kwargs):
        encoder_outputs = self.encoder(inputs)
        decoder_input_ids = torch.zeros((inputs.size(0), 1), dtype=torch.long, device=inputs.device)
        outputs = self.decoder(decoder_input_ids, encoder_outputs)
        return outputs.argmax(dim=-1)
    

In [None]:
def load_wave(wave_path, sample_rate: int = 16000) -> torch.Tensor:
    waveform, sr = torchaudio.load(wave_path, normalize=True)
    if sample_rate != sr:
        waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
    return waveform

class audioDataset(Dataset):
    def __init__(self, csv_file, aud_dir, tokenizer, sample_rate=16000):
        self.aud_dir = aud_dir
        self.tokenizer = tokenizer
        self.sample_rate = sample_rate
        self.samples = []

        with open(csv_file, 'r', encoding='utf-8') as f:
            reader = csv.reader(f)
            next(reader) 
            for row in reader:
                aud_path, label = row[0], row[1]
                self.samples.append((aud_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        aud_path, label = self.samples[idx]
        label = handle_unknown_characters(label)
        aud = f'{self.aud_dir}/{aud_path}'
        return {
            'input_features': aud,
            'labels': label,
            'input_ids': label 
        }

def handle_unknown_characters(label): 
    label = label.encode('utf-8').decode('utf-8', errors='replace')
    label = neologdn.normalize(label, repeat=1)
    return label

class WhisperDataCollatorWithPadding:
    def __init__(self, tokenizer, n_mels, n_fft, hop_length, sample_rate=16000):
        self.tokenizer = tokenizer
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.mel_spectrogram_transform = transforms.MelSpectrogram(
            sample_rate=self.sample_rate,
            n_fft=self.n_fft,
            n_mels=self.n_mels,
            hop_length=self.hop_length
        )

    def __call__(self, features):
        input_features, dec_input_ids, labels = [], [], []

        for f in features:
            aud_path = f['input_features']
            aud, _ = torchaudio.load(aud_path, normalize=True)
            aud = whisper.pad_or_trim(aud.flatten())

            mel_spectrogram = self.mel_spectrogram_transform(aud)
            log_mel_spectrogram = torch.log(mel_spectrogram + 1e-8)  

            label = handle_unknown_characters(f['labels']) 
            encoded_input = self.tokenizer.encode(label)
            encoded_label = self.tokenizer.encode(label)

            dec_input_ids.append([self.tokenizer.bos_token_id] + encoded_input)
            labels.append(encoded_label + [self.tokenizer.eos_token_id])
            input_features.append(log_mel_spectrogram)

        input_features = torch.stack(input_features)

        input_lengths = [len(ids) for ids in dec_input_ids]
        label_lengths = [len(lab) for lab in labels]
        max_len = max(input_lengths + label_lengths)

        dec_input_ids = [np.pad(ids, (0, max_len - len(ids)), 'constant', constant_values=self.tokenizer.pad_token_id) for ids in dec_input_ids]
        labels = [np.pad(lab, (0, max_len - len(lab)), 'constant', constant_values=-100) for lab in labels]

        batch = {
            "input_ids": dec_input_ids,
            "labels": labels,
            "input_features": input_features
        }
        batch = {k: torch.tensor(v, requires_grad=False) for k, v in batch.items()}
        return batch

metrics_cer = evaluate.load("cer")
def compute_metrics(pred):
    pred_ids = pred["predictions"]
    label_ids = pred["label_ids"]
    label_ids[label_ids == -100] = tokenizer.pad_token_id
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    cer = 100 * metrics_cer.compute(predictions=pred_str, references=label_str)
    return {"cer": cer}

checkpoint_dir = 'D:/proj3/checkpoints/'
os.makedirs(checkpoint_dir, exist_ok=True)
log_dir = "D:/proj3/logs/run_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
os.makedirs(log_dir, exist_ok=True)

writer = SummaryWriter(log_dir)
import logging
logging.basicConfig(
    filename=os.path.join(log_dir, 'training.log'), 
    filemode='w', 
    format='%(asctime)s - %(levelname)s - %(message)s', 
    level=logging.INFO
)

In [None]:

def train_and_evaluate(model, train_loader, eval_loader, optimizer, scheduler, loss_fn, num_epochs=1, max_steps=None, device='cuda', accumulation_steps=1, clear_cache=True, log_interval=10, eval_interval=20, save_interval=100, checkpoint_dir="checkpoint_dir", log_dir="log_dir"):
    model.to(device)
    global_step = 0
    scaler = GradScaler()
    writer = SummaryWriter(log_dir=log_dir)
    lr_warning_printed = False

    for epoch in range(num_epochs):
        if max_steps is not None and global_step >= max_steps:
            break

        model.train()
        total_loss = 0
        optimizer.zero_grad()
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")

        for step, batch in enumerate(progress_bar):
            if max_steps is not None and global_step >= max_steps:
                break

            start_time = time.time()

            input_features = batch['input_features'].to(device)
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].long().to(device)

            with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
                with record_function("model_training"):
                    with torch.amp.autocast(device_type='cuda'):
                        input_features_encoded = model.encoder(input_features)
                        decoder_output = model.decoder(input_ids, input_features_encoded)
                        logits = decoder_output.view(-1, decoder_output.size(-1))
                        loss = loss_fn(logits, labels.view(-1))
                        total_loss += loss.item()
                        loss = loss / accumulation_steps

                    scaler.scale(loss).backward()

                    if (step + 1) % accumulation_steps == 0:
                        scaler.step(optimizer)
                        scaler.update()
                        optimizer.zero_grad()

                        if clear_cache:
                            torch.cuda.empty_cache()

            global_step += 1
            end_time = time.time()
            samples_per_sec = len(batch['input_features']) / (end_time - start_time)

            total_norm = 0
            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** (1. / 2)

            if global_step % log_interval == 0:
                writer.add_scalar('Loss/train', total_loss / (step + 1), global_step)
                writer.add_scalar('GradientNorm', total_norm, global_step)
                
                lr = optimizer.param_groups[0].get('lr', None)
                if lr is not None:
                    writer.add_scalar('LearningRate', lr, global_step)
                else:
                    if not lr_warning_printed:
                        print(f"Warning: Learning rate is None at step {global_step}")
                        lr_warning_printed = True

                writer.add_scalar('SamplesPerSec', samples_per_sec, global_step)

            if global_step % eval_interval == 0:
                model.eval()
                eval_loss = 0
                all_predictions = []
                all_labels = []
                with torch.no_grad():
                    for eval_batch in eval_loader:
                        input_features = eval_batch['input_features'].to(device)
                        input_ids = eval_batch['input_ids'].to(device)
                        labels = eval_batch['labels'].long().to(device)
                        input_features_encoded = model.encoder(input_features)
                        decoder_output = model.decoder(input_ids, input_features_encoded)
                        logits = decoder_output.view(-1, decoder_output.size(-1))
                        loss = loss_fn(logits, labels.view(-1))
                        eval_loss += loss.item()
                        all_predictions.extend(torch.argmax(decoder_output, dim=-1).cpu().numpy().tolist())
                        all_labels.extend(labels.cpu().numpy().tolist())

                eval_loss /= len(eval_loader)
                predictions = {"predictions": np.array(all_predictions, dtype="object"), "label_ids": np.array(all_labels, dtype="object")}
                metrics = compute_metrics(predictions)
                writer.add_scalar('Loss/eval', eval_loss, global_step)
                writer.add_scalar('CER', metrics['cer'], global_step)
                scheduler.step()  # Step the scheduler

                sample_indices = range(min(1, len(all_predictions))) 
                for idx in sample_indices:
                    pred_str = tokenizer.decode(all_predictions[idx], skip_special_tokens=True)
                    label_str = tokenizer.decode(all_labels[idx], skip_special_tokens=True)
                    print(f"Evaluation Loss: {eval_loss:.4f}")
                    print(f"Evaluation Sample {idx}: Prediction: {pred_str}, Label: {label_str}")
                    logging.info(f"Evaluation Sample {idx}: Prediction: {pred_str}, Label: {label_str}")

                model.train()

                print(f"Evaluation Loss: {eval_loss:.4f}")
                print(f"Character Error Rate (CER): {metrics['cer']:.4f}")

            if global_step % save_interval == 0:
                checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_step_{global_step}.pt')
                torch.save(model.state_dict(), checkpoint_path)
                print(f"Model saved at step {global_step} to {checkpoint_path}")
                logging.info(f"Model saved at step {global_step} to {checkpoint_path}")

        scheduler.step()

        print(f'Epoch {epoch + 1}, Loss: {total_loss / len(train_loader)}')
        logging.info(f'Epoch {epoch + 1}, Loss: {total_loss / len(train_loader)}')

    final_model_path = os.path.join(checkpoint_dir, 'final_model.pt')
    torch.save(model.state_dict(), final_model_path)
    print(f"Final model saved to {final_model_path}")
    logging.info(f"Final model saved to {final_model_path}")
    writer.close()


In [None]:

if __name__ == "__main__":

    tokenizer = WhisperTokenizerFast.from_pretrained("openai/whisper-small")
    csv_file = 'D:/proj/datasets/gf_1/metadata.csv'
    audio_dir = 'D:/proj/datasets/gf_1/'

    def train_val_dataset(dataset, val_split=0.001):
        train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=val_split)
        datasets = {}
        datasets['train'] = Subset(dataset, train_idx)
        datasets['val'] = Subset(dataset, val_idx)
        return datasets

    dataset = audioDataset(csv_file, audio_dir, tokenizer)
    datasets = train_val_dataset(dataset)
    train_dataset = datasets['train']
    eval_dataset = datasets['val']
    
    def train_dataloader():   
        return DataLoader(
            train_dataset,
            batch_size=1,
            drop_last=False, 
            shuffle=True, 
            num_workers=0,
            collate_fn=collate_fn
        )

    def eval_dataloader():
        return DataLoader(
            eval_dataset,
            batch_size=1, 
            drop_last=True,
            shuffle=False,
            num_workers=0,
            collate_fn=collate_fn
        )

    collate_fn = WhisperDataCollatorWithPadding(tokenizer, n_fft=1024, hop_length=256, n_mels=80)
    train_loader = train_dataloader()
    eval_loader = eval_dataloader()

    config = WhisperConfig(
        n_mels=80,
        n_audio_ctx=1500,
        n_audio_state=1024,
        n_audio_head=16,
        n_audio_layer=24,
        n_vocab=51865,
        n_text_ctx=448,
        n_text_state=1024,
        n_text_head=16,
        n_text_layer=20,
        checkpointing=True,
        )

    model = Whisper(config).cuda()
    # model.resize_token_embeddings(len(tokenizer))
    optimizer = transformers.Adafactor(model.parameters(), 
                                    clip_threshold=0.99, 
                                    weight_decay=0.025, 
                                    scale_parameter=True, 
                                    relative_step=False, 
                                    warmup_init=False, 
                                    lr=2.25e-3)

    scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)

    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)

    from torch.utils.tensorboard import SummaryWriter

    train_and_evaluate(model, train_loader, eval_loader, optimizer, scheduler, loss_fn, max_steps=100, num_epochs=1, device='cuda', accumulation_steps=1, clear_cache=True, log_interval=1, eval_interval=10, save_interval=100, checkpoint_dir=checkpoint_dir, log_dir=log_dir)
