In [None]:
import sys, os, random, math, time
import math
import logging, warnings, csv, base64, gzip
import json, datetime, numpy as np, torch, torch.nn as nn
import torch.optim as optim, torch.nn.functional as F, torchaudio
import torchaudio.transforms as transforms, torch.utils.checkpoint as checkpoint
import torch.utils.tensorboard as tensorboard, torch.optim.lr_scheduler as lr_scheduler
import transformers, neologdn, evaluate, MeCab, deepl, logging, datasets, tqdm, whisper
import transformers.utils.logging
from datasets import load_from_disk, load_dataset
from contextlib import contextmanager
from dataclasses import dataclass
from torch.utils.data import Subset, IterableDataset
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from tqdm import tqdm
from torch.profiler import profile, ProfilerActivity, record_function
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
from torch import amp, Tensor
from torch.optim import Adamax
import logging
from safetensors import safe_open
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.trainer_pt_utils import IterableDatasetShard
from transformers.trainer_utils import is_main_process
from transformers.trainer_pt_utils import find_batch_size, get_parameter_names
from transformers import (
    TrainerState,
    TrainerControl,
    logging,
    Trainer,
    TrainingArguments,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    PretrainedConfig,
    GenerationConfig,
    WhisperFeatureExtractor,
    WhisperProcessor,
    WhisperForConditionalGeneration,
    WhisperTokenizerFast,
    WhisperTokenizer,
    WhisperModel,
    WhisperConfig,
    Adafactor,
    TrainerCallback,
    logging
)
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.amp import GradScaler, autocast
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
warnings.filterwarnings(action="ignore")
warnings.warn = lambda *args, **kwargs: None

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
checkpointing_args = {"reentrant": False}

try:
    from torch.nn.functional import scaled_dot_product_attention
    SDPA_AVAILABLE = True
except (ImportError, RuntimeError, OSError):
    scaled_dot_product_attention = None
    SDPA_AVAILABLE = False

from whisper.decoding import decode as decode_function
from whisper.decoding import detect_language as detect_language_function
from whisper.transcribe import transcribe as transcribe_function

from torch.utils.checkpoint import checkpoint

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mecab = MeCab.Tagger("-Owakati")

transformers.utils.logging.set_verbosity_error()


In [None]:

@dataclass
class ModelDimensions:
    n_mels: int
    n_audio_ctx: int
    n_audio_state: int
    n_audio_head: int
    n_audio_layer: int
    n_vocab: int
    n_text_ctx: int
    n_text_state: int
    n_text_head: int
    n_text_layer: int

class LayerNorm(nn.LayerNorm):
    def forward(self, x: Tensor) -> Tensor:
        return super().forward(x.float()).type(x.dtype)

class Linear(nn.Linear):
    def forward(self, x: Tensor) -> Tensor:
        return F.linear(
            x,
            self.weight.to(x.dtype),
            None if self.bias is None else self.bias.to(x.dtype),
        )

class Conv1d(nn.Conv1d):
    def _conv_forward(
        self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
    ) -> Tensor:
        return super()._conv_forward(
            x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
        )

def sinusoids(length, channels, max_timescale=10000):
    """Returns sinusoids for positional embedding"""
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)

@contextmanager
def disable_sdpa():
    prev_state = MultiHeadAttention.use_sdpa
    try:
        MultiHeadAttention.use_sdpa = False
        yield
    finally:
        MultiHeadAttention.use_sdpa = prev_state

