# **Creating a GPT-2-Based Chatbot with Human Preferences**

## **Important Libraries**

### **Mount Google Drive**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### **Install**

In [None]:
!curl -LsSf https://astral.sh/uv/install.sh | sh

downloading uv 0.7.3 x86_64-unknown-linux-gnu
no checksums to verify
installing to /usr/local/bin
  uv
  uvx
everything's installed!


In [None]:
# https://github.com/astral-sh/uv/issues/12724
!mkdir /backend-container
!mkdir /backend-container/containers
!touch /backend-container/containers/build.constraints
!touch /backend-container/containers/requirements.constraints

In [None]:
try:
    !uv pip install -q --no-cache-dir --system lightning pytelegrambotapi
    import lightning
    !uv pip install -q --no-cache-dir --system datasets transformers trl bitsandbytes
    import datasets
    !uv pip install -q --no-cache-dir --system langchain langgraph
    import langchain
    del lightning, transformers, datasets, langchain
except Exception as e:
    print(e)
    !pip install -q --no-cache-dir lightning pytelegrambotapi
    !pip install -q --no-cache-dir transformers datasets trl bitsandbytes
    !pip install -q --no-cache-dir langchain langgraph

name 'transformers' is not defined


### **Import**

In [None]:
import lightning as L
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.nn.functional as F

from collections.abc import Sequence
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Callable

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.prompts import PromptTemplate

from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, MessagesState, StateGraph

from google.colab.patches import cv2_imshow
from google.colab import userdata

from tqdm.auto import tqdm

from transformers import (
    PretrainedConfig,
    DataCollatorForLanguageModeling,
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoConfig,
    GPT2LMHeadModel,
    PreTrainedTokenizer,
    get_cosine_with_hard_restarts_schedule_with_warmup,
    BaseImageProcessor,
    FeatureExtractionMixin,
    PreTrainedTokenizerBase,
    ProcessorMixin,
    PreTrainedModel,
    DataCollator,
    is_comet_available,
    is_wandb_available,
    is_torch_xla_available,
)

import transformers
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_fx_proxy

import inspect
import glob

import importlib.resources as pkg_resources

from collections import defaultdict
from dataclasses import dataclass
from contextlib import nullcontext

from torch.nn.utils.rnn import pad_sequence
import torch.amp as amp

from packaging import version
import textwrap

from huggingface_hub import ModelCard, ModelCardData

if is_peft_available():
    from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training

if is_wandb_available():
    import wandb

if is_torch_xla_available():
    import torch_xla.core.xla_model as xm

if is_comet_available():
    import comet_ml

from trl import (
    ORPOTrainer,
    ORPOConfig,
)
from accelerate import PartialState

import datasets
from datasets import Dataset, load_dataset

from rich.console import Console
from rich.panel import Panel
from rich.prompt import Prompt

import cv2
import telebot
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import copy
import math
import time
import shutil
import random
import warnings

warnings.filterwarnings("ignore")

%matplotlib inline
plt.rcParams['axes.facecolor'] = 'lightgray'
plt.rcParams['mathtext.fontset'] = 'cm'
plt.rcParams['font.family'] = 'STIXGeneral'

## **Configuration**

In [None]:
BATCH_SIZE = 2
ACCUMULATE_GRAD_BATCH = 4  # Increase effective batch size
MAX_SEQ_LENGTH = 2048
VOCAB_SIZE = 50264
LEARNING_RATE = 6.9e-5
ADAM_EPSILON = 1e-8
WEIGHT_DECAY = 1e-2
WARMUP_STEP = 420
NUM_EPOCH = 10
PRECISION = (
    16
    if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    else "16-mixed"
)  # Use '16-mixed' (fp16) or 'bf16-mixed' if supported, else 32
NUM_CYCLE = 4
NUM_GPU = torch.cuda.device_count()  # Use all available GPUs
NUM_WORKER = 0
SAVE_INTERVAL = 1024
LOG_INTERVAL = 512

In [None]:
EARLY_STOPPING_PATIENCE = 2
EARLY_STOPPING_DELTA = 1e-3

In [None]:
BASE_MODEL = "gpt2"
DATASET_NAME = [
    "OpenAssistant/oasst1",
    "tatsu-lab/alpaca",
]
DATASET_SUBSET = [
    "train",
    "train",
]
MODEL_NAME = f"{BASE_MODEL}chat-finetuned-lightning"

In [None]:
START_HEADER_TOKEN = "<|start_header_id|>"
END_HEADER_TOKEN = "<|end_header_id|>"
END_OF_TURN_TOKEN = "<|eot_id|>"
PADDING_TOKEN = "<|finetune_right_pad_id|>"
BOS_TOKEN = "<|beginoftext|>"
START_CONTEXT = "<|start_context|>"
END_CONTEXT = "<|end_context|>"

In [None]:
EVAL_DATASET = "OpenAssistant/oasst1"
EVAL_SUBSET = "validation"

In [None]:
PREFERENCE_DATASET = "trl-lib/ultrafeedback_binarized"
PREFERENCE_DATASET_SUBSET = "train"
PREFERENCE_EPOCH = 2
PREFERENCE_WARMUP_STEP = 1200
PREFERENCE_LOGGING_STEP = 600
PREFERENCE_LEARNING_RATE = 9e-6
PREFERENCE_SAVE_STEP = 2400

In [None]:
EXPERIMENT_DIR = "/content/drive/MyDrive/GPT2Chatbot"
TRAINING_DIR = os.path.join(EXPERIMENT_DIR, "training")
DATASET_DIR = os.path.join(EXPERIMENT_DIR, "dataset")
MODEL_DIR = os.path.join(EXPERIMENT_DIR, "model")

os.makedirs(EXPERIMENT_DIR, exist_ok=True)
os.makedirs(TRAINING_DIR, exist_ok=True)
os.makedirs(DATASET_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

In [None]:
METRIC_TO_MONITOR = "train_loss"
METRIC_MODE = "min"

In [None]:
SEED = int(np.random.randint(2147483647))
print(f"Random seed: {SEED}")

Random seed: 1651820508


## **Dataset**

### **Utils**

In [None]:
def build_conversation_threads(dataset):
    # Convert dataset to DataFrame for easier manipulation
    df = pd.DataFrame(dataset)

    # English only!
    df = df[df['lang'] == 'en']

    # Initialize a dictionary to store conversation threads
    threads = dict()

    # Group messages by message_tree_id
    for tree_id in tqdm(df['message_tree_id'].unique(), desc="Traversing the message tree"):
        tree_messages = df[
            df['message_tree_id'] == tree_id
        ].sort_values('created_date')
        threads[tree_id] = list()

        # Build conversation paths by following parent_id
        for _, message in tree_messages.iterrows():
            if message['parent_id'] is None:  # Root message (prompt)
                threads[tree_id].append([message])
            else:
                # Find the parent message in the thread
                for thread in threads[tree_id]:
                    if thread[-1]['message_id'] == message['parent_id']:
                        threads[tree_id].append(thread + [message])
                        break

    # Convert threads to prompt-response pairs
    conversations = list()
    for tree_id, thread_list in tqdm(threads.items(), desc="Building conversations"):
        for thread in thread_list:
            conversation = list()
            for msg in thread:
                conversation.append({
                    'role': msg['role'],
                    'content': msg['text']
                })
            conversations.append(conversation)

    return conversations

In [None]:
def get_chat_template(
    is_jinja: bool=True,
    **kwargs,
):
    if is_jinja:
        return (
            """
            {% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}
            {% set has_system = false %}
            {% for message in messages %}
                {% if message['role'] == 'system' %}
                    {% set has_system = true %}
                {% endif %}
            {% endfor %}
            {% if not has_system %}
                {{- '<|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant who converses with the user.<|eot_id|>' }}
            {% endif %}
            {%- for message in messages %}
                {% set role = message['role'] %}
                {% if role == 'prompter' %}
                    {% set role = 'user' %}
                    {{- '<|start_header_id|>' + role + '<|end_header_id|>' + '\n' + message['content'] + '<|eot_id|>' }}
                {% endif %}
                {% if role == 'assistant' %}
                    {{- '<|start_header_id|>' + role + '<|end_header_id|>' + '\n' + message['content'] + '<|eot_id|>' }}
                {% endif %}
                {% if role == 'start_context' %}
                    {{- '<|start_context|>'}}
                {% endif %}
                {% if role == 'end_context' %}
                    {{- '<|end_context|>'}}
                {% endif %}
                {% if role == 'no_role' %}
                    {{- message['content'] }}
                {% endif %}
            {%- endfor %}
            {%- if add_generation_prompt %}
                {{- '<|start_header_id|>assistant<|end_header_id|>\n' }}
            {%- endif %}
            """
        )

    template = "\n"
    has_system = False
    for message in kwargs["messages"]:
        if message["role"] == "system":
            has_system = True
    if not has_system:
        template += (
            "<|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant who converses with the user.<|eot_id|>\n"
        )
    for message in kwargs["messages"]:
        role = message["role"]
        if role == "start_context":
            template += "<|start_context|>\n"
            continue
        if role == "end_context":
            template += "<|end_context|>\n"
            continue
        if role == "no_role":
            template += f"{message['content']}\n"
            continue
        template += "<|start_header_id|>"
        if role == "prompter":
            role = "user"
        template += role
        template += "<|end_header_id|>\n"
        template += message["content"]
        template += "<|eot_id|>\n"
    template += "<|start_header_id|>assistant<|end_header_id|>\n"

    return template

In [None]:
def get_chat_tokenizer(
    model_name: str=BASE_MODEL,
    bos_token: str=BOS_TOKEN,
    pad_token: str=PADDING_TOKEN,
    start_header_token: str=START_HEADER_TOKEN,
    end_header_token: str=END_HEADER_TOKEN,
    end_of_turn_token: str=END_OF_TURN_TOKEN,
    start_context: str=START_CONTEXT,
    end_context: str=END_CONTEXT,
    **kwargs,
):
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True,
    )

    tokenizer.model_max_length = MAX_SEQ_LENGTH

    tokenizer.bos_token = bos_token
    tokenizer.pad_token = pad_token

    special_tokens = {
        "bos_token": bos_token,
        "pad_token": pad_token,
        "additional_special_tokens": [
            start_header_token,
            end_header_token,
            end_of_turn_token,
            start_context,
            end_context,
        ],
    }

    tokenizer.add_special_tokens(special_tokens)

    tokenizer.chat_template = get_chat_template(is_jinja=True)

    return tokenizer

