In [None]:
import sys, os
import random
import math
import time
import logging
import warnings
import csv
import base64
import gzip
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import whisper
import neologdn
import evaluate
import MeCab
import torch.utils.checkpoint as checkpoint
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Dict, Iterable, Optional, Tuple
from torch.utils.data import Dataset, DataLoader, Subset
from torch.amp import autocast, GradScaler
from torch.utils.tensorboard import SummaryWriter
from torch.optim import lr_scheduler
from torch.profiler import profile, record_function, ProfilerActivity
from sklearn.model_selection import train_test_split
from transformers import WhisperTokenizerFast
from tqdm import tqdm
from torch import Tensor
import torchaudio.transforms as at
import transformers

transformers.utils.logging.set_verbosity_error()

try:
    from torch.nn.functional import scaled_dot_product_attention
    SDPA_AVAILABLE = True
except (ImportError, RuntimeError, OSError):
    scaled_dot_product_attention = None
    SDPA_AVAILABLE = False

from whisper.decoding import decode as decode_function
from whisper.decoding import detect_language as detect_language_function
from whisper.transcribe import transcribe as transcribe_function

warnings.filterwarnings(action="ignore")
warnings.warn = lambda *args, **kwargs: None

In [None]:
@dataclass
class ModelDimensions:
    n_mels: int
    n_audio_ctx: int
    n_audio_state: int
    n_audio_head: int
    n_audio_layer: int
    n_vocab: int
    n_text_ctx: int
    n_text_state: int
    n_text_head: int
    n_text_layer: int

class LayerNorm(nn.LayerNorm):
    def forward(self, x: Tensor) -> Tensor:
        return super().forward(x.float()).type(x.dtype)

class GroupNorm(nn.Module): #GroupNormRMSNorm
    def __init__(self, num_groups, num_channels, eps=1e-6):
        super(GroupNorm, self).__init__()
        self.num_groups = num_groups
        self.eps = eps
        self.g = nn.Parameter(torch.ones(num_channels))
        self.b = nn.Parameter(torch.zeros(num_channels))
        self.scale = nn.Parameter(torch.ones(num_channels))

        self.group_norm = nn.GroupNorm(num_groups, num_channels, eps=eps)

    def forward(self, x):
        x = x.permute(0, 2, 1) 
        x = self.group_norm(x)
        norm_x = torch.norm(x, dim=-1, keepdim=True)
        rms_x = x / norm_x
        scaled_x = rms_x * self.scale.view(1, -1, 1) 
        x = self.g.view(1, -1, 1) * scaled_x + self.b.view(1, -1, 1)
        x = x.permute(0, 2, 1)  
        return x

class Linear(nn.Linear):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight) 
        if self.bias is not None:
            nn.init.zeros_(self.bias)

    def forward(self, x: Tensor) -> Tensor:
        weight = self.weight.to(x.dtype)
        bias = None if self.bias is None else self.bias.to(x.dtype)
        return F.linear(x, weight, bias)

class Conv1d(nn.Conv1d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')
        if self.bias is not None:
            nn.init.zeros_(self.bias)

    def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
        weight = self.weight.to(x.dtype)
        bias = None if self.bias is None else self.bias.to(x.dtype)
        return super()._conv_forward(x, weight, bias)

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

    def rotate_queries_or_keys(self, x):
        sinusoid_inp = torch.einsum('i , j -> i j', torch.arange(x.shape[1], device=x.device), self.inv_freq) 
        sin = sinusoid_inp.sin()[None, :, None, :] 
        cos = sinusoid_inp.cos()[None, :, None, :]
        x1, x2 = x[..., ::2], x[..., 1::2]
        x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
        return x

class SinusoidalFeatures:
    def __init__(self, n_ctx, n_state):
        self.n_ctx = n_ctx
        self.n_state = n_state
        self.features = self.sinusoidal_features(n_ctx, n_state)

    @staticmethod
    def sinusoidal_features(n_ctx, n_state):
        position = torch.arange(0, n_ctx, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, n_state, 2).float() * -(math.log(10000.0) / n_state))
        features = torch.zeros(n_ctx, n_state)
        features[:, 0::2] = torch.sin(position * div_term)
        features[:, 1::2] = torch.cos(position * div_term)
        return features

    def __call__(self):
        return self.features

class LearnedSinusoidalEmbeddings(nn.Module): 
    def __init__(self, n_ctx, n_state, gradient_checkpointing=False):
        super().__init__()
        self.n_ctx = n_ctx
        self.n_state = n_state
        self.gradient_checkpointing = gradient_checkpointing

        sinusoidal_embeddings = SinusoidalFeatures(n_ctx, n_state)()
        self.positional_embeddings = nn.Parameter(sinusoidal_embeddings)

    def forward(self, positions):
        if self.gradient_checkpointing:
            position_embeddings = checkpoint.checkpoint(lambda x: self.positional_embeddings[x], positions)
        else:
            position_embeddings = self.positional_embeddings[positions]

        position_embeddings = F.normalize(position_embeddings, p=2, dim=-1)  
        return position_embeddings

class BiasedCrossAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int, dropout_rate=0.1):
        super().__init__()
        self.n_head = n_head
        self.n_state = n_state
        self.head_dim = n_state // n_head

        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)

        self.bias = nn.Parameter(torch.zeros(n_head, self.head_dim))
        self.dropout = nn.Dropout(dropout_rate)
        self.norm = nn.LayerNorm(n_state)

    def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
        q = self.query(q).view(q.shape[0], q.shape[1], self.n_head, self.head_dim)
        k = self.key(k).view(k.shape[0], k.shape[1], self.n_head, self.head_dim)
        v = self.value(v).view(v.shape[0], v.shape[1], self.n_head, self.head_dim)

        qk = (q @ k.transpose(-2, -1)) / self.head_dim ** 0.5 + self.bias
        if mask is not None:
            qk += mask

        w = F.softmax(qk, dim=-1)
        w = self.dropout(w)

        out = (w @ v).permute(0, 2, 1, 3).reshape(q.shape[0], q.shape[1], -1)
        out = self.norm(self.out(out) + q.view(q.shape[0], q.shape[1], -1))
        return out

class AugmentedMemory(nn.Module):
    def __init__(self, n_state: int, memory_size: int, n_head: int, dropout_rate=0.1):
        super().__init__()
        self.memory = nn.Parameter(torch.zeros(memory_size, n_state))
        self.n_head = n_head
        self.head_dim = n_state // n_head

        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.memory_update = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)

        self.dropout = nn.Dropout(dropout_rate)
        self.norm = nn.LayerNorm(n_state)

    def forward(self, x: Tensor):
        bsz, seq_len, _ = x.size()

        q = self.query(x).view(bsz, seq_len, self.n_head, self.head_dim)
        k = self.key(self.memory).view(self.memory.size(0), self.n_head, self.head_dim)
        v = self.value(self.memory).view(self.memory.size(0), self.n_head, self.head_dim)

        qk = (q @ k.transpose(-2, -1)) / self.head_dim ** 0.5
        w = F.softmax(qk, dim=-1)
        w = self.dropout(w)

        memory_out = (w @ v).permute(0, 2, 1, 3).reshape(bsz, seq_len, -1)
        memory_out = self.norm(self.out(memory_out) + x)

        new_memory = self.memory_update(memory_out)
        self.memory = self.memory + new_memory.mean(dim=0)
        return memory_out

class DynamicConvAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int, kernel_size=3, dropout_rate=0.1, use_GroupNorm=False):
        super().__init__()
        self.n_state = n_state
        self.n_head = n_head
        self.kernel_size = kernel_size
        self.conv = nn.Conv1d(n_state, n_state, kernel_size, padding=kernel_size//2, groups=n_head)
        self.dropout = nn.Dropout(dropout_rate)
        
        if use_GroupNorm:
            self.norm = nn.GroupNorm(num_groups=n_head, num_channels=n_state)
        else:
            self.norm = nn.LayerNorm(n_state)
        
        self.out = nn.Linear(n_state, n_state)

    def forward(self, x: Tensor):

        batch_size, seq_len, embed_dim = x.size()
        if embed_dim != self.n_state:
            raise ValueError(f"Expected embed_dim of {self.n_state}, but got {embed_dim}")

        x = x.permute(0, 2, 1) 
        conv_out = self.conv(x)
        conv_out = conv_out.permute(0, 2, 1) 
        conv_out = self.norm(conv_out)
        conv_out = self.dropout(conv_out)
        return self.out(conv_out) + x.permute(0, 2, 1) 

class HybridAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int, window_size=1.0, dropout_rate=0.1, use_GroupNorm=False):
        super().__init__()
        self.local_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)
        self.global_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)
        
        self.use_GroupNorm = use_GroupNorm
        if self.use_GroupNorm:
            self.ln_local = GroupNorm(num_groups=1, num_channels=n_state)
            self.ln_global = GroupNorm(num_groups=1, num_channels=n_state)
        else:
            self.ln_local = LayerNorm(n_state)
            self.ln_global = LayerNorm(n_state)

        self.dropout = nn.Dropout(dropout_rate)
        self.window_size = window_size 

    def forward(self, x: Tensor):
        x_local = self.ln_local(x)
        x_global = self.ln_global(x)

        x_local = x_local.permute(1, 0, 2)
        x_global = x_global.permute(1, 0, 2)

        local_out = self.sliding_window_attention(x_local)
        global_out, _ = self.global_attn(x_global, x_global, x_global)

        combined_out = local_out + global_out
        combined_out = combined_out.permute(1, 0, 2)

        return self.dropout(combined_out)

    def sliding_window_attention(self, x):
        seq_len, batch_size, n_state = x.size()
        window_size = int(self.window_size)

        output = torch.zeros_like(x, device=x.device, dtype=x.dtype)
        for i in range(seq_len):
            start = max(0, i - window_size)
            end = min(seq_len, i + window_size + 1)
            query = x[i:i+1, :, :]
            key = x[start:end, :, :]
            value = x[start:end, :, :]

            attn_output, _ = self.local_attn(query, key, value)
            output[i:i+1, :, :] = attn_output

        return output
    
