# train.py フロー検証 (nGPT classifier)

train.pyの実行ステップを1つずつ追跡しながらnGPT分類器をテストできる検証ノートブックです。

**リファクタリング後のモジュール構造に対応（修正版）**

## 0. 準備とインポート

In [1]:
%load_ext autoreload
%autoreload 2

import json
import logging
import os
import pathlib
import random
import sys
from functools import partial

import numpy as np
import torch
from transformers import AutoConfig, AutoTokenizer, PrinterCallback

# 環境変数設定
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("WANDB_MODE", "disabled")
os.environ.setdefault("WANDB_DISABLED", "true")

# パス設定
PROJECT_ROOT = pathlib.Path('..').resolve()
UTILS_DIR = PROJECT_ROOT / 'utils'
DATASET_DIR = PROJECT_ROOT / 'dataset'
OUTPUT_ROOT = PROJECT_ROOT / 'outputs'
OUTPUT_ROOT.mkdir(exist_ok=True)

# パスに追加（後方互換性のため）
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(UTILS_DIR))

# リファクタリング後のインポート
from utils.config import ModelArguments, DataTrainingArguments, TrainingArguments
from utils.data import load_raw_datasets, prepare_label_mappings
from utils.training import (
    setup_model_and_config,
    setup_tokenizer,
    prepare_datasets,
    create_trainer,
)

# ロギング設定
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# シード設定
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

print(f"Project root: {PROJECT_ROOT}")
print(f"CUDA available: {torch.cuda.is_available()}")

Project root: /remote/csifs1/disk3/users/yama11235/yama11235/Sentiment-Circle
CUDA available: True


## 1. 引数の設定

train.pyと同じようにModelArguments、DataTrainingArguments、TrainingArgumentsを設定します。

In [2]:
# モデル引数
model_args = ModelArguments(
    model_name_or_path='mixedbread-ai/mxbai-embed-large-v1',
    encoding_type='bi_encoder',
    freeze_encoder=True,
    classifier_configs=str(OUTPUT_ROOT / 'ngpt_classifier_config.json'),
    device_map='cuda:0',
)

# データ引数
data_args = DataTrainingArguments(
    max_seq_length=512,
    max_train_samples=64,
    max_eval_samples=64,
    max_predict_samples=64,
    train_file=[str(DATASET_DIR / 'Train_df.csv')],
    validation_file=[str(DATASET_DIR / 'Valid_df.csv')],
    test_file=[str(DATASET_DIR / 'Test_df.csv')],
)

# トレーニング引数
training_args = TrainingArguments(
    output_dir=str(OUTPUT_ROOT / 'ngpt_test'),
    do_train=True,
    do_eval=True,
    do_predict=True,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=64,
    gradient_accumulation_steps=2,
    learning_rate=2e-5,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    save_steps=50,
    eval_steps=50,
    eval_strategy='steps',
    save_strategy='steps',
    load_best_model_at_end=True,
    metric_for_best_model='eval_loss',
    greater_is_better=False,
    bf16=True if torch.cuda.is_bf16_supported() else False,
    fp16=False if torch.cuda.is_bf16_supported() else True,
    seed=seed,
    report_to=[],
    remove_unused_columns=False,
)

print(f"✓ Model: {model_args.model_name_or_path}")
print(f"✓ Encoding type: {model_args.encoding_type}")
print(f"✓ Max seq length: {data_args.max_seq_length}")
print(f"✓ Train samples: {data_args.max_train_samples}")
print(f"✓ Batch size: {training_args.per_device_train_batch_size}")
print(f"✓ Learning rate: {training_args.learning_rate}")
print(f"✓ Epochs: {training_args.num_train_epochs}")

✓ Model: mixedbread-ai/mxbai-embed-large-v1
✓ Encoding type: bi_encoder
✓ Max seq length: 512
✓ Train samples: 64
✓ Batch size: 4
✓ Learning rate: 2e-05
✓ Epochs: 3


## 2. データセット読み込み

`load_raw_datasets`関数を使用してデータセットを読み込みます。

In [3]:
# データセット読み込み（train.pyと同じ引数）
raw_datasets, sentence3_flag = load_raw_datasets(
    model_args=model_args,
    data_args=data_args,
    training_args=training_args,
    seed=seed,
)

print("\nRaw datasets loaded:")
print(raw_datasets)
print(f"\nsentence3_flag: {sentence3_flag}")
print(f"\nTrain: {len(raw_datasets['train'])} samples")
print(f"Validation: {len(raw_datasets['validation'])} samples")
print(f"Test: {len(raw_datasets['test'])} samples")

# サンプル確認
print("\nSample from train:")
print(raw_datasets['train'][0])