In [None]:
class InstructionDataModule(L.LightningDataModule):
    def __init__(
        self,
        model_name: str = BASE_MODEL,
        dataset_name: Union[str, list] = DATASET_NAME,
        dataset_subset: Union[str, list] = DATASET_SUBSET,
        batch_size: int = BATCH_SIZE,
        max_seq_length: int = MAX_SEQ_LENGTH,
        num_workers: int = NUM_WORKER,
    ):
        super().__init__()
        self.model_name = model_name
        self.dataset_name = dataset_name
        self.dataset_subset = dataset_subset
        self.batch_size = batch_size
        self.max_seq_length = max_seq_length
        self.num_workers = num_workers

        self.tokenizer = get_chat_tokenizer(model_name)

        # Define data collator for causal LM. MLM=False prepares
        # labels automatically.
        self.data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer,
            mlm=False,
        )

        self.train_dataset = None

    def prepare_data(self):
        # Download data. This runs only on the main process.
        for dataset_name, dataset_subset in zip(
            self.dataset_name, self.dataset_subset
        ):
            load_dataset(dataset_name, split=dataset_subset)

    def setup(self, stage: str = None):
        # Load data and tokenize. This runs on all processes.

        conversations = list()

        for dataset_name, dataset_subset in zip(
            self.dataset_name, self.dataset_subset
        ):
            raw_dataset = load_dataset(dataset_name, split=dataset_subset)

            if "oasst" in dataset_name:
                conversations_raw = build_conversation_threads(raw_dataset)

                for message in conversations_raw:
                    message = message[:len(message) - (len(message) % 2)]
                    if not len(message):
                        continue
                    if len(message) > 2:
                        message.insert(
                            0,
                            {
                                "role": "start_context",
                                "content": "",
                            },
                        )
                        message.insert(
                            -2,
                            {
                                "role": "end_context",
                                "content": "",
                            },
                        )
                    conversation = self.tokenizer.apply_chat_template(
                        message,
                        tokenize=False,
                        add_generation_prompt=False,
                    )
                    conversations.append(
                        self.tokenizer.bos_token
                        + conversation
                        + self.tokenizer.eos_token
                    )
            else:
                instructions = raw_dataset["instruction"]
                responses = raw_dataset["output"]
                inputs = raw_dataset["input"]

                for idx in range(len(instructions)):
                    if inputs[idx] == "":
                        message = [
                            {
                                "role": "prompter",
                                "content": instructions[idx]
                            },
                            {
                                "role": "assistant",
                                "content": responses[idx]
                            },
                        ]
                    else:
                        message = [
                            {
                                "role": "start_context",
                                "content": "",
                            },
                            {
                                "role": "no_role",
                                "content": inputs[idx],
                            },
                            {
                                "role": "end_context",
                                "content": "",
                            },
                            {
                                "role": "prompter",
                                "content": instructions[idx],
                            },
                            {
                                "role": "assistant",
                                "content": responses[idx],
                            },
                        ]
                    conversation = self.tokenizer.apply_chat_template(
                        message,
                        tokenize=False,
                        add_generation_prompt=False,
                    )
                    conversations.append(
                        self.tokenizer.bos_token
                        + conversation
                        + self.tokenizer.eos_token
                    )

        formatted_dataset = Dataset.from_dict(
            {"text": conversations}
        )

        # Define tokenization function
        def tokenize_function(examples):
            # Tokenize the text column
            tokenized = self.tokenizer(
                examples["text"],
                truncation=True,
                padding=False,  # Data collator will handle padding
                max_length=self.max_seq_length,
            )
            if len(tokenized["input_ids"]) >= self.max_seq_length:
                tokenized["input_ids"][-1] = self.tokenizer.eos_token_id
            return tokenized

        # Tokenize the dataset
        tokenized_dataset = formatted_dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=["text"],  # Keep only tokenized columns
        )

        # Assign to train_dataset
        self.train_dataset = tokenized_dataset
        print(
            "Dataset setup complete."
            f"Train dataset size: {len(self.train_dataset)}"
        )

    def train_dataloader(self):
        if self.train_dataset is None:
            raise RuntimeError(
                "Train dataset not initialized. Call setup() first."
            )

        return data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            collate_fn=self.data_collator,  # Use the data collator here!
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,  # Improves performance when using GPUs
        )

### **Load**

In [None]:
DATA_MODULE = InstructionDataModule()

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

## **Model**

### **Utils**

In [None]:
class AvgMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.scores = list()

    def update(self, val):
        self.scores.append(val)

    def show(self):
        scores = torch.stack(self.scores)
        return torch.mean(scores)

### **Module**

In [None]:
class GPT2ChatConfig(PretrainedConfig):
    model_type = "gpt2chat"
    keys_to_ignore_at_inference = ["past_key_values"]
    attribute_map = {
        "hidden_size": "n_embd",
        "max_position_embeddings": "n_positions",
        "num_attention_heads": "n_head",
        "num_hidden_layers": "n_layer",
    }

    def __init__(
        self,
        vocab_size=VOCAB_SIZE,
        n_positions=MAX_SEQ_LENGTH,
        n_embd=768,
        n_layer=12,
        n_head=12,
        n_inner=None,
        activation_function="gelu_new",
        resid_pdrop=0.1,
        embd_pdrop=0.1,
        attn_pdrop=0.1,
        layer_norm_epsilon=1e-5,
        initializer_range=0.02,
        summary_type="cls_index",
        summary_use_proj=True,
        summary_activation=None,
        summary_proj_to_labels=True,
        summary_first_dropout=0.1,
        scale_attn_weights=True,
        use_cache=True,
        bos_token_id=50257,
        eos_token_id=50256,
        scale_attn_by_inverse_layer_idx=False,
        reorder_and_upcast_attn=False,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.n_positions = n_positions
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_inner = n_inner
        self.activation_function = activation_function
        self.resid_pdrop = resid_pdrop
        self.embd_pdrop = embd_pdrop
        self.attn_pdrop = attn_pdrop
        self.layer_norm_epsilon = layer_norm_epsilon
        self.initializer_range = initializer_range
        self.summary_type = summary_type
        self.summary_use_proj = summary_use_proj
        self.summary_activation = summary_activation
        self.summary_first_dropout = summary_first_dropout
        self.summary_proj_to_labels = summary_proj_to_labels
        self.scale_attn_weights = scale_attn_weights
        self.use_cache = use_cache
        self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
        self.reorder_and_upcast_attn = reorder_and_upcast_attn

        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id

        super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

### **Wrapper**

In [None]:
class GPT2ChatFineTuner(L.LightningModule):
    def __init__(
        self,
        model_name: str = BASE_MODEL,
        learning_rate: float = LEARNING_RATE,
        adam_epsilon: float = ADAM_EPSILON,
        weight_decay: float = WEIGHT_DECAY,
        warmup_steps: int = WARMUP_STEP,
        num_cycles: int = NUM_CYCLE,
        log_interval: int = LOG_INTERVAL,
    ):
        super().__init__()

        self.model_name = model_name
        self.learning_rate = learning_rate
        self.adam_epsilon = adam_epsilon
        self.warmup_steps = warmup_steps
        self.weight_decay = weight_decay
        self.num_cycles = num_cycles
        self.log_interval = log_interval

        config = GPT2ChatConfig()

        self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
        self.model.resize_token_embeddings(len(DATA_MODULE.tokenizer))
        self.model.config = config
        self.model.config.bos_token_id = DATA_MODULE.tokenizer.bos_token_id
        self.model.config.eos_token_id = DATA_MODULE.tokenizer.eos_token_id
        self.model.config.n_positions = MAX_SEQ_LENGTH
        self.model.config.vocab_size = VOCAB_SIZE
        wpe_prev_mean = self.model.transformer.wpe.weight.data.mean()
        wpe_prev_std = self.model.transformer.wpe.weight.data.std()
        self.model.transformer.wpe = nn.Embedding(
            MAX_SEQ_LENGTH,
            self.model.config.hidden_size,
        )
        self.model.transformer.wpe.weight.data.normal_(
            mean=wpe_prev_mean,
            std=wpe_prev_std,
        )
        if self.model.transformer.wpe.padding_idx is not None:
            self.model.transformer.wpe.weight.data[
                self.model.transformer.wpe.padding_idx
            ].zero_()

        for block in self.model.transformer.h:
            block.attn.bias = torch.tril(
                torch.ones(
                    (MAX_SEQ_LENGTH, MAX_SEQ_LENGTH),
                    dtype=torch.bool
                )
            ).view(
                1, 1, MAX_SEQ_LENGTH, MAX_SEQ_LENGTH
            )

        self.model.train()

        self.train_loss = list()
        self.train_loss_recorder = AvgMeter()

        self.steps = list()

    def forward(
        self,
        input_ids,
        attention_mask=None,
        labels=None,
    ):
        return self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,  # Pass labels to compute the loss
        )

    def training_step(self, batch, batch_idx):
        # batch comes from the DataLoader, prepared by
        #     DataCollatorForLanguageModeling
        # It should contain 'input_ids', 'attention_mask', and 'labels'
        outputs = self(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
        )

        loss = outputs.loss  # Extract the loss from the model outputs

        # Log the training loss
        self.log(
            "train_loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        self.train_loss_recorder.update(loss.data)

        if (batch_idx + 1) % self.log_interval == 0:
            self.train_loss.append(
                self.train_loss_recorder.show().data.cpu().numpy()
            )
            self.train_loss_recorder = AvgMeter()
            self.steps.append(
                self.current_epoch
                * (len(DATA_MODULE.train_dataset) // BATCH_SIZE)
                + batch_idx
            )

        return loss  # Return the loss tensor

    def on_train_epoch_end(self):
        self._save_plot_curves()

    def _save_plot_curves(self):
        # Loss
        img_file = os.path.join(
            TRAINING_DIR,
            f"{MODEL_NAME}_loss_plot.png",
        )
        plt.plot(self.steps, self.train_loss, color="b", label="loss")
        plt.title("Loss Curves")
        plt.xlabel("Step")
        plt.ylabel("Loss")
        plt.legend()
        plt.grid()
        plt.savefig(img_file)
        plt.clf()

    def configure_optimizers(self):
        # Apply weight decay to parameters, excluding bias and LayerNorm weights
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": self.weight_decay,
            },
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
            },
        ]

        # Create optimizer
        optimizer = optim.AdamW(
            optimizer_grouped_parameters,
            lr=self.learning_rate,
            eps=self.adam_epsilon,
        )

        # Create learning rate scheduler (optional)
        # Calculate total training steps if needed for scheduler
        # Available after trainer initialization
        num_training_steps = self.trainer.estimated_stepping_batches

        # LR scheduler with warmup
        scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.warmup_steps,
            num_training_steps=num_training_steps,
            num_cycles=self.num_cycles,
        )

        return [optimizer], [{"scheduler": scheduler, "interval": "step"}]

## **Training**

In [None]:
try:
    seed_everything(SEED, workers=True)


    MODEL_MODULE = GPT2ChatFineTuner()


    callback = list()

    checkpoint_callback = ModelCheckpoint(
        dirpath=MODEL_DIR,
        filename=MODEL_NAME,
        save_top_k=1,
        monitor=METRIC_TO_MONITOR,
        mode=METRIC_MODE,
        save_last=True,
        every_n_train_steps=SAVE_INTERVAL,
    )
    callback.append(checkpoint_callback)

    early_stop_callback = EarlyStopping(
        monitor=METRIC_TO_MONITOR,
        min_delta=EARLY_STOPPING_DELTA,
        patience=EARLY_STOPPING_PATIENCE,
        verbose=False,
        mode=METRIC_MODE,
    )
    callback.append(early_stop_callback)

    if os.path.exists(os.path.join(MODEL_DIR, "last.ckpt")):
        os.rename(
            os.path.join(MODEL_DIR, "last.ckpt"),
            os.path.join(MODEL_DIR, f"{MODEL_NAME}.ckpt"),
        )
        CKPT_PATH = os.path.join(MODEL_DIR, f"{MODEL_NAME}.ckpt")
    else:
        CKPT_PATH = None

    trainer = Trainer(
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=NUM_GPU if torch.cuda.is_available() else 1,
        max_epochs=NUM_EPOCH,
        accumulate_grad_batches=ACCUMULATE_GRAD_BATCH,
        precision=PRECISION if torch.cuda.is_available() else 32,
        callbacks=callback,
        log_every_n_steps=20,
        logger=False,
    )

    print("Starting training...")
    trainer.fit(
        model=MODEL_MODULE,
        datamodule=DATA_MODULE,
        ckpt_path=CKPT_PATH,
    )
    print("Training finished.")


    final_model_path = os.path.join(MODEL_DIR, MODEL_NAME)

    print(f"Saving final model to: {final_model_path}")

    MODEL_MODULE = GPT2ChatFineTuner.load_from_checkpoint(
        os.path.join(MODEL_DIR, f"{MODEL_NAME}.ckpt")
    )

    MODEL_MODULE.model.save_pretrained(final_model_path)
    DATA_MODULE.tokenizer.save_pretrained(final_model_path)

    print("Model saving complete.")

    print("Registering GPT2Chat...")

    GPT2LMHeadModel.config_class = GPT2ChatConfig
    AutoConfig.register(f"{BASE_MODEL}chat", GPT2ChatConfig)
    AutoModelForCausalLM.register(GPT2ChatConfig, GPT2LMHeadModel)

    print("GPT2Chat registered.")

