# train.py フロー検証 (nGPT classifier)

`Sentiment-Circle/utils/demo2.ipynb` と同じ方針で、`train.py` の実行ステップを 1 つずつ追跡しながら **nGPT 分類器** をテストできるようにした検証ノートブックです。

## 0. 準備
- `train.py` で定義されているデータセット前処理・トレーナー初期化の関数を直接呼び出し、フローをそのまま再現します。
- Weights & Biases 連携はデバッグ用途なので無効化しています (`WANDB_MODE=disabled`)。
- `Train_df.csv` / `Valid_df.csv` / `Test_df.csv` から少数サンプルを取り、計算負荷を抑えます。

In [1]:
%load_ext autoreload
%autoreload 2

import json
import logging
import os
import pathlib
import random
import sys
from functools import partial

import numpy as np
import torch
from transformers import AutoConfig, AutoTokenizer, PrinterCallback

os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("WANDB_MODE", "disabled")
os.environ.setdefault("WANDB_DISABLED", "true")

PROJECT_ROOT = pathlib.Path('..').resolve()
UTILS_DIR = PROJECT_ROOT / 'utils'
DATASET_DIR = PROJECT_ROOT / 'dataset'
OUTPUT_ROOT = PROJECT_ROOT / 'outputs'
OUTPUT_ROOT.mkdir(exist_ok=True)

sys.path.append(str(UTILS_DIR))

from train import (
    ModelArguments,
    DataTrainingArguments,
    TrainingArguments,
    load_raw_datasets,
    prepare_label_mappings,
)
from dataset_preprocessing import batch_get_preprocessing_function, get_preprocessing_function
from model.modeling_utils import DataCollatorForBiEncoder, get_model
from clf_trainer import CustomTrainer
from progress_logger import LogCallback
from model.nGPT_model import NGPTWeightNormCallback
from metrics import compute_metrics

logging.basicConfig(level=logging.INFO)
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

print(f"Project root: {PROJECT_ROOT}")
print(f"CUDA available: {torch.cuda.is_available()}")

Project root: /remote/csifs1/disk3/users/yama11235/yama11235/Sentiment-Circle
CUDA available: True


## 1. ハイパーパラメータとクラス分類器設定
`train.sh` のデフォルト値 (学習率・エポック数など) を参考にしつつ、デバッグしやすいようにバッチサイズとサンプル数だけ縮小しています。

In [7]:
MODEL_NAME = "mixedbread-ai/mxbai-embed-large-v1"
POOLER_TYPE = "avg"
MAX_SEQ_LENGTH = 512
LEARNING_RATE = 1e-4
TRAIN_BATCH_SIZE = 32   # train.sh の 128 だとメモリを圧迫するため縮小
EVAL_BATCH_SIZE = 64    # train.sh の 256 から縮小
NUM_EPOCHS = 1
GRAD_ACCUM = 1
LOGGING_STEPS = 5
EVAL_STEPS = 5
MAX_TRAIN_SAMPLES = 64
MAX_EVAL_SAMPLES = 64
MAX_PRED_SAMPLES = 64

classifier_config = {
    "sentiment": {
        "type": "nGPT",
        "layer": -1,
        "objective": "infoNCE",
        "distance": "cosine",
        "pooler_type": POOLER_TYPE,
        "dropout": 0.1,
        "bias": False,
        "base_scale": 0.03125
    }
}
CLASSIFIER_CONFIG_PATH = OUTPUT_ROOT / "ngpt_classifier_config.json"
with open(CLASSIFIER_CONFIG_PATH, "w") as f:
    json.dump(classifier_config, f, indent=2)

model_args = ModelArguments(
    model_name_or_path=MODEL_NAME,
    pooler_type=POOLER_TYPE,
    encoding_type="bi_encoder",
    freeze_encoder=True,
    classifier_configs=str(CLASSIFIER_CONFIG_PATH),
)

data_args = DataTrainingArguments(
    max_seq_length=MAX_SEQ_LENGTH,
    max_train_samples=MAX_TRAIN_SAMPLES,
    max_eval_samples=MAX_EVAL_SAMPLES,
    max_predict_samples=MAX_PRED_SAMPLES,
    train_file=[str(DATASET_DIR / "Train_df.csv")],
    validation_file=[str(DATASET_DIR / "Valid_df.csv")],
    test_file=[str(DATASET_DIR / "Test_df.csv")],
)

training_args = TrainingArguments(
    output_dir=str(OUTPUT_ROOT / "ngpt_debug_run"),
    overwrite_output_dir=True,
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=EVAL_BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_EPOCHS,
    lr_scheduler_type="constant",
    logging_steps=LOGGING_STEPS,
    eval_steps=EVAL_STEPS,
    eval_strategy="steps",
    save_strategy="no",
    bf16=True,
    do_train=True,
    do_eval=True,
    do_predict=True,
    report_to=["none"],
    wandb_project_name="sentiment_info_nce_ngpt_demo",
    wandb_project="sentiment_circle",
    seed=42,
)
training_args.remove_unused_columns = False

print(model_args)
print(data_args)
print(training_args)