class MultiHeadAttention(nn.Module):
    use_sdpa = True

    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)

    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
    ):
        q = self.query(x)

        if kv_cache is None or xa is None or self.key not in kv_cache:
            # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
            # otherwise, perform key/value projections for self- or cross-attention as usual.
            k = self.key(x if xa is None else xa)
            v = self.value(x if xa is None else xa)
        else:
            # for cross-attention, calculate keys and values once and reuse in subsequent calls.
            k = kv_cache[self.key]
            v = kv_cache[self.value]

        wv, qk = self.qkv_attention(q, k, v, mask)
        return self.out(wv), qk

    def qkv_attention(
        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

        if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
            a = scaled_dot_product_attention(
                q, k, v, is_causal=mask is not None and n_ctx > 1
            )
            out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
            qk = None
        else:
            qk = (q * scale) @ (k * scale).transpose(-1, -2)
            if mask is not None:
                qk = qk + mask[:n_ctx, :n_ctx]
            qk = qk.float()

            w = F.softmax(qk, dim=-1).to(q.dtype)
            out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
            qk = qk.detach()

        return out, qk

class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
        super().__init__()

        self.attn = MultiHeadAttention(n_state, n_head)
        self.attn_ln = LayerNorm(n_state)

        self.cross_attn = (
            MultiHeadAttention(n_state, n_head) if cross_attention else None
        )
        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None

        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
        )
        self.mlp_ln = LayerNorm(n_state)

    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
    ):
        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
        if self.cross_attn:
            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
        x = x + self.mlp(self.mlp_ln(x))
        return x

class AudioEncoder(nn.Module):
    def __init__(
        self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        super().__init__()
        self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
        self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))

        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
        )
        self.ln_post = LayerNorm(n_state)

    def forward(self, x: Tensor):

        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)

        assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
        x = (x + self.positional_embedding).to(x.dtype)

        for block in self.blocks:
            x = block(x)

        x = self.ln_post(x)
        return x

class TextDecoder(nn.Module):
    def __init__(
        self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        super().__init__()

        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))

        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [
                ResidualAttentionBlock(n_state, n_head, cross_attention=True)
                for _ in range(n_layer)
            ]
        )
        self.ln = LayerNorm(n_state)

        mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
        self.register_buffer("mask", mask, persistent=False)

    def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):

        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
        x = (
            self.token_embedding(x)
            + self.positional_embedding[offset : offset + x.shape[-1]]
        )
        x = x.to(xa.dtype)

        for block in self.blocks:
            x = block(x, xa, mask=self.mask, kv_cache=kv_cache)

        x = self.ln(x)
        logits = (
            x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
        ).float()

        return logits