except Exception as e:

    print(f"Training has been stopped. Reason: {e}")


torch.cuda.empty_cache()

INFO: Seed set to 1651820508
INFO:lightning.fabric.utilities.seed:Seed set to 1651820508
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


Starting training...


README.md:   0%|          | 0.00/10.2k [00:00<?, ?B/s]

(…)-00000-of-00001-b42a775f407cee45.parquet:   0%|          | 0.00/39.5M [00:00<?, ?B/s]

(…)-00000-of-00001-134b8fd0c89408b6.parquet:   0%|          | 0.00/2.08M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/84437 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4401 [00:00<?, ? examples/s]

README.md:   0%|          | 0.00/7.47k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


(…)-00000-of-00001-a09b74b3ef9c3b56.parquet:   0%|          | 0.00/24.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/52002 [00:00<?, ? examples/s]

Traversing the message tree:   0%|          | 0/3574 [00:00<?, ?it/s]

Building conversations:   0%|          | 0/3574 [00:00<?, ?it/s]

Map:   0%|          | 0/87694 [00:00<?, ? examples/s]

INFO: Restoring states from the checkpoint path at /content/drive/MyDrive/GPT2Chatbot/model/gpt2chat-finetuned-lightning.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at /content/drive/MyDrive/GPT2Chatbot/model/gpt2chat-finetuned-lightning.ckpt


Dataset setup complete.Train dataset size: 87694


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: Loading `train_dataloader` to estimate number of stepping batches.
INFO:lightning.pytorch.utilities.rank_zero:Loading `train_dataloader` to estimate number of stepping batches.
INFO: 
  | Name  | Type            | Params | Mode 
--------------------------------------------------
0 | model | GPT2LMHeadModel | 125 M  | train
--------------------------------------------------
125 M     Trainable params
0         Non-trainable params
125 M     Total params
500.926   Total estimated model params size (MB)
164       Modules in train mode
0         Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name  | Type            | Params | Mode 
--------------------------------------------------
0 | model | GPT2LMHeadModel | 125 M  | train
--------------------------------------------------
125 M     Trainable params
0         Non-trainable par

Training: |          | 0/? [00:00<?, ?it/s]

`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
INFO: `Trainer.fit` stopped: `max_epochs=10` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.


Training finished.
Saving final model to: /content/drive/MyDrive/GPT2Chatbot/model/gpt2chat-finetuned-lightning
Model saving complete.
Registering GPT2Chat...
GPT2Chat registered.


<Figure size 640x480 with 0 Axes>

## **Preference Alignment**

### **Initialize**

In [None]:
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME)

In [None]:
preference_dataset = load_dataset(
    path=PREFERENCE_DATASET,
    split=PREFERENCE_DATASET_SUBSET,
)

README.md:   0%|          | 0.00/643 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/131M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/2.14M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/62135 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [None]:
preference_device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
eval_precision = (
    torch.float16
    if preference_device == "cuda"
    else torch.float32
)

try:
    gpt2chat_model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=eval_precision,
    ).to(preference_device)
    gpt2chat_tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
except Exception as e:
    print(f"Could not load the model due to {e}.\n")
    print("Registering GPT2Chat...")
    GPT2LMHeadModel.config_class = GPT2ChatConfig
    AutoConfig.register(f"{BASE_MODEL}chat", GPT2ChatConfig)
    AutoModelForCausalLM.register(GPT2ChatConfig, GPT2LMHeadModel)
    print("GPT2Chat registered.")

    gpt2chat_model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=eval_precision,
    ).to(preference_device)
    gpt2chat_tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

gpt2chat_model.config.use_cache = False

Could not load the model due to The checkpoint you are trying to load has model type `gpt2chat` but Transformers does not recognize this architecture. This could be because of an issue with the checkpoint, or because your version of Transformers is out of date.

You can update Transformers with the command `pip install --upgrade transformers`. If this does not work, and the checkpoint is very new, then there may not be a release version that supports this model yet. In this case, you can get the most up-to-date code by installing Transformers from source with the command `pip install git+https://github.com/huggingface/transformers.git`.

Registering GPT2Chat...
GPT2Chat registered.


In [None]:
files = glob.glob(os.path.join(MODEL_DIR, "checkpoint-*"))
if len(files):
    files = sorted(files, key=lambda x: int(os.path.basename(x).split('-')[-1]))
    preference_resume_path = files[-1]
    print(f"Resuming from {preference_resume_path}")
else:
    preference_resume_path = None
    print("No checkpoints found. Starting from scratch.")

training_args = ORPOConfig(
    # Small learning rate to prevent catastrophic forgetting
    learning_rate=PREFERENCE_LEARNING_RATE,
    # Linear learning rate decay over training
    lr_scheduler_type="linear",
    # Maximum combined length of prompt + completion
    max_length=MAX_SEQ_LENGTH,
    # Maximum length for input prompts
    max_prompt_length=MAX_SEQ_LENGTH // 4,
    # Controls weight of the odds ratio loss (λ in paper)
    beta=0.1,
    # Batch size for training
    per_device_train_batch_size=BATCH_SIZE,
    # Helps with training stability by accumulating gradients before updating
    gradient_accumulation_steps=ACCUMULATE_GRAD_BATCH,
    # Memory-efficient optimizer for CUDA, falls back to adamw_torch for CPU/MPS
    optim="paged_adamw_8bit" if preference_device == "cuda" else "adamw_torch",
    # Total number of training epochs
    num_train_epochs=PREFERENCE_EPOCH,
    # Log metrics every step
    logging_steps=PREFERENCE_LOGGING_STEP,
    # Gradual learning rate warmup
    warmup_steps=PREFERENCE_WARMUP_STEP,
    # Disable external logging
    report_to="none",
    # Where to save model/checkpoints
    output_dir=MODEL_DIR,
    # Enable MPS (Metal Performance Shaders) if available
    use_mps_device=preference_device == "mps",
    restore_callback_states_from_checkpoint=True,
    resume_from_checkpoint=preference_resume_path,
    save_steps=PREFERENCE_SAVE_STEP,
)

Resuming from /content/drive/MyDrive/GPT2Chatbot/model/checkpoint-2400


### **Utils**

In [None]:
def is_conversational(example: dict[str, Any]) -> bool:
    supported_keys = ["prompt", "chosen", "rejected", "completion", "messages"]
    example_keys = {key for key in example.keys() if key in supported_keys}

    # It must have one of the supported keys
    if example_keys:
        key = example_keys.pop()  # take the first supported key
        maybe_messages = example[key]
        # It must be a list of messages,
        if isinstance(maybe_messages, list):
            maybe_message = maybe_messages[0]
            # Each message must a list of dictionaries with keys "role" and "content"
            if (
                isinstance(maybe_message, dict)
                and "role" in maybe_message
                and "content" in maybe_message
            ):
                return True

    return False

In [None]:
def extract_prompt(example: dict[str, Sequence]) -> dict[str, Sequence]:
    for idx in range(min(len(example["chosen"]), len(example["rejected"]))):
        if example["chosen"][idx] != example["rejected"][idx]:
            if example["chosen"][idx - 1] == " ":  # remove space before the prompt
                idx -= 1
            break

    return {
        "prompt": example["chosen"][:idx],
        "chosen": example["chosen"][idx:],
        "rejected": example["rejected"][idx:],
    }


def maybe_extract_prompt(example: dict[str, list]) -> dict[str, list]:
    if "chosen" not in example or "rejected" not in example:
        return example
    if "prompt" in example:
        # Both conversational or both non-conversational
        chosen_conv = is_conversational({"chosen": example["chosen"]})
        prompt_conv = is_conversational({"prompt": example["prompt"]})
        if (chosen_conv and prompt_conv) or (not chosen_conv and not prompt_conv):
            return example
    return extract_prompt(
        {"chosen": example["chosen"], "rejected": example["rejected"]}
    )

In [None]:
def apply_chat_template(
    example: dict[str, list[dict[str, str]]],
    tokenizer: PreTrainedTokenizerBase,
    tools: Optional[list[Union[dict, Callable]]] = None,
) -> dict[str, str]:
    # Check that the example has the correct keys
    supported_keys = ["prompt", "chosen", "rejected", "completion", "messages", "label"]
    example_keys = {key for key in example.keys() if key in supported_keys}
    if example_keys not in [
        {"messages"},  # language modeling
        {"prompt"},  # prompt-only
        {"prompt", "completion"},  # prompt-completion
        {"prompt", "chosen", "rejected"},  # preference
        {"chosen", "rejected"},  # preference with implicit prompt
        {"prompt", "completion", "label"},  # unpaired preference
    ]:
        raise KeyError(f"Invalid keys in the example: {example_keys}")

    # Apply the chat template to the whole conversation
    if "messages" in example:
        messages = tokenizer.apply_chat_template(
            example["messages"], tools=tools, tokenize=False
        )

    # Apply the chat template to the prompt, adding the generation prompt
    if "prompt" in example:
        last_role = example["prompt"][-1]["role"]
        if last_role == "user":
            add_generation_prompt = True
            continue_final_message = False
        elif last_role == "assistant":
            add_generation_prompt = False
            continue_final_message = True
        else:
            raise ValueError(f"Invalid role in the last message: {last_role}")
        prompt = tokenizer.apply_chat_template(
            example["prompt"],
            tools=tools,
            continue_final_message=continue_final_message,
            tokenize=False,
            add_generation_prompt=add_generation_prompt,
        )

    # Remove extra spaces
    if prompt[-12:] == "            ":
        prompt = prompt[:-12]

    # Apply the chat template to the entire prompt + completion
    if "prompt" in example:  # explicit prompt and prompt-completion case
        if "chosen" in example:
            prompt_chosen = tokenizer.apply_chat_template(
                example["prompt"] + example["chosen"], tools=tools, tokenize=False
            )
            chosen = prompt_chosen[len(prompt) :]
        if "rejected" in example and "prompt" in example:  # explicit prompt
            prompt_rejected = tokenizer.apply_chat_template(
                example["prompt"] + example["rejected"], tools=tools, tokenize=False
            )
            rejected = prompt_rejected[len(prompt) :]
        if "completion" in example:
            prompt_completion = tokenizer.apply_chat_template(
                example["prompt"] + example["completion"], tools=tools, tokenize=False
            )
            completion = prompt_completion[len(prompt) :]
    else:  # implicit prompt case
        if "chosen" in example:
            chosen = tokenizer.apply_chat_template(
                example["chosen"], tools=tools, tokenize=False
            )
        if "rejected" in example:
            rejected = tokenizer.apply_chat_template(
                example["rejected"], tools=tools, tokenize=False
            )

    # Ensure that the prompt is the initial part of the prompt-completion string
    if "prompt" in example:
        error_message = (
            "The chat template applied to the prompt + completion does not start with the chat template applied to "
            "the prompt alone. This can indicate that the chat template is not supported by TRL."
            "\n**Prompt**:\n{}\n\n**Prompt + Completion**:\n{}"
        )
        if "chosen" in example and not prompt_chosen.startswith(prompt):
            raise ValueError(error_message.format(prompt, prompt_chosen))
        if "rejected" in example and not prompt_rejected.startswith(prompt):
            raise ValueError(error_message.format(prompt, prompt_rejected))
        if "completion" in example and not prompt_completion.startswith(prompt):
            raise ValueError(error_message.format(prompt, prompt_completion))

    # Extract the completion by removing the prompt part from the prompt-completion string
    output = {}
    if "messages" in example:
        output["text"] = messages
    if "prompt" in example:
        output["prompt"] = prompt
    if "chosen" in example:
        output["chosen"] = chosen
    if "rejected" in example:
        output["rejected"] = rejected
    if "completion" in example:
        output["completion"] = completion
    if "label" in example:
        output["label"] = example["label"]

    return output