2025-11-17 14:27:23,104 - utils.data.data_loader - INFO: Load train files: ['/remote/csifs1/disk3/users/yama11235/yama11235/Sentiment-Circle/dataset/Train_df.csv']
2025-11-17 14:27:23,105 - utils.data.data_loader - INFO: Load validation files: ['/remote/csifs1/disk3/users/yama11235/yama11235/Sentiment-Circle/dataset/Valid_df.csv']
2025-11-17 14:27:23,106 - utils.data.data_loader - INFO: Load test files: ['/remote/csifs1/disk3/users/yama11235/yama11235/Sentiment-Circle/dataset/Test_df.csv']



Raw datasets loaded:
DatasetDict({
    train: Dataset({
        features: ['sentence1', 'labels'],
        num_rows: 64
    })
    validation: Dataset({
        features: ['sentence1', 'labels'],
        num_rows: 64
    })
    test: Dataset({
        features: ['sentence1', 'labels'],
        num_rows: 64
    })
})

sentence3_flag: False

Train: 64 samples
Validation: 64 samples
Test: 64 samples

Sample from train:
{'sentence1': 'is cold and wished to go back to bed', 'labels': 'relief'}


## 3. ラベルマッピングの準備

`prepare_label_mappings`関数でラベルマッピングと分類器設定を準備します。

In [4]:
# ラベルマッピング準備（train.pyと同じ引数）
(
    raw_datasets,
    labels,
    id2label,
    label2id,
    aspect_key,
    classifier_configs,
    classifier_configs_for_trainer,
    corr_labels,
    corr_weights,
    label_name_mappings,
) = prepare_label_mappings(
    raw_datasets=raw_datasets,
    model_args=model_args,
    data_args=data_args,
)

print("\nLabel mappings prepared:")
print(f"Labels: {labels}")
print(f"Aspect keys: {aspect_key}")
print(f"\nClassifier configs:")
for name, config in classifier_configs.items():
    print(f"  {name}: {config.get('type', 'N/A')}")
print(f"\nLabel name mappings:")
print(json.dumps(label_name_mappings, indent=2))


Label mappings prepared:
Labels: ['labels']
Aspect keys: ['sentiment']

Classifier configs:
  sentiment: nGPT

Label name mappings:
{
  "sentiment": {
    "0": "anger",
    "1": "boredom",
    "2": "disgust",
    "3": "excitement",
    "4": "fear",
    "5": "gratitude",
    "6": "joy",
    "7": "optimism",
    "8": "relief",
    "9": "sadness",
    "10": "surprise"
  }
}


## 4. モデルと設定のセットアップ

`setup_model_and_config`関数でモデルとconfigを初期化します。

In [5]:
# モデルとconfig設定（train.pyと同じ引数）
config, model, use_ngpt_riemann = setup_model_and_config(
    model_args=model_args,
    training_args=training_args,
    labels=list(classifier_configs_for_trainer.keys()),
    id2label=id2label,
    label2id=label2id,
    classifier_configs=classifier_configs,
)

print(f"\nModel: {model.__class__.__name__}")
print(f"Config: {config.__class__.__name__}")
print(f"Use nGPT Riemann: {use_ngpt_riemann}")
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Some weights of the model checkpoint at mixedbread-ai/mxbai-embed-large-v1 were not used when initializing BertModel: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
2025-11-17 14:27:30,475 - utils.model.modeling_encoders - INFO: Detected nGPT-style classifier block(s); applying initial weight normalization.
2025-11-17 14:27:30,516 - utils.training.train_setup - INFO: nGPT-style classifier detected. Enabling pseudo-Riemann weight normalization and nGPT-friendly optimizer settings.



Model: BiEncoderForClassification
Config: BertConfig
Use nGPT Riemann: True

Total parameters: 350,880,768
Trainable parameters: 16,788,480


### モデルのアーキテクチャ

In [40]:
model