class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, use_dynamic_conv: bool = False, use_hybrid_attention: bool = False,
                 use_biased_attention: bool = False, use_augmented_memory: bool = False,
                 cross_attention: bool = False, dropout_rate=0.1, gradient_checkpointing=False, use_GroupNorm=False, window_size: int = 5):
        super().__init__()

        self.use_dynamic_conv = use_dynamic_conv
        self.use_hybrid_attention = use_hybrid_attention
        self.use_biased_attention = use_biased_attention
        self.use_augmented_memory = use_augmented_memory

        if self.use_dynamic_conv:
            self.attn = DynamicConvAttention(n_state, n_head, dropout_rate=dropout_rate)
        elif self.use_hybrid_attention:
            self.attn = HybridAttention(n_state, n_head, window_size=window_size, dropout_rate=dropout_rate, use_GroupNorm=use_GroupNorm)
        elif self.use_biased_attention:
            self.attn = BiasedCrossAttention(n_state, n_head, dropout_rate=dropout_rate)
        elif self.use_augmented_memory:
            self.attn = AugmentedMemory(n_state, memory_size=512, n_head=n_head, dropout_rate=dropout_rate)
        else:
            self.attn = MultiHeadAttention(n_state, n_head, dropout_rate=dropout_rate, gradient_checkpointing=gradient_checkpointing, use_GroupNorm=use_GroupNorm)

        self.attn_ln = GroupNorm(num_groups=4, num_channels=n_state) if use_GroupNorm else LayerNorm(n_state)
        self.cross_attention = cross_attention
        if self.cross_attention:
            self.cross_attn = MultiHeadAttention(n_state, n_head, dropout_rate=dropout_rate, gradient_checkpointing=gradient_checkpointing, use_GroupNorm=use_GroupNorm)
            self.cross_attn_ln = GroupNorm(num_groups=4, num_channels=n_state) if use_GroupNorm else LayerNorm(n_state)

        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            Linear(n_state, n_mlp),
            GroupNorm(num_groups=4, num_channels=n_mlp) if use_GroupNorm else LayerNorm(n_mlp),
            nn.GELU(),
            nn.Dropout(p=dropout_rate),
            Linear(n_mlp, n_state)
        )
        self.mlp_ln = GroupNorm(num_groups=4, num_channels=n_state) if use_GroupNorm else LayerNorm(n_state)
        self.gradient_checkpointing = gradient_checkpointing

    def forward(self, x: Tensor, xa: Optional[Tensor] = None,
                mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None):

        attn_input = self.attn_ln(x)

        if self.gradient_checkpointing:
            if any([self.use_hybrid_attention, self.use_dynamic_conv, self.use_biased_attention, self.use_augmented_memory]):
                attn_out = x + checkpoint.checkpoint(self.attn, attn_input)
            else:
                attn_out = x + checkpoint.checkpoint(self.attn, attn_input, mask, kv_cache)[0]
        else:
            if any([self.use_hybrid_attention, self.use_dynamic_conv, self.use_biased_attention, self.use_augmented_memory]):
                attn_out = x + self.attn(attn_input)
            else:
                attn_out = x + self.attn(attn_input, mask=mask, kv_cache=kv_cache)[0]

        if self.cross_attention and xa is not None:
            cross_attn_input = self.cross_attn_ln(attn_out)

            if self.gradient_checkpointing:
                attn_out = attn_out + checkpoint.checkpoint(self.cross_attn, cross_attn_input, xa, kv_cache)[0]
            else:
                attn_out = attn_out + self.cross_attn(cross_attn_input, xa, kv_cache=kv_cache)[0]

        mlp_input = self.mlp_ln(attn_out)

        if self.gradient_checkpointing:
            mlp_out = attn_out + checkpoint.checkpoint(self.mlp, mlp_input)
        else:
            mlp_out = attn_out + self.mlp(mlp_input)

        return mlp_out

class MultiHeadAttention(nn.Module):
    use_sdpa = True

    def __init__(self, n_state: int, n_head: int, dropout_rate=0.1, gradient_checkpointing=False, use_GroupNorm=False):
        super().__init__()
        self.n_head = n_head
        self.n_state = n_state
        self.head_dim = n_state // n_head

        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)

        self.rotary_emb = RotaryEmbedding(dim=self.head_dim) 
        self.temperature = nn.Parameter(torch.ones(1) * (self.head_dim ** -0.5)) 
        self.dropout = nn.Dropout(dropout_rate) 

        self.use_GroupNorm = use_GroupNorm
        self.attn_ln = GroupNorm(num_groups=1, num_channels=n_state) if use_GroupNorm else LayerNorm(n_state)
        
        self.gradient_checkpointing = gradient_checkpointing

    def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None):
        x_norm = self.attn_ln(x)

        q = self.query(x_norm)
        k = self.key(x_norm if xa is None else xa)
        v = self.value(x_norm if xa is None else xa)

        if kv_cache is not None and self.key in kv_cache:
            k = kv_cache[self.key]
            v = kv_cache[self.value]

        q = q.view(q.shape[0], q.shape[1], self.n_head, -1)
        k = k.view(k.shape[0], k.shape[1], self.n_head, -1)

        q = self.rotary_emb.rotate_queries_or_keys(q)
        k = self.rotary_emb.rotate_queries_or_keys(k)

        q = q.view(q.shape[0], q.shape[1], -1)
        k = k.view(k.shape[0], k.shape[1], -1)

        wv, qk = self.qkv_attention(q, k, v, mask)

        return self.out(wv) + x, qk 

    def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
        n_batch, n_ctx, n_state = q.shape
        scale = self.temperature

        q = q.view(n_batch, n_ctx, self.n_head, self.head_dim).permute(0, 2, 1, 3)
        k = k.view(n_batch, k.shape[1], self.n_head, self.head_dim).permute(0, 2, 1, 3)
        v = v.view(n_batch, v.shape[1], self.n_head, self.head_dim).permute(0, 2, 1, 3)

        if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
            a = scaled_dot_product_attention(q, k, v, is_causal=mask is not None and n_ctx > 1)
            out = a.permute(0, 2, 1, 3).reshape(n_batch, n_ctx, n_state)
            qk = None
        else:
            qk = (q * scale) @ (k.transpose(-2, -1) * scale)
            if mask is not None:
                qk += mask[:n_ctx, :n_ctx]
            qk = qk.float()

            w = F.softmax(qk, dim=-1).to(q.dtype)
            w = self.dropout(w)
            out = (w @ v).permute(0, 2, 1, 3).reshape(n_batch, n_ctx, n_state)
            qk = qk.detach()
        
        return out, qk

