In [1]:
#        *
#       /|\
#      /*|O\
#     /*/|\*\
#    /X/O|*\X\
#   /*/X/|\O/*\
#  /X/O/X/|*\X\
# /O/X/*/O/|X/O\
#        | |
#        | |

## Whisper with rotary encoder and learned-sinusoid rotary decoder

# Standard Libraries
import os
import re
import gzip
import csv
import warnings
from contextlib import contextmanager

# Third-Party Libraries
import torch
import MeCab
import librosa
import neologdn
import evaluate
import torchaudio
import soundfile as sf
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from dataclasses import dataclass
from typing import Any, Dict, List, Union, Optional, Tuple, Iterable

# PyTorch Components
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import torch.profiler as profiler
from torch.utils.data import Dataset, DataLoader, Subset

# Datasets and Transformers
from datasets import (
    load_dataset, 
    load_from_disk, 
    concatenate_datasets, 
    Dataset, 
    DatasetDict, 
    Audio, 
    IterableDataset, 
    interleave_datasets
)
from transformers import (
    Trainer, 
    Seq2SeqTrainer, 
    TrainingArguments, 
    Seq2SeqTrainingArguments, 
    WhisperForConditionalGeneration, 
    WhisperModel, 
    WhisperTokenizer, 
    WhisperProcessor, 
    WhisperConfig, 
    WhisperFeatureExtractor, 
    WhisperTokenizerFast, 
    GenerationConfig, 
    pipeline, 
    TrainerControl, 
    TrainerCallback,
    Adafactor, 
    get_cosine_schedule_with_warmup
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.trainer_pt_utils import IterableDatasetShard

# Whisper and Related Libraries
from whisper import load_audio, log_mel_spectrogram, pad_or_trim
from rotary_embedding_torch import RotaryEmbedding
from deepl import translator
from whisper_normalizer.basic import BasicTextNormalizer
from accelerate import Accelerator
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift

# Custom Functions
from decoding import decode as decode_function
from decoding import detect_language as detect_language_function
from transcribe import transcribe as transcribe_function

# Optional Import Handling
try:
    from torch.nn.functional import scaled_dot_product_attention
    SDPA_AVAILABLE = True
except (ImportError, RuntimeError, OSError):
    scaled_dot_product_attention = None
    SDPA_AVAILABLE = False

# Ignore Warnings
warnings.filterwarnings(action="ignore")
warnings.warn = lambda *args, **kwargs: None

# Device and Tensor Type
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float32 if torch.cuda.is_available() else torch.float16



In [2]:

class LayerNorm(nn.Module):  # RMSNorm
    def __init__(self, dim, unit_offset=False):
        super().__init__()
        self.unit_offset = unit_offset
        self.scale = dim ** 0.5
        self.g = nn.Parameter(torch.zeros(dim))
        nn.init.constant_(self.g, 1. - float(unit_offset))

    def forward(self, x):
        gamma = self.g + float(self.unit_offset)
        return F.normalize(x, dim=-1) * self.scale * gamma

class Linear(nn.Linear):
    def forward(self, x):
        return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype))

class Conv1d(nn.Conv1d):
    def _conv_forward(self, x, weight, bias):
        return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))

def sinusofeatures(length, channels, max_timescale=10000):
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)

@contextmanager
def disable_sdpa():
    prev_state = MultiHeadAttention.use_sdpa
    try:
        MultiHeadAttention.use_sdpa = False
        yield
    finally:
        MultiHeadAttention.use_sdpa = prev_state