ModelArguments(model_name_or_path='mixedbread-ai/mxbai-embed-large-v1', config_name=None, tokenizer_name=None, cache_dir=None, use_fast_tokenizer=True, model_revision='main', use_auth_token=False, use_flash_attention='eager', device_map=None, encoding_type='bi_encoder', pooler_type='avg', freeze_encoder=True, transform=False, classifier_configs='/remote/csifs1/disk3/users/yama11235/yama11235/Sentiment-Circle/outputs/ngpt_classifier_config.json', corr_labels=None, corr_weights=None, aspect_key=None, objective='regression')
DataTrainingArguments(max_seq_length=512, overwrite_cache=False, pad_to_max_length=False, max_train_samples=64, max_eval_samples=64, max_predict_samples=64, train_file=['/remote/csifs1/disk3/users/yama11235/yama11235/Sentiment-Circle/dataset/Train_df.csv'], validation_file=['/remote/csifs1/disk3/users/yama11235/yama11235/Sentiment-Circle/dataset/Valid_df.csv'], test_file=['/remote/csifs1/disk3/users/yama11235/yama11235/Sentiment-Circle/dataset/Test_df.csv'], max_simil

## 2. データセット読み込み
`load_raw_datasets` で `Train/Valid/Test` を読み込み、必要であれば `sentence1` 列へリネームします。

In [8]:
raw_datasets, sentence3_flag = load_raw_datasets(
    model_args=model_args,
    data_args=data_args,
    training_args=training_args,
    seed=training_args.seed,
)
print(raw_datasets)
print(f"sentence3 flag: {sentence3_flag}")
print(raw_datasets["train"][0])

2025-11-16 12:11:37,196 - train - INFO: Load train files: ['/remote/csifs1/disk3/users/yama11235/yama11235/Sentiment-Circle/dataset/Train_df.csv']
2025-11-16 12:11:37,198 - train - INFO: Load validation files: ['/remote/csifs1/disk3/users/yama11235/yama11235/Sentiment-Circle/dataset/Valid_df.csv']
2025-11-16 12:11:37,198 - train - INFO: Load test files: ['/remote/csifs1/disk3/users/yama11235/yama11235/Sentiment-Circle/dataset/Test_df.csv']


DatasetDict({
    train: Dataset({
        features: ['sentence1', 'labels'],
        num_rows: 64
    })
    validation: Dataset({
        features: ['sentence1', 'labels'],
        num_rows: 64
    })
    test: Dataset({
        features: ['sentence1', 'labels'],
        num_rows: 64
    })
})
sentence3 flag: False
{'sentence1': 'is cold and wished to go back to bed', 'labels': 'relief'}


## 3. ラベルマッピング & クラス分類器辞書
CSV の `labels` 列を `sentiment` に付け替え、`nGPT` 分類器設定を `prepare_label_mappings` に渡します。

In [9]:
(
    raw_datasets,
    labels,
    id2label,
    label2id,
    aspect_key,
    classifier_configs,
    classifier_configs_for_trainer,
    corr_labels,
    corr_weights,
    label_name_mappings,
) = prepare_label_mappings(
    raw_datasets=raw_datasets,
    model_args=model_args,
    data_args=data_args,
)
print(f"labels: {labels}")
print(f"aspect_key: {aspect_key}")
print(f"classifier configs: {json.dumps(classifier_configs, indent=2)}")

labels: ['labels']
aspect_key: ['sentiment']
classifier configs: {
  "sentiment": {
    "type": "nGPT",
    "layer": -1,
    "objective": "infoNCE",
    "distance": "cosine",
    "output_dim": 256,
    "dropout": 0.1,
    "bias": false,
    "base_scale": 0.03125
  }
}


## 4. Config / Tokenizer / モデル (nGPT 判定込み)
ここから `train.py` と同様に `AutoConfig` / `AutoTokenizer` をロードし、nGPT ブロック検出によって最適化条件を調整します。

In [11]:
config = AutoConfig.from_pretrained(
    model_args.config_name if model_args.config_name else model_args.model_name_or_path,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)

tokenizer = AutoTokenizer.from_pretrained(
    model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
    use_fast=model_args.use_fast_tokenizer,
)

model_cls = get_model(model_args)
config.update(
    {
        "freeze_encoder": model_args.freeze_encoder,
        "model_name_or_path": model_args.model_name_or_path,
        "pooler_type": model_args.pooler_type,
        "transform": model_args.transform,
        "attn_implementation": model_args.use_flash_attention,
        "device_map": model_args.device_map,
    }
)
labels_for_heads = list(classifier_configs_for_trainer.keys())
id2_head = {i: head for i, head in enumerate(labels_for_heads)}
model = model_cls(model_config=config, classifier_configs=classifier_configs)

if model_args.freeze_encoder:
    for param in model.backbone.parameters():
        param.requires_grad = False

use_ngpt_riemann = bool(getattr(model, "use_ngpt_blocks", False))
print(f"use_ngpt_blocks: {use_ngpt_riemann}")

Some weights of the model checkpoint at mixedbread-ai/mxbai-embed-large-v1 were not used when initializing BertModel: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
2025-11-16 12:17:11,683 - model.modeling_encoders - INFO: Detected nGPT-style classifier block(s); applying initial weight normalization.


use_ngpt_blocks: True


In [12]:
model

BiEncoderForClassification(
  (backbone): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024

## 5. トークナイズと特徴量生成
`get_preprocessing_function` / `batch_get_preprocessing_function` を選び、`DatasetDict.map` で `tokenizer` を実行します。

In [13]:
padding = "longest" if data_args.pad_to_max_length else False
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
if sentence3_flag:
    preprocess_function = batch_get_preprocessing_function(
        tokenizer=tokenizer,
        sentence1_key="sentence1",
        sentence2_key="sentence2",
        sentence3_key="sentence3",
        sentence3_flag=sentence3_flag,
        aspect_key=aspect_key,
        padding=padding,
        max_seq_length=max_seq_length,
        model_args=model_args,
        scale=None,
    )
    batched = True
else:
    preprocess_function = get_preprocessing_function(
        tokenizer=tokenizer,
        sentence1_key="sentence1",
        sentence2_key="sentence2",
        sentence3_key="sentence3",
        sentence3_flag=sentence3_flag,
        aspect_key=aspect_key,
        padding=padding,
        max_seq_length=max_seq_length,
        model_args=model_args,
        scale=None,
    )
    batched = False

processed_datasets = raw_datasets.map(
    preprocess_function,
    batched=batched,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
    remove_columns=raw_datasets["train"].column_names,
)
train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation"]
predict_dataset = processed_datasets["test"]
print(train_dataset[0].keys())

Running tokenizer on dataset:   0%|          | 0/64 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/64 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/64 [00:00<?, ? examples/s]

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'input_ids_2', 'attention_mask_2', 'token_type_ids_2', 'active_heads', 'labels'])


