In [None]:
import sys, os, random, math, time
import math
import logging, warnings, csv, base64, gzip
import json, datetime, numpy as np, torch, torch.nn as nn
import torch.optim as optim, torch.nn.functional as F, torchaudio
import torchaudio.transforms as transforms, torch.utils.checkpoint as checkpoint
import torch.utils.tensorboard as tensorboard, torch.optim.lr_scheduler as lr_scheduler
import transformers, neologdn, evaluate, MeCab, deepl, logging, datasets, tqdm, whisper
import transformers.utils.logging
from datasets import load_from_disk, load_dataset
from contextlib import contextmanager
from dataclasses import dataclass
from torch.utils.data import Subset
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from tqdm import tqdm
from torch.profiler import profile, ProfilerActivity, record_function
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
from torch import amp, Tensor
from torch.optim import Adamax
import logging
from transformers import (
    logging,
    Trainer,
    TrainingArguments,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    PretrainedConfig,
    GenerationConfig,
    WhisperFeatureExtractor,
    WhisperProcessor,
    WhisperForConditionalGeneration,
    WhisperTokenizerFast,
    WhisperTokenizer,
    WhisperModel,
    WhisperConfig,
    Adafactor,
    TrainerCallback,
    logging
)
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.amp import GradScaler, autocast
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
warnings.filterwarnings(action="ignore")
warnings.warn = lambda *args, **kwargs: None

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
checkpointing_args = {"reentrant": False}

try:
    from torch.nn.functional import scaled_dot_product_attention
    SDPA_AVAILABLE = True
except (ImportError, RuntimeError, OSError):
    scaled_dot_product_attention = None
    SDPA_AVAILABLE = False

from whisper.decoding import decode as decode_function
from whisper.decoding import detect_language as detect_language_function
from whisper.transcribe import transcribe as transcribe_function

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mecab = MeCab.Tagger("-Owakati")

transformers.utils.logging.set_verbosity_error()


In [None]:
class LayerNorm(nn.LayerNorm):
    def forward(self, x) -> Tensor:
        return super().forward(x.float()).type(x.dtype)

class Linear(nn.Linear):
    def __init__(self, *args, dropout_rate: float = 0.5, **kwargs):
        super().__init__(*args, **kwargs)
        self.reset_parameters()
        self.dropout = nn.Dropout(dropout_rate)  

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)

    def forward(self, x) -> Tensor:
        weight = self.weight.to(x.dtype)
        bias = None if self.bias is None else self.bias.to(x.dtype)
        x = F.linear(x, weight, bias)  
        return self.dropout(x) 
    
class Conv1d(nn.Conv1d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')
        if self.bias is not None:
            nn.init.zeros_(self.bias)

    def _conv_forward(self, x, weight, bias) -> Tensor:
        weight = self.weight.to(x.dtype)
        bias = None if self.bias is None else self.bias.to(x.dtype)
        return super()._conv_forward(x, weight, bias)

class RotationLayer(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.rotation_matrix = nn.Parameter(torch.eye(embed_dim))

    def forward(self, x):
        rotated_x = torch.matmul(x, self.rotation_matrix)
        return rotated_x

    def reset_parameters(self):
        nn.init.orthogonal_(self.rotation_matrix)

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

    def rotate_queries_or_keys(self, x):
        sinusoid_inp = torch.einsum('i , j -> i j', torch.arange(x.shape[1], device=x.device), self.inv_freq) 
        sin = sinusoid_inp.sin()[None, :, None, :] 
        cos = sinusoid_inp.cos()[None, :, None, :]
        x1, x2 = x[..., ::2], x[..., 1::2]
        x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
        return x

class SinusoidalFeatures:
    def __init__(self, n_ctx, n_state):
        self.n_ctx = n_ctx
        self.n_state = n_state
        self.features = self.sinusoidal_features(n_ctx, n_state)

    @staticmethod
    def sinusoidal_features(n_ctx, n_state):
        position = torch.arange(0, n_ctx, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, n_state, 2).float() * -(math.log(10000.0) / n_state))
        features = torch.zeros(n_ctx, n_state)
        features[:, 0::2] = torch.sin(position * div_term)
        features[:, 1::2] = torch.cos(position * div_term)
        return features

    def __call__(self):
        return self.features

class LearnedSinusoidalEmbeddings(nn.Module): 
    def __init__(self, n_ctx, n_state, gradient_checkpointing=False):
        super().__init__()
        self.n_ctx = n_ctx
        self.n_state = n_state
        self.gradient_checkpointing = gradient_checkpointing

        sinusoidal_embeddings = SinusoidalFeatures(n_ctx, n_state)()
        self.positional_embeddings = nn.Parameter(sinusoidal_embeddings)

    def forward(self, positions):
        if self.gradient_checkpointing:
            position_embeddings = checkpoint.checkpoint(lambda x: self.positional_embeddings[x], positions)
        else:
            position_embeddings = self.positional_embeddings[positions]

        position_embeddings = F.normalize(position_embeddings, p=2, dim=-1)  
        return position_embeddings


In [None]:
class BiasedCrossAttention(nn.Module):
    def __init__(self, n_state, n_head, dropout_rate=0.1, group_norm=False):
        super().__init__()
        self.n_head = n_head
        self.n_state = n_state
        self.head_dim = n_state // n_head

        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)

        self.bias = nn.Parameter(torch.zeros(n_head, self.head_dim))
        self.dropout = nn.Dropout(dropout_rate)
        self.norm = nn.GroupNorm(num_groups=16, num_channels=n_state) if group_norm else nn.LayerNorm(n_state) 

    def forward(self, q, k, v, mask = None):
        q = self.query(q).view(q.shape[0], q.shape[1], self.n_head, self.head_dim)
        k = self.key(k).view(k.shape[0], k.shape[1], self.n_head, self.head_dim)
        v = self.value(v).view(v.shape[0], v.shape[1], self.n_head, self.head_dim)

        qk = (q @ k.transpose(-2, -1)) / self.head_dim ** 0.5 + self.bias
        if mask is not None:
            qk += mask

        w = F.softmax(qk, dim=-1)
        w = self.dropout(w)

        out = (w @ v).permute(0, 2, 1, 3).reshape(q.shape[0], q.shape[1], -1)
        out = self.norm(self.out(out) + q.view(q.shape[0], q.shape[1], -1))
        return out