class MultiHeadAttention(nn.Module):
    use_sdpa = True

    def __init__(self, n_state, n_head):
        super().__init__()
        self.n_head = n_head
        self.head_dim = n_state // n_head
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)
        self.rotary_emb = RotaryEmbedding(dim=n_state // n_head)

    def forward(self, x, xa=None, mask=None, kv_cache=None):
        q = self.query(x)
        if kv_cache is None or xa is None or self.key not in kv_cache:
            k = self.key(x if xa is None else xa)
            v = self.value(x if xa is None else xa)
        else:
            k = kv_cache[self.key]
            v = kv_cache[self.value]

        q = self.rotary_emb.rotate_queries_or_keys(q)
        k = self.rotary_emb.rotate_queries_or_keys(k)

        wv, qk = self.qkv_attention(q, k, v, mask)
        return self.out(wv), qk

    def qkv_attention(self, q, k, v, mask=None):
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25

        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

        if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
            a = scaled_dot_product_attention(q, k, v, is_causal=mask is not None and n_ctx > 1)
            out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
            qk = None
        else:
            qk = (q * scale) @ (k * scale).transpose(-1, -2)
            if mask is not None:
                qk = qk + mask[:n_ctx, :n_ctx]
            qk = qk.float()

            w = F.softmax(qk, dim=-1).to(q.dtype)
            out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
            qk = qk.detach()
        return out, qk

class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state, n_head, cross_attention=False):
        super().__init__()

        self.attn = MultiHeadAttention(n_state, n_head)
        self.attn_ln = LayerNorm(n_state)

        self.cross_attention = cross_attention
        if self.cross_attention:
            self.cross_attn = MultiHeadAttention(n_state, n_head)
            self.cross_attn_ln = LayerNorm(n_state)

        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
        )
        self.mlp_ln = LayerNorm(n_state)

    def forward(self, x, xa=None, mask=None, kv_cache=None):
        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
        if self.cross_attention:
            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
        x = x + self.mlp(self.mlp_ln(x))
        return x
    
class AudioEncoder(nn.Module):
    def __init__(self, n_mels, n_ctx, n_state, n_head, n_layer, activation='gelu'):
        super().__init__()
        self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
        self.blocks = nn.ModuleList([ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)])
        self.ln_post = LayerNorm(n_state)
        self.checkpointing = False
        self.activation = activation

    def forward(self, x):
        if self.checkpointing:
            x = checkpoint(self._activate_checkpointed, x, self.conv1)
            x = checkpoint(self._activate_checkpointed, x, self.conv2)
            for block in self.blocks:
                x = checkpoint(self._block_checkpointed, x, block)
        else:
            x = self._activate(self.conv1(x))
            x = self._activate(self.conv2(x))
            x = x.permute(0, 2, 1)
            for block in self.blocks:
                x = block(x)
        x = self.ln_post(x)
        return x

    def _activate(self, layer_output):
        if self.activation == 'gelu':
            return F.gelu(layer_output)
        elif self.activation == 'relu':
            return F.relu(layer_output)
        else:
            raise ValueError("Unsupported activation function")

    def _activate_checkpointed(self, x, layer):
        output = layer(x)
        return self._activate(output)

    def _block_checkpointed(self, x, block):
        return block(x)

    def gradient_checkpointing_enable(self):
        self.checkpointing = True

    def gradient_checkpointing_disable(self):
        self.checkpointing = False



# class AudioEncoder(nn.Module):
#     def __init__(self, n_mels, n_ctx, n_state, n_head, n_layer):
#         super().__init__()
#         self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
#         self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
#         self.blocks = nn.ModuleList([ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)])
#         self.ln_post = LayerNorm(n_state)
#         self.checkpointing = False

#     def forward(self, x):
    
#         x = F.gelu(self.conv1(x))
#         x = F.gelu(self.conv2(x))
#         x = x.permute(0, 2, 1)

#         for block in self.blocks:
#             x = checkpoint.checkpoint(block, x)
        
#         x = self.ln_post(x)
#         return x