class Whisper(nn.Module):
    def __init__(self, config: ModelDimensions):
        super().__init__()
        self.config = config
        self.encoder = AudioEncoder(
            self.config.n_mels,
            self.config.n_audio_ctx,
            self.config.n_audio_state,
            self.config.n_audio_head,
            self.config.n_audio_layer,
        )
        self.decoder = TextDecoder(
            self.config.n_vocab,
            self.config.n_text_ctx,
            self.config.n_text_state,
            self.config.n_text_head,
            self.config.n_text_layer,
        )

        all_heads = torch.zeros(
            self.config.n_text_layer, self.config.n_text_head, dtype=torch.bool
        )
        all_heads[self.config.n_text_layer // 2 :] = True
        self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)

    def set_alignment_heads(self, dump: bytes):
        array = np.frombuffer(
            gzip.decompress(base64.b85decode(dump)), dtype=bool
        ).copy()
        mask = torch.from_numpy(array).reshape(
            self.config.n_text_layer, self.config.n_text_head
        )
        self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)

    def embed_audio(self, mel: torch.Tensor):
        return self.encoder(mel)

    def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
        return self.decoder(tokens, audio_features)

    @staticmethod
    def shift_tokens_right(input_ids: torch.Tensor, pad_token_id, decoder_start_token_id) -> torch.Tensor:
        shifted_input_ids = input_ids.new_zeros(input_ids.shape, dtype=torch.long)
        shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
        shifted_input_ids[:, 0] = decoder_start_token_id

        if pad_token_id is None:
            raise ValueError("pad_token_id has to be defined.")
        shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
        return shifted_input_ids

    def forward(self, input_features: torch.Tensor, labels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        input_features = input_features.float()
        encoded_features = self.encoder(input_features)
        if labels is not None:
            labels = labels.long()
            decoder_input_ids = self.shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
        else:
            decoder_input_ids = None

        logits = self.decoder(decoder_input_ids, encoded_features)
        loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
            logits = logits.view(-1, self.config.n_vocab)
            labels = labels.view(-1).long()
            loss = loss_fct(logits, labels)
        return {"loss": loss, "logits": logits, "input_features": encoded_features}
    
    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def is_multilingual(self):
        return self.config.n_vocab >= 51865

    @property
    def num_languages(self):
        return self.config.n_vocab - 51765 - int(self.is_multilingual)

    def install_kv_cache_hooks(self, cache: Optional[dict] = None):

        cache = {**cache} if cache is not None else {}
        hooks = []

        def save_to_cache(module, _, output):
            if module not in cache or output.shape[1] > self.config.n_text_ctx:

                cache[module] = output
            else:
                cache[module] = torch.cat([cache[module], output], dim=1).detach()
            return cache[module]

        def install_hooks(layer: nn.Module):
            if isinstance(layer, MultiHeadAttention):
                hooks.append(layer.key.register_forward_hook(save_to_cache))
                hooks.append(layer.value.register_forward_hook(save_to_cache))

        self.decoder.apply(install_hooks)
        return cache, hooks

    detect_language = detect_language_function
    transcribe = transcribe_function
    decode = decode_function


    def install_kv_cache_hooks(self, cache: Optional[dict] = None):
        cache = {**cache} if cache is not None else {}
        hooks = []

        def save_to_cache(module, _, output):
            if module not in cache or output.shape[1] > self.config.n_text_ctx:
                cache[module] = output
            else:
                cache[module] = torch.cat([cache[module], output], dim=1).detach()
            return cache[module]

        def install_hooks(layer: nn.Module):
            if isinstance(layer, MultiHeadAttention):
                hooks.append(layer.key.register_forward_hook(save_to_cache))
                hooks.append(layer.value.register_forward_hook(save_to_cache))

        self.decoder.apply(install_hooks)
        return cache, hooks

    def set_input_embeddings(self, new_embeddings: torch.nn.Embedding):
        self.decoder.token_embedding = new_embeddings

    def get_input_embeddings(self):
        return self.decoder.token_embedding

    def resize_token_embeddings(self, new_num_tokens: int):
        old_embeddings = self.get_input_embeddings()
        old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
        new_embeddings = torch.nn.Embedding(new_num_tokens, old_embedding_dim)
    
        new_embeddings.weight.data[:old_num_tokens, :] = old_embeddings.weight.data
        self.set_input_embeddings(new_embeddings)
        self.config.n_vocab = new_num_tokens

    def save_pretrained(self, save_directory, safetensor=False):
        self.save_pretrained(save_directory)
        if safetensor:
            safetensor_path = os.path.join(save_directory, "model.safetensors")
            with safe_open(safetensor_path, framework="pt", mode="w") as f:
                for key, value in self.state_dict().items():
                    f.set_tensor(key, value)
        else:
            torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, safetensor=False, *model_args, **kwargs):
        config = WhisperConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
        model = cls(config, *model_args, **kwargs)
        if safetensor:
            safetensor_path = f"{pretrained_model_name_or_path}/model.safetensors"
            with safe_open(safetensor_path, framework="pt", device="cpu") as f:
                state_dict = {key: torch.tensor(f.get_tensor(key)) for key in f.keys()}
        else:
            state_dict = torch.load(os.path.join(pretrained_model_name_or_path, "pytorch_model.bin"), map_location="cpu")
        model.load_state_dict(state_dict, strict=False)
        return model


    def get_encoder(self):
        return self.encoder

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {'input_features': input_ids}

    def _prepare_decoder_input_ids_for_generation(self, batch_size, decoder_start_token_id=None, bos_token_id=None):
        return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * self.config.decoder_start_token_id

    def can_generate(self):
        return True
    
    def generate(self, inputs, **kwargs):
        encoder_outputs = self.encoder(inputs)
        decoder_input_ids = torch.zeros((inputs.size(0), 1), dtype=torch.long, device=inputs.device)
        outputs = self.decoder(decoder_input_ids, encoder_outputs)
        return outputs.argmax(dim=-1)

In [None]:
from datetime import datetime