In [14]:
train_dataset[0]

{'input_ids': [101, 2003, 3147, 1998, 6257, 2000, 2175, 2067, 2000, 2793, 102],
 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 'input_ids_2': None,
 'attention_mask_2': None,
 'token_type_ids_2': None,
 'active_heads': ['sentiment'],
 'labels': [8]}

## 6. DataCollator / Trainer 構築
`CustomTrainer` を初期化し、nGPT 用の正規化コールバックやメトリクス関数を登録します。

In [15]:
collator_dtype = getattr(config, "torch_dtype", torch.float32)
data_collator = DataCollatorForBiEncoder(
    tokenizer=tokenizer,
    padding="max_length",
    pad_to_multiple_of=None,
    dtype=collator_dtype,
)

trainer_ref = {"trainer": None}

def train_centroid_getter():
    trainer_obj = trainer_ref["trainer"]
    if trainer_obj is None:
        return {}
    return trainer_obj.get_train_label_centroids()

def compute_fn(eval_pred):
    trainer_obj = trainer_ref["trainer"]
    embedding_mode = "classifier"
    if trainer_obj is not None and getattr(trainer_obj, "use_original_eval_embeddings", False):
        embedding_mode = "original"
    return compute_metrics(
        eval_pred,
        classifier_configs=classifier_configs_for_trainer,
        id2_head=id2_head,
        train_centroid_getter=train_centroid_getter,
        embedding_eval_mode=embedding_mode,
    )

ngpt_callback = NGPTWeightNormCallback(enabled=use_ngpt_riemann)
trainer = CustomTrainer(
    model=model,
    args=training_args,
    classifier_configs=classifier_configs_for_trainer,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_fn,
    tokenizer=tokenizer,
    callbacks=[LogCallback, ngpt_callback],
    dtype=collator_dtype,
    corr_labels=corr_labels,
    corr_weights=corr_weights,
    tsne_save_dir=os.path.join(training_args.output_dir, "tsne_plots"),
    tsne_label_mappings=label_name_mappings,
)
trainer_ref["trainer"] = trainer
trainer
trainer.remove_callback(PrinterCallback)

  super().__init__(*args, **kwargs)


## 7. 評価→学習→テスト
`train.py` と同様に、初期 `evaluate` → `train` → `test (evaluate on test split)` の順に実行してログを確認します。

In [16]:
baseline_metrics = trainer.evaluate(eval_dataset=eval_dataset)
baseline_metrics

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


StopIteration: Caught StopIteration in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/remote/csifs1/disk3/users/yama11235/yama11235/SLBERT/my-project/.venv/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py", line 97, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/remote/csifs1/disk3/users/yama11235/yama11235/SLBERT/my-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/remote/csifs1/disk3/users/yama11235/yama11235/SLBERT/my-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/remote/csifs1/disk3/users/yama11235/yama11235/Sentiment-Circle/utils/model/modeling_encoders.py", line 270, in forward
    "token_type_ids_2": token_type_ids_2,
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
StopIteration


In [None]:
train_result = trainer.train()
train_metrics = train_result.metrics
train_metrics["train_samples"] = len(train_dataset)
train_metrics

In [None]:
test_metrics = trainer.evaluate(eval_dataset=predict_dataset)
test_metrics