<a href="https://colab.research.google.com/github/y-chiba1008/talk-support-asr/blob/main/notebooks/training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip -q install evaluate ginza ja-ginza jiwer audiomentations

In [None]:
# ドライブのdataフォルダをマウント

from google.colab import drive
import os

## Google Driveをマウント
drive.mount('/drive')

## シンボリックリンクの作成
links = [
    {'from': '/drive/Othercomputers/マイ コンピュータ/data', 'to': '/content/data'},
    {'from': '/drive/Othercomputers/マイ コンピュータ/model', 'to': '/content/model'},
];

for link in links:
    drive_folder_path = link['from']
    colab_link_path = link['to']

    ### リンク先が既に存在する場合は削除
    if os.path.isdir(colab_link_path):
        print(f'{colab_link_path}がすでに存在する為、一度削除します')
        !rm -rf "$colab_link_path"

    ### シンボリックリンクを作成
    !ln -s "$drive_folder_path" "$colab_link_path"

In [None]:
# 定数
PREPROCESSED_DATA_PATH = 'data/02_all/preprocessed_data'
BASE_MODEL = 'openai/whisper-small'
SEED = 42
DROPOUT_RATE = 0.1
LORA_ENABLE = True
LORA_R = 32

## データコレータ

In [None]:
from dataclasses import dataclass
import torch
from transformers import WhisperProcessor

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    '''データコレータ
    '''
    processor: WhisperProcessor

    def __call__(self, features: list[dict[str, list[int] | torch.Tensor]]) -> dict[str, torch.Tensor]:

        # 音響特徴量側をまとめる処理
        # (一応バッチ単位でパディングしているが、すべて30秒分であるはず)
        input_features \
            = [{'input_features': feature['input_features']} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors='pt')

        # トークン化された系列をバッチ単位でパディング
        label_features = [{'input_ids': feature['labels']} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors='pt')

        # attention_maskが0の部分は、トークンを-100に置き換えてロス計算時に無視させる
        # -100を無視するのは、PyTorchの仕様
        labels \
            = labels_batch['input_ids'].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # BOSトークンがある場合は削除
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        # 整形したlabelsをバッチにまとめる
        batch['labels'] = labels

        return batch


## 評価関数

In [None]:
from dataclasses import dataclass
import evaluate
from evaluate import EvaluationModule
import ginza
import spacy
from spacy import Language
from transformers import WhisperProcessor

@dataclass
class ComputeMetrics:
    """メトリクス計算処理"""
    processor: WhisperProcessor
    metric: EvaluationModule
    nlp: Language

    def __init__(self, processor: WhisperProcessor) -> None:
        # Processor
        self.processor = processor

        # metric(wer)
        self.metric = evaluate.load('wer')

        # nlp
        nlp = spacy.load('ja_ginza')
        ginza.set_split_mode(nlp, 'C')
        self.nlp = nlp

    def __call__(self, pred):
        """metricsを計算する"""
        pred_ids = pred.predictions
        label_ids = pred.label_ids

        # '-100'をパディングトークンに変換
        label_ids[label_ids == -100] = self.processor.tokenizer.pad_token_id

        # トークン列→文字列に変換
        pred_strs = self.processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
        label_strs = self.processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

        wer = self.compute_wer(pred_strs, label_strs)

        return {'wer': wer}

    def insert_spaces(self, sentences: list[str]) -> list[str]:
        '''文字列を単語ごとに分解して、スペース区切りにする'''
        spaced_sentences = []
        for sentence in sentences:
            doc = self.nlp(sentence)
            words = [token.text for token in doc]
            spaced_sentence = ' '.join(words)
            spaced_sentences.append(spaced_sentence)

        return spaced_sentences

    def compute_wer(self, pred_strs: list[str], label_strs: list[str]):
        '''WERを計算する'''
        # 予測と正解ラベルの単語間にスペースを挿入
        preds_spaced = self.insert_spaces(pred_strs)
        labels_spaced = self.insert_spaces(label_strs)

        # WERを計算
        wer = 100 * self.metric.compute(predictions=preds_spaced, references=labels_spaced)
        return wer

In [None]:
from transformers import GenerationConfig
from transformers import WhisperForConditionalGeneration
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