def maybe_apply_chat_template(
    example: dict[str, list[dict[str, str]]],
    tokenizer: PreTrainedTokenizerBase,
    tools: Optional[list[Union[dict, Callable]]] = None,
) -> dict[str, str]:
    if is_conversational(example):
        return apply_chat_template(example, tokenizer, tools)
    else:
        return example

In [None]:
def pad(
    tensors: list[torch.Tensor],
    padding_value: int = 0,
    padding_side: str = "right",
    pad_to_multiple_of: Optional[int] = None,
) -> torch.Tensor:
    # Determine the maximum shape for each dimension
    output_shape = np.max([t.shape for t in tensors], 0).tolist()

    # Apply pad_to_multiple_of to the first (sequence) dimension
    if pad_to_multiple_of is not None:
        remainder = output_shape[0] % pad_to_multiple_of
        if remainder != 0:
            output_shape[0] += pad_to_multiple_of - remainder

    # Create an output tensor filled with the padding value
    output = torch.full((len(tensors), *output_shape), padding_value, dtype=tensors[0].dtype, device=tensors[0].device)

    for i, t in enumerate(tensors):
        if padding_side == "left":
            seq_start = output_shape[0] - t.shape[0]
        elif padding_side == "right":
            seq_start = 0
        else:
            raise ValueError("padding_side must be 'left' or 'right'")

        # Define the slices
        seq_slice = slice(seq_start, seq_start + t.shape[0])
        slices = (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:])
        output[i][slices] = t

    return output

In [None]:
@dataclass
class DPODataCollatorWithPadding:

    pad_token_id: int = 0
    label_pad_token_id: int = -100
    is_encoder_decoder: Optional[bool] = False

    def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
        # first, pad everything to the same length
        padded_batch = {}
        for k in features[0].keys():
            if k.endswith(("_input_ids", "_attention_mask", "_labels", "_pixel_values")):
                if self.is_encoder_decoder:
                    to_pad = [torch.LongTensor(ex[k]) for ex in features]

                    if (k.startswith("prompt")) and (k.endswith("input_ids")):
                        if self.pad_token_id is None:
                            raise ValueError(
                                "Padding is enabled, but the tokenizer is not configured with a padding token."
                                " Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)"
                                " before calling the trainer."
                            )
                        padding_value = self.pad_token_id
                    elif k.endswith("_attention_mask"):
                        padding_value = 0
                    elif k.startswith(("chosen", "rejected", "completion")) or ("decoder" in k):
                        padding_value = self.label_pad_token_id
                    else:
                        raise ValueError(f"Unexpected key in batch '{k}'")
                    padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
                else:
                    # Set padding value based on the key
                    if k.endswith("_input_ids"):
                        if self.pad_token_id is None:
                            raise ValueError(
                                "Padding is enabled, but the tokenizer is not configured with a padding token."
                                " Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)"
                                " before calling the trainer."
                            )
                        padding_value = self.pad_token_id
                    elif k.endswith("_labels"):
                        padding_value = self.label_pad_token_id
                    elif k.endswith("_attention_mask"):
                        padding_value = 0
                    elif k.endswith("_pixel_values"):
                        padding_value = 0  # TODO: check if this is correct
                    else:
                        raise ValueError(f"Unexpected key in batch '{k}'")

                    # Set padding side based on the key
                    if k in ["prompt_input_ids", "prompt_attention_mask"]:
                        padding_side = "left"
                    else:
                        padding_side = "right"

                    # Set the dtype
                    if k.endswith("_pixel_values"):
                        dtype = torch.float32  # will be downcasted if necessary by the Trainer
                    else:
                        dtype = torch.int64

                    # Convert to tensor and pad
                    to_pad = [torch.tensor(ex[k], dtype=dtype) for ex in features]
                    padded_batch[k] = pad(to_pad, padding_value=padding_value, padding_side=padding_side)
            elif k.endswith("_logps"):
                # the cached reference model logprobs
                padded_batch[k] = torch.tensor([ex[k] for ex in features])
            else:
                padded_batch[k] = [ex[k] for ex in features]

        return padded_batch

In [None]:
def peft_module_casting_to_bf16(model):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.LayerNorm) or "norm" in name:
            module = module.to(torch.float32)
        elif any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
            if hasattr(module, "weight"):
                if module.weight.dtype == torch.float32:
                    module = module.to(torch.bfloat16)

In [None]:
def disable_dropout_in_model(model: torch.nn.Module) -> None:
    for module in model.modules():
        if isinstance(module, torch.nn.Dropout):
            module.p = 0

In [None]:
def add_bos_token_if_needed(
    bos_token_id: Optional[int],
    prompt_len_input_ids: int,
    prompt_tokens: dict[str, list[int]],
    chosen_prompt_len_input_ids: int,
    chosen_tokens: dict[str, list[int]],
    rejected_prompt_len_input_ids: int,
    rejected_tokens: dict[str, list[int]],
):
    if bos_token_id is not None:
        if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]:
            prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"]
            prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
        if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]:
            chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"]
            chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
        if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]:
            rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"]
            rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]
    return prompt_tokens, chosen_tokens, rejected_tokens

In [None]:
def add_eos_token_if_needed(
    eos_token_id: int, chosen_tokens: dict[str, list[int]], rejected_tokens: dict[str, list[int]]
):
    if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]:
        chosen_tokens["input_ids"].append(eos_token_id)
        chosen_tokens["attention_mask"].append(1)
    if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]:
        rejected_tokens["input_ids"].append(eos_token_id)
        rejected_tokens["attention_mask"].append(1)
    return chosen_tokens, rejected_tokens

In [None]:
def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor:
    if tensor.size(dim) >= length:
        return tensor
    else:
        pad_size = list(tensor.shape)
        pad_size[dim] = length - tensor.size(dim)
        return torch.cat(
            [
                tensor,
                pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device),
            ],
            dim=dim,
        )

In [None]:
def selective_log_softmax(logits, index):
    if logits.dtype in [torch.float32, torch.float64]:
        selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
        # loop to reduce peak mem consumption
        logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
        per_token_logps = selected_logits - logsumexp_values  # log_softmax(x_i) = x_i - logsumexp(x)
    else:
        # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
        per_token_logps = []
        for row_logits, row_labels in zip(logits, index):  # loop to reduce peak mem consumption
            row_logps = F.log_softmax(row_logits, dim=-1)
            row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
            per_token_logps.append(row_per_token_logps)
        per_token_logps = torch.stack(per_token_logps)
    return per_token_logps

In [None]:
def log_table_to_comet_experiment(name: str, table: pd.DataFrame) -> None:
    if not is_comet_available():
        raise ModuleNotFoundError("The comet-ml is not installed. Please install it first: pip install comet-ml")

    experiment = comet_ml.get_running_experiment()
    if experiment is not None:
        experiment.log_table(tabular_data=table, filename=name)

In [None]:
def get_comet_experiment_url() -> Optional[str]:
    if not is_comet_available():
        return None

    if comet_ml.get_running_experiment() is not None:
        return comet_ml.get_running_experiment().url

    return None

In [None]:
def generate_model_card(
    base_model: Optional[str],
    model_name: str,
    hub_model_id: str,
    dataset_name: Optional[str],
    tags: list[str],
    wandb_url: Optional[str],
    trainer_name: str,
    trainer_citation: Optional[str] = None,
    paper_title: Optional[str] = None,
    paper_id: Optional[str] = None,
    comet_url: Optional[str] = None,
) -> ModelCard:
    card_data = ModelCardData(
        base_model=base_model,
        datasets=dataset_name,
        library_name="transformers",
        licence="license",
        model_name=model_name,
        tags=["generated_from_trainer", *tags],
    )
    card = ModelCard.from_template(
        card_data,
        template_path=str(pkg_resources.files("trl").joinpath("templates/lm_model_card.md")),
        base_model=base_model,
        model_name=model_name,
        hub_model_id=hub_model_id,
        dataset_name=dataset_name,
        wandb_url=wandb_url,
        comet_url=comet_url,
        trainer_name=trainer_name,
        trainer_citation=trainer_citation,
        paper_title=paper_title,
        paper_id=paper_id,
        trl_version=version("trl"),
        transformers_version=version("transformers"),
        pytorch_version=version("torch"),
        datasets_version=version("datasets"),
        tokenizers_version=version("tokenizers"),
    )
    return card

