In [1]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import DPOTrainer
import wandb

This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.
If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=
If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH
For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64



In [2]:
# --- Device Configuration ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
# --- Model and Tokenizer Setup ---
model_name_or_path = "gpt2"
ignore_bias_buffers = False

model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
if ignore_bias_buffers:
    model._ddp_params_and_buffers_to_ignore = [name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool]

model_ref = AutoModelForCausalLM.from_pretrained(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token



In [4]:
# --- Dataset Loading and Prompt Extraction ---
def create_prompt(sample):
    return {"prompt": sample["input"]}


train_dataset = load_dataset("argilla/distilabel-intel-orca-dpo-pairs", split="train[:90%]")
eval_dataset = load_dataset("argilla/distilabel-intel-orca-dpo-pairs", split="train[90%:]")

# Extract the prompt
train_dataset = train_dataset.map(create_prompt)
eval_dataset = eval_dataset.map(create_prompt)

Map:   0%|          | 0/11573 [00:00<?, ? examples/s]

Map:   0%|          | 0/1286 [00:00<?, ? examples/s]

In [5]:
# --- Hyperparameters ---
learning_rate = 0.0000538
per_device_train_batch_size = 8
gradient_accumulation_steps = 4
max_length = 512
max_prompt_length = 256
max_target_length = 256
max_steps = 500
report_to = "wandb"
gradient_checkpointing = True
beta = 0.2

In [6]:
# --- Weights & Biases Setup ---
wandb.init(
    project="dpo-training",
    config={
        "model_name": model_name_or_path,
        "dataset": "argilla/distilabel-intel-orca-dpo-pairs",
        "learning_rate": learning_rate,
        "batch_size": per_device_train_batch_size,
        "gradient_accumulation_steps": gradient_accumulation_steps,
        "max_length": max_length,
        "max_prompt_length": max_prompt_length,
        "max_target_length": max_target_length,
        "max_steps": max_steps,
        "beta": beta,
        "gradient_checkpointing": gradient_checkpointing,
        "optimizer": "adamw_torch",
    },
)

# --- Training Arguments ---
training_args = TrainingArguments(
    per_device_train_batch_size=per_device_train_batch_size,
    max_steps=max_steps,
    remove_unused_columns=False,
    gradient_accumulation_steps=gradient_accumulation_steps,
    learning_rate=learning_rate,
    evaluation_strategy="steps",
    logging_first_step=True,
    logging_steps=5,
    eval_steps=50,
    output_dir="./nlp-a5",
    optim="adamw_torch",
    warmup_steps=50,
    report_to=report_to,
    bf16=True,
    gradient_checkpointing=gradient_checkpointing,
    save_strategy="steps",
    save_steps=50,
)

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: sila-nmht (sila-nmht-asian-institute-of-technology) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin




In [7]:
# --- DPOTrainer Initialization ---
dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=beta,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    max_length=max_length,
    max_target_length=max_target_length,
    max_prompt_length=max_prompt_length,
    generate_during_eval=True,
)

Map:   0%|          | 0/11573 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1105 > 1024). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/1286 [00:00<?, ? examples/s]

max_steps is given, it will override any value given in num_train_epochs


In [8]:
# --- Training ---
dpo_trainer.train()



  0%|          | 0/500 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
Could not estimate the number of tokens of the input, floating-point operations will not be computed