def load_model(base_model, lora=False):
    """モデルのロード"""
    model = WhisperForConditionalGeneration.from_pretrained(base_model)

    model.config.use_cache = False

    # 各種ドロップアウトパラメータを更新
    model.config.dropout = DROPOUT_RATE            # 全結合層などの基本ドロップアウト
    model.config.attention_dropout = DROPOUT_RATE  # アテンション層のドロップアウト
    model.config.activation_dropout = DROPOUT_RATE # 活性化関数後のドロップアウト

    # SpecAugment系設定（データの一部をマスキングして過学習を防ぐ）
    model.config.apply_spec_augment = True
    model.config.mask_time_prob = 0.05       # 時間方向にマスクをかける確率
    model.config.mask_time_length = 10      # 時間方向のマスクの最大幅
    model.config.mask_feature_prob = 0.05   # 周波数方向にマスクをかける確率
    model.config.mask_feature_length = 10   # 周波数方向のマスクの最大幅

    # LoRA適用
    if lora:
        model = prepare_model_for_kbit_training(model)
        lora_config = LoraConfig(
            r=LORA_R,                           # ランク数。大きいほど表現力が増すが過学習のリスクも増える（8, 16, 32が一般的）
            lora_alpha=LORA_R * 2,              # スケーリング係数。通常は r の2倍程度に設定
            target_modules=['q_proj', 'v_proj'], # Whisperのどの層に適用するか（Attention層が一般的）
            lora_dropout=0.1,                   # LoRA層内のドロップアウト
            bias='none',
        )
        model = get_peft_model(model, lora_config)

    # その他の設定
    model.generation_config = GenerationConfig.from_pretrained(base_model)
    model.generation_config.language = 'japanese'
    model.generation_config.task = 'transcribe'
    model.generation_config.forced_decoder_ids = None # 念のためクリア
    model.generation_config.use_cache = False
    return model

In [None]:
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
from transformers import EarlyStoppingCallback
from transformers import WhisperProcessor

def get_trainer(model, dataset):
    training_args = Seq2SeqTrainingArguments(
        output_dir='./whisper-small-ja', # 出力ディレクトリ
        seed=SEED,
        data_seed=SEED,

        learning_rate=1e-5, # 学習率
        # metric_for_best_model='wer',
        metric_for_best_model='eval_loss',

        # max_steps=10, # 学習ステップ数
        # warmup_steps=2,
        # eval_steps=2,
        # logging_steps=2,
        # save_steps=2,

        max_steps=4000, # 学習ステップ数
        warmup_steps=500,
        eval_steps=50,
        logging_steps=50,

        # チェックポイント
        save_total_limit=2,
        save_steps=50,

        # 過学習対策
        weight_decay=0.01,
        lr_scheduler_type='cosine',
        # lr_scheduler_type='linear',

        # max_steps=4000, # Hugging Faceブログではこちら
        # warmup_steps=500, # Hugging Faceブログではこちら

        per_device_train_batch_size=5,
        per_device_eval_batch_size=5,
        group_by_length=True,

        gradient_accumulation_steps=1,
        # gradient_checkpointing=True,
        fp16=True,
        # fp16=False, # lossの計算の検証のため
        predict_with_generate=True,
        generation_max_length=225,
        load_best_model_at_end=True,
        greater_is_better=False,
        report_to=['none'],
        eval_strategy='steps', # 評価を行うタイミング (steps, epochなど)
        push_to_hub=False,
        gradient_checkpointing=False,
    )

    # 早期終了
    early_stopping = EarlyStoppingCallback(
        early_stopping_patience=3,        # 3回(計1500ステップ)改善がなければ終了
        early_stopping_threshold=0.001    # どの程度の改善を「進歩」とみなすかの閾値（任意）
    )

    # processor
    processor = WhisperProcessor.from_pretrained(BASE_MODEL,
                                             language='Japanese',
                                             task='transcribe');

    # データコレータ
    data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor)

    # メトリクス計算クラス
    compute_metrics = ComputeMetrics(processor)

    trainer = Seq2SeqTrainer(
        args=training_args,
        model=model,
        train_dataset=dataset['train'],
        eval_dataset=dataset['validation'],
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        processing_class=processor.feature_extractor,
        callbacks=[early_stopping],
    )

    return trainer

In [None]:
import torch
from datasets import load_from_disk

# GPUで実行する
assert torch.cuda.is_available()

model = load_model(BASE_MODEL, lora=LORA_ENABLE)
dataset = load_from_disk(PREPROCESSED_DATA_PATH)
trainer = get_trainer(model, dataset)
trainer.train()

In [None]:
def show_predict(predict_result, processor):
    pred_sentences = processor.tokenizer.batch_decode(predict_result.predictions, skip_special_tokens=True)
    label_sentences = processor.tokenizer.batch_decode(predict_result.label_ids, skip_special_tokens=True)

    for pre, lab in zip(pred_sentences, label_sentences):
        lab = lab.splitlines()[0]
        pre = pre.splitlines()[0]
        print(f'ラベル  : {lab}')
        print(f'推論結果: {pre}')

    print(f'\n')
    print(f'WER: {predict_result.metrics['test_wer']}')

In [None]:
predict_result = trainer.predict(dataset['test'],
                                language='ja',
                                task='transcribe')
show_predict(predict_result, trainer.data_collator.processor)

In [None]:
from pathlib import Path

def save_model(trainer, output_dir):
    output_dir = Path(output_dir)
    assert not output_dir.exists(), f'すでに存在するディレクトリです[{output_dir}]\nデータの上書きを防ぐため、保存先は存在しないディレクトリを指定してください'

    # モデル本体の保存
    trainer.save_model(output_dir)

    # プロセッサ（Feature ExtractorとTokenizer）の保存
    trainer.data_collator.processor.save_pretrained(output_dir)

# 保存先のパス
output_dir = './data/model/20260126_01'

# save_model(trainer, output_dir)