class TextDecoder(nn.Module):
    def __init__(self, n_vocab, n_ctx, n_state, n_head, n_layer):
        super().__init__()

        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, n_state)
        self.rotary_emb = RotaryEmbedding(dim=n_state // n_head)
        
        self.blocks = nn.ModuleList(
            [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
        )
        self.ln = LayerNorm(n_state)

        mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
        self.register_buffer('mask', mask, persistent=False)

    def forward(self, x, xa, kv_cache=None):
        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
        positions = torch.arange(offset, offset + x.shape[-1], device=x.device)
        x = self.token_embedding(x) + self.positional_embedding(positions)
        x = x.to(xa.dtype)

        for block in self.blocks:
            x = block(x, xa, mask=self.mask, kv_cache=kv_cache)

        x = self.ln(x)
        logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
        return logits

def block_forward(block, x, xa, mask, kv_cache):
    return block(x, xa, mask=mask, kv_cache=kv_cache)
   
class LearnedSinusoidalEmbeddings(nn.Module):
    def __init__(self, n_ctx, n_state):
        super().__init__()
        self.n_ctx = n_ctx
        self.n_state = n_state

        # Initialize with sinusoidal embeddings
        sinusoidal_embeddings = sinusofeatures(n_ctx, n_state)
        self.positional_embeddings = nn.Parameter(sinusoidal_embeddings)

    def forward(self, positions):
        position_embeddings = self.positional_embeddings[positions]
        return position_embeddings

In [3]:
class CustomWhisperConfig(WhisperConfig):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.n_mels = kwargs.get("n_mels", 80)
        self.n_audio_ctx = kwargs.get("n_audio_ctx", 1500)
        self.n_audio_state = kwargs.get("n_audio_state", 1024)
        self.n_audio_head = kwargs.get("n_audio_head", 16)
        self.n_audio_layer = kwargs.get("n_audio_layer", 24)
        self.n_vocab = kwargs.get("n_vocab", 51865)
        self.n_text_ctx = kwargs.get("n_text_ctx", 1500)
        self.n_text_state = kwargs.get("n_text_state", 1024)
        self.n_text_head = kwargs.get("n_text_head", 15)
        self.n_text_layer = kwargs.get("n_text_layer", 24
        )

class CustomWhisperModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config


        self.encoder = AudioEncoder(
            config.n_mels,
            config.n_audio_ctx,
            config.n_audio_state,
            config.n_audio_head,
            config.n_audio_layer
        )
        self.decoder = TextDecoder(
            config.n_vocab,
            config.n_text_ctx,
            config.n_text_state,
            config.n_text_head,
            config.n_text_layer
        )

        all_heads = torch.zeros(config.n_text_layer, config.n_text_head, dtype=torch.bool)
        all_heads[config.n_text_layer // 2:] = True
        self.register_buffer('alignment_heads', all_heads.to_sparse(), persistent=False)

    def forward(self, input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, labels=None, **kwargs):
        if input_ids is None:
            raise ValueError("Input IDs cannot be None")
        encoder_outputs = self.encoder(input_ids)
        if decoder_input_ids is None:
            raise ValueError("Decoder input IDs cannot be None")
        decoder_outputs = self.decoder(decoder_input_ids, encoder_outputs)
        logits = decoder_outputs.view(-1, decoder_outputs.size(-1))
        loss = None
        if labels is not None:
            criterion = nn.CrossEntropyLoss(ignore_index=-100)
            loss = criterion(logits, labels.view(-1))
        return {"loss": loss, "logits": logits}
    
    # def gradient_checkpointing_enable(self):
    #     self.encoder.gradient_checkpointing_enable()
    #     # Optionally enable checkpointing for other parts of the model if needed

    # def gradient_checkpointing_disable(self):
    #     self.encoder.gradient_checkpointing_disable()
    #     # Optionally disable checkpointing for other parts of the model if needed


    # def gradient_checkpointing_enable(self, **kwargs):
    #     # Enable gradient checkpointing for layers that support it
    #     for layer in self.children():
    #         if hasattr(layer, 'gradient_checkpointing_enable'):
    #             layer.gradient_checkpointing_enable(**kwargs)

    def embed_audio(self, mel):
        return self.encoder(mel)

    def logits(self, tokens, audio_features):
        return self.decoder(tokens, audio_features)

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def is_multilingual(self):
        return self.config.n_vocab >= 51865

    @property
    def num_languages(self):
        return self.config.n_vocab - 51765 - int(self.is_multilingual)

    def install_kv_cache_hooks(self, cache=None):
        cache = {**cache} if cache is not None else {}
        hooks = []

        def save_to_cache(module, _, output):
            if module not in cache or output.shape[1] > self.config.n_text_ctx:
                cache[module] = output
            else:
                cache[module] = torch.cat([cache[module], output], dim=1).detach()
            return cache[module]

        def install_hooks(layer):
            if isinstance(layer, MultiHeadAttention):
                hooks.append(layer.key.register_forward_hook(save_to_cache))
                hooks.append(layer.value.register_forward_hook(save_to_cache))

        self.decoder.apply(install_hooks)
        return cache, hooks

    detect_language = detect_language_function
    transcribe = transcribe_function
    decode = decode_function


In [4]:
# Initialize the custom configuration
config = CustomWhisperConfig(
    n_mels=80, 
    n_audio_ctx=1500, 
    n_audio_state=1024, 
    n_audio_head=16, 
    n_audio_layer=24, 
    n_vocab=51865, 
    n_text_ctx=448, 
    n_text_state=1024, 
    n_text_head=16, 
    n_text_layer=24
    )

# Initialize the custom model
model = CustomWhisperModel(config).cuda()


In [5]:
pretrained_model = WhisperForConditionalGeneration.from_pretrained('openai/whisper-medium')
pretrained_state_dict = pretrained_model.state_dict()

model_state_dict = model.state_dict()

def transfer_layer(src_name, tgt_name):
    if src_name in pretrained_state_dict and tgt_name in model_state_dict:
        src_tensor = pretrained_state_dict[src_name]
        tgt_tensor = model_state_dict[tgt_name]
        print(f'Transferring layer {src_name} to {tgt_name}')
        print(f'Source shape: {src_tensor.shape}, Target shape: {tgt_tensor.shape}')
        tgt_tensor.copy_(src_tensor)

# Transfer convolutional layers
transfer_layer('model.encoder.conv1.weight', 'encoder.conv1.weight')
transfer_layer('model.encoder.conv1.bias', 'encoder.conv1.bias')
transfer_layer('model.encoder.conv2.weight', 'encoder.conv2.weight')
transfer_layer('model.encoder.conv2.bias', 'encoder.conv2.bias')

# Transfer positional embeddings
# transfer_layer('model.encoder.embed_positions.weight', 'encoder.positional_embedding.weight')
# transfer_layer('model.decoder.embed_positions.weight', 'decoder.positional_embedding.weight')

# Transfer layer norms
transfer_layer('model.encoder.layer_norm.weight', 'encoder.ln_post.weight')
transfer_layer('model.encoder.layer_norm.bias', 'encoder.ln_post.bias')
transfer_layer('model.decoder.layer_norm.weight', 'decoder.ln.weight')
transfer_layer('model.decoder.layer_norm.bias', 'decoder.ln.bias')

# Transfer token embeddings
transfer_layer('model.decoder.embed_tokens.weight', 'decoder.token_embedding.weight') # tokenizer

# Transfer encoder and decoder block layers
for i in range(6):
    transfer_layer(f'model.encoder.layers.{i}.self_attn.k_proj.weight', f'encoder.blocks.{i}.attn.key.weight')
    transfer_layer(f'model.encoder.layers.{i}.self_attn.v_proj.weight', f'encoder.blocks.{i}.attn.value.weight')
    transfer_layer(f'model.encoder.layers.{i}.self_attn.q_proj.weight', f'encoder.blocks.{i}.attn.query.weight')
    transfer_layer(f'model.encoder.layers.{i}.self_attn.out_proj.weight', f'encoder.blocks.{i}.attn.out.weight')
    transfer_layer(f'model.encoder.layers.{i}.self_attn_layer_norm.weight', f'encoder.blocks.{i}.attn_ln.weight')
    transfer_layer(f'model.encoder.layers.{i}.self_attn_layer_norm.bias', f'encoder.blocks.{i}.attn_ln.bias')
    transfer_layer(f'model.encoder.layers.{i}.fc1.weight', f'encoder.blocks.{i}.mlp.0.weight')
    transfer_layer(f'model.encoder.layers.{i}.fc1.bias', f'encoder.blocks.{i}.mlp.0.bias')
    transfer_layer(f'model.encoder.layers.{i}.fc2.weight', f'encoder.blocks.{i}.mlp.2.weight')
    transfer_layer(f'model.encoder.layers.{i}.fc2.bias', f'encoder.blocks.{i}.mlp.2.bias')
    transfer_layer(f'model.encoder.layers.{i}.final_layer_norm.weight', f'encoder.blocks.{i}.mlp_ln.weight')
    transfer_layer(f'model.encoder.layers.{i}.final_layer_norm.bias', f'encoder.blocks.{i}.mlp_ln.bias')
    transfer_layer(f'model.decoder.layers.{i}.self_attn.k_proj.weight', f'decoder.blocks.{i}.attn.key.weight')
    transfer_layer(f'model.decoder.layers.{i}.self_attn.v_proj.weight', f'decoder.blocks.{i}.attn.value.weight')
    transfer_layer(f'model.decoder.layers.{i}.self_attn.q_proj.weight', f'decoder.blocks.{i}.attn.query.weight')
    transfer_layer(f'model.decoder.layers.{i}.self_attn.out_proj.weight', f'decoder.blocks.{i}.attn.out.weight')
    transfer_layer(f'model.decoder.layers.{i}.self_attn_layer_norm.weight', f'decoder.blocks.{i}.attn_ln.weight')
    transfer_layer(f'model.decoder.layers.{i}.self_attn_layer_norm.bias', f'decoder.blocks.{i}.attn_ln.bias')
    transfer_layer(f'model.decoder.layers.{i}.encoder_attn.k_proj.weight', f'decoder.blocks.{i}.cross_attn.key.weight')
    transfer_layer(f'model.decoder.layers.{i}.encoder_attn.v_proj.weight', f'decoder.blocks.{i}.cross_attn.value.weight')
    transfer_layer(f'model.decoder.layers.{i}.encoder_attn.q_proj.weight', f'decoder.blocks.{i}.cross_attn.query.weight')
    transfer_layer(f'model.decoder.layers.{i}.encoder_attn.out_proj.weight', f'decoder.blocks.{i}.cross_attn.out.weight')
    transfer_layer(f'model.decoder.layers.{i}.encoder_attn_layer_norm.weight', f'decoder.blocks.{i}.cross_attn_ln.weight')
    transfer_layer(f'model.decoder.layers.{i}.encoder_attn_layer_norm.bias', f'decoder.blocks.{i}.cross_attn_ln.bias')
    transfer_layer(f'model.decoder.layers.{i}.fc1.weight', f'decoder.blocks.{i}.mlp.0.weight')
    transfer_layer(f'model.decoder.layers.{i}.fc1.bias', f'decoder.blocks.{i}.mlp.0.bias')
    transfer_layer(f'model.decoder.layers.{i}.fc2.weight', f'decoder.blocks.{i}.mlp.2.weight')
    transfer_layer(f'model.decoder.layers.{i}.fc2.bias', f'decoder.blocks.{i}.mlp.2.bias')
    transfer_layer(f'model.decoder.layers.{i}.final_layer_norm.weight', f'decoder.blocks.{i}.mlp_ln.weight')
    transfer_layer(f'model.decoder.layers.{i}.final_layer_norm.bias', f'decoder.blocks.{i}.mlp_ln.bias')

# Load the updated state dict into the custom model
model.load_state_dict(model_state_dict)

# Now you can proceed with training

Transferring layer model.encoder.conv1.weight to encoder.conv1.weight
Source shape: torch.Size([1024, 80, 3]), Target shape: torch.Size([1024, 80, 3])
Transferring layer model.encoder.conv1.bias to encoder.conv1.bias
Source shape: torch.Size([1024]), Target shape: torch.Size([1024])
Transferring layer model.encoder.conv2.weight to encoder.conv2.weight
Source shape: torch.Size([1024, 1024, 3]), Target shape: torch.Size([1024, 1024, 3])
Transferring layer model.encoder.conv2.bias to encoder.conv2.bias
Source shape: torch.Size([1024]), Target shape: torch.Size([1024])
Transferring layer model.decoder.embed_tokens.weight to decoder.token_embedding.weight
Source shape: torch.Size([51865, 1024]), Target shape: torch.Size([51865, 1024])
Transferring layer model.encoder.layers.0.self_attn.k_proj.weight to encoder.blocks.0.attn.key.weight
Source shape: torch.Size([1024, 1024]), Target shape: torch.Size([1024, 1024])
Transferring layer model.encoder.layers.0.self_attn.v_proj.weight to encoder.bl

<All keys matched successfully>

In [None]:
# Define processor and tokenizer
processor = WhisperProcessor.from_pretrained("openai/whisper-base")
Feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base", feature_size=128, do_normalize=True)
tokenizer = WhisperTokenizerFast.from_pretrained("openai/whisper-base", padding="longest")

metric = evaluate.load("cer")

class DataCollatorSpeechSeq2SeqWithPadding:
    def __init__(self, processor, decoder_start_token_id):
        self.processor = processor
        self.decoder_start_token_id = decoder_start_token_id

    def __call__(self, features):
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        batch["input_ids"] = batch["input_features"]
        batch["decoder_input_ids"] = labels_batch["input_ids"]

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

# Prepare the datasets
ds_a = load_from_disk("D:/proj/datasets/gvjas")["train"].to_iterable_dataset(num_shards=200).filter(lambda sample: bool(sample["sentence"]))
ds_b = load_from_disk("D:/proj/datasets/gvjas")["test"].to_iterable_dataset(num_shards=2).filter(lambda sample: bool(sample["sentence"]))

def map_dataset(map):
    map["input_features"] = processor.feature_extractor(map["audio"]["array"], sampling_rate=map["audio"]["sampling_rate"]).input_features[0]
    map["labels"] = processor.tokenizer(map["sentence"]).input_ids
    return map

train = ds_a.map(map_dataset).select_columns(["input_features", "labels"])
test = ds_b.map(map_dataset).select_columns(["input_features", "labels"])

metric = evaluate.load("cer")
wakati = MeCab.Tagger("-Owakati")

def compute_metrics(pred):
    pred_features = pred.predictions
    label_features = pred.label_features
    label_features[label_features == -100] = tokenizer.pad_token_id
    pred_str = tokenizer.batch_decode(pred_features, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_features, skip_special_tokens=True)
    
    pred_str_nj = [wakati.parse(pred) for pred in pred_str] 
    label_str_nj = [wakati.parse(label) for label in label_str] 
    pred_str_nj = [pred_str_nj[i] for i in range(len(pred_str_nj)) if len(label_str_nj[i]) > 0]
    label_str_nj = [
        label_str_nj[i]
        for i in range(len(label_str_nj))
        if len(label_str_nj[i]) > 0]
    
    pred_str_neo = [neologdn.normalize(pred) for pred in pred_str] 
    label_str_neo = [neologdn.normalize(label) for label in label_str] 
    pred_str_neo = [pred_str_neo[i] for i in range(len(pred_str_neo)) if len(label_str_neo[i]) > 0]
    label_str_neo = [
        label_str_neo[i]
        for i in range(len(label_str_neo))
        if len(label_str_neo[i]) > 0]
    
    cer = 100 * metric.compute(predictions=pred_str, references=label_str) # no normalization
    cer_mecab = 100 * metric.compute(predictions=pred_str_nj, references=label_str_nj) # mecab normalization
    cer_neo = 100 * metric.compute(predictions=pred_str_neo, references=label_str_neo) # 
    return {"cer": cer,  "cer_mecab": cer_mecab, "cer_neo": cer_neo}#, "blue": blue, "accuracy": accuracy} 


Loading dataset from disk:   0%|          | 0/26 [00:00<?, ?it/s]

Loading dataset from disk:   0%|          | 0/26 [00:00<?, ?it/s]

In [None]:
# Initialize the custom model
model = CustomWhisperModel(config, n_mels=n_mels, n_ctx=1500, n_state=1280, n_head=20, n_layer=26, activation='relu').cuda()

training_args = Seq2SeqTrainingArguments(
    output_dir="./out",
    overwrite_output_dir=True,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=1,
    learning_rate=6.2e-4,
    warmup_steps=100,
    num_train_epochs=1,
    max_steps=7500,
    eval_strategy="steps",
    save_steps=50,
    eval_steps=50,
    fp16=True,
    eval_on_start=False,
    logging_steps=5,
    logging_dir=("./logs"),
    logging_strategy="steps",
    logging_first_step=True,
    report_to=["tensorboard"],
    push_to_hub=False,
    remove_unused_columns=False,
    label_names=["labels"],
    hub_private_repo=True,
    metric_for_best_model="cer",
    predict_with_generate=True,
    greater_is_better=False,
    generation_max_length=128,
    optim = "adafactor",
    weight_decay=0.002,

    )


training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    learning_rate=5e-5,
    gradient_accumulation_steps=8,
    gradient_checkpointing=True,  # Enable gradient checkpointing
    num_train_epochs=3,
)

# Enable gradient checkpointing for the model
model.gradient_checkpointing_enable()

# To disable gradient checkpointing later if needed
model.gradient_checkpointing_disable()


In [None]:

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()