BiEncoderForClassification(
  (backbone): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024

## 5. トークナイザーのセットアップ

`setup_tokenizer`関数でトークナイザーを初期化します。

In [6]:
# トークナイザー設定（train.pyと同じ引数）
tokenizer = setup_tokenizer(model_args)

print(f"\nTokenizer: {tokenizer.__class__.__name__}")
print(f"Vocab size: {len(tokenizer)}")
print(f"Model max length: {tokenizer.model_max_length}")


Tokenizer: BertTokenizerFast
Vocab size: 30522
Model max length: 512


## 6. データセットの前処理

`prepare_datasets`関数でトークナイズと前処理を適用します。

In [7]:
# データセット前処理（train.pyと同じ引数）
train_dataset, eval_dataset, predict_dataset, max_train_samples = prepare_datasets(
    raw_datasets=raw_datasets,
    tokenizer=tokenizer,
    data_args=data_args,
    model_args=model_args,
    training_args=training_args,
    aspect_key=aspect_key,
    sentence3_flag=sentence3_flag,
)

print(f"\nPreprocessed datasets:")
print(f"Train: {len(train_dataset)} samples")
print(f"Eval: {len(eval_dataset)} samples")
print(f"Test: {len(predict_dataset)} samples")
print(f"Max train samples: {max_train_samples}")

# サンプル確認
print("\nPreprocessed sample keys:")
print(list(train_dataset[0].keys()))

Running tokenizer on dataset:   0%|          | 0/64 [00:00<?, ? examples/s]

2025-11-17 14:27:37,190 - utils.training.train_setup - INFO: tokens: [CLS] thank you! forgot to add her youtube. and her second, smaller, channels ( super entertaining rants ) : [SEP]
2025-11-17 14:27:37,191 - utils.training.train_setup - INFO: Sample 14 of the training set: {'input_ids': [101, 4067, 2017, 999, 9471, 2000, 5587, 2014, 7858, 1012, 1998, 2014, 2117, 1010, 3760, 1010, 6833, 1006, 3565, 14036, 2743, 3215, 1007, 1024, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'input_ids_2': None, 'attention_mask_2': None, 'token_type_ids_2': None, 'active_heads': ['sentiment'], 'labels': [5]}.
2025-11-17 14:27:37,192 - utils.training.train_setup - INFO: tokens: [CLS] they all clapped. [SEP]
2025-11-17 14:27:37,193 - utils.training.train_setup - INFO: Sample 3 of the training set: {'input_ids': [101, 2027, 2035, 18310, 1012, 102], 'token_ty


Preprocessed datasets:
Train: 64 samples
Eval: 64 samples
Test: 64 samples
Max train samples: 64

Preprocessed sample keys:
['input_ids', 'token_type_ids', 'attention_mask', 'input_ids_2', 'attention_mask_2', 'token_type_ids_2', 'active_heads', 'labels']


### 推論テスト

In [45]:
test_case = tokenizer("This is a test sentence for inference.", return_tensors="pt")
model(input_ids=test_case['input_ids'].to(model_args.device_map), attention_mask=test_case['attention_mask'].to(model_args.device_map))

{'original_avg': tensor([[ 0.3068, -0.1795, -0.2015,  ...,  0.3230, -0.2399, -0.2997]],
        device='cuda:0'),
 'original_cls': tensor([[ 0.2106, -0.0967, -0.1799,  ...,  0.4371, -0.1756, -0.3231]],
        device='cuda:0'),
 'original_max': tensor([[ 0.6953, -0.0715,  0.1926,  ...,  0.5316,  0.0780,  0.0718]],
        device='cuda:0'),
 'sentiment': tensor([[ 0.0185, -0.0078, -0.0117,  ...,  0.0192, -0.0134, -0.0153]],
        device='cuda:0', grad_fn=<DivBackward0>)}

## 7. トレーナーの作成

`create_trainer`関数でCustomTrainerを初期化します。

In [8]:
# トレーナー作成（train.pyと同じ引数）
id2_head = {i: head for i, head in enumerate(classifier_configs_for_trainer.keys())}

trainer, trainer_state = create_trainer(
    model=model,
    config=config,
    training_args=training_args,
    classifier_configs_for_trainer=classifier_configs_for_trainer,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    corr_labels=corr_labels,
    corr_weights=corr_weights,
    label_name_mappings=label_name_mappings,
    use_ngpt_riemann=use_ngpt_riemann,
    id2_head=id2_head,
)

# PrinterCallbackを削除（ノートブック環境用）
trainer.remove_callback(PrinterCallback)

print("\n✓ Trainer initialized successfully")
print(f"Model device: {trainer.model.device}")
print(f"\nHead objectives:")
for head_name, objective in trainer.head_objectives.items():
    print(f"  {head_name}: {objective.__class__.__name__}")


✓ Trainer initialized successfully
Model device: cuda:0

Head objectives:
  sentiment: InfoNCEObjective


  super().__init__(*args, **kwargs)


## 8. 初期評価（ベースライン）

トレーニング前の初期性能を確認します。

In [9]:
print("Running baseline evaluation...")
baseline_metrics = trainer.evaluate(eval_dataset=eval_dataset)

print("\n" + "="*80)
print("BASELINE METRICS")
print("="*80)
for key, value in baseline_metrics.items():
    print(f"{key:40s}: {value}")
print("="*80)

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Running baseline evaluation...


RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/remote/csifs1/disk3/users/yama11235/yama11235/SLBERT/my-project/.venv/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py", line 97, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/remote/csifs1/disk3/users/yama11235/yama11235/SLBERT/my-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/remote/csifs1/disk3/users/yama11235/yama11235/SLBERT/my-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/remote/csifs1/disk3/users/yama11235/yama11235/Sentiment-Circle/utils/model/modeling_encoders.py", line 222, in forward
    outputs = self._paths[uniform].run_full(batch_inputs, extra_kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/remote/csifs1/disk3/users/yama11235/yama11235/Sentiment-Circle/utils/model/sentence_paths.py", line 45, in run_full
    return self._forward(batch, extra_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/remote/csifs1/disk3/users/yama11235/yama11235/Sentiment-Circle/utils/model/sentence_paths.py", line 68, in _forward
    return self.model.encode(**args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/remote/csifs1/disk3/users/yama11235/yama11235/Sentiment-Circle/utils/model/modeling_encoders.py", line 255, in encode
    outputs = self.backbone(
              ^^^^^^^^^^^^^^
  File "/remote/csifs1/disk3/users/yama11235/yama11235/SLBERT/my-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/remote/csifs1/disk3/users/yama11235/yama11235/SLBERT/my-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/remote/csifs1/disk3/users/yama11235/yama11235/SLBERT/my-project/.venv/lib/python3.12/site-packages/transformers/models/bert/modeling_bert.py", line 932, in forward
    embedding_output = self.embeddings(
                       ^^^^^^^^^^^^^^^^
  File "/remote/csifs1/disk3/users/yama11235/yama11235/SLBERT/my-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/remote/csifs1/disk3/users/yama11235/yama11235/SLBERT/my-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/remote/csifs1/disk3/users/yama11235/yama11235/SLBERT/my-project/.venv/lib/python3.12/site-packages/transformers/models/bert/modeling_bert.py", line 179, in forward
    inputs_embeds = self.word_embeddings(input_ids)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/remote/csifs1/disk3/users/yama11235/yama11235/SLBERT/my-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/remote/csifs1/disk3/users/yama11235/yama11235/SLBERT/my-project/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/remote/csifs1/disk3/users/yama11235/yama11235/SLBERT/my-project/.venv/lib/python3.12/site-packages/torch/nn/modules/sparse.py", line 190, in forward
    return F.embedding(
           ^^^^^^^^^^^^
  File "/remote/csifs1/disk3/users/yama11235/yama11235/SLBERT/my-project/.venv/lib/python3.12/site-packages/torch/nn/functional.py", line 2551, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument index in method wrapper_CUDA__index_select)


## 9. トレーニング実行

モデルをトレーニングします。

In [None]:
print("Starting training...\n")
train_result = trainer.train()

train_metrics = train_result.metrics
train_metrics["train_samples"] = len(train_dataset)

print("\n" + "="*80)
print("TRAINING COMPLETED")
print("="*80)
for key, value in train_metrics.items():
    print(f"{key:40s}: {value}")
print("="*80)

## 10. テストセットでの最終評価

In [None]:
print("Running test evaluation...")
test_metrics = trainer.evaluate(eval_dataset=predict_dataset, metric_key_prefix="test")

print("\n" + "="*80)
print("TEST METRICS")
print("="*80)
for key, value in test_metrics.items():
    print(f"{key:40s}: {value}")
print("="*80)

## 11. 結果の比較

ベースライン vs テスト結果を比較します。

In [None]:
import pandas as pd

# 比較用データ作成
comparison_data = []

# 主要なメトリクスを抽出
for key in baseline_metrics.keys():
    if key.startswith('eval_') and not any(x in key for x in ['runtime', 'samples_per_second', 'steps_per_second', 'model_preparation']):
        metric_name = key.replace('eval_', '')
        test_key = 'test_' + metric_name
        
        baseline_val = baseline_metrics.get(key, 'N/A')
        test_val = test_metrics.get(test_key, 'N/A')
        
        # 数値なら改善率を計算
        if isinstance(baseline_val, (int, float)) and isinstance(test_val, (int, float)):
            improvement = ((test_val - baseline_val) / baseline_val * 100) if baseline_val != 0 else 0
            comparison_data.append({
                'Metric': metric_name,
                'Baseline': f"{baseline_val:.4f}",
                'Test': f"{test_val:.4f}",
                'Change': f"{improvement:+.2f}%"
            })
        else:
            comparison_data.append({
                'Metric': metric_name,
                'Baseline': str(baseline_val),
                'Test': str(test_val),
                'Change': 'N/A'
            })

df = pd.DataFrame(comparison_data)

print("\n" + "="*100)
print("PERFORMANCE COMPARISON: BASELINE vs TEST")
print("="*100)
print(df.to_string(index=False))
print("="*100)

## 12. モデルの保存（オプション）

In [None]:
# モデル保存
save_path = OUTPUT_ROOT / 'ngpt_test_final'
save_path.mkdir(exist_ok=True)

print(f"Saving model to {save_path}...")
trainer.save_model(str(save_path))
tokenizer.save_pretrained(str(save_path))

print("\n✓ Model saved successfully!")
print(f"Saved to: {save_path}")