log_dir = os.path.join('./output/logs/', datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
os.makedirs(log_dir, exist_ok=True)

old_model = "openai/whisper-small"
feature_extractor = WhisperFeatureExtractor.from_pretrained(old_model)#, sampling_rate=16000, n_fft=1024, hop_length=256, n_mels=80)
tokenizer = WhisperTokenizerFast.from_pretrained(old_model)
proccesor = WhisperProcessor.from_pretrained(old_model, tokenizer=tokenizer, feature_extractor=feature_extractor)

config = WhisperConfig(
    n_mels=80,
    n_audio_ctx=1500,
    n_audio_state=768,
    n_audio_head=12,
    n_audio_layer=12,
    n_vocab=51865,
    n_text_ctx=448,
    n_text_state=768,
    n_text_head=12,
    n_text_layer=12,
    )

model = Whisper(config).cuda()

optimizer = transformers.Adafactor(model.parameters(), 
                                clip_threshold=0.99, 
                                weight_decay=0.025, 
                                scale_parameter=True, 
                                relative_step=False, 
                                warmup_init=False, 
                                lr=2.25e-3)

scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)

def get_adamax_optimizer(model, learning_rate=0.001, weight_decay=0.0):
    return Adamax(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

dataset = load_dataset("audiofolder", data_dir="D:/proj/datasets/gf_1", cache_dir = "D:/hf")['train'].to_iterable_dataset(num_shards=20).filter(lambda x: len(x['sentence']) > 0).map(lambda x: {"sentence": neologdn.normalize(x['sentence'], repeat=1)}).shuffle(seed=42, buffer_size=1000)

def prepare_dataset(batch):
    audio = batch["audio"]
    input_features = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    input_features = torch.nan_to_num(torch.tensor(input_features), nan=0.0, posinf=1.0, neginf=-1.0).tolist()
    labels = tokenizer(batch["sentence"]).input_ids
    labels = torch.nan_to_num(torch.tensor(labels), nan=0.0, posinf=1.0, neginf=-1.0).tolist()
    
    batch["input_features"] = input_features
    batch["labels"] = labels
    return batch

dataset = dataset.map(prepare_dataset).select_columns(["input_features", "labels"])
test , train = dataset.take(100), dataset.skip(100)

metric = evaluate.load("cer")
wakati = MeCab.Tagger("-Owakati")

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    proccesor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.proccesor.feature_extractor.pad(input_features, return_tensors="pt")
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.proccesor.tokenizer.pad(label_features, return_tensors="pt")
        
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]
        batch["labels"] = labels

        batch["input_features"] = torch.nan_to_num(batch["input_features"], nan=0.0, posinf=1.0, neginf=-1.0)
        batch["labels"] = torch.nan_to_num(batch["labels"], nan=0.0, posinf=1.0, neginf=-1.0)
        
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(proccesor=proccesor, decoder_start_token_id=model.config.decoder_start_token_id)
cer_metric = evaluate.load("cer")


In [None]:
class CustomTensorBoardCallback(TrainerCallback):
    def __init__(self, tb_writer):
        self.tb_writer = tb_writer

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            for key, value in logs.items():
                self.tb_writer.add_scalar(key, value, state.global_step)
            if 'predictions' in logs and 'label_ids' in logs:
                cer = compute_cer(logs['predictions'], logs['label_ids'])
                self.tb_writer.add_scalar("cer", cer, state.global_step)

