ライブラリのインストール

In [None]:
# 依存関係のインストール
!pip install numpy==1.26.4 torch==2.6.0 transformers==4.49.0 tokenizers==0.21.0 accelerate==1.4.0 trl==0.15.2 peft==0.14.0 sentencepiece==0.2.0 wandb deepspeed==0.16.4 bitsandbytes==0.45.3 scipy==1.15.2
!pip install datasets
!pip install flash-attn --no-build-isolation
!pip install --upgrade accelerate

モデルのパス

In [None]:
# Google Driveをマウント
from google.colab import drive
drive.mount('/content/drive')

# データとモデルのパスを設定（適宜変更）
DATA_FILE_PATH = '/content/drive/MyDrive/program/code_translation/data/sample_data.jsonl
MODEL_PATH = '/content/drive/MyDrive/program/code_translation/base_model/llm-jp-3-1.8b'
OUTPUT_DIR = '/content/drive/MyDrive/program/code_translation/model/test_model_1'

Weights & Biasesにログイン

In [None]:
import wandb
wandb.login()

In [None]:
train_chat.pyのコード

In [None]:
import logging
from dataclasses import dataclass
from typing import Optional, List

import torch
import wandb
from peft import LoraConfig
from datasets import disable_caching, load_dataset, concatenate_datasets
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    BitsAndBytesConfig,
)
from trl import SFTTrainer

disable_caching()

logger = logging.getLogger(__name__)

@dataclass
class SFTTrainingArguments:
    model_name_or_path: str
    data_files: List[str]
    eval_data_files: Optional[List[str]] = None
    tokenizer_name_or_path: Optional[str] = None
    use_fast: bool = True
    additional_special_tokens: Optional[List[str]] = None
    max_seq_length: int = 4096
    load_in_8bit: bool = False
    load_in_4bit: bool = False
    use_flash_attention_2: bool = False
    use_peft: bool = False
    peft_target_model: Optional[str] = "llm-jp"
    peft_target_modules: Optional[List[str]] = None
    peft_lora_r: int = 8
    peft_lora_alpha: int = 32
    peft_lora_dropout: float = 0.05
    wandb_project: Optional[str] = None
    wandb_run_name: Optional[str] = None
    wandb_log_steps: Optional[int] = None

    def __post_init__(self):
        if self.load_in_8bit and self.load_in_4bit:
            raise ValueError("load_in_8bit and load_in_4bit are mutually exclusive")
        if self.peft_target_model and self.peft_target_modules is None:
            if self.peft_target_model == "llm-jp":
                self.peft_target_modules = ["c_attn", "c_proj", "c_fc"]
            elif self.peft_target_model == "llama":
                self.peft_target_modules = [
                    "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"
                ]
            elif self.peft_target_model == "llama-all":
                self.peft_target_modules = [
                    "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head", "embed_tokens"
                ]
            else:
                logger.warning(
                    f"peft_target_model '{self.peft_target_model}' is not supported, "
                    f"so peft_target_modules is set to None."
                )

    def from_pretrained_kwargs(self, training_args):
        if self.load_in_8bit:
            kwargs = {"load_in_8bit": True}
        elif self.load_in_4bit:
            kwargs = {
                "quantization_config": BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_use_double_quant=True,
                    bnb_4bit_quant_type="nf4",
                    bnb_4bit_compute_dtype=torch.bfloat16,
                )
            }
        elif training_args.bf16:
            kwargs = {"torch_dtype": torch.bfloat16}
        else:
            kwargs = {"torch_dtype": torch.float16}
        kwargs["use_flash_attention_2"] = self.use_flash_attention_2
        return kwargs

def load_datasets(data_files, tokenizer, max_seq_length=2048):
    datasets = []
    for data_file in data_files:
        dataset = load_dataset("json", data_files=data_file)
        dataset = dataset["train"]

        def tokenize_function(example):
            tokenized = tokenizer.apply_chat_template(
                example["messages"],
                tokenize=True,
                add_generation_prompt=False,
                truncation=True,
                max_length=max_seq_length,
                return_tensors="pt",
                return_dict=True
            )
            return {
                "input_ids": tokenized["input_ids"].squeeze(0).tolist(),
                "attention_mask": tokenized["attention_mask"].squeeze(0).tolist()
            }

        dataset = dataset.map(tokenize_function, remove_columns=dataset.column_names)
        datasets.append(dataset)
    return concatenate_datasets(datasets)