In [None]:
class AugmentedMemory(nn.Module):
    def __init__(self, n_state, n_head, memory_size, dropout_rate=0.1):
        super().__init__()
        self.n_head = n_head
        self.head_dim = n_state // n_head
        self.memory_size = memory_size

        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.memory = nn.Parameter(torch.randn(memory_size, n_state))
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        q = self.query(x).view(x.size(0), x.size(1), self.n_head, self.head_dim).transpose(1, 2)
        k = self.key(self.memory).view(self.memory_size, self.n_head, self.head_dim).transpose(1, 0)
        v = self.value(self.memory).view(self.memory_size, self.n_head, self.head_dim).transpose(1, 0)

        qk = torch.einsum('bhqd,hnd->bhnq', q, k) / self.head_dim ** 0.5
        w = F.softmax(qk, dim=-1)
        w = self.dropout(w)

        out = torch.einsum('bhnq,hnd->bhqd', w, v)
        out = out.transpose(1, 2).contiguous().view(x.size(0), -1, self.n_head * self.head_dim)

        return out

In [None]:

class DynamicConvAttention(nn.Module):
    def __init__(self, n_state, n_head, kernel_size=3, dropout_rate=0.1, group_norm=False):
        super().__init__()

        self.n_state = n_state
        self.n_head = n_head
        self.kernel_size = kernel_size
        self.conv = nn.Conv1d(n_state, n_state, kernel_size, padding=kernel_size//2, groups=n_head)
        self.dropout = nn.Dropout(dropout_rate)

        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)

        self.norm = nn.GroupNorm(num_groups=16, num_channels=n_state) if group_norm else nn.LayerNorm(n_state)       
        self.out = nn.Linear(n_state, n_state)

    def forward(self, x):

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        batch_size, seq_len, embed_dim = x.size()
        if embed_dim != self.n_state:
            raise ValueError(f"Expected embed_dim of {self.n_state}, but got {embed_dim}")

        x = x.permute(0, 2, 1) 
        conv_out = self.conv(x)
        conv_out = conv_out.permute(0, 2, 1) 
        conv_out = self.norm(conv_out)
        conv_out = self.dropout(conv_out)
        return self.out(conv_out) + x.permute(0, 2, 1) 


In [None]:

class HybridAttention(nn.Module):
    def __init__(self, n_state, n_head, window_size=1, dropout_rate=0.1, group_norm=False):
        super().__init__()
        self.local_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)
        self.global_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)
        self.group_norm = group_norm
        self.ln_local = GroupNorm(num_groups=16, num_channels=n_state) if group_norm else LayerNorm(n_state)
        self.ln_global = GroupNorm(num_groups=16, num_channels=n_state) if group_norm else LayerNorm(n_state)
        self.dropout = nn.Dropout(dropout_rate)
        self.window_size = window_size

    def forward(self, x: torch.Tensor):
        x_local = self.ln_local(x)
        x_global = self.ln_global(x)
        x_local = x_local.permute(1, 0, 2)
        x_global = x_global.permute(1, 0, 2)
        local_out = self.sliding_window_attention(x_local)
        global_out, _ = self.global_attn(x_global, x_global, x_global)
        combined_out = local_out + global_out
        combined_out = combined_out.permute(1, 0, 2)
        return self.dropout(combined_out)

    def sliding_window_attention(self, x):
        seq_len, batch_size, n_state = x.size()
        window_size = min(self.window_size, max(1, seq_len // 4)) 
        output = torch.zeros_like(x, device=x.device, dtype=x.dtype)
        
        for i in range(0, seq_len, window_size):
            end = min(i + window_size, seq_len)
            query = x[i:end, :, :]
            start = max(0, i - window_size)
            key = x[start:end, :, :]
            value = x[start:end, :, :]
            attn_output, _ = self.local_attn(query, key, value)
            output[i:end, :, :] = attn_output[:end - i, :, :]
            
        return output


In [None]:

class GroupNorm(nn.Module): 
    def __init__(self, num_groups, num_channels, eps=1e-6):
        super(GroupNorm, self).__init__()
        self.num_groups = num_groups
        self.eps = eps
        self.g = nn.Parameter(torch.ones(num_channels))
        self.b = nn.Parameter(torch.zeros(num_channels))
        self.scale = nn.Parameter(torch.ones(num_channels))

        self.group_norm = nn.GroupNorm(num_groups, num_channels, eps=eps)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.group_norm(x)
        norm_x = torch.norm(x, dim=-1, keepdim=True)
        rms_x = x / norm_x
        scaled_x = rms_x * self.scale.view(1, -1, 1)
        x = self.g.view(1, -1, 1) * scaled_x + self.b.view(1, -1, 1)
        x = x.permute(0, 2, 1)
        return x


In [None]:
class MultiHeadAttention(nn.Module):
    sdpa = True

    def __init__(self, n_state, n_head, dropout_rate=0.01, gradient_checkpointing=False, group_norm=False, use_rotation_dynamics=False, use_dynamic_attention_integration=False, hybrid_attention=False, dynamic_conv=False, biased_attention=False, augmented_memory=False):
        super().__init__()
        self.n_head = n_head
        self.n_state = n_state
        self.head_dim = n_state // n_head

        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)



        self.use_rotation_dynamics = use_rotation_dynamics
        self.use_dynamic_attention_integration = use_dynamic_attention_integration
        self.hybrid_attention = hybrid_attention
        self.dynamic_conv = dynamic_conv
        self.biased_attention = biased_attention
        self.augmented_memory = augmented_memory
        self.gradient_checkpointing = gradient_checkpointing
        self.group_norm = group_norm

        self.rotary_emb = RotaryEmbedding(dim=self.head_dim)
        self.rotation_layer = RotationLayer(self.head_dim) 
        # self.temperature = nn.Parameter(torch.ones(1) * (self.head_dim ** -0.5))
        self.dropout = nn.Dropout(dropout_rate)
        
        self.attn_ln = GroupNorm(num_groups=16, num_channels=n_state) if group_norm else nn.LayerNorm(n_state)

    def forward(self, x, xa = None, mask = None, kv_cache = None):
        x_norm = self.attn_ln(x)

        q = self.query(x_norm)
        k = self.key(x_norm if xa is None else xa)
        v = self.value(x_norm if xa is None else xa)

        if kv_cache is not None and self.key in kv_cache:
            k = kv_cache[self.key]
            v = kv_cache[self.value]

        q = q.view(q.shape[0], q.shape[1], self.n_head, -1)
        k = k.view(k.shape[0], k.shape[1], self.n_head, -1)

        q = self.rotary_emb.rotate_queries_or_keys(q)
        k = self.rotary_emb.rotate_queries_or_keys(k)
        q = self.rotation_layer(q)
        k = self.rotation_layer(k)

        q = q.view(q.shape[0], q.shape[1], -1)
        k = k.view(k.shape[0], k.shape[1], -1)

        wv, qk = self.qkv_attention(q, k, v, mask)

        return self.out(wv) + x, qk

    def qkv_attention(self, q, k, v, mask = None):
        n_batch, n_ctx, n_state = q.shape
        scale = 1.0##### quick fix needs work

        q = q.view(n_batch, n_ctx, self.n_head, self.head_dim).permute(0, 2, 1, 3)
        k = k.view(n_batch, k.shape[1], self.n_head, self.head_dim).permute(0, 2, 1, 3)
        v = v.view(n_batch, v.shape[1], self.n_head, self.head_dim).permute(0, 2, 1, 3)

        if SDPA_AVAILABLE and MultiHeadAttention.sdpa:
            a = scaled_dot_product_attention(q, k, v, is_causal=mask is not None and n_ctx > 1)
            out = a.permute(0, 2, 1, 3).reshape(n_batch, n_ctx, n_state)
            qk = None
        else:
            qk = (q * scale) @ (k.transpose(-2, -1) * scale)
            if mask is not None:
                qk += mask[:n_ctx, :n_ctx]
            qk = qk.float()

            w = F.softmax(qk, dim=-1).to(q.dtype)
            w = self.dropout(w)
            out = (w @ v).permute(0, 2, 1, 3).reshape(n_batch, n_ctx, n_state)
            qk = qk.detach()

        return out, qk


In [None]:
class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state, n_head, dynamic_conv = False, hybrid_attention = False,
                 biased_attention = False, augmented_memory = False,
                 cross_attention = False, dropout_rate=0.1, gradient_checkpointing=False, group_norm=False, 
                 window_size=1, use_rotation_dynamics=False, use_dynamic_attention_integration=False):
        super().__init__()
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)

        self.use_rotation_dynamics = use_rotation_dynamics
        self.use_dynamic_attention_integration = use_dynamic_attention_integration
        self.hybrid_attention = hybrid_attention
        self.dynamic_conv = dynamic_conv
        self.biased_attention = biased_attention
        self.augmented_memory = augmented_memory

        self.attention_mechanisms = nn.ModuleList([MultiHeadAttention(n_state, n_head, dropout_rate=dropout_rate, gradient_checkpointing=gradient_checkpointing, group_norm=group_norm)])
        
        if self.hybrid_attention:
            self.attention_mechanisms.append(HybridAttention(n_state, n_head, window_size=window_size, dropout_rate=dropout_rate, group_norm=group_norm))
        if self.dynamic_conv:
            self.attention_mechanisms.append(DynamicConvAttention(n_state, n_head, dropout_rate=dropout_rate))
        if self.biased_attention:
            self.attention_mechanisms.append(BiasedCrossAttention(n_state, n_head, dropout_rate=dropout_rate))
        if self.augmented_memory:
            self.attention_mechanisms.append(AugmentedMemory(n_state, memory_size=512, n_head=n_head, dropout_rate=dropout_rate))

        self.attn_scores = nn.Parameter(torch.randn(len(self.attention_mechanisms)))
        self.attn_ln = GroupNorm(num_groups=16, num_channels=n_state) if group_norm else nn.LayerNorm(n_state)
        self.cross_attention = cross_attention
        if self.cross_attention:
            self.cross_attn = MultiHeadAttention(n_state, n_head, dropout_rate=dropout_rate, gradient_checkpointing=gradient_checkpointing, group_norm=group_norm)
            self.cross_attn_ln = GroupNorm(num_groups=16, num_channels=n_state) if group_norm else nn.LayerNorm(n_state)

        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            Linear(n_state, n_mlp),
            GroupNorm(num_groups=16, num_channels=n_mlp) if group_norm else nn.LayerNorm(n_mlp),
            nn.GELU(),
            nn.Dropout(p=dropout_rate),
            Linear(n_mlp, n_state)
        )
        self.mlp_ln = GroupNorm(num_groups=16, num_channels=n_state) if group_norm else nn.LayerNorm(n_state)
        self.gradient_checkpointing = gradient_checkpointing

    def forward(self, x, xa = None, mask = None, kv_cache = None, 
                biased_attention=False, augmented_memory=False):
        global printed_attention_mechanisms_once
        if not printed_attention_mechanisms_once:
            print("Attention mechanisms being used:")
            for attn in self.attention_mechanisms:
                print(f"- {type(attn).__name__}")
            printed_attention_mechanisms_once = True

        q = self.query(x)
        k = self.key(xa if xa is not None else x)
        v = self.value(xa if xa is not None else x)

        attn_input = self.attn_ln(x)
        attn_outputs = []

        for attn in self.attention_mechanisms:
            if isinstance(attn, BiasedCrossAttention) and not biased_attention:
                continue
            if isinstance(attn, AugmentedMemory) and not augmented_memory:
                continue

            if self.gradient_checkpointing:
                output = checkpoint.checkpoint(attn, attn_input, k, v, mask) if isinstance(attn, BiasedCrossAttention) else checkpoint.checkpoint(attn, attn_input)
            else:
                output = attn(attn_input, k, v, mask) if isinstance(attn, BiasedCrossAttention) else attn(attn_input)

            attn_outputs.append(output if isinstance(output, Tensor) else output[0])

        if not attn_outputs:
            raise ValueError("No attention mechanisms are enabled. Ensure at least one attention mechanism is active.")

        if self.use_dynamic_attention_integration:
            attn_outputs = torch.stack(attn_outputs, dim=0)
            attn_scores = F.softmax(self.attn_scores[:len(attn_outputs)], dim=0)  # Match scores to the number of attention mechanisms
            weighted_attn_output = torch.einsum('i,ijkm->ijkm', attn_scores, attn_outputs).sum(dim=0)
            attn_out = x + weighted_attn_output
        else:
            attn_out = x + attn_outputs[0]

        if self.cross_attention and xa is not None:
            cross_attn_input = self.cross_attn_ln(attn_out)
            if self.gradient_checkpointing:
                attn_out = attn_out + checkpoint.checkpoint(self.cross_attn, cross_attn_input, xa, kv_cache)[0]
            else:
                attn_out = attn_out + self.cross_attn(cross_attn_input, xa, kv_cache=kv_cache)[0]

        mlp_input = self.mlp_ln(attn_out)
        if self.gradient_checkpointing:
            mlp_out = attn_out + checkpoint.checkpoint(self.mlp, mlp_input)
        else:
            mlp_out = attn_out + self.mlp(mlp_input)

        return mlp_out