class SavePeftModelCallback(TrainerCallback):
    def on_save(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
        peft_model_path = os.path.join(checkpoint_folder, "adapter")
        kwargs["model"].save_pretrained(peft_model_path)
        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
        if os.path.exists(pytorch_model_path):
            os.remove(pytorch_model_path)
        return control
    
class ShuffleCallback(TrainerCallback):
    def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
        if isinstance(train_dataloader.dataset, IterableDatasetShard):
            pass  
        elif isinstance(train_dataloader.dataset, IterableDataset):
            train_dataloader.dataset.shuffle()       
            train_dataloader.dataset.set_epoch(train_dataloader.dataset.epoch + 1)

class GradientClippingCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        torch.nn.utils.clip_grad_norm_(kwargs["model"].parameters(), max_norm=1.0)

class PredictionLogger(TrainerCallback):
    def on_log(self, args, state, control, **kwargs):
        if state.global_step % args.logging_steps == 0:
            eval_loss = kwargs.get('metrics', {}).get('eval_loss', None)
            if eval_loss is not None:
                all_predictions = kwargs['metrics']['predictions']
                all_labels = kwargs['metrics']['label_ids']
                tokenizer = kwargs['tokenizer']

                sample_indices = range(min(1, len(all_predictions)))
                for idx in sample_indices:
                    pred_str = tokenizer.decode(all_predictions[idx], skip_special_tokens=True)
                    label_str = tokenizer.decode(all_labels[idx], skip_special_tokens=True)
                    print(f"Evaluation Loss: {eval_loss:.4f}")
                    print(f"Evaluation Sample {idx}: Prediction: {pred_str}, Label: {label_str}")

class ModelFreezingCallback(TrainerCallback):
    def on_train_begin(self, args, state, control, **kwargs):
        for name, param in kwargs['model'].named_parameters():
            if 'layer.0' in name:  # Change this to match the layers you want to freeze
                param.requires_grad = False
                print(f"Freezing layer: {name}")


In [None]:

def compute_cer(predictions, label_ids):
    pred_str = proccesor.tokenizer.batch_decode(predictions, skip_special_tokens=True)
    label_str = proccesor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    return 100 * cer_metric.compute(predictions=pred_str, references=label_str)

def compute_metrics(pred):
    wakati = MeCab.Tagger("-Owakati")   
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = proccesor.tokenizer.pad_token_id
    
    pred_str = proccesor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = proccesor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    
    cer = 100 * cer_metric.compute(predictions=pred_str, references=label_str)
    
    pred_flat = np.argmax(pred_ids, axis=2).flatten()
    labels_flat = label_ids.flatten()
    mask = labels_flat != proccesor.tokenizer.pad_token_id

    accuracy = accuracy_score(labels_flat[mask], pred_flat[mask])
    precision = precision_score(labels_flat[mask], pred_flat[mask], average='weighted')
    recall = recall_score(labels_flat[mask], pred_flat[mask], average='weighted')
    f1 = f1_score(labels_flat[mask], pred_flat[mask], average='weighted')
    
    return {
        "cer": cer,
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }

tb_writer = SummaryWriter(log_dir=log_dir)
tb_callback = CustomTensorBoardCallback(tb_writer)

In [None]:
out = "./output"
training_args = Seq2SeqTrainingArguments(
    output_dir=log_dir,
    overwrite_output_dir=True,
    per_device_train_batch_size=1, 
    gradient_accumulation_steps=1,
    learning_rate=2.4e-5,
    warmup_steps=10,
    num_train_epochs=1,
    max_steps=1000,
    tf32=True,
    bf16=True,
    save_steps=1000,
    logging_steps=10,
    logging_dir=log_dir+"logs_trainer/",
    logging_strategy="steps",
    report_to=["tensorboard"],
    push_to_hub=False,
    remove_unused_columns=False,
    label_names=["labels"],
    hub_private_repo=True,
    metric_for_best_model="cer",
    predict_with_generate=False,
    greater_is_better=False,
    generation_max_length=128,
    optim="adafactor",
    weight_decay=0.0025,
    disable_tqdm=False,
    save_total_limit=2,  
    # torch_empty_cache_steps=1,
)

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.cuda.empty_cache()
torch.cuda.set_device(0)

train = train.shuffle(seed=42)
test = test.shuffle(seed=42)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train,
    eval_dataset=test,
    data_collator=data_collator,
    tokenizer=proccesor.feature_extractor,
    compute_metrics=compute_metrics,
    optimizers=(optimizer, scheduler), # optimizers=(get_adamax_optimizer, None)
    callbacks=[CustomTensorBoardCallback(tb_writer), GradientClippingCallback, PredictionLogger]#, ModelFreezingCallback, SavePeftModelCallback, ShuffleCallback],
)

In [None]:
# trainer.evaluate()
trainer.train(resume_from_checkpoint=False)
tb_writer.close()
from torch.utils.tensorboard import SummaryWriter