In [None]:
class HijackedORPOTrainer(transformers.Trainer):

    _tag_names = ["trl", "orpo"]

    def __init__(
        self,
        model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
        args: Optional[ORPOConfig] = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
        processing_class: Optional[
            Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
        ] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        callbacks: Optional[list[TrainerCallback]] = None,
        optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
        peft_config: Optional[dict] = None,
        compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
    ):
        if args.model_init_kwargs is None:
            model_init_kwargs = {}
        elif not isinstance(model, str):
            raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
        else:
            model_init_kwargs = args.model_init_kwargs
            torch_dtype = model_init_kwargs.get("torch_dtype")
            if torch_dtype is not None:
                # Convert to `torch.dtype` if an str is passed
                if isinstance(torch_dtype, str) and torch_dtype != "auto":
                    torch_dtype = getattr(torch, torch_dtype)
                if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
                    raise ValueError(
                        f"Invalid `torch_dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
                    )
                model_init_kwargs["torch_dtype"] = torch_dtype

        if isinstance(model, str):
            model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)

        # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
        # has been called in order to properly call autocast if needed.
        self._peft_has_been_casted_to_bf16 = False

        if not is_peft_available() and peft_config is not None:
            raise ValueError(
                "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
            )
        elif is_peft_available() and peft_config is not None:
            # if model is a peft model and we have a peft_config, we merge and unload it first
            if isinstance(model, PeftModel):
                model = model.merge_and_unload()

            if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
                _support_gc_kwargs = hasattr(
                    args, "gradient_checkpointing_kwargs"
                ) and "gradient_checkpointing_kwargs" in list(
                    inspect.signature(prepare_model_for_kbit_training).parameters
                )

                prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}

                if _support_gc_kwargs:
                    prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs

                model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
            elif getattr(args, "gradient_checkpointing", False):
                # For backward compatibility with older versions of transformers
                if hasattr(model, "enable_input_require_grads"):
                    model.enable_input_require_grads()
                else:

                    def make_inputs_require_grad(module, input, output):
                        output.requires_grad_(True)

                    model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

            # get peft model with the given config
            model = get_peft_model(model, peft_config)
            if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
                peft_module_casting_to_bf16(model)
                # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
                self._peft_has_been_casted_to_bf16 = True

        # For models that use gradient_checkpointing, we need to attach a hook that enables input
        # to explicitly have `requires_grad=True`, otherwise training will either silently
        # fail or completely fail.
        elif getattr(args, "gradient_checkpointing", False):
            # For backward compatibility with older versions of transformers
            if hasattr(model, "enable_input_require_grads"):
                model.enable_input_require_grads()
            else:

                def make_inputs_require_grad(module, input, output):
                    output.requires_grad_(True)

                model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

        if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
            raise ValueError(
                "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
                " Please install `wandb` or `comet-ml` to resolve."
            )

        if model is not None:
            self.is_encoder_decoder = model.config.is_encoder_decoder
        elif args.is_encoder_decoder is None:
            raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
        else:
            self.is_encoder_decoder = args.is_encoder_decoder

        if self.is_encoder_decoder:
            self.decoder_start_token_id = model.config.decoder_start_token_id
            self.pad_token_id = model.config.pad_token_id

        if processing_class is None:
            raise ValueError("processing_class must be specified to tokenize a ORPO dataset.")
        if args.max_length is None:
            warnings.warn(
                "`max_length` is not set in the ORPOConfig's init"
                " it will default to `512` by default, but you should do it yourself in the future.",
                UserWarning,
            )
            max_length = 512
        else:
            max_length = args.max_length
        if args.max_prompt_length is None:
            warnings.warn(
                "`max_prompt_length` is not set in the ORPOConfig's init"
                " it will default to `128` by default, but you should do it yourself in the future.",
                UserWarning,
            )
            max_prompt_length = 128
        else:
            max_prompt_length = args.max_prompt_length

        if args.max_completion_length is None and self.is_encoder_decoder:
            warnings.warn(
                "When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
                " it will default to `128` by default, but you should do it yourself in the future.",
                UserWarning,
            )
            self.max_completion_length = 128
        else:
            self.max_completion_length = args.max_completion_length

        if data_collator is None:
            data_collator = DPODataCollatorWithPadding(
                pad_token_id=processing_class.pad_token_id,
                label_pad_token_id=args.label_pad_token_id,
                is_encoder_decoder=self.is_encoder_decoder,
            )

            if args.remove_unused_columns:
                args.remove_unused_columns = False
                # warn users
                warnings.warn(
                    "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
                    " we have set it for you, but you should do it yourself in the future.",
                    UserWarning,
                )

            self.use_dpo_data_collator = True
        else:
            self.use_dpo_data_collator = False

        # Disable dropout in the model and reference model
        if args.disable_dropout:
            disable_dropout_in_model(model)

        self.max_length = max_length
        self.generate_during_eval = args.generate_during_eval
        self.label_pad_token_id = args.label_pad_token_id
        self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
        self.max_prompt_length = max_prompt_length
        self.truncation_mode = args.truncation_mode
        self.processing_class = processing_class

        self.beta = args.beta
        self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
        self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
        if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
            warnings.warn(
                "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
                "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
                "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
                "loss.",
                UserWarning,
            )

        self._stored_metrics = defaultdict(lambda: defaultdict(list))

        # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
        # input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the
        # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
        # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
        # of the input, floating-point operations will not be computed." To suppress this warning, we set the
        # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
        # that the warning has already been issued.
        model.warnings_issued["estimate_tokens"] = True

        # Compute that only on the main process for faster data processing.
        # see: https://github.com/huggingface/trl/pull/1255
        with PartialState().main_process_first():
            # Extract the prompt if needed, and apply the chat template if needed
            train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
            train_dataset = train_dataset.map(
                maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
            )
            train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
            if eval_dataset is not None:
                eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
                eval_dataset = eval_dataset.map(
                    maybe_apply_chat_template,
                    fn_kwargs={"tokenizer": processing_class},
                    num_proc=args.dataset_num_proc,
                )
                eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)

        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            processing_class=processing_class,
            model_init=model_init,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
            optimizers=optimizers,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        )

        # Add tags for models that have been loaded with the correct transformers version
        if hasattr(self.model, "add_model_tags"):
            self.model.add_model_tags(self._tag_names)

        if not hasattr(self, "accelerator"):
            raise AttributeError(
                "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
            )

    def build_tokenized_answer(self, prompt, answer):
        """
        Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
        It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
        Reference:
            https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
        """

        full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
        prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]

        answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
        answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]

        # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
        full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])

        # Prepare input tokens for token by token comparison
        full_input_ids = np.array(full_tokenized["input_ids"])

        if len(full_input_ids) != len(full_concat_input_ids):
            raise ValueError("Prompt input ids and answer input ids should have the same length.")

        # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
        # can be merged together when tokenizing prompt+answer. This could result
        # on the last token from the prompt being different when tokenized on its own
        # vs when done as prompt+answer.
        response_token_ids_start_idx = len(prompt_input_ids)

        # If tokenized prompt is different than both prompt+answer, then it means the
        # last token has changed due to merging.
        if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
            response_token_ids_start_idx -= 1

        prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
        prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]

        if len(prompt_input_ids) != len(prompt_attention_mask):
            raise ValueError("Prompt input ids and attention mask should have the same length.")

        answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
        answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]

        return dict(
            prompt_input_ids=prompt_input_ids,
            prompt_attention_mask=prompt_attention_mask,
            input_ids=answer_input_ids,
            attention_mask=answer_attention_mask,
        )

    def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
        """Tokenize a single row from a ORPO specific dataset.

        At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
        in case the prompt + chosen or prompt + rejected responses is/are too long. First
        we truncate the prompt; if we're still too long, we truncate the chosen/rejected.

        We also create the labels for the chosen/rejected responses, which are of length equal to
        the sum of the length of the prompt and the chosen/rejected response, with
        label_pad_token_id  for the prompt tokens.
        """
        batch = {}
        prompt = feature["prompt"]
        chosen = feature["chosen"]
        rejected = feature["rejected"]

        if not self.is_encoder_decoder:
            # Check issues below for more details
            #  1. https://github.com/huggingface/trl/issues/907
            #  2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
            #  3. https://github.com/LianjiaTech/BELLE/issues/337

            if not isinstance(prompt, str):
                raise ValueError(f"prompt should be an str but got {type(prompt)}")
            prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
            prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}

            if not isinstance(chosen, str):
                raise ValueError(f"chosen should be an str but got {type(chosen)}")
            chosen_tokens = self.build_tokenized_answer(prompt, chosen)

            if not isinstance(rejected, str):
                raise ValueError(f"rejected should be an str but got {type(rejected)}")
            rejected_tokens = self.build_tokenized_answer(prompt, rejected)

            # Last prompt token might get merged by tokenizer and
            # it should not be included for generation if that happens
            prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])

            chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
            rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
            prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)

            for k, v in prompt_tokens.items():
                prompt_tokens[k] = v[:prompt_len_input_ids]

            # Make sure prompts only have one different token at most an
            # and length only differs by 1 at most
            num_diff_tokens = sum(
                [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
            )
            num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
            if num_diff_tokens > 1 or num_diff_len > 1:
                raise ValueError(
                    "Chosen and rejected prompt_input_ids might only differ on the "
                    "last token due to tokenizer merge ops."
                )

            # add BOS token to head of prompt. Avoid adding if it's already there
            prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
                self.processing_class.bos_token_id,
                prompt_len_input_ids,
                prompt_tokens,
                chosen_prompt_len_input_ids,
                chosen_tokens,
                rejected_prompt_len_input_ids,
                rejected_tokens,
            )

            # add EOS token to end of answer. Avoid adding if it's already there
            chosen_tokens, rejected_tokens = add_eos_token_if_needed(
                self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
            )

            longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))

            # if combined sequence is too long, truncate the prompt
            for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
                if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
                    if self.truncation_mode == "keep_start":
                        for k in ["prompt_input_ids", "prompt_attention_mask"]:
                            answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
                    elif self.truncation_mode == "keep_end":
                        for k in ["prompt_input_ids", "prompt_attention_mask"]:
                            answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
                    else:
                        raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")

            # if that's still too long, truncate the response
            for answer_tokens in [chosen_tokens, rejected_tokens]:
                if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
                    for k in ["input_ids", "attention_mask"]:
                        answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]

            # Create labels
            chosen_sequence_tokens = {
                k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
            }
            rejected_sequence_tokens = {
                k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
            }
            chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
            chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
                self.label_pad_token_id
            ] * len(chosen_tokens["prompt_input_ids"])
            rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
            rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
                self.label_pad_token_id
            ] * len(rejected_tokens["prompt_input_ids"])

            for k, toks in {
                "chosen_": chosen_sequence_tokens,
                "rejected_": rejected_sequence_tokens,
                "": prompt_tokens,
            }.items():
                for type_key, tokens in toks.items():
                    if type_key == "token_type_ids":
                        continue
                    batch[f"{k}{type_key}"] = tokens

        else:
            chosen_tokens = self.processing_class(
                chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
            )
            rejected_tokens = self.processing_class(
                rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
            )
            prompt_tokens = self.processing_class(
                prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
            )

            batch["chosen_labels"] = chosen_tokens["input_ids"]
            batch["rejected_labels"] = rejected_tokens["input_ids"]
            batch["prompt_input_ids"] = prompt_tokens["input_ids"]
            batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]

            if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
                batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
                    labels=torch.tensor(batch["rejected_labels"])
                )
                batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
                    labels=torch.tensor(batch["chosen_labels"])
                )

        if is_torch_xla_available():
            # Pad the sequences to global max_length to avoid TorchXLA recompilation
            for k in batch:
                if "labels" in k or self.is_encoder_decoder:
                    pad_value = self.label_pad_token_id
                elif k.endswith("_input_ids"):
                    pad_value = self.padding_value
                elif k.endswith("_attention_mask"):
                    pad_value = 0
                batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k]))
        return batch

    @staticmethod
    def concatenated_inputs(
        batch: dict[str, Union[list, torch.LongTensor]],
        is_encoder_decoder: bool = False,
        label_pad_token_id: int = -100,
        padding_value: int = 0,
        device: Optional[torch.device] = None,
    ) -> dict[str, torch.LongTensor]:
        """Concatenate the chosen and rejected inputs into a single tensor.

        Args:
            batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
            is_encoder_decoder: Whether the model is an encoder-decoder model.
            label_pad_token_id: The label pad token id.
            padding_value: The padding value to use for the concatenated inputs_ids.
            device: The device for the concatenated inputs.

        Returns:
            A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
        """
        concatenated_batch = {}

        if is_encoder_decoder:
            max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
        else:
            max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])

        for k in batch:
            if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
                if "labels" in k or is_encoder_decoder:
                    pad_value = label_pad_token_id
                elif k.endswith("_input_ids"):
                    pad_value = padding_value
                elif k.endswith("_attention_mask"):
                    pad_value = 0
                concatenated_key = k.replace("chosen", "concatenated")
                concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
        for k in batch:
            if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
                if "labels" in k or is_encoder_decoder:
                    pad_value = label_pad_token_id
                elif k.endswith("_input_ids"):
                    pad_value = padding_value
                elif k.endswith("_attention_mask"):
                    pad_value = 0
                concatenated_key = k.replace("rejected", "concatenated")
                concatenated_batch[concatenated_key] = torch.cat(
                    (
                        concatenated_batch[concatenated_key],
                        pad_to_length(batch[k], max_length, pad_value=pad_value),
                    ),
                    dim=0,
                ).to(device=device)

        if is_encoder_decoder:
            concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
            concatenated_batch["concatenated_attention_mask"] = (
                batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
            )

        return concatenated_batch

    def odds_ratio_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
    ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.

        Args:
            policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
            policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)

        Returns:
            A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
            The losses tensor contains the ORPO loss for each example in the batch.
            The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
            The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes.
            The `log(sigmoid(log_odds_chosen))` for logging purposes.
        """

        # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
        log_odds = (policy_chosen_logps - policy_rejected_logps) - (
            torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
        )
        ratio = F.logsigmoid(log_odds)
        losses = self.beta * ratio

        chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
        rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()

        return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)

    @staticmethod
    def get_batch_logps(
        logits: torch.FloatTensor,
        labels: torch.LongTensor,
        average_log_prob: bool = False,
        label_pad_token_id: int = -100,
        is_encoder_decoder: bool = False,
    ) -> torch.FloatTensor:
        """Compute the log probabilities of the given labels under the given logits.

        Args:
            logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
            labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
            average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
            label_pad_token_id: The label pad token id.
            is_encoder_decoder: Whether the model is an encoder-decoder model.

        Returns:
            A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
        """
        if logits.shape[:-1] != labels.shape:
            raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")

        if not is_encoder_decoder:
            labels = labels[:, 1:].clone()
            logits = logits[:, :-1, :]
        loss_mask = labels != label_pad_token_id

        # dummy token; we'll ignore the losses on these tokens later
        labels = torch.where(labels == label_pad_token_id, 0, labels)

        per_token_logps = selective_log_softmax(logits, labels)

        if average_log_prob:
            return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
        else:
            return (per_token_logps * loss_mask).sum(-1)

    def concatenated_forward(
        self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
    ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

        We do this to avoid doing two forward passes, because it's faster for FSDP.
        """
        concatenated_batch = self.concatenated_inputs(
            batch,
            is_encoder_decoder=self.is_encoder_decoder,
            label_pad_token_id=self.label_pad_token_id,
            padding_value=self.padding_value,
            device=self.accelerator.device,
        )
        len_chosen = batch["chosen_labels"].shape[0]

        model_kwargs = (
            {
                "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
            }
            if self.is_encoder_decoder
            else {}
        )

        if self.aux_loss_enabled:
            model_kwargs["output_router_logits"] = True

        outputs = model(
            concatenated_batch["concatenated_input_ids"],
            attention_mask=concatenated_batch["concatenated_attention_mask"],
            use_cache=False,
            **model_kwargs,
        )
        all_logits = outputs.logits

        def cross_entropy_loss(logits, labels):
            if not self.is_encoder_decoder:
                # Shift so that tokens < n predict n
                logits = logits[..., :-1, :].contiguous()
                labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = nn.CrossEntropyLoss()
            logits = logits.view(-1, logits.shape[-1])
            labels = labels.view(-1)
            # Enable model parallelism
            labels = labels.to(logits.device)
            loss = loss_fct(logits, labels)
            return loss

        if self.is_encoder_decoder:
            labels = concatenated_batch["concatenated_labels"].clone()
        else:
            labels = concatenated_batch["concatenated_input_ids"].clone()
            attention_mask = concatenated_batch["concatenated_attention_mask"]
            labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
        # orpo chosen nll loss is computed over the full prompt and response
        chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

        all_logps = self.get_batch_logps(
            all_logits,
            concatenated_batch["concatenated_labels"],
            average_log_prob=True,
            is_encoder_decoder=self.is_encoder_decoder,
            label_pad_token_id=self.label_pad_token_id,
        )

        chosen_logps = all_logps[:len_chosen]
        rejected_logps = all_logps[len_chosen:]

        if not self.is_encoder_decoder:
            chosen_logits = all_logits[:len_chosen, :-1, :]
            rejected_logits = all_logits[len_chosen:, :-1, :]
        else:
            chosen_logits = all_logits[:len_chosen]
            rejected_logits = all_logits[len_chosen:]

        if self.aux_loss_enabled:
            return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)

        return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)

    def get_batch_loss_metrics(
        self,
        model,
        batch: dict[str, Union[list, torch.LongTensor]],
        train_eval: Literal["train", "eval"] = "train",
    ):
        """Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
        metrics = {}

        forward_output = self.concatenated_forward(model, batch)
        (
            policy_chosen_logps,
            policy_rejected_logps,
            policy_chosen_logits,
            policy_rejected_logits,
            policy_nll_loss,
        ) = forward_output[:5]
        if self.aux_loss_enabled:
            aux_loss = forward_output[5]

        losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
            policy_chosen_logps, policy_rejected_logps
        )
        # full ORPO loss
        loss = policy_nll_loss - losses.mean()

        reward_accuracies = (chosen_rewards > rejected_rewards).float()

        prefix = "eval_" if train_eval == "eval" else ""
        metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
        metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
        metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
        metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
            chosen_rewards - rejected_rewards
        ).mean()
        metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
        metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
        metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
            policy_rejected_logits.detach().mean()
        ).mean()
        metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
            policy_chosen_logits.detach().mean()
        ).mean()
        metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
        metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
        metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
        if is_torch_xla_available():
            xm.mark_step()  # needed because .item() calls
        for k, v in metrics.items():
            metrics[k] = v.item()
        if self.aux_loss_enabled:
            loss += self.aux_loss_coef * aux_loss

        return loss, metrics

    def compute_loss(
        self,
        model: Union[PreTrainedModel, nn.Module],
        inputs: dict[str, Union[torch.Tensor, Any]],
        return_outputs=False,
        num_items_in_batch=None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
        compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

        with compute_loss_context_manager:
            loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")

        # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
        loss = loss.to(self.args.device)

        # force log the metrics
        self.store_metrics(metrics, train_eval="train")

        if return_outputs:
            return (loss, metrics)
        return loss

    def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
        """Generate samples from the model and reference model for the given batch of inputs."""

        # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
        # the torch cuda amp context manager as some hidden states are silently casted to full precision.
        generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

        with generate_context_manager:
            policy_output = model.generate(
                input_ids=batch["prompt_input_ids"],
                attention_mask=batch["prompt_attention_mask"],
                max_length=self.max_length,
                do_sample=True,
                pad_token_id=self.processing_class.pad_token_id,
            )

        policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
        policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)

        return policy_output_decoded

    def prediction_step(
        self,
        model: Union[PreTrainedModel, nn.Module],
        inputs: dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[list[str]] = None,
    ):
        if not self.use_dpo_data_collator:
            warnings.warn(
                "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
                "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
            )
        if ignore_keys is None:
            if hasattr(model, "config"):
                ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []

        prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

        with torch.no_grad(), prediction_context_manager:
            loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")

        # force log the metrics
        self.store_metrics(metrics, train_eval="eval")

        if prediction_loss_only:
            return (loss.detach(), None, None)

        # logits for the chosen and rejected samples from model
        logits_dict = {
            "eval_logits/chosen": metrics["eval_logits/chosen"],
            "eval_logits/rejected": metrics["eval_logits/rejected"],
        }
        logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
        logits = torch.tensor(logits, device=self.accelerator.device)
        labels = torch.zeros(logits.shape[0], device=self.accelerator.device)

        return (loss.detach(), logits, labels)

    def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
        for key, value in metrics.items():
            self._stored_metrics[train_eval][key].append(value)

    def evaluation_loop(
        self,
        dataloader: data.DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[list[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> EvalLoopOutput:
        """
        Overriding built-in evaluation loop to store metrics for each batch.
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.

        Works both with or without labels.
        """

        # Sample and save to game log if requested (for one batch to save time)
        if self.generate_during_eval:
            # Generate random indices within the range of the total number of samples
            num_samples = len(dataloader.dataset)
            random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)

            # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
            random_batch_dataset = dataloader.dataset.select(random_indices)
            random_batch = self.data_collator(random_batch_dataset)
            random_batch = self._prepare_inputs(random_batch)

            policy_output_decoded = self.generate_from_model(self.model, random_batch)

            table = pd.DataFrame(
                columns=["Prompt", "Policy"],
                data=[
                    [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
                ],
            )
            if "wandb" in self.args.report_to:
                wandb.log({"game_log": wandb.Table(data=table)})

            if "comet_ml" in self.args.report_to:
                log_table_to_comet_experiment(
                    name="game_log.csv",
                    table=table,
                )

        # Base evaluation
        initial_output = super().evaluation_loop(
            dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
        )

        return initial_output

    def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
        """
        Log `logs` on the various objects watching training, including stored metrics.

        Args:
            logs (`dict[str, float]`):
                The values to log.
            start_time (`float` or `None`, *optional*, defaults to `None`):
                Start time of the training.
        """
        # logs either has 'loss' or 'eval_loss'
        train_eval = "train" if "loss" in logs else "eval"
        # Add averaged stored metrics to logs
        for key, metrics in self._stored_metrics[train_eval].items():
            logs[key] = torch.tensor(metrics).mean().item()
        del self._stored_metrics[train_eval]

        if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
            return super().log(logs, start_time)
        else:  # transformers<=4.46
            return super().log(logs)

    def _shift_right(self, input_ids):
        if self.decoder_start_token_id is None:
            raise ValueError(
                "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
            )

        # shift inputs to the right
        if is_torch_fx_proxy(input_ids):
            # Item assignment is not supported natively for proxies.
            shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
            shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
        else:
            shifted_input_ids = input_ids.new_zeros(input_ids.shape)
            shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
            shifted_input_ids[..., 0] = self.decoder_start_token_id

        if self.pad_token_id is None:
            raise ValueError("model.config.pad_token_id has to be defined.")
        # replace possible -100 values in labels by `pad_token_id`
        shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)

        return shifted_input_ids

    def create_model_card(
        self,
        model_name: Optional[str] = None,
        dataset_name: Optional[str] = None,
        tags: Union[str, list[str], None] = None,
    ):
        """
        Creates a draft of a model card using the information available to the `Trainer`.

        Args:
            model_name (`str` or `None`, *optional*, defaults to `None`):
                Name of the model.
            dataset_name (`str` or `None`, *optional*, defaults to `None`):
                Name of the dataset used for training.
            tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
                Tags to be associated with the model card.
        """
        if not self.is_world_process_zero():
            return

        if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
            base_model = self.model.config._name_or_path
        else:
            base_model = None

        tags = tags or []
        if isinstance(tags, str):
            tags = [tags]

        if hasattr(self.model.config, "unsloth_version"):
            tags.append("unsloth")

        citation = textwrap.dedent("""\
        @article{hong2024orpo,
            title        = {{ORPO: Monolithic Preference Optimization without Reference Model}},
            author       = {Jiwoo Hong and Noah Lee and James Thorne},
            year         = 2024,
            eprint       = {arXiv:2403.07691}
        }""")

        model_card = generate_model_card(
            base_model=base_model,
            model_name=model_name,
            hub_model_id=self.hub_model_id,
            dataset_name=dataset_name,
            tags=tags,
            wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
            comet_url=get_comet_experiment_url(),
            trainer_name="ORPO",
            trainer_citation=citation,
            paper_title="ORPO: Monolithic Preference Optimization without Reference Model",
            paper_id="2403.07691",
        )

        model_card.save(os.path.join(self.args.output_dir, "README.md"))

In [None]:
ORPOTrainer = HijackedORPOTrainer

### **Train**

In [None]:
trainer = ORPOTrainer(
    # The model to be trained
    model=gpt2chat_model,
    # Training configuration from above
    args=training_args,
    # Dataset containing preferred/rejected response pairs
    train_dataset=preference_dataset,
    # Tokenizer for processing inputs
    processing_class=gpt2chat_tokenizer,
)

Map:   0%|          | 0/62135 [00:00<?, ? examples/s]

Map:   0%|          | 0/62135 [00:00<?, ? examples/s]

Map:   0%|          | 0/62135 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (2167 > 2048). Running this sequence through the model will result in indexing errors


In [None]:
try:
    # Train the model
    # https://discuss.huggingface.co/t/resuming-training-there-were-missing-keys-in-the-checkpoint-model-loaded-lm-head-weight/103831/3
    trainer.train(
        resume_from_checkpoint=False if not preference_resume_path else True,
    )

    # Save the model
    trainer.save_model(
        os.path.join(MODEL_DIR, MODEL_NAME + "-orpo")
    )
except (Exception, KeyboardInterrupt) as e:
    print(f"Could not train the model due to {e}.")

There were missing keys in the checkpoint model loaded: ['lm_head.weight'].


Step,Training Loss
3000,9.4009
3600,9.3284
4200,9.2888
4800,9.2241
5400,9.2769
6000,9.2018
6600,9.2214
7200,9.3028
7800,9.2392
8400,9.1549


Step,Training Loss
3000,9.4009
3600,9.3284
4200,9.2888
4800,9.2241
5400,9.2769
6000,9.2018
6600,9.2214
7200,9.3028
7800,9.2392
8400,9.1549


## **Evaluation**

### **Utils**

In [None]:
def measure_tokens_per_second(model, tokenizer, dataset, max_new_tokens=120):
    """Measures the tokens per second of a model on a dataset."""

    tok_per_sec = list()

    ending_template = list(
        tokenizer(
            "<|eot_id|>\n<|start_header_id|>assistant"
            "<|end_header_id|>\n",
        )['input_ids']
    )

    for prompt in tqdm(dataset, desc="Measuring tokens/sec"):

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        inputs = tokenizer(
            prompt,
            truncation=True,
            max_length=MAX_SEQ_LENGTH - max_new_tokens,
            return_tensors="pt",
            return_attention_mask=True,
        )
        inputs = {
            key: tensor.to(device) for key, tensor in inputs.items()
        }

        for id, token_id in enumerate(ending_template):
            inputs['input_ids'][
                :,
                id - len(ending_template)
            ] = token_id

        max_new_tokens = min(
            max(
                0,
                MAX_SEQ_LENGTH - inputs['input_ids'].shape[-1],
            ),
            max_new_tokens,
        )

        try:
            start_time = time.time()
            with torch.no_grad(): # Disable gradient calculations for inference
                output_sequences = model.generate(
                    input_ids=inputs['input_ids'],
                    attention_mask=inputs["attention_mask"], # Pass attention mask
                    max_new_tokens=max_new_tokens,           # Max tokens to generate *after* the prompt
                    do_sample=True,                          # Use sampling for more varied output
                    temperature=0.69,                        # Controls randomness (lower -> more focused)
                    top_k=42,                                # Consider only top K tokens for sampling
                    top_p=0.96,                              # Nucleus sampling: consider tokens summing up to P probability
                    num_return_sequences=1,                  # Number of responses to generate
                    pad_token_id=tokenizer.pad_token_id,     # Set pad token ID for generation
                )
            end_time = time.time()
            elapsed_time = end_time - start_time
        except Exception as e:
            print(f"Error processing example: {e}")
            continue

        tokens_per_second = len(
            output_sequences[0][inputs['input_ids'].shape[-1]:]
        ) / elapsed_time
        tok_per_sec.append(tokens_per_second)

    return np.mean(np.array(tok_per_sec))

In [None]:
def calculate_perplexity(model, tokenizer, dataset):
    """Calculates the perplexity of a model on a dataset."""

    ppl = list()

    for prompt in tqdm(dataset, desc="Calculating perplexity"):

        log_likelihood = 0
        num_tokens = 0

        try:
            input_ids = tokenizer(
                prompt,
                truncation=True,
                max_length=MAX_SEQ_LENGTH,
                return_tensors="pt",
                return_attention_mask=True,
            )

            input_ids = {
                key: tensor.to(model.device) for key, tensor in input_ids.items()
            }

            if len(input_ids['input_ids']) >= MAX_SEQ_LENGTH:
                input_ids['input_ids'][-1] = tokenizer.eos_token_id

            try:
                outputs = model(
                    input_ids=input_ids['input_ids'],
                    labels=input_ids['input_ids'],
                    attention_mask=input_ids['attention_mask'],
                )
            except Exception as e:
                print(f"Error processing example: {e}")
                continue

            # Cross-entropy loss
            loss = outputs.loss
            # Accumulate log-likelihood
            log_likelihood += loss.item() * input_ids['input_ids'].shape[1]
            num_tokens += input_ids['input_ids'].shape[1]
        except Exception as e:
            print(f"Error processing example: {e}")
            continue

        perplexity = math.exp(log_likelihood / num_tokens)
        ppl.append(perplexity)

    return np.mean(np.array(ppl))

In [None]:
def measure_flops_per_token(model, tokenizer, dataset, max_new_tokens=120):
    """Measures the FLOPs per token of a model on a dataset."""

    flops_per_token = list()

    ending_template = list(
        tokenizer(
            "<|eot_id|>\n<|start_header_id|>assistant"
            "<|end_header_id|>\n",
        )['input_ids']
    )

    for prompt in tqdm(dataset, desc="Measuring FLOPs/token"):

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        inputs = tokenizer(
            prompt,
            truncation=True,
            max_length=MAX_SEQ_LENGTH - max_new_tokens,
            return_tensors="pt",
            return_attention_mask=True,
        )
        inputs = {
            key: tensor.to(device) for key, tensor in inputs.items()
        }

        for id, token_id in enumerate(ending_template):
            inputs['input_ids'][
                :,
                id - len(ending_template)
            ] = token_id

        max_new_tokens = min(
            max(
                0,
                MAX_SEQ_LENGTH - inputs['input_ids'].shape[-1],
            ),
            max_new_tokens,
        )

        # Use a profiler to measure FLOPs
        try:
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)
            start_event.record()
            with torch.no_grad():
                _ = model.generate(
                    input_ids=inputs['input_ids'],
                    attention_mask=inputs["attention_mask"], # Pass attention mask
                    max_new_tokens=max_new_tokens,           # Max tokens to generate *after* the prompt
                    do_sample=True,                          # Use sampling for more varied output
                    temperature=0.69,                        # Controls randomness (lower -> more focused)
                    top_k=42,                                # Consider only top K tokens for sampling
                    top_p=0.96,                              # Nucleus sampling: consider tokens summing up to P probability
                    num_return_sequences=1,                  # Number of responses to generate
                    pad_token_id=tokenizer.pad_token_id,     # Set pad token ID for generation

                )
            end_event.record()
            torch.cuda.synchronize()  # Wait for GPU operations to finish
        except Exception as e:
            end_event.record()
            torch.cuda.synchronize()
            continue

        elapsed_time_ms = start_event.elapsed_time(end_event)
        elapsed_time_s = elapsed_time_ms / 1000

        # Estimate FLOPs (This is a rough estimate)
        num_params = sum(p.numel() for p in model.parameters())
        # Simplified FLOP estimation
        estimated_flops = num_params * max_new_tokens

        flops_per_token.append(estimated_flops / max_new_tokens)

    return np.mean(np.array(flops_per_token))

### **Result**

In [None]:
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME + "-orpo")

eval_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {eval_device}")
eval_precision = (
    torch.float16
    if eval_device == "cuda"
    else torch.float32
)

try:
    gpt2chat_model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=eval_precision,
    )
    gpt2chat_tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
except Exception as e:
    print(f"Could not load the model due to {e}.")
    print("Registering GPT2Chat...")
    GPT2LMHeadModel.config_class = GPT2ChatConfig
    AutoConfig.register(f"{BASE_MODEL}chat", GPT2ChatConfig)
    AutoModelForCausalLM.register(GPT2ChatConfig, GPT2LMHeadModel)
    print("GPT2Chat registered.")

    gpt2chat_model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=eval_precision,
    )
    gpt2chat_tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

gpt2chat_model.to(eval_device)
gpt2chat_model.eval()
print("Model and tokenizer loaded successfully.")

eval_dataset_raw = load_dataset(EVAL_DATASET, split=EVAL_SUBSET)
eval_conversations_raw = build_conversation_threads(eval_dataset_raw)

eval_conversations = list()
eval_conversations_ppl = list()

for example in eval_conversations_raw:
    example = example[:len(example) - (len(example) % 2)]
    if not len(example):
        continue
    if len(example) > 2:
        example.insert(
            0,
            {
                "role": "start_context",
                "content": "",
            },
        )
        example.insert(
            -2,
            {
                "role": "end_context",
                "content": "",
            },
        )
    conversation_ppl = gpt2chat_tokenizer.apply_chat_template(
        example,
        tokenize=False,
        add_generation_prompt=False,
    )
    eval_conversations_ppl.append(
        gpt2chat_tokenizer.bos_token
        + conversation_ppl
        + gpt2chat_tokenizer.eos_token
    )
    conversation = gpt2chat_tokenizer.apply_chat_template(
        example[:-1],
        tokenize=False,
        add_generation_prompt=True,
    )
    eval_conversations.append(conversation[:-12])

try:
    tokens_per_second = measure_tokens_per_second(
        gpt2chat_model, gpt2chat_tokenizer, eval_conversations
    )
    print(f"\nGPT2Chat Tokens/second: {tokens_per_second}\n")

    perplexity = calculate_perplexity(
        gpt2chat_model, gpt2chat_tokenizer, eval_conversations_ppl
    )
    print(f"\nGPT2Chat Perplexity: {perplexity}\n")

    flops_per_token = measure_flops_per_token(
        gpt2chat_model, gpt2chat_tokenizer, eval_conversations,
    )
    print(f"\nGPT2Chat FLOPs/token: {flops_per_token}")
except (Exception, KeyboardInterrupt) as e:
    print(f"Could not measure the model due to {e}.")

Using device: cuda
Model and tokenizer loaded successfully.


Traversing the message tree:   0%|          | 0/191 [00:00<?, ?it/s]

Building conversations:   0%|          | 0/191 [00:00<?, ?it/s]

Measuring tokens/sec:   0%|          | 0/1831 [00:00<?, ?it/s]


GPT2Chat Tokens/second: 90.21437679291938



Calculating perplexity:   0%|          | 0/1831 [00:00<?, ?it/s]

`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.



GPT2Chat Perplexity: 12.88924653326722



Measuring FLOPs/token:   0%|          | 0/1831 [00:00<?, ?it/s]


GPT2Chat FLOPs/token: 125231616.0


## **Inference**

### **Initialize**

In [None]:
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME + "-orpo")

INFERENCE_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {INFERENCE_DEVICE}")

INFERENCE_PRECISION = (
    torch.float16
    if INFERENCE_DEVICE == "cuda"
    else torch.float32
)

print(f"Loading model from: {MODEL_PATH}")
try:
    INFERENCE_MODEL = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=INFERENCE_PRECISION,
    )
    INFERENCE_TOKENIZER = AutoTokenizer.from_pretrained(MODEL_PATH)
except Exception as e:
    print(f"Could not load the model due to {e}.")

    print("Registering GPT2Chat...")

    GPT2LMHeadModel.config_class = GPT2ChatConfig
    AutoConfig.register(f"{BASE_MODEL}chat", GPT2ChatConfig)
    AutoModelForCausalLM.register(GPT2ChatConfig, GPT2LMHeadModel)

    print("GPT2Chat registered.")

    try:
        INFERENCE_MODEL = AutoModelForCausalLM.from_pretrained(
            MODEL_PATH,
            torch_dtype=INFERENCE_PRECISION,
        )
        INFERENCE_TOKENIZER = AutoTokenizer.from_pretrained(MODEL_PATH)
    except Exception as e:
        print(f"Could not load the model due to {e}.")
        INFERENCE_MODEL = AutoModelForCausalLM.from_pretrained(
            MODEL_PATH,
            torch_dtype=torch.float32,
        )
        INFERENCE_TOKENIZER = AutoTokenizer.from_pretrained(MODEL_PATH)

INFERENCE_MODEL.to(INFERENCE_DEVICE)
INFERENCE_MODEL.eval()
print("Model and tokenizer loaded successfully.")

Using device: cpu
Loading model from: /content/drive/MyDrive/GPT2Chatbot/model/gpt2chat-finetuned-lightning-orpo
Model and tokenizer loaded successfully.


### **Live Conversation**

In [None]:
# @markdown # **Let's Talk to GPT2Chat!**
# @markdown To end the conversation, please say: "exit", "quit", or "bye".
# @markdown * * *
# @markdown ### **Conversation Setting**
# @markdown Apply before run!
top_k = 36 # @param {type:"integer"}
top_p = 0.96 # @param {type:"slider", min:0.00, max:1.0, step:0.01}
temperature = 0.72  #@param {type:"slider", min:0.00, max:2.0, step:0.01}

context = list()

console = Console()

keep_n_conversation = 2
max_new_tokens = 120
ending_template = list(
    INFERENCE_TOKENIZER(
        "<|eot_id|>\n<|start_header_id|>assistant"
        "<|end_header_id|>\n"
    )['input_ids']
)


try:
    while True:
        messages = list()

        console.clear()
        user_input = Prompt.ask("[bold bright_magenta]User[/bold bright_magenta]")
        os.system('cls' if os.name == 'nt' else 'clear')

        print()
        print()

        if user_input.lower() in ["exit", "quit", "bye"]:
            console.print(Panel(
                "Ok, bye!",
                title="GPT2Chat",
                border_style="bold bright_green",
                title_align="center",
                padding=(1, 2)
            ))
            print()
            print()
            console.clear()
            os.system('cls' if os.name == 'nt' else 'clear')
            break

        if len(context) > 2 * keep_n_conversation:
            context = context[-(2 * keep_n_conversation):]

        messages.extend([
            {
                "role": "start_context",
                "content": "",
            },
        ])
        messages.extend(context)
        messages.extend([
            {
                "role": "end_context",
                "content": "",
            },
        ])
        messages.extend([
            {"role": "user", "content": user_input},
        ])

        prompt = get_chat_template(
            is_jinja=False,
            messages=messages,
        )
        prompt = (
            INFERENCE_TOKENIZER.bos_token
            + prompt
        )

        inputs = INFERENCE_TOKENIZER(
            prompt,
            truncation=True,
            max_length=MAX_SEQ_LENGTH - max_new_tokens,
            return_tensors="pt",
            return_attention_mask=True,
        )
        inputs = {
            key: tensor.to(INFERENCE_DEVICE) for key, tensor in inputs.items()
        }

        for id, token_id in enumerate(ending_template):
            inputs['input_ids'][
                :,
                id - len(ending_template)
            ] = token_id

        max_new_tokens = min(
            max(
                0,
                MAX_SEQ_LENGTH - inputs['input_ids'].shape[-1],
            ),
            max_new_tokens,
        )

        with torch.no_grad(): # Disable gradient calculations for inference
            output_sequences = INFERENCE_MODEL.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs["attention_mask"], # Pass attention mask
                max_new_tokens=max_new_tokens, # Max tokens to generate *after* the prompt
                do_sample=True,             # Use sampling for more varied output
                temperature=temperature,    # Controls randomness (lower -> more focused)
                top_k=top_k,                # Consider only top K tokens for sampling
                top_p=top_p,                 # Nucleus sampling: consider tokens summing up to P probability
                num_return_sequences=1,     # Number of responses to generate
                pad_token_id=INFERENCE_TOKENIZER.pad_token_id # Set pad token ID for generation
            )

        assistant_response = INFERENCE_TOKENIZER.decode(output_sequences[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)

        console.print(Panel(
            f"{assistant_response}",
            title="GPT2Chat",
            border_style="bold bright_green",
            title_align="center",
            padding=(1, 2)
        ))
        print()
        print()

        context.extend([
            {
                "role": "user",
                "content": user_input,
            },
            {
                "role": "assistant",
                "content": assistant_response,
            },
        ])

except (KeyboardInterrupt, Exception) as e:
    print(f"\nStopping conversation. Reason: {e}")
    console.clear()
    os.system('cls' if os.name == 'nt' else 'clear')


torch.cuda.empty_cache()

Tell me a joke about programming!








Write a story about this joke!








Write a poem about this story!








Write a letter to a friend about this poem!








It seems you forget about the greeting and introduction!








Uhh, Okay? I guess.








You seem unwell. Let's wrap it up!








Ok, thanks! Good bye! See you later!








bye








### **Chatbot**

#### **LangChain**

In [None]:
class GPT2Chat(LLM):
    max_length_response: int = 4096 # Telegram maximum text length
    max_new_tokens: int = 120
    ending_template: list = list(
        INFERENCE_TOKENIZER(
            "<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n"
        )['input_ids']
    )

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """Run the LLM on the given input.

        Override this method to implement the LLM logic.

        Args:
            prompt: The prompt to generate from.
            stop: Stop words to use when generating. Model output is cut off at
                the first occurrence of any of the stop substrings.
                If stop tokens are not supported consider raising
                NotImplementedError.
            run_manager: Callback manager for the run.
            **kwargs: Arbitrary additional keyword arguments. These are usually
                passed to the model provider API call.

        Returns:
            The model output as a string. Actual completions SHOULD NOT include
            the prompt.
        """

        inputs = INFERENCE_TOKENIZER(
            prompt,
            truncation=True,
            max_length=MAX_SEQ_LENGTH - self.max_new_tokens,
            return_tensors="pt",
            return_attention_mask=True,
        )
        inputs = {
            key: tensor.to(INFERENCE_DEVICE) for key, tensor in inputs.items()
        }

        for id, token_id in enumerate(self.ending_template):
            inputs['input_ids'][
                :,
                id - len(self.ending_template)
            ] = token_id

        num_new_tokens = min(
            max(
                0,
                MAX_SEQ_LENGTH - inputs['input_ids'].shape[-1],
            ),
            self.max_new_tokens,
        )

        with torch.no_grad():
            try:
                output_sequences = INFERENCE_MODEL.generate(
                    input_ids=inputs['input_ids'],
                    attention_mask=inputs["attention_mask"],
                    max_new_tokens=num_new_tokens,
                    do_sample=True,
                    temperature=0.72,
                    top_k=36,
                    top_p=0.96,
                    num_return_sequences=1,
                    pad_token_id=INFERENCE_TOKENIZER.pad_token_id
                )
            except Exception as e:
                print(f"Error generating output: {e}")
                raise e

        response = INFERENCE_TOKENIZER.decode(
            output_sequences[0][inputs['input_ids'].shape[-1]:],
            skip_special_tokens=True,
        )

        return response[:self.max_length_response]

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """Return a dictionary of identifying parameters."""
        return {
            # The model name allows users to specify custom token counting
            # rules in LLM monitoring applications (e.g., in LangSmith users
            # can provide per token pricing for their model and monitor
            # costs for the given LLM.)
            "model_name": "GPT2Chat",
        }

    @property
    def _llm_type(self) -> str:
        """Get the type of language model used by this chat model.
        Used for logging purposes only."""
        return "custom"

In [None]:
# Define a new graph
workflow = StateGraph(state_schema=MessagesState)

llm = GPT2Chat()


keep_n_conversation = 2


# Define the function that calls the model
def call_model(state: MessagesState):
    template = """<|beginoftext|>