def main():
    # トレーニング引数の設定
    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        num_train_epochs=1,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=4,
        gradient_accumulation_steps=8,
        learning_rate=5e-5,
        warmup_ratio=0.1,
        lr_scheduler_type="cosine",
        bf16=True,
        logging_steps=10,
        report_to=["wandb"],  # Weights & Biasesにログを送信
    )

    # SFTトレーニング引数の設定
    sft_training_args = SFTTrainingArguments(
        model_name_or_path=MODEL_PATH,
        data_files=[DATA_FILE_PATH],
        max_seq_length=4096,
        load_in_4bit=True,  # ColabのGPUメモリ節約のため4bitロードを使用
        use_flash_attention_2=True,  # Flash Attentionを使用（対応する場合）
        wandb_project="puwaer-sft",
        wandb_run_name="doujinshi-1.8b-instruct-1",
        wandb_log_steps=10,
    )

    # W&Bの初期化
    wandb.init(
        project=sft_training_args.wandb_project,
        name=sft_training_args.wandb_run_name,
        config=vars(training_args),
    )

    # トークナイザーのロード
    tokenizer_name_or_path = sft_training_args.tokenizer_name_or_path or sft_training_args.model_name_or_path
    logger.info(f"Loading tokenizer from {tokenizer_name_or_path}")
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_name_or_path,
        use_fast=sft_training_args.use_fast,
        additional_special_tokens=sft_training_args.additional_special_tokens,
        trust_remote_code=True,
    )

    # チャットテンプレートの設定
    chat_template = (
        "{{bos_token}}{% for message in messages %}"
        "{% if message['role'] == 'user' %}{{ '\\n\\n### 指示:\\n' + message['content'] }}"
        "{% elif message['role'] == 'system' %}{{ '以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。' }}"
        "{% elif message['role'] == 'assistant' %}{{ '\\n\\n### 応答:\\n' + message['content'] + eos_token }}"
        "{% endif %}"
        "{% if loop.last and add_generation_prompt %}{{ '\\n\\n### 応答:\\n' }}{% endif %}"
        "{% endfor %}"
    )
    tokenizer.chat_template = chat_template
    logger.info("Custom chat template applied to tokenizer")

    if tokenizer.bos_token is None:
        tokenizer.bos_token = "<|begin_of_text|>"
        logger.info(f"Set default bos_token: {tokenizer.bos_token}")
    if tokenizer.eos_token is None:
        tokenizer.eos_token = "<|end_of_text|>"
        logger.info(f"Set default eos_token: {tokenizer.eos_token}")

    # データのロード
    logger.info("Loading data")
    train_dataset = load_datasets(sft_training_args.data_files, tokenizer, sft_training_args.max_seq_length)

    # モデルのロード
    logger.info(f"Loading model from {sft_training_args.model_name_or_path}")
    kwargs = sft_training_args.from_pretrained_kwargs(training_args)
    model = AutoModelForCausalLM.from_pretrained(
        sft_training_args.model_name_or_path,
        trust_remote_code=True,
        **kwargs,
    )

    # トレーナーの設定
    logger.info("Setting up trainer")
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        processing_class=tokenizer,
        train_dataset=train_dataset,
        eval_dataset=None,  # 評価データがない場合
    )

    # トレーニングの実行
    logger.info("Training")
    trainer.train()

    # モデルの保存
    logger.info("Saving model")
    trainer.save_model()

    # W&Bの終了
    wandb.finish()

if __name__ == "__main__":
    logging.basicConfig(
        level=logging.DEBUG,
        format="%(asctime)s %(name)s:%(lineno)d: %(levelname)s: %(message)s",
    )
    main()