In [None]:
class AudioEncoder(nn.Module):
    main_input_name = "input_features"

    def __init__(self, window_size, n_mels, n_ctx, n_state, n_head, n_layer,
                 dropout_rate=0.1, gradient_checkpointing=False, dynamic_conv = False, hybrid_attention = False, biased_attention = False, augmented_memory = False, group_norm=False,
                 use_rotation_dynamics=False, use_dynamic_attention_integration=False):
        super().__init__()

        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)

        self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
        self.dropout = nn.Dropout(dropout_rate)
        self.gradient_checkpointing = gradient_checkpointing
        self.group_norm = group_norm

        self.blocks = nn.ModuleList([
            ResidualAttentionBlock(
                n_state, n_head,
                window_size=window_size,
                hybrid_attention=hybrid_attention,                
                dynamic_conv=dynamic_conv,
                biased_attention=biased_attention,
                augmented_memory=augmented_memory,
                dropout_rate=dropout_rate,
                gradient_checkpointing=gradient_checkpointing,
                group_norm=group_norm,
                use_rotation_dynamics=use_rotation_dynamics,  
                use_dynamic_attention_integration=use_dynamic_attention_integration 
            )
            for _ in range(n_layer)
        ])

        if self.group_norm:
            self.ln_post = GroupNorm(num_groups=16, num_channels=n_state)
        else:
            self.ln_post = LayerNorm(n_state)

    def forward(self, x: torch.Tensor):
        x = F.gelu(self.conv1(x))
        x = self.dropout(x)
        x = F.gelu(self.conv2(x))
        x = self.dropout(x)
        x = x.permute(0, 2, 1)

        for block in self.blocks:
            if self.gradient_checkpointing:
                x = checkpoint.checkpoint(block, x)
            else:
                x = block(x)

        x = self.ln_post(x)
        return x


In [None]:
class TextDecoder(nn.Module):
    def __init__(self, n_vocab, n_ctx, n_state, n_head, n_layer,
                 dropout_rate=0.1, gradient_checkpointing=False, dynamic_conv=False, hybrid_attention=False, biased_attention=False, augmented_memory=False, group_norm=False,
                 use_rotation_dynamics=False, use_dynamic_attention_integration=False):
        super().__init__()

        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, n_state, gradient_checkpointing=gradient_checkpointing)
        self.gradient_checkpointing = gradient_checkpointing
        self.group_norm = group_norm
        self.n_head = n_head

        self.rotary_emb = RotaryEmbedding(dim=n_state // n_head)  
        self.rotation_layer = RotationLayer(n_state // n_head) 

        self.blocks = nn.ModuleList([
            ResidualAttentionBlock(
                n_state, n_head,
                hybrid_attention=hybrid_attention,
                dynamic_conv=dynamic_conv,
                biased_attention=biased_attention,
                augmented_memory=augmented_memory,
                cross_attention=True,
                dropout_rate=dropout_rate,
                gradient_checkpointing=gradient_checkpointing,
                group_norm=group_norm,
                use_rotation_dynamics=use_rotation_dynamics, 
                use_dynamic_attention_integration=use_dynamic_attention_integration  
            )
            for _ in range(n_layer)
        ])

        if self.group_norm:
            self.ln_post = GroupNorm(num_groups=16, num_channels=n_state)
        else:
            self.ln_post = LayerNorm(n_state)

        mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
        self.register_buffer('mask', mask, persistent=False)

    def forward(self, x, xa, kv_cache=None):
        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
        positions = torch.arange(x.shape[1], device=x.device) + offset
        pos_emb = self.positional_embedding(positions).unsqueeze(0)

        x = self.token_embedding(x) + pos_emb
        x = x.to(xa.dtype)

        batch_size, seq_length, embedding_dim = x.shape
        num_heads = self.n_head
        head_dim = embedding_dim // num_heads
        x = x.view(batch_size, seq_length, num_heads, head_dim)

        x = self.rotary_emb.rotate_queries_or_keys(x)
        x = self.rotation_layer(x)

        x = x.view(batch_size, seq_length, embedding_dim)

        for block in self.blocks:
            if self.gradient_checkpointing:
                x = checkpoint.checkpoint(block, x, xa, self.mask, kv_cache)
            else:
                x = block(x, xa, self.mask, kv_cache)

        x = self.ln_post(x)
        logits = (x @ self.token_embedding.weight.to(x.dtype).T).float()
        if kv_cache is not None:
            for layer, cache in self.kv_cache_layers.items():
                kv_cache[layer] = cache

        return logits


In [None]:
class WhisperConfig(PretrainedConfig):
    model_type = "whisper"
    keys_to_ignore_at_inference = []
    attribute_map = {}

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.activation_dropout = kwargs.get("activation_dropout", 0.0)
        self.activation_function = kwargs.get("activation_function", "gelu")
        self.architectures = kwargs.get("architectures", ["WhisperForConditionalGeneration"])
        self.attention_dropout = kwargs.get("attention_dropout", 0.0)
        self.begin_suppress_tokens = kwargs.get("begin_suppress_tokens", [220, 50257])
        self.bos_token_id = kwargs.get("bos_token_id", 50257)
        self.d_model = kwargs.get("d_model", 1024)
        self.decoder_attention_heads = kwargs.get("decoder_attention_heads", 16)
        self.decoder_ffn_dim = kwargs.get("decoder_ffn_dim", 4096)
        self.decoder_layerdrop = kwargs.get("decoder_layerdrop", 0.0)
        self.decoder_layers = kwargs.get("decoder_layers", 24)
        self.decoder_start_token_id = kwargs.get("decoder_start_token_id", 50258)
        self.dropout = kwargs.get("dropout", 0.0)
        self.encoder_attention_heads = kwargs.get("encoder_attention_heads", 16)
        self.encoder_ffn_dim = kwargs.get("encoder_ffn_dim", 4096)
        self.encoder_layerdrop = kwargs.get("encoder_layerdrop", 0.0)
        self.encoder_layers = kwargs.get("encoder_layers", 24)
        self.eos_token_id = kwargs.get("eos_token_id", 50257)
        self.forced_decoder_ids = kwargs.get("forced_decoder_ids", [[1, 50259], [2, 50359], [3, 50363]])
        self.init_std = kwargs.get("init_std", 0.02)
        self.is_encoder_decoder = kwargs.get("is_encoder_decoder", True)
        self.max_length = kwargs.get("max_length", 448)
        self.max_source_positions = kwargs.get("max_source_positions", 1500)
        self.max_target_positions = kwargs.get("max_target_positions", 448)
        self.num_hidden_layers = kwargs.get("num_hidden_layers", 24)
        self.num_mel_bins = kwargs.get("num_mel_bins", 80)
        self.pad_token_id = kwargs.get("pad_token_id", 50257)
        self.scale_embedding = kwargs.get("scale_embedding", False)
        self.suppress_tokens = kwargs.get("suppress_tokens", [
            1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873,
            893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585,
            6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553,
            16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865,
            42863, 47425, 49870, 50254, 50258, 50358, 50359, 50360, 50361, 50362
        ])
        self.torch_dtype = kwargs.get("torch_dtype", "float32")
        self.transformers_version = kwargs.get("transformers_version", "4.27.0.dev0")
        self.cache = kwargs.get("cache", True)
        self.vocab_size = kwargs.get("vocab_size", 51865)

        # Compatibility mappings
        self.n_vocab = self.vocab_size
        self.n_mels = self.num_mel_bins
        self.n_audio_ctx = self.max_source_positions
        self.n_audio_state = self.d_model
        self.n_audio_head = self.encoder_attention_heads
        self.n_audio_layer = self.encoder_layers
        self.n_text_ctx = self.max_target_positions
        self.n_text_state = self.d_model
        self.n_text_head = self.decoder_attention_heads
        self.n_text_layer = self.decoder_layers
        self.dropout_rate = self.dropout

        super().__init__(
            pad_token_id=kwargs.get("pad_token_id", 50256),
            bos_token_id=kwargs.get("bos_token_id", 50256),
            eos_token_id=kwargs.get("eos_token_id", 50256),
            is_encoder_decoder=kwargs.get("is_encoder_decoder", True),
            decoder_start_token_id=kwargs.get("decoder_start_token_id", 50256),
            suppress_tokens=kwargs.get("suppress_tokens", None),
            begin_suppress_tokens=kwargs.get("begin_suppress_tokens", [220, 50256]),
            **kwargs,
        )


In [None]:
class Whisper(nn.Module):
    def __init__(self, config: WhisperConfig):
        super().__init__()
        self.config = config
        self.n_vocab = config.n_vocab
        self.generation_config = GenerationConfig()
        self.encoder = AudioEncoder(
            self.config.window_size,
            self.config.n_mels,
            self.config.n_audio_ctx,
            self.config.n_audio_state,
            self.config.n_audio_head,
            self.config.n_audio_layer,
            self.config.dropout_rate,
            self.config.gradient_checkpointing,
            self.config.dynamic_conv,
            self.config.hybrid_attention,
            self.config.biased_attention,
            self.config.augmented_memory,
            self.config.group_norm,
            self.config.use_rotation_dynamics,
            self.config.use_dynamic_attention_integration 
        )
        self.decoder = TextDecoder(
            self.config.n_vocab,
            self.config.n_text_ctx,
            self.config.n_text_state,
            self.config.n_text_head,
            self.config.n_text_layer,
            self.config.dropout_rate,
            self.config.dynamic_conv,
            self.config.hybrid_attention,
            self.config.biased_attention,
            self.config.augmented_memory,
            self.config.group_norm,
            self.config.use_rotation_dynamics, 
            self.config.use_dynamic_attention_integration
        )

        all_heads = torch.zeros(
            self.config.n_text_layer, self.config.n_text_head, dtype=torch.bool
        )
        all_heads[self.config.n_text_layer // 2:] = True
        self.register_buffer('alignment_heads', all_heads.to_sparse(), persistent=False)

    @staticmethod
    def shift_tokens_right(input_ids: torch.Tensor, pad_token_id, decoder_start_token_id) -> torch.Tensor:
        shifted_input_ids = input_ids.new_zeros(input_ids.shape)
        shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
        shifted_input_ids[:, 0] = decoder_start_token_id

        if pad_token_id is None:
            raise ValueError("pad_token_id has to be defined.")
        shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
        return shifted_input_ids

    def forward(self, input_features: torch.Tensor, labels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        encoded_features = self.encoder(input_features)
        
        if labels is not None:
            decoder_input_ids = self.shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
        else:
            decoder_input_ids = None
        
        logits = self.decoder(decoder_input_ids, encoded_features)

        loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
            logits = logits.view(-1, self.config.n_vocab)  
            labels = labels.view(-1).long()  
            loss = loss_fct(logits, labels)  
        return {"loss": loss, "logits": logits, "input_features": encoded_features}

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def is_multilingual(self):
        return self.config.n_vocab >= len(tokenizer)

    @property
    def num_languages(self):
        return self.config.n_vocab - (len(tokenizer)-100)

    def set_alignment_heads(self, dump: bytes):
        array = np.frombuffer(
            gzip.decompress(base64.b85decode(dump)), dtype=bool
        ).copy()
        mask = torch.from_numpy(array).reshape(
            self.config.n_text_layer, self.config.n_text_head
        )
        self.register_buffer('alignment_heads', mask.to_sparse(), persistent=False)

    def embed_aud(self, input_features: torch.Tensor):
        return self.encoder(input_features)

    def logits(self, input_ids: torch.Tensor, input_features: torch.Tensor):
        return self.decoder(input_ids, input_features)

    def install_kv_cache_hooks(self, cache = None):
        cache = {**cache} if cache is not None else {}
        hooks = []

        def save_to_cache(module, _, output):
            if module not in cache or output.shape[1] > self.config.n_text_ctx:
                cache[module] = output
            else:
                cache[module] = torch.cat([cache[module], output], dim=1).detach()
            return cache[module]

        def install_hooks(layer: nn.Module):
            if isinstance(layer, MultiHeadAttention):
                hooks.append(layer.key.register_forward_hook(save_to_cache))
                hooks.append(layer.value.register_forward_hook(save_to_cache))

        self.decoder.apply(install_hooks)
        return cache, hooks

    def set_input_embeddings(self, new_embeddings: torch.nn.Embedding):
        self.decoder.token_embedding = new_embeddings

    def get_input_embeddings(self):
        return self.decoder.token_embedding

    def resize_token_embeddings(self, new_num_tokens):
        old_embeddings = self.get_input_embeddings()
        old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
        new_embeddings = torch.nn.Embedding(new_num_tokens, old_embedding_dim)
    
        new_embeddings.weight.data[:old_num_tokens, :] = old_embeddings.weight.data
        self.set_input_embeddings(new_embeddings)
        self.config.n_vocab = new_num_tokens

    detect_language = detect_language_function
    transcribe = transcribe_function
    decode = decode_function

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, ignore_mismatched_sizes=False, **kwargs):
        config = WhisperConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
        model = cls(config, **kwargs)
        state_dict = torch.load(os.path.join(pretrained_model_name_or_path, "pytorch_model.bin"), map_location="cpu")
        
        if ignore_mismatched_sizes:
            model_state_dict = model.state_dict()
            for key in state_dict.keys():
                if key in model_state_dict and state_dict[key].size() != model_state_dict[key].size():
                    print(f"Skipping loading of {key} due to size mismatch")
                    state_dict[key] = model_state_dict[key]
        
        model.load_state_dict(state_dict, strict=not ignore_mismatched_sizes)
        return model

    def get_encoder(self):
        return self.encoder

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {'input_features': input_ids}

    def _prepare_decoder_input_ids_for_generation(self, batch_size, decoder_start_token_id=None, bos_token_id=None):
        return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * self.config.decoder_start_token_id

    def can_generate(self):
        return True
    
    def generate(self, inputs, **kwargs):
        encoder_outputs = self.encoder(inputs)
        decoder_input_ids = torch.zeros((inputs.size(0), 1), dtype=torch.long, device=inputs.device)
        outputs = self.decoder(decoder_input_ids, encoder_outputs)
        return outputs.argmax(dim=-1)

    def save_pretrained(self, save_directory):
        self.config.save_pretrained(save_directory)
        torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))


In [None]:
from datetime import datetime

log_dir = os.path.join('./output/logs', datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
os.makedirs(log_dir, exist_ok=True)

old_model = "D:/proj3/checkpoints/whisper3"
feature_extractor = WhisperFeatureExtractor.from_pretrained(old_model) 
tokenizer = WhisperTokenizerFast.from_pretrained(old_model)
proccesor = WhisperProcessor.from_pretrained(old_model, tokenizer=tokenizer, feature_extractor=feature_extractor, local_files_only=True)


config = WhisperConfig(
    n_mels=80,
    n_audio_ctx=1500,
    n_audio_state=768, 
    n_audio_head=12, 
    n_audio_layer=6, 
    n_vocab=len(tokenizer),#55116,
    n_text_ctx=448,
    n_text_state=768, 
    n_text_head=12, 
    n_text_layer=6,
    dropout_rate=0.01,
    max_source_positions=1500,
    window_size=40,
    gradient_checkpointing=False,
    hybrid_attention=False,
    dynamic_conv=False,
    biased_attention=False,
    augmented_memory=False,
    group_norm=False,
    use_rotation_dynamics=True,
    use_dynamic_attention_integration=False,
)

model = Whisper(config).cuda()
# model.resize_token_embeddings(len(tokenizer))
optimizer = transformers.Adafactor(model.parameters(), 
                                clip_threshold=0.99, 
                                weight_decay=0.025, 
                                scale_parameter=True, 
                                relative_step=False, 
                                warmup_init=False, 
                                lr=2.25e-3)

scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)

def get_adamax_optimizer(model, learning_rate=0.001, weight_decay=0.0):
    return Adamax(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

global printed_attention_mechanisms_once
printed_attention_mechanisms_once = False

import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

dataset = load_dataset("audiofolder", data_dir="D:/proj/datasets/gf_1", cache_dir = "D:/hf")['train'].to_iterable_dataset(num_shards=20).filter(lambda x: len(x['sentence']) > 0).map(lambda x: {"sentence": neologdn.normalize(x['sentence'], repeat=1)}).shuffle(seed=42, buffer_size=1000)

def prepare_dataset(batch):
    audio = batch["audio"]
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

dataset = dataset.map(prepare_dataset).select_columns(["input_features", "labels"])
test , train = dataset.take(100), dataset.skip(100)

metric = evaluate.load("cer")
wakati = MeCab.Tagger("-Owakati")

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    proccesor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.proccesor.feature_extractor.pad(input_features, return_tensors="pt")
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.proccesor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]
        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(proccesor=proccesor, decoder_start_token_id=model.config.decoder_start_token_id)

cer_metric = evaluate.load("cer")

class CustomTensorBoardCallback(TrainerCallback):
    def __init__(self, tb_writer):
        self.tb_writer = tb_writer

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            for key, value in logs.items():
                self.tb_writer.add_scalar(key, value, state.global_step)
            if 'predictions' in logs and 'label_ids' in logs:
                cer = compute_cer(logs['predictions'], logs['label_ids'])
                self.tb_writer.add_scalar("cer", cer, state.global_step)

def compute_cer(predictions, label_ids):
    pred_str = proccesor.tokenizer.batch_decode(predictions, skip_special_tokens=True)
    label_str = proccesor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    return 100 * cer_metric.compute(predictions=pred_str, references=label_str)

def compute_metrics(pred):
    wakati = MeCab.Tagger("-Owakati")   
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = proccesor.tokenizer.pad_token_id
    
    pred_str = proccesor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = proccesor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    
    cer = 100 * cer_metric.compute(predictions=pred_str, references=label_str)
    
    pred_flat = np.argmax(pred_ids, axis=2).flatten()
    labels_flat = label_ids.flatten()
    mask = labels_flat != proccesor.tokenizer.pad_token_id

    accuracy = accuracy_score(labels_flat[mask], pred_flat[mask])
    precision = precision_score(labels_flat[mask], pred_flat[mask], average='weighted')
    recall = recall_score(labels_flat[mask], pred_flat[mask], average='weighted')
    f1 = f1_score(labels_flat[mask], pred_flat[mask], average='weighted')
    
    return {
        "cer": cer,
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }

tb_writer = SummaryWriter(log_dir=log_dir)
tb_callback = CustomTensorBoardCallback(tb_writer)

out = log_dir
training_args = Seq2SeqTrainingArguments(
    output_dir=out,
    overwrite_output_dir=True,
    per_device_train_batch_size=1, 
    gradient_accumulation_steps=1,
    learning_rate=2.4e-5,
    warmup_steps=5,
    num_train_epochs=1,
    max_steps=100,
    tf32=True,
    bf16=True,
    save_steps=1000,
    logging_steps=5,
    logging_dir=out+"/log_2",
    logging_strategy="steps",
    report_to=["tensorboard"],
    push_to_hub=False,
    remove_unused_columns=False,
    label_names=["labels"],
    hub_private_repo=True,
    metric_for_best_model="cer",
    predict_with_generate=False,
    greater_is_better=False,
    generation_max_length=128,
    optim="adafactor",
    weight_decay=0.0025,
    disable_tqdm=False,
    save_total_limit=2,  
    # torch_empty_cache_steps=1,
    gradient_checkpointing_kwargs={"reentrant": False},
    # max_grad_norm=0.9,
)

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.cuda.empty_cache()
torch.cuda.set_device(0)

train = train.shuffle(seed=42)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train,
    eval_dataset=test,
    data_collator=data_collator,
    tokenizer=proccesor.feature_extractor,
    # optimizers=(optimizer, scheduler), # optimizers=(get_adamax_optimizer, None)
    callbacks=[tb_callback]
)