class AudioEncoder(nn.Module):
    main_input_name = "input_features"

    def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int,
                 dropout_rate=0.1, gradient_checkpointing=False, use_dynamic_conv: bool = False, use_hybrid_attention: bool = False, use_biased_attention: bool = False, use_augmented_memory: bool = False, use_GroupNorm=False):
        super().__init__()
        self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
        self.dropout = nn.Dropout(dropout_rate)
        self.gradient_checkpointing = gradient_checkpointing
        self.use_GroupNorm = use_GroupNorm

        self.blocks = nn.ModuleList([
            ResidualAttentionBlock(
                n_state, n_head,
                use_dynamic_conv=use_dynamic_conv,
                use_hybrid_attention=use_hybrid_attention,
                use_biased_attention=use_biased_attention,
                use_augmented_memory=use_augmented_memory,
                dropout_rate=dropout_rate,
                gradient_checkpointing=gradient_checkpointing,
                use_GroupNorm=use_GroupNorm
            )
            for _ in range(n_layer)
        ])
        
        if self.use_GroupNorm:
            self.ln_post = GroupNorm(num_groups=1, num_channels=n_state)
        else:
            self.ln_post = LayerNorm(n_state)

    def forward(self, x: torch.Tensor):
        x = F.gelu(self.conv1(x))
        x = self.dropout(x)
        x = F.gelu(self.conv2(x))
        x = self.dropout(x)
        x = x.permute(0, 2, 1)

        for block in self.blocks:
            if self.gradient_checkpointing:
                x = checkpoint.checkpoint(block, x)
            else:
                x = block(x)

        x = self.ln_post(x)
        return x

class TextDecoder(nn.Module):
    def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int,
                 dropout_rate=0.1, gradient_checkpointing=False, use_dynamic_conv: bool = False, use_hybrid_attention: bool = False, use_biased_attention: bool = False, use_augmented_memory: bool = False, use_GroupNorm=False):
        super().__init__()
        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, n_state, gradient_checkpointing=gradient_checkpointing)
        self.gradient_checkpointing = gradient_checkpointing
        self.use_GroupNorm = use_GroupNorm

        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList([
            ResidualAttentionBlock(
                n_state, n_head,
                use_dynamic_conv=use_dynamic_conv,
                use_hybrid_attention=use_hybrid_attention,
                use_biased_attention=use_biased_attention,
                use_augmented_memory=use_augmented_memory,
                cross_attention=True,
                dropout_rate=dropout_rate,
                gradient_checkpointing=gradient_checkpointing,
                use_GroupNorm=use_GroupNorm
            )
            for _ in range(n_layer)
        ])
        
        if self.use_GroupNorm:
            self.ln_post = GroupNorm(num_groups=1, num_channels=n_state)
        else:
            self.ln_post = LayerNorm(n_state)

        mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
        self.register_buffer('mask', mask, persistent=False)

    def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
        positions = torch.arange(x.shape[1], device=x.device)
        pos_emb = self.positional_embedding(positions).unsqueeze(0) 
        x = self.token_embedding(x) + pos_emb
        x = x.to(xa.dtype)

        for block in self.blocks:
            if self.gradient_checkpointing:
                x = checkpoint.checkpoint(block, x, xa, self.mask, kv_cache)
            else:
                x = block(x, xa, self.mask, kv_cache)

        x = self.ln_post(x)
        logits = (x @ self.token_embedding.weight.to(x.dtype).T).float()
        return logits


In [None]:

class Whisper(nn.Module):
    def __init__(self, dims: ModelDimensions, n_vocab, dropout_rate=0.1, gradient_checkpointing=False,
                 use_dynamic_conv=False, use_hybrid_attention=False, use_biased_attention=False, use_augmented_memory=False, use_GroupNorm=False):
        super().__init__()
        self.dims = dims
        self.n_vocab = n_vocab
        self.encoder = AudioEncoder(
            self.dims.n_mels,
            self.dims.n_audio_ctx,
            self.dims.n_audio_state,
            self.dims.n_audio_head,
            self.dims.n_audio_layer,
            dropout_rate=dropout_rate,
            gradient_checkpointing=gradient_checkpointing,
            use_dynamic_conv=use_dynamic_conv,
            use_hybrid_attention=use_hybrid_attention,
            use_biased_attention=use_biased_attention,
            use_augmented_memory=use_augmented_memory,
            use_GroupNorm=use_GroupNorm
        )
        self.decoder = TextDecoder(
            self.n_vocab,
            self.dims.n_text_ctx,
            self.dims.n_text_state,
            self.dims.n_text_head,
            self.dims.n_text_layer,
            dropout_rate=dropout_rate,
            #gradient_checkpointing=gradient_checkpointing, # needs debugging
            use_dynamic_conv=use_dynamic_conv,
            use_hybrid_attention=use_hybrid_attention,
            use_biased_attention=use_biased_attention,
            use_augmented_memory=use_augmented_memory,
            use_GroupNorm=use_GroupNorm
        )
        all_heads = torch.zeros(
            self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
        )
        all_heads[self.dims.n_text_layer // 2:] = True
        self.register_buffer('alignment_heads', all_heads.to_sparse(), persistent=False)

    def forward(self, audio_features: torch.Tensor, input_ids: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        audio_features_encoded = self.encoder(audio_features)
        logits = self.decoder(input_ids, audio_features_encoded)
        
        loss = None
        if input_ids is not None:
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
            logits = logits.view(-1, self.n_vocab)
            input_ids = input_ids.view(-1).long()
            loss = loss_fct(logits, input_ids)
        
        return {"loss": loss, "logits": logits, "audio_features_encoded": audio_features_encoded}

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def is_multilingual(self):
        return self.dims.n_vocab >= 51865

    @property
    def num_languages(self):
        return self.dims.n_vocab - 51765 - int(self.is_multilingual)

    def set_alignment_heads(self, dump: bytes):
        array = np.frombuffer(
            gzip.decompress(base64.b85decode(dump)), dtype=bool
        ).copy()
        mask = torch.from_numpy(array).reshape(
            self.dims.n_text_layer, self.dims.n_text_head
        )
        self.register_buffer('alignment_heads', mask.to_sparse(), persistent=False)

    def embed_audio(self, audio_features: torch.Tensor):
        return self.encoder(audio_features)

    def logits(self, input_ids: torch.Tensor, audio_features: torch.Tensor):
        return self.decoder(input_ids, audio_features)

    def install_kv_cache_hooks(self, cache: Optional[dict] = None):
        cache = {**cache} if cache is not None else {}
        hooks = []

        def save_to_cache(module, _, output):
            if module not in cache or output.shape[1] > self.dims.n_text_ctx:
                cache[module] = output
            else:
                cache[module] = torch.cat([cache[module], output], dim=1).detach()
            return cache[module]

        def install_hooks(layer: nn.Module):
            if isinstance(layer, MultiHeadAttention):
                hooks.append(layer.key.register_forward_hook(save_to_cache))
                hooks.append(layer.value.register_forward_hook(save_to_cache))

        self.decoder.apply(install_hooks)
        return cache, hooks

    detect_language = detect_language_function
    transcribe = transcribe_function
    decode = decode_function

In [None]:
def load_wave(wave_path, sample_rate: int = 16000) -> torch.Tensor:
    waveform, sr = torchaudio.load(wave_path, normalize=True)
    if sample_rate != sr:
        waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
    return waveform

class CustomAudioDataset(Dataset):
    def __init__(self, csv_file, audio_dir, tokenizer, sample_rate=16000):
        self.audio_dir = audio_dir
        self.tokenizer = tokenizer
        self.sample_rate = sample_rate
        self.samples = []

        with open(csv_file, 'r', encoding='utf-8') as f:
            reader = csv.reader(f)
            next(reader) 
            for row in reader:
                audio_path, label = row[0], row[1]
                self.samples.append((audio_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        audio_path, label = self.samples[idx]
        audio = f'{self.audio_dir}/{audio_path}'
        return {
            'audio_features': audio,
            'labels': label,
            'input_ids': label 
        }

class WhisperDataCollatorWithPadding:
    def __init__(self, tokenizer, sample_rate=16000, n_mels=80):
        self.tokenizer = tokenizer
        self.sample_rate = sample_rate
        self.n_mels = n_mels

    def __call__(self, features):
        audio_features, dec_input_ids, labels = [], [], []

        for f in features:
            audio_path = f['audio_features']
            audio, _ = torchaudio.load(audio_path, normalize=True)
            audio = whisper.pad_or_trim(audio.flatten())
            audio = whisper.log_mel_spectrogram(audio, n_mels=self.n_mels)
            
            encoded_input = self.tokenizer.encode(f['labels'])
            encoded_label = self.tokenizer.encode(f['labels'])

            dec_input_ids.append([self.tokenizer.bos_token_id] + encoded_input)
            labels.append(encoded_label + [self.tokenizer.eos_token_id])
            audio_features.append(audio)

        audio_features = torch.stack(audio_features)

        input_lengths = [len(ids) for ids in dec_input_ids]
        label_lengths = [len(lab) for lab in labels]
        max_len = max(input_lengths + label_lengths)

        dec_input_ids = [np.pad(ids, (0, max_len - len(ids)), 'constant', constant_values=self.tokenizer.pad_token_id) for ids in dec_input_ids]
        labels = [np.pad(lab, (0, max_len - len(lab)), 'constant', constant_values=-100) for lab in labels]

        batch = {
            "input_ids": dec_input_ids,
            "labels": labels,
            "audio_features": audio_features
        }
        batch = {k: torch.tensor(v, requires_grad=False) for k, v in batch.items()}
        return batch

def compute_metrics(pred):
    pred_ids = pred["predictions"]
    label_ids = pred["label_ids"]
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    cer = 100 * metrics_cer.compute(predictions=pred_str, references=label_str)
    return {"cer": cer}

checkpoint_dir = 'D:/proj3/ckpt/'
os.makedirs(checkpoint_dir, exist_ok=True)
log_dir = 'D:/proj3/ckpt/logs/'
os.makedirs(log_dir, exist_ok=True)

writer = SummaryWriter(log_dir)

logging.basicConfig(
    filename=os.path.join(log_dir, 'training.log'), 
    filemode='w', 
    format='%(asctime)s - %(levelname)s - %(message)s', 
    level=logging.INFO
)

In [None]:
def train_and_evaluate(model, train_loader, eval_loader, optimizer, loss_fn, scheduler, num_epochs=1, max_steps=None, device='cuda', accumulation_steps=1, clear_cache=True, log_interval=10, eval_interval=20, save_interval=100, checkpoint_dir="checkpoint_dir", log_dir="log_dir"):
    model.to(device)
    global_step = 0
    scaler = torch.amp.GradScaler()
    writer = SummaryWriter(log_dir=log_dir)

    for epoch in range(num_epochs):
        if max_steps is not None and global_step >= max_steps:
            break

        model.train()
        total_loss = 0
        optimizer.zero_grad()
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")

        for step, batch in enumerate(progress_bar):
            if max_steps is not None and global_step >= max_steps:
                break

            start_time = time.time()

            try:
                audio_features = batch['audio_features'].to(device)
                input_ids = batch['input_ids'].to(device)
                labels = batch['labels'].long().to(device)
            except KeyError as e:
                print(f"Key error: {e}. Available keys in batch: {batch.keys()}")
                continue

            with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
                with record_function("model_training"):
                    with torch.amp.autocast(device_type='cuda'):
                        audio_features_encoded = model.encoder(audio_features)
                        decoder_output = model.decoder(input_ids, audio_features_encoded)
                        logits = decoder_output.view(-1, decoder_output.size(-1))
                        loss = loss_fn(logits, labels.view(-1))
                        total_loss += loss.item()
                        loss = loss / accumulation_steps

                    scaler.scale(loss).backward()

                    if (step + 1) % accumulation_steps == 0:
                        scaler.step(optimizer)
                        scaler.update()
                        optimizer.zero_grad()

                        if clear_cache:
                            torch.cuda.empty_cache()

            global_step += 1
            end_time = time.time()
            samples_per_sec = len(batch['audio_features']) / (end_time - start_time)

            total_norm = 0
            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** (1. / 2)

            if global_step % log_interval == 0:
                writer.add_scalar('Loss/train', total_loss / (step + 1), global_step)
                writer.add_scalar('GradientNorm', total_norm, global_step)
                writer.add_scalar('LearningRate', optimizer.param_groups[0]['lr'], global_step)
                writer.add_scalar('SamplesPerSec', samples_per_sec, global_step)
                writer.add_scalar("Memory/Allocated", torch.cuda.memory_allocated(), global_step)
                writer.add_scalar("Memory/Cached", torch.cuda.memory_reserved(), global_step)
                
            if global_step % eval_interval == 0:
                model.eval()
                eval_loss = 0
                all_predictions = []
                all_labels = []
                with torch.no_grad():
                    for eval_batch in eval_loader:
                        try:
                            audio_features = eval_batch['audio_features'].to(device)
                            input_ids = eval_batch['input_ids'].to(device)
                            labels = eval_batch['labels'].long().to(device)
                        except KeyError as e:
                            print(f"Key error: {e}. Available keys in eval batch: {eval_batch.keys()}")
                            continue

                        audio_features_encoded = model.encoder(audio_features)
                        decoder_output = model.decoder(input_ids, audio_features_encoded)

                        logits = decoder_output.view(-1, decoder_output.size(-1))
                        loss = loss_fn(logits, labels.view(-1))
                        eval_loss += loss.item()

                        all_predictions.extend(torch.argmax(decoder_output, dim=-1).cpu().numpy().tolist())
                        all_labels.extend(labels.cpu().numpy().tolist())

                predictions = {
                    "predictions": np.array(all_predictions, dtype="object"),
                    "label_ids": np.array(all_labels, dtype="object")
                }

                metrics = compute_metrics(predictions)
                writer.add_scalar('Loss/eval', eval_loss / len(eval_loader), global_step)
                writer.add_scalar('CER', metrics['cer'], global_step)
                # writer.add_histogram('Predictions', torch.argmax(logits, dim=-1), global_step)
                # writer.add_histogram('Labels', labels, global_step)

                scheduler.step(eval_loss / len(eval_loader))

                sample_indices = range(min(1, len(all_predictions))) 
                for idx in sample_indices:
                    pred_str = tokenizer.decode(all_predictions[idx], skip_special_tokens=True)
                    label_str = tokenizer.decode(all_labels[idx], skip_special_tokens=True)
                    print(f"Sample {idx}: Prediction: {pred_str}, Label: {label_str}")
                    logging.info(f"Sample {idx}: Prediction: {pred_str}, Label: {label_str}")

                model.train()

            if global_step % save_interval == 0:
                checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_step_{global_step}.pt')
                torch.save(model.state_dict(), checkpoint_path)
                print(f"Model saved at step {global_step} to {checkpoint_path}")
                logging.info(f"Model saved at step {global_step} to {checkpoint_path}")

        print(f'Epoch {epoch + 1}, Loss: {total_loss / len(train_loader)}')
        logging.info(f'Epoch {epoch + 1}, Loss: {total_loss / len(train_loader)}')

    final_model_path = os.path.join(checkpoint_dir, 'final_model.pt')
    torch.save(model.state_dict(), final_model_path)
    print(f"Final model saved to {final_model_path}")
    logging.info(f"Final model saved to {final_model_path}")
    writer.close()


In [None]:
def evaluate_model(model, eval_loader, device, tokenizer):
    model.eval()
    eval_loss = 0
    all_predictions = []
    all_labels = []
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
    
    with torch.no_grad():
        for batch in eval_loader:
            try:
                input_ids = batch['input_ids'].to(device)
                audio_features = batch['audio_features'].to(device)
                labels = batch['labels'].long().to(device)
            except KeyError as e:
                print(f"Key error: {e}. Available keys in batch: {batch.keys()}")
                continue

            outputs = model(audio_features, labels=labels)
            logits = outputs['logits']
            loss = outputs['loss'] if 'loss' in outputs else loss_fct(logits.view(-1, model.n_vocab), labels.view(-1))
            eval_loss += loss.item()

            all_predictions.extend(torch.argmax(logits, dim=-1).cpu().numpy().tolist())
            all_labels.extend(labels.cpu().numpy().tolist())

    eval_loss /= len(eval_loader)
    metrics = compute_metrics({"predictions": np.array(all_predictions, dtype="object"), "label_ids": np.array(all_labels, dtype="object")})
    return eval_loss, metrics


In [None]:
config = {
    "n_mels": 80,
    "n_audio_ctx": 1500,
    "n_audio_state": 1024,
    "n_audio_head": 14,
    "n_audio_layer": 8,
    "n_vocab": 51865,
    "n_text_ctx": 448,
    "n_text_state": 1024,
    "n_text_head": 14,
    "n_text_layer": 8
}
with open('config.json', 'w') as f:
    json.dump(config, f)

dimensions = ModelDimensions(
    n_mels=80, 
    n_audio_ctx=1500,
    n_audio_state=768, 
    n_audio_head=12, 
    n_audio_layer=12, 
    n_vocab=51865, 
    n_text_ctx=448,
    n_text_state=768, 
    n_text_head=12, 
    n_text_layer=12
    )

if __name__ == "__main__":
  
    tokenizer = WhisperTokenizerFast.from_pretrained("D:/good/my_tokenizer", task="transcribe", language="japanese", local_files_only=True)
    csv_file = 'D:/proj/datasets/gvj/trimmed/metadata.csv'
    audio_dir = 'D:/proj/datasets/gvj/trimmed/'

    metrics_cer = evaluate.load("cer")

    def train_val_dataset(dataset, val_split=0.001):
        train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=val_split)
        datasets = {}
        datasets['train'] = Subset(dataset, train_idx)
        datasets['val'] = Subset(dataset, val_idx)
        return datasets

    dataset = CustomAudioDataset(csv_file, audio_dir, tokenizer)
    datasets = train_val_dataset(dataset)
    train_dataset = datasets['train']
    eval_dataset = datasets['val']

    def train_dataloader():   
        return DataLoader(
            train_dataset,
            batch_size=1, 
            drop_last=True, 
            shuffle=True, 
            num_workers=0,
            collate_fn=collate_fn
        )

    def eval_dataloader():
        return DataLoader(
            eval_dataset,
            batch_size=1, 
            drop_last=True,
            shuffle=False,
            num_workers=0,
            collate_fn=collate_fn
        )
        
    collate_fn = WhisperDataCollatorWithPadding(tokenizer)
    train_loader = train_dataloader()
    eval_loader = eval_dataloader()

    model = Whisper(dimensions, n_vocab=51865, dropout_rate=0.0, 
                    gradient_checkpointing=True, 
                    use_hybrid_attention=False, 
                    use_GroupNorm=True,
                    use_dynamic_conv=True,
                    use_biased_attention=False,
                    use_augmented_memory=False,
                    ).cuda()

    optimizer = optim.Adafactor(
        model.parameters(), 
        lr=0.025, 
        beta2_decay=-0.8, 
        eps=(None, 0.001), 
        d=1.0, 
        weight_decay=0.0, 
        foreach=None, 
        maximize=False
    )

    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)

    scheduler = lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.1,
        patience=10,
        threshold=0.0001,
        threshold_mode='rel',
        cooldown=0,
        min_lr=0,
        eps=1e-08
    )

    train_and_evaluate(model, train_loader, eval_loader, optimizer, loss_fn, scheduler, num_epochs=1, device='cuda', accumulation_steps=1, clear_cache=True, log_interval=5, eval_interval=10, save_interval=1000, checkpoint_dir=checkpoint_dir, log_dir=log_dir)


In [None]:
class DynamicConvAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int, kernel_size=3, dropout_rate=0.1):
        super().__init__()
        self.n_state = n_state
        self.n_head = n_head
        self.kernel_size = kernel_size
        self.conv = nn.Conv1d(n_state, n_state, kernel_size, padding=kernel_size//2, groups=n_head)
        self.dropout = nn.Dropout(dropout_rate)
        self.norm = nn.LayerNorm(n_state)
        self.out = nn.Linear(n_state, n_state)

    def forward(self, x: Tensor):
        # Determine dynamic dimensions
        batch_size, seq_len, embed_dim = x.size()
        if embed_dim != self.n_state:
            raise ValueError(f"Expected embed_dim of {self.n_state}, but got {embed_dim}")

        x = x.permute(0, 2, 1)  # (batch_size, embed_dim, seq_len)
        conv_out = self.conv(x)
        conv_out = conv_out.permute(0, 2, 1)  # (batch_size, seq_len, embed_dim)
        conv_out = self.norm(conv_out)
        conv_out = self.dropout(conv_out)
        return self.out(conv_out) + x.permute(0, 2, 1)  # Ensure x is permuted back to original shape
    
# Example input tensor
x = torch.randn(32, 512, 768)  # (batch_size, seq_len, embed_dim)

# Initialize DynamicConvAttention
dynamic_conv_attn = DynamicConvAttention(n_state=768, n_head=8, kernel_size=3, dropout_rate=0.1)

# Forward pass
output = dynamic_conv_attn(x)
print("Output shape:", output.shape)  # Expected output shape: (batch_size, seq_len, embed_dim)


In [None]:

# # Helper function to transfer layers
# def transfer_layer(source_model, target_model, source_layer_name, target_layer_name):
#     source_layer = dict(source_model.named_parameters()).get(source_layer_name)
#     target_layer = dict(target_model.named_parameters()).get(target_layer_name)
#     if source_layer is not None and target_layer is not None:
#         target_layer.data.copy_(source_layer.data)

# # Load pretrained model
# pretrained_model_name = "openai/whisper-small"
# pretrained_model = WhisperForConditionalGeneration.from_pretrained(pretrained_model_name)
# model_state_dict = model.state_dict()

# transfer_layer(pretrained_model, model, 'model.decoder.embed_tokens.weight', 'decoder.token_embedding.weight')

# # Encoder layers
# for i in range(len(model.encoder.blocks)):
#     # Attention layers
#     transfer_layer(pretrained_model, model, f'model.encoder.layers.{i}.self_attn.q_proj.weight', f'encoder.blocks.{i}.attn.query.weight')
#     transfer_layer(pretrained_model, model, f'model.encoder.layers.{i}.self_attn.k_proj.weight', f'encoder.blocks.{i}.attn.key.weight')
#     transfer_layer(pretrained_model, model, f'model.encoder.layers.{i}.self_attn.v_proj.weight', f'encoder.blocks.{i}.attn.value.weight')
#     transfer_layer(pretrained_model, model, f'model.encoder.layers.{i}.self_attn.out_proj.weight', f'encoder.blocks.{i}.attn.out.weight')
    
#     # MLP layers
#     transfer_layer(pretrained_model, model, f'model.encoder.layers.{i}.fc1.weight', f'encoder.blocks.{i}.mlp.0.weight')
#     transfer_layer(pretrained_model, model, f'model.encoder.layers.{i}.fc1.bias', f'encoder.blocks.{i}.mlp.0.bias')
#     transfer_layer(pretrained_model, model, f'model.encoder.layers.{i}.fc2.weight', f'encoder.blocks.{i}.mlp.2.weight')
#     transfer_layer(pretrained_model, model, f'model.encoder.layers.{i}.fc2.bias', f'encoder.blocks.{i}.mlp.2.bias')

# # Decoder layers
# for i in range(len(model.decoder.blocks)):
#     # Self attention
#     transfer_layer(pretrained_model, model, f'model.decoder.layers.{i}.self_attn.q_proj.weight', f'decoder.blocks.{i}.attn.query.weight')
#     transfer_layer(pretrained_model, model, f'model.decoder.layers.{i}.self_attn.k_proj.weight', f'decoder.blocks.{i}.attn.key.weight')
#     transfer_layer(pretrained_model, model, f'model.decoder.layers.{i}.self_attn.v_proj.weight', f'decoder.blocks.{i}.attn.value.weight')
#     transfer_layer(pretrained_model, model, f'model.decoder.layers.{i}.self_attn.out_proj.weight', f'decoder.blocks.{i}.attn.out.weight')
    
#     # Cross attention
#     transfer_layer(pretrained_model, model, f'model.decoder.layers.{i}.cross_attn.q_proj.weight', f'decoder.blocks.{i}.cross_attn.query.weight')
#     transfer_layer(pretrained_model, model, f'model.decoder.layers.{i}.cross_attn.k_proj.weight', f'decoder.blocks.{i}.cross_attn.key.weight')
#     transfer_layer(pretrained_model, model, f'model.decoder.layers.{i}.cross_attn.v_proj.weight', f'decoder.blocks.{i}.cross_attn.value.weight')
#     transfer_layer(pretrained_model, model, f'model.decoder.layers.{i}.cross_attn.out_proj.weight', f'decoder.blocks.{i}.cross_attn.out.weight')
    
#     # MLP layers
#     transfer_layer(pretrained_model, model, f'model.decoder.layers.{i}.fc1.weight', f'decoder.blocks.{i}.mlp.0.weight')
#     transfer_layer(pretrained_model, model, f'model.decoder.layers.{i}.fc1.bias', f'decoder.blocks.{i}.mlp.0.bias')
#     transfer_layer(pretrained_model, model, f'model.decoder.layers.{i}.fc2.weight', f'decoder.blocks.{i}.mlp.2.weight')
#     transfer_layer(pretrained_model, model, f'model.decoder.layers.{i}.fc2.bias', f'decoder.blocks.{i}.mlp.2.bias')

# print("Full layer transfer complete! Ready to test training.")
# model.load_state_dict(model_state_dict)