<|start_header_id|>system<|end_header_id|>
You are a helpful assistant who converses with the user.<|eot_id|>
<|start_context|>
{context}
<|end_context|>
<|start_header_id|>user<|end_header_id|>
{user_input}<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
"""

    prompt = PromptTemplate.from_template(template)
    chain = prompt | llm

    messages = list()
    for msg in state['messages']:
        messages.append(msg.content)

    user_input = messages[-1]

    context = "\n"
    messages = messages[-(2 * keep_n_conversation + 1):-1]
    for idx, msg in enumerate(messages):
        role = "user" if idx % 2 == 0 else "assistant"
        context += (
            "<|start_header_id|>"
            + role
            + "<|end_header_id|>\n"
            + msg
            + "<|eot_id|>\n"
        )

    response = chain.invoke(
        {
            'user_input': user_input,
            'context': context,
        }
    )
    return {"messages": response}

# Define the (single) node in the graph
workflow.add_edge(START, "model")
workflow.add_node("model", call_model)

# Add memory
memory = MemorySaver()
app = workflow.compile(checkpointer=memory)

#### **Telegram**

In [None]:
""" Please provide your Telegram bot's API TOKEN in Colab's secret! """

TOKEN = userdata.get("TOKEN")
bot = telebot.TeleBot(TOKEN)

In [None]:
@bot.message_handler(commands=['start'])
def start(message):
    """Send a message when the command /start is issued."""

    bot.reply_to(
        message,
        "GPT2Chat. "
        "Talk something random to me!",
    )

In [None]:
@bot.message_handler(commands=['help'])
def help(message):
    """Send a message when the command /help is issued."""

    bot.reply_to(
        message,
        "Just type and send texts, it will reply.",
    )

In [None]:
@bot.message_handler(func=lambda m: True)
def reply_text(message):
    """Reply text input from the user message."""

    config = {"configurable": {"thread_id": message.chat.id}}

    user_input = message.text

    output = app.invoke({"messages": user_input}, config)
    response = output["messages"][-1].content

    bot.reply_to(message, response)

In [None]:
try:
    bot.polling()
except (Exception, KeyboardInterrupt) as e:
    print(f"Stopping bot. Reason: {e}")