{'loss': 0.6931, 'grad_norm': 86.32217407226562, 'learning_rate': 1.076e-06, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/accuracies': 0.0, 'rewards/margins': 0.0, 'logps/rejected': -429.841064453125, 'logps/chosen': -380.14141845703125, 'logits/rejected': -101.80658721923828, 'logits/chosen': -101.73731994628906, 'epoch': 0.0}
{'loss': 0.6917, 'grad_norm': 82.48517608642578, 'learning_rate': 5.38e-06, 'rewards/chosen': -0.655906617641449, 'rewards/rejected': -1.1794593334197998, 'rewards/accuracies': 0.515625, 'rewards/margins': 0.5235527753829956, 'logps/rejected': -413.0698547363281, 'logps/chosen': -359.45428466796875, 'logits/rejected': -104.19410705566406, 'logits/chosen': -102.56261444091797, 'epoch': 0.01}
{'loss': 0.4781, 'grad_norm': 51.92380142211914, 'learning_rate': 1.076e-05, 'rewards/chosen': -0.745120644569397, 'rewards/rejected': -2.0767862796783447, 'rewards/accuracies': 0.8125, 'rewards/margins': 1.3316656351089478, 'logps/rejected': -422.7705078125, 'log

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


{'game_log': <wandb.sdk.data_types.table.Table object at 0x000002470D24FD90>, 'epoch': 0.14}


  0%|          | 0/161 [00:00<?, ?it/s]

{'eval_loss': 0.7701175808906555, 'eval_runtime': 118.3686, 'eval_samples_per_second': 10.864, 'eval_steps_per_second': 1.36, 'eval_rewards/chosen': 0.46673712134361267, 'eval_rewards/rejected': -1.3878153562545776, 'eval_rewards/accuracies': 0.7590579390525818, 'eval_rewards/margins': 1.8545525074005127, 'eval_logps/rejected': -406.8403015136719, 'eval_logps/chosen': -339.9384765625, 'eval_logits/rejected': -95.71627044677734, 'eval_logits/chosen': -95.33931732177734, 'epoch': 0.14}


  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 0.558, 'grad_norm': 79.04922485351562, 'learning_rate': 5.320222222222222e-05, 'rewards/chosen': 0.4374000132083893, 'rewards/rejected': -1.672490119934082, 'rewards/accuracies': 0.8062499761581421, 'rewards/margins': 2.1098902225494385, 'logps/rejected': -417.58819580078125, 'logps/chosen': -351.3348388671875, 'logits/rejected': -96.4607162475586, 'logits/chosen': -95.52983856201172, 'epoch': 0.15}
{'loss': 0.8263, 'grad_norm': 47.55983352661133, 'learning_rate': 5.260444444444444e-05, 'rewards/chosen': -1.0790084600448608, 'rewards/rejected': -3.867455244064331, 'rewards/accuracies': 0.762499988079071, 'rewards/margins': 2.7884469032287598, 'logps/rejected': -433.04083251953125, 'logps/chosen': -375.5039978027344, 'logits/rejected': -96.74488830566406, 'logits/chosen': -96.39244079589844, 'epoch': 0.17}
{'loss': 0.6953, 'grad_norm': 54.888671875, 'learning_rate': 5.2006666666666665e-05, 'rewards/chosen': 0.06507492810487747, 'rewards/rejected': -2.1249938011169434, 'rewards/



{'game_log': <wandb.sdk.data_types.table.Table object at 0x000002470BA3E550>, 'epoch': 0.28}


  0%|          | 0/161 [00:00<?, ?it/s]

{'eval_loss': 0.7531054019927979, 'eval_runtime': 117.5481, 'eval_samples_per_second': 10.94, 'eval_steps_per_second': 1.37, 'eval_rewards/chosen': 0.27912962436676025, 'eval_rewards/rejected': -2.1547889709472656, 'eval_rewards/accuracies': 0.7776915431022644, 'eval_rewards/margins': 2.4339182376861572, 'eval_logps/rejected': -410.67523193359375, 'eval_logps/chosen': -340.8764953613281, 'eval_logits/rejected': -85.44556427001953, 'eval_logits/chosen': -85.26910400390625, 'epoch': 0.28}


  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 0.8066, 'grad_norm': 41.73305892944336, 'learning_rate': 4.7224444444444444e-05, 'rewards/chosen': 0.4517017900943756, 'rewards/rejected': -1.5693342685699463, 'rewards/accuracies': 0.7875000238418579, 'rewards/margins': 2.021036148071289, 'logps/rejected': -420.3628845214844, 'logps/chosen': -355.36126708984375, 'logits/rejected': -86.13858795166016, 'logits/chosen': -85.57471466064453, 'epoch': 0.29}
{'loss': 0.7271, 'grad_norm': 67.39576721191406, 'learning_rate': 4.6626666666666665e-05, 'rewards/chosen': 0.5972003936767578, 'rewards/rejected': -1.8452813625335693, 'rewards/accuracies': 0.800000011920929, 'rewards/margins': 2.442481756210327, 'logps/rejected': -399.28125, 'logps/chosen': -358.2326354980469, 'logits/rejected': -85.44587707519531, 'logits/chosen': -84.99433898925781, 'epoch': 0.3}
{'loss': 0.6502, 'grad_norm': 41.41155242919922, 'learning_rate': 4.602888888888889e-05, 'rewards/chosen': 0.0006014645332470536, 'rewards/rejected': -2.6945464611053467, 'rewards/a



{'game_log': <wandb.sdk.data_types.table.Table object at 0x000002474D7F9490>, 'epoch': 0.41}


  0%|          | 0/161 [00:00<?, ?it/s]

{'eval_loss': 0.7164323329925537, 'eval_runtime': 117.5167, 'eval_samples_per_second': 10.943, 'eval_steps_per_second': 1.37, 'eval_rewards/chosen': 0.04009328410029411, 'eval_rewards/rejected': -2.6230392456054688, 'eval_rewards/accuracies': 0.7743271589279175, 'eval_rewards/margins': 2.663132667541504, 'eval_logps/rejected': -413.0164489746094, 'eval_logps/chosen': -342.0716857910156, 'eval_logits/rejected': -77.79004669189453, 'eval_logits/chosen': -78.4781494140625, 'epoch': 0.41}


  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 0.6313, 'grad_norm': 61.368038177490234, 'learning_rate': 4.124666666666667e-05, 'rewards/chosen': -0.0895061120390892, 'rewards/rejected': -3.2197909355163574, 'rewards/accuracies': 0.768750011920929, 'rewards/margins': 3.1302852630615234, 'logps/rejected': -424.0433044433594, 'logps/chosen': -370.56231689453125, 'logits/rejected': -77.89362335205078, 'logits/chosen': -78.12007904052734, 'epoch': 0.43}
{'loss': 0.9444, 'grad_norm': 36.428863525390625, 'learning_rate': 4.064888888888889e-05, 'rewards/chosen': 0.01371699571609497, 'rewards/rejected': -2.3674216270446777, 'rewards/accuracies': 0.7875000238418579, 'rewards/margins': 2.381138324737549, 'logps/rejected': -397.24969482421875, 'logps/chosen': -346.69927978515625, 'logits/rejected': -76.88948822021484, 'logits/chosen': -77.67308044433594, 'epoch': 0.44}
{'loss': 0.6902, 'grad_norm': 41.582950592041016, 'learning_rate': 4.005111111111111e-05, 'rewards/chosen': 0.8930956721305847, 'rewards/rejected': -1.7185869216918945



{'game_log': <wandb.sdk.data_types.table.Table object at 0x000002474D819970>, 'epoch': 0.55}


  0%|          | 0/161 [00:00<?, ?it/s]

{'eval_loss': 0.7222762703895569, 'eval_runtime': 121.0032, 'eval_samples_per_second': 10.628, 'eval_steps_per_second': 1.331, 'eval_rewards/chosen': -0.032697293907403946, 'eval_rewards/rejected': -3.1584994792938232, 'eval_rewards/accuracies': 0.796066164970398, 'eval_rewards/margins': 3.1258020401000977, 'eval_logps/rejected': -415.6938171386719, 'eval_logps/chosen': -342.43560791015625, 'eval_logits/rejected': -73.72225952148438, 'eval_logits/chosen': -74.38448333740234, 'epoch': 0.55}


  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 0.5457, 'grad_norm': 42.19619369506836, 'learning_rate': 3.526888888888889e-05, 'rewards/chosen': 0.3945646286010742, 'rewards/rejected': -3.658681869506836, 'rewards/accuracies': 0.8187500238418579, 'rewards/margins': 4.05324649810791, 'logps/rejected': -466.9610900878906, 'logps/chosen': -357.0932922363281, 'logits/rejected': -72.65667724609375, 'logits/chosen': -73.61896514892578, 'epoch': 0.57}
{'loss': 0.7351, 'grad_norm': 44.63496780395508, 'learning_rate': 3.4671111111111116e-05, 'rewards/chosen': 1.0736498832702637, 'rewards/rejected': -1.8460795879364014, 'rewards/accuracies': 0.7749999761581421, 'rewards/margins': 2.919729232788086, 'logps/rejected': -423.306884765625, 'logps/chosen': -362.35284423828125, 'logits/rejected': -70.54705810546875, 'logits/chosen': -70.61494445800781, 'epoch': 0.58}
{'loss': 0.826, 'grad_norm': 63.51594161987305, 'learning_rate': 3.407333333333333e-05, 'rewards/chosen': 1.1699292659759521, 'rewards/rejected': -1.4635512828826904, 'rewards



{'game_log': <wandb.sdk.data_types.table.Table object at 0x000002474D80AC70>, 'epoch': 0.69}


  0%|          | 0/161 [00:00<?, ?it/s]

{'eval_loss': 0.6676872372627258, 'eval_runtime': 121.6993, 'eval_samples_per_second': 10.567, 'eval_steps_per_second': 1.323, 'eval_rewards/chosen': 0.6186218857765198, 'eval_rewards/rejected': -2.0402002334594727, 'eval_rewards/accuracies': 0.7903726696968079, 'eval_rewards/margins': 2.6588220596313477, 'eval_logps/rejected': -410.102294921875, 'eval_logps/chosen': -339.17901611328125, 'eval_logits/rejected': -66.41825103759766, 'eval_logits/chosen': -67.22669219970703, 'epoch': 0.69}


  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 0.7138, 'grad_norm': 52.590721130371094, 'learning_rate': 2.9291111111111107e-05, 'rewards/chosen': 0.6930635571479797, 'rewards/rejected': -2.0636179447174072, 'rewards/accuracies': 0.8062499761581421, 'rewards/margins': 2.756681442260742, 'logps/rejected': -430.04669189453125, 'logps/chosen': -366.04327392578125, 'logits/rejected': -66.62266540527344, 'logits/chosen': -67.40010833740234, 'epoch': 0.7}
{'loss': 0.4698, 'grad_norm': 29.189342498779297, 'learning_rate': 2.8693333333333332e-05, 'rewards/chosen': 0.6795655488967896, 'rewards/rejected': -2.4594383239746094, 'rewards/accuracies': 0.84375, 'rewards/margins': 3.1390042304992676, 'logps/rejected': -394.29034423828125, 'logps/chosen': -321.86761474609375, 'logits/rejected': -67.69721984863281, 'logits/chosen': -68.16939544677734, 'epoch': 0.72}
{'loss': 1.2202, 'grad_norm': 43.41884994506836, 'learning_rate': 2.8095555555555557e-05, 'rewards/chosen': 0.10626578330993652, 'rewards/rejected': -2.0635104179382324, 'reward



{'game_log': <wandb.sdk.data_types.table.Table object at 0x000002474D7F0910>, 'epoch': 0.83}


  0%|          | 0/161 [00:00<?, ?it/s]

{'eval_loss': 0.6199276447296143, 'eval_runtime': 115.807, 'eval_samples_per_second': 11.105, 'eval_steps_per_second': 1.39, 'eval_rewards/chosen': 0.5863352417945862, 'eval_rewards/rejected': -2.493701934814453, 'eval_rewards/accuracies': 0.8115941286087036, 'eval_rewards/margins': 3.0800371170043945, 'eval_logps/rejected': -412.36981201171875, 'eval_logps/chosen': -339.3404846191406, 'eval_logits/rejected': -66.51512908935547, 'eval_logits/chosen': -67.28250885009766, 'epoch': 0.83}


  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 0.5988, 'grad_norm': 44.31158447265625, 'learning_rate': 2.3313333333333333e-05, 'rewards/chosen': 0.581853449344635, 'rewards/rejected': -1.9521478414535522, 'rewards/accuracies': 0.800000011920929, 'rewards/margins': 2.534001111984253, 'logps/rejected': -392.5242614746094, 'logps/chosen': -335.1625061035156, 'logits/rejected': -66.18138122558594, 'logits/chosen': -66.18134307861328, 'epoch': 0.84}
{'loss': 0.5134, 'grad_norm': 32.24290084838867, 'learning_rate': 2.2715555555555554e-05, 'rewards/chosen': 1.1869428157806396, 'rewards/rejected': -1.7295684814453125, 'rewards/accuracies': 0.8125, 'rewards/margins': 2.9165115356445312, 'logps/rejected': -374.95172119140625, 'logps/chosen': -303.47381591796875, 'logits/rejected': -67.24435424804688, 'logits/chosen': -67.9585952758789, 'epoch': 0.86}
{'loss': 0.4492, 'grad_norm': 41.668556213378906, 'learning_rate': 2.2117777777777776e-05, 'rewards/chosen': 1.2061660289764404, 'rewards/rejected': -2.1104722023010254, 'rewards/accur



{'game_log': <wandb.sdk.data_types.table.Table object at 0x000002474D7EA730>, 'epoch': 0.97}


  0%|          | 0/161 [00:00<?, ?it/s]

{'eval_loss': 0.6213732957839966, 'eval_runtime': 117.5062, 'eval_samples_per_second': 10.944, 'eval_steps_per_second': 1.37, 'eval_rewards/chosen': 1.101764440536499, 'eval_rewards/rejected': -1.4390374422073364, 'eval_rewards/accuracies': 0.7841615080833435, 'eval_rewards/margins': 2.540802001953125, 'eval_logps/rejected': -407.0964660644531, 'eval_logps/chosen': -336.7633056640625, 'eval_logits/rejected': -64.94148254394531, 'eval_logits/chosen': -65.81298828125, 'epoch': 0.97}


  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 0.4819, 'grad_norm': 30.804285049438477, 'learning_rate': 1.7335555555555558e-05, 'rewards/chosen': 1.484919786453247, 'rewards/rejected': -1.1313515901565552, 'rewards/accuracies': 0.8187500238418579, 'rewards/margins': 2.6162712574005127, 'logps/rejected': -415.90447998046875, 'logps/chosen': -360.48321533203125, 'logits/rejected': -62.61079025268555, 'logits/chosen': -63.80034255981445, 'epoch': 0.98}
{'loss': 0.874, 'grad_norm': 38.822994232177734, 'learning_rate': 1.6737777777777776e-05, 'rewards/chosen': 0.8699784278869629, 'rewards/rejected': -1.6794168949127197, 'rewards/accuracies': 0.768750011920929, 'rewards/margins': 2.5493953227996826, 'logps/rejected': -416.244140625, 'logps/chosen': -341.0328674316406, 'logits/rejected': -64.50392150878906, 'logits/chosen': -65.36022186279297, 'epoch': 1.0}
{'loss': 0.9483, 'grad_norm': 9.058503150939941, 'learning_rate': 1.6139999999999998e-05, 'rewards/chosen': 0.9552860260009766, 'rewards/rejected': -2.8314433097839355, 'rewa



{'game_log': <wandb.sdk.data_types.table.Table object at 0x0000024784B3CFA0>, 'epoch': 1.11}


  0%|          | 0/161 [00:00<?, ?it/s]

{'eval_loss': 0.644171953201294, 'eval_runtime': 117.6817, 'eval_samples_per_second': 10.928, 'eval_steps_per_second': 1.368, 'eval_rewards/chosen': 0.4069330096244812, 'eval_rewards/rejected': -2.869373321533203, 'eval_rewards/accuracies': 0.8281574249267578, 'eval_rewards/margins': 3.276306629180908, 'eval_logps/rejected': -414.2481689453125, 'eval_logps/chosen': -340.2374572753906, 'eval_logits/rejected': -64.66106414794922, 'eval_logits/chosen': -65.4554443359375, 'epoch': 1.11}


  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 0.1039, 'grad_norm': 3.904911994934082, 'learning_rate': 1.1357777777777777e-05, 'rewards/chosen': 1.911709189414978, 'rewards/rejected': -4.101197242736816, 'rewards/accuracies': 0.9750000238418579, 'rewards/margins': 6.012906551361084, 'logps/rejected': -411.5382385253906, 'logps/chosen': -335.3758850097656, 'logits/rejected': -64.78614044189453, 'logits/chosen': -65.04131317138672, 'epoch': 1.12}
{'loss': 0.0957, 'grad_norm': 6.655702590942383, 'learning_rate': 1.076e-05, 'rewards/chosen': 1.8591140508651733, 'rewards/rejected': -3.2132465839385986, 'rewards/accuracies': 0.9750000238418579, 'rewards/margins': 5.072360992431641, 'logps/rejected': -390.045166015625, 'logps/chosen': -317.8333740234375, 'logits/rejected': -63.70588302612305, 'logits/chosen': -65.53626251220703, 'epoch': 1.13}
{'loss': 0.1196, 'grad_norm': 10.346579551696777, 'learning_rate': 1.0162222222222222e-05, 'rewards/chosen': 1.9247852563858032, 'rewards/rejected': -3.6962714195251465, 'rewards/accuracie



{'game_log': <wandb.sdk.data_types.table.Table object at 0x000002474D7F9D90>, 'epoch': 1.24}


  0%|          | 0/161 [00:00<?, ?it/s]

{'eval_loss': 0.6730199456214905, 'eval_runtime': 115.8401, 'eval_samples_per_second': 11.102, 'eval_steps_per_second': 1.39, 'eval_rewards/chosen': 1.1133185625076294, 'eval_rewards/rejected': -1.9897171258926392, 'eval_rewards/accuracies': 0.8131469488143921, 'eval_rewards/margins': 3.1030356884002686, 'eval_logps/rejected': -409.849853515625, 'eval_logps/chosen': -336.70556640625, 'eval_logits/rejected': -65.83477783203125, 'eval_logits/chosen': -66.72872161865234, 'epoch': 1.24}


  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 0.1846, 'grad_norm': 46.18772888183594, 'learning_rate': 5.38e-06, 'rewards/chosen': 2.432614326477051, 'rewards/rejected': -3.799328565597534, 'rewards/accuracies': 0.949999988079071, 'rewards/margins': 6.231942653656006, 'logps/rejected': -419.77239990234375, 'logps/chosen': -336.6206970214844, 'logits/rejected': -66.97927856445312, 'logits/chosen': -66.82938385009766, 'epoch': 1.26}
{'loss': 0.1087, 'grad_norm': 8.702348709106445, 'learning_rate': 4.7822222222222226e-06, 'rewards/chosen': 2.0061357021331787, 'rewards/rejected': -3.8021392822265625, 'rewards/accuracies': 0.96875, 'rewards/margins': 5.80827522277832, 'logps/rejected': -417.48211669921875, 'logps/chosen': -344.32440185546875, 'logits/rejected': -66.13191223144531, 'logits/chosen': -66.78289031982422, 'epoch': 1.27}
{'loss': 0.0955, 'grad_norm': 8.919349670410156, 'learning_rate': 4.184444444444444e-06, 'rewards/chosen': 2.3867013454437256, 'rewards/rejected': -3.760941743850708, 'rewards/accuracies': 0.96875, 



{'game_log': <wandb.sdk.data_types.table.Table object at 0x000002474D813F70>, 'epoch': 1.38}


  0%|          | 0/161 [00:00<?, ?it/s]

{'eval_loss': 0.6408655047416687, 'eval_runtime': 112.62, 'eval_samples_per_second': 11.419, 'eval_steps_per_second': 1.43, 'eval_rewards/chosen': 0.9778485894203186, 'eval_rewards/rejected': -2.149104356765747, 'eval_rewards/accuracies': 0.8234990239143372, 'eval_rewards/margins': 3.126952648162842, 'eval_logps/rejected': -410.6468505859375, 'eval_logps/chosen': -337.3829040527344, 'eval_logits/rejected': -66.98162841796875, 'eval_logits/chosen': -67.84807586669922, 'epoch': 1.38}
{'train_runtime': 4061.3209, 'train_samples_per_second': 3.94, 'train_steps_per_second': 0.123, 'train_loss': 0.5710406829714775, 'epoch': 1.38}


TrainOutput(global_step=500, training_loss=0.5710406829714775, metrics={'train_runtime': 4061.3209, 'train_samples_per_second': 3.94, 'train_steps_per_second': 0.123, 'total_flos': 0.0, 'train_loss': 0.5710406829714775, 'epoch': 1.38217000691085})

In [9]:
# --- Saving ---
dpo_trainer.save_model("./trained_model_orca_dpo_wandb")

In [10]:
# --- Finish wandb run ---
wandb.finish()

0,1
eval/logits/chosen,▁▃▅▆█████▇
eval/logits/rejected,▁▃▅▆█████▇
eval/logps/chosen,▄▃▁▁▅▅█▄█▇
eval/logps/rejected,█▅▃▁▅▄█▂▆▅
eval/loss,█▇▅▆▃▁▁▂▃▂
eval/rewards/accuracies,▁▃▃▅▄▆▄█▆█
eval/rewards/chosen,▄▃▁▁▅▅█▄█▇
eval/rewards/margins,▁▄▅▇▅▇▄█▇▇
eval/rewards/rejected,█▅▃▁▅▄█▂▆▅
eval/runtime,▅▅▅▇█▃▅▅▃▁

0,1
eval/logits/chosen,-67.84808
eval/logits/rejected,-66.98163
eval/logps/chosen,-337.3829
eval/logps/rejected,-410.64685
eval/loss,0.64087
eval/rewards/accuracies,0.8235
eval/rewards/chosen,0.97785
eval/rewards/margins,3.12695
eval/rewards/rejected,-2.1491
eval/runtime,112.62


In [None]:
from huggingface_hub import login

login()
dpo_trainer.push_to_hub()

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/498M [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.18k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/silanm/nlp-a5/commit/9c2eccfc0a7c23c6f8e54e323523296a810202a6', commit_message='End of training', commit_description='', oid='9c2eccfc0a7c23c6f8e54e323523296a810202a6', pr_url=None, repo_url=RepoUrl('https://huggingface.co/silanm/nlp-a5', endpoint='https://huggingface.co', repo_type='model', repo_id='silanm/nlp-a5'), pr_revision=None, pr_num=None)