# **Instruction Fine-tuning of the GPT2MoE Model: GPT-2 with Mixture-of-Experts**

## **Important Libraries**

### **Mount Google Drive**

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### **Install**

In [2]:
!curl -LsSf https://astral.sh/uv/install.sh | sh

downloading uv 0.6.17 x86_64-unknown-linux-gnu
no checksums to verify
installing to /usr/local/bin
  uv
  uvx
everything's installed!


In [3]:
# https://github.com/astral-sh/uv/issues/12724
!mkdir /backend-container
!mkdir /backend-container/containers
!touch /backend-container/containers/build.constraints
!touch /backend-container/containers/requirements.constraints

In [4]:
try:
    !uv pip install -q --no-cache-dir --system lightning pytelegrambotapi
    import lightning
    !uv pip install -q --no-cache-dir --system transformers datasets
    import datasets
    !uv pip install -q --no-cache-dir --system langchain langgraph
    import langchain
    del lightning, transformers, datasets, langchain
except Exception as e:
    print(e)
    !pip install -q --no-cache-dir lightning pytelegrambotapi
    !pip install -q --no-cache-dir transformers datasets
    !pip install -q --no-cache-dir langchain langgraph

name 'transformers' is not defined


### **Import**

In [28]:
import lightning as L
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.nn.functional as F

from typing import Any, Dict, List, Optional, Tuple, Union

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.prompts import PromptTemplate

from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, MessagesState, StateGraph

from google.colab.patches import cv2_imshow
from google.colab import userdata

from tqdm.auto import tqdm

from transformers import (
    PreTrainedModel,
    PretrainedConfig,
    DataCollatorForLanguageModeling,
    AutoModelForCausalLM,
    AutoConfig,
    AutoTokenizer,
    GenerationMixin,
    get_cosine_with_hard_restarts_schedule_with_warmup,
)
from transformers.modeling_outputs import (
    MoeModelOutputWithPast,
    MoeCausalLMOutputWithPast
)

import datasets
from datasets import Dataset, load_dataset

from rich.console import Console
from rich.panel import Panel
from rich.prompt import Prompt

import cv2
import telebot
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import copy
import math
import time
import shutil
import random
import warnings

warnings.filterwarnings("ignore")

%matplotlib inline
plt.rcParams['axes.facecolor'] = 'lightgray'
plt.rcParams['mathtext.fontset'] = 'cm'
plt.rcParams['font.family'] = 'STIXGeneral'

## **Configuration**

In [6]:
BATCH_SIZE = 2
ACCUMULATE_GRAD_BATCH = 4  # Increase effective batch size
MAX_SEQ_LENGTH = 1024
LEARNING_RATE = 6.9e-5
ADAM_EPSILON = 1e-8
WEIGHT_DECAY = 1e-2
WARMUP_STEP = 420
NUM_EPOCH = 12
PRECISION = (
    16
    if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    else "16-mixed"
)  # Use '16-mixed' (fp16) or 'bf16-mixed' if supported, else 32
NUM_CYCLE = 3
NUM_GPU = torch.cuda.device_count()  # Use all available GPUs
NUM_WORKER = 4
SAVE_INTERVAL = 1000
LOG_INTERVAL = 1999

In [7]:
BASE_MODEL = "gpt2"
DATASET_NAME = "tatsu-lab/alpaca"
DATASET_SUBSET = "train"
MODEL_NAME = f"{BASE_MODEL}moe-alpaca-finetuned-lightning"

In [8]:
EXPERIMENT_DIR = "/content/drive/MyDrive/GPT2MoEInstruct"
TRAINING_DIR = os.path.join(EXPERIMENT_DIR, "training")
DATASET_DIR = os.path.join(EXPERIMENT_DIR, "dataset")
MODEL_DIR = os.path.join(EXPERIMENT_DIR, "model")

os.makedirs(EXPERIMENT_DIR, exist_ok=True)
os.makedirs(TRAINING_DIR, exist_ok=True)
os.makedirs(DATASET_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

In [9]:
METRIC_TO_MONITOR = "train_loss"
METRIC_MODE = "min"

In [10]:
SEED = int(np.random.randint(2147483647))
print(f"Random seed: {SEED}")

Random seed: 244610225


## **Dataset**

### **Utils**

In [None]:
class InstructionDataModule(L.LightningDataModule):
    def __init__(
        self,
        model_name: str = BASE_MODEL,
        dataset_name: str = DATASET_NAME,
        dataset_subset: str = DATASET_SUBSET,
        batch_size: int = BATCH_SIZE,
        max_seq_length: int = MAX_SEQ_LENGTH,
        num_workers: int = NUM_WORKER,
    ):
        super().__init__()
        self.model_name = model_name
        self.dataset_name = dataset_name
        self.dataset_subset = dataset_subset
        self.batch_size = batch_size
        self.max_seq_length = max_seq_length
        self.num_workers = num_workers

        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name, trust_remote_code=True
        )
        # Set padding token if necessary
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # Define data collator for causal LM. MLM=False prepares
        # labels automatically.
        self.data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer, mlm=False
        )

        self.train_dataset = None

    def prepare_data(self):
        # Download data. This runs only on the main process.
        load_dataset(self.dataset_name, split=self.dataset_subset)

    def setup(self, stage: str = None):
        # Load data and tokenize. This runs on all processes.
        raw_dataset = load_dataset(self.dataset_name, split=self.dataset_subset)

        # Apply the formatting function
        def formatting_prompts_func(example):
            output_texts = list()
            for i in range(len(example["instruction"])):
                text = (
                    f"### Instruction:\n{example['instruction'][i]}\n\n"
                    f"### Input:\n{example['input'][i]}\n\n"
                    f"### Response:\n{example['output'][i]}"
                )
                if example["instruction"][i] and example["output"][i]:
                    output_texts.append(
                        text + self.tokenizer.eos_token # Add EOS token
                    )
                else:
                    pass
            return {"text": output_texts}

        formatted_dataset = raw_dataset.map(
            formatting_prompts_func,
            batched=True,
            remove_columns=raw_dataset.column_names,  # Keep only 'text'
        )

        # Define tokenization function
        def tokenize_function(examples):
            # Tokenize the text column
            return self.tokenizer(
                examples["text"],
                truncation=True,
                padding=False,  # Data collator will handle padding
                max_length=self.max_seq_length,
            )

        # Tokenize the dataset
        tokenized_dataset = formatted_dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=["text"],  # Keep only tokenized columns
        )

        # Assign to train_dataset
        self.train_dataset = tokenized_dataset
        print(
            "Dataset setup complete."
            f"Train dataset size: {len(self.train_dataset)}"
        )

    def train_dataloader(self):
        if self.train_dataset is None:
            raise RuntimeError(
                "Train dataset not initialized. Call setup() first."
            )

        return data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            collate_fn=self.data_collator,  # Use the data collator here!
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,  # Improves performance when using GPUs
        )

### **Load**

In [None]:
DATA_MODULE = InstructionDataModule()

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

## **Model**

### **Utils**

In [11]:
class AvgMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.scores = list()

    def update(self, val):
        self.scores.append(val)

    def show(self):
        scores = torch.stack(self.scores)
        return torch.mean(scores)

In [12]:
@torch.no_grad()
def model_weight_copy(
    source_model: PreTrainedModel,
    target_model: PreTrainedModel,
    model_name: str=BASE_MODEL,
):
    """
    Copy weights from a source model to a target model.

    Args:
        source_model (PreTrainedModel): The source model from which weights are copied.
        target_model (PreTrainedModel): The target model to which weights are copied.
        model_name (str, optional): The name of the model. Defaults to BASE_MODEL.

    Returns:
        PreTrainedModel: The target model with copied weights.
    """

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    source_model.eval()
    target_model.eval()

    source_state_dict = source_model.state_dict()
    target_state_dict = target_model.state_dict()

    copied_keys = 0
    missing_keys = []
    # Start with all keys in target model
    unexpected_keys = list(target_state_dict.keys())

    # Create a new state dict for our model based on source weights
    new_target_state_dict = {}

    for source_key, source_value in source_state_dict.items():
        # Check if the key exists directly in target model
        if source_key in target_state_dict:
            # Check if shapes match
            if target_state_dict[source_key].shape == source_value.shape:
                new_target_state_dict[source_key] = source_value
                copied_keys += 1
                if source_key in unexpected_keys:
                    unexpected_keys.remove(source_key)
            else:
                print(
                    f"Shape mismatch for key: {source_key}. \n"
                    f"Source: {source_value.shape}, \n"
                    f"Target: {target_state_dict[source_key].shape}\n"
                )
                missing_keys.append(f"{source_key} (Shape Mismatch)")
        else:
            missing_keys.append(source_key + " (Not found in target model)")


    print(f"Copied {copied_keys} parameter tensors.")
    if unexpected_keys:
        print(
            "\nWarning: Some keys in the target model were not found in the "
            "source state dict:"
        )
        for key in unexpected_keys:
            print(f" - {key}")
        print()
    if missing_keys:
        print(
            "\nWarning: Some keys from the source state dict were not loaded "
            "into the target model:"
        )
        for key in missing_keys:
            print(f" - {key}")
        print()

    # Load the mapped state dictionary
    try:
        # Use strict=False initially if debugging mismatches
        target_model.load_state_dict(new_target_state_dict, strict=False)
        print("\nWeights loaded successfully!\n")
    except Exception as e:
        print(f"\nError loading state dict: {e}\n")

    # --- Verification Step (Compare Outputs) ---
    print("\nVerifying model outputs...")
    prompt = "Hello, my name is"
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"]

    # Ensure models are on the same device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    source_model.to(device)
    target_model.to(device)
    input_ids = input_ids.to(device)

    # Get outputs from both models
    with torch.no_grad():
        source_logits = source_model(input_ids).logits
        target_logits = target_model(input_ids).logits

    # Compare the logits
    print("Logits shapes:", source_logits.shape, target_logits.shape)

    if source_logits.shape == target_logits.shape:
        # Compare the logits for the last token prediction
        source_last_logit = source_logits[0, -1, :]
        target_last_logit = target_logits[0, -1, :]

        # Calculate difference
        abs_diff = torch.abs(source_last_logit - target_last_logit)
        print(
            "Max absolute difference in last token "
            f"logits: {abs_diff.max().item()}"
        )
        print(
            "Mean absolute difference in last token "
            f"logits: {abs_diff.mean().item()}"
        )

        # Check if they are close enough (allowing for floating point precision
        # differences)
        # Adjust tolerance as needed
        if torch.allclose(source_last_logit, target_last_logit, atol=1e-5):
            print("Verification successful: Outputs are numerically close!\n")
        else:
            warnings.warn(
                "Verification failed: Outputs differ significantly.\n"
            )
    else:
        raise Exception("Verification failed: Output shapes do not match.\n")
        warnings.warn("Verification failed: Output shapes do not match.\n")

    target_model.train()

    return target_model

In [13]:
def load_balancing_loss_func(
    gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
    num_experts: Optional[int] = None,
    top_k=2,
    attention_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, int]:
    """
    Computes auxiliary load balancing loss as in Switch Transformer -
    implemented in Pytorch.

    See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details.
    This function implements the loss function presented in equations (4) - (6)
    of the paper. It aims at penalizing cases where the routing between
    experts is too unbalanced.

    Args:
        gate_logits:
            Logits from the `gate`, should be a tuple of
            model.config.num_hidden_layers tensors of shape
            [batch_size X sequence_length, num_experts].
        num_experts:
            Number of experts
        top_k:
            The number of experts to route per-token, can be also interpreted
            as the `top-k` routing parameter.
        attention_mask (`torch.Tensor`, *optional*):
            The attention_mask used in forward function
            shape [batch_size X sequence_length] if not None.

    Returns:
        The auxiliary loss.
    """
    if gate_logits is None or not isinstance(gate_logits, tuple):
        return 0

    if isinstance(gate_logits, tuple):
        compute_device = gate_logits[0].device
        concatenated_gate_logits = torch.cat(
            [layer_gate.to(compute_device) for layer_gate in gate_logits],
            dim=0,
        )

    routing_weights = F.softmax(concatenated_gate_logits, dim=-1)

    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)

    expert_mask = F.one_hot(selected_experts, num_experts)

    if attention_mask is None:
        # Compute the percentage of tokens routed to each experts
        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)

        # Compute the average probability of routing to these experts
        router_prob_per_expert = torch.mean(routing_weights, dim=0)
    else:
        batch_size, sequence_length = attention_mask.shape
        num_hidden_layers = (
            concatenated_gate_logits.shape[0]
            // (batch_size * sequence_length)
        )

        # Compute the mask that masks all padding tokens as 0 with the same
        # shape of expert_mask
        expert_attention_mask = (
            attention_mask[None, :, :, None, None]
            .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
            .reshape(-1, top_k, num_experts)
            .to(compute_device)
        )

        # Compute the percentage of tokens routed to each experts
        tokens_per_expert = (
            torch.sum(expert_mask.float() * expert_attention_mask, dim=0)
            / torch.sum(expert_attention_mask, dim=0)
        )

        # Compute the mask that masks all padding tokens as 0 with the same
        # shape of tokens_per_expert
        router_per_expert_attention_mask = (
            attention_mask[None, :, :, None]
            .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
            .reshape(-1, num_experts)
            .to(compute_device)
        )

        # Compute the average probability of routing to these experts
        router_prob_per_expert = (
            torch.sum(routing_weights * router_per_expert_attention_mask, dim=0)
            / torch.sum(router_per_expert_attention_mask, dim=0)
        )

    overall_loss = torch.sum(
        tokens_per_expert * router_prob_per_expert.unsqueeze(0)
    )
    return overall_loss * num_experts

### **Module**

#### **GPT2MoEConfig**

In [14]:
class GPT2MoEConfig(PretrainedConfig):
    model_type = "gpt2moe"
    keys_to_ignore_at_inference = ["past_key_values"]
    attribute_map = {
        "hidden_size": "n_embd",
        "max_position_embeddings": "n_positions",
        "num_attention_heads": "n_head",
        "num_hidden_layers": "n_layer",
    }

    def __init__(
        self,
        vocab_size=50257,
        n_positions=1024,
        n_embd=768,
        n_layer=12,
        n_head=12,
        n_inner=None,
        resid_pdrop=0.1,
        embd_pdrop=0.1,
        attn_pdrop=0.1,
        layer_norm_epsilon=1e-5,
        initializer_range=0.02,
        summary_type="cls_index",
        summary_use_proj=True,
        summary_activation=None,
        summary_proj_to_labels=True,
        summary_first_dropout=0.1,
        scale_attn_weights=True,
        use_cache=True,
        bos_token_id=50256,
        eos_token_id=50256,
        scale_attn_by_inverse_layer_idx=False,
        reorder_and_upcast_attn=False,
        n_expert=10,
        top_k_expert=3,
        router_aux_loss_coef=4e-2,
        scale_down_ffn=3,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.n_positions = n_positions
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_inner = n_inner
        self.resid_pdrop = resid_pdrop
        self.embd_pdrop = embd_pdrop
        self.attn_pdrop = attn_pdrop
        self.layer_norm_epsilon = layer_norm_epsilon
        self.initializer_range = initializer_range
        self.summary_type = summary_type
        self.summary_use_proj = summary_use_proj
        self.summary_activation = summary_activation
        self.summary_first_dropout = summary_first_dropout
        self.summary_proj_to_labels = summary_proj_to_labels
        self.scale_attn_weights = scale_attn_weights
        self.use_cache = use_cache
        self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
        self.reorder_and_upcast_attn = reorder_and_upcast_attn

        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id

        self.n_expert = n_expert
        self.top_k_expert = top_k_expert
        self.router_aux_loss_coef = router_aux_loss_coef

        self.scale_down_ffn = scale_down_ffn

        super().__init__(
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            **kwargs,
        )

#### **GPT2MoEPreTrainedModel**

In [15]:
class GPT2MoEPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization.
    """

    config_class = GPT2MoEConfig
    base_model_prefix = "transformer"
    supports_gradient_checkpointing = True
    _no_split_modules = ["GPT2MoEBlock"]
    _skip_keys_device_placement = "past_key_values"
    _supports_flash_attn_2 = True
    _supports_sdpa = True

    def __init__(self, *inputs, **kwargs):
        super().__init__(*inputs, **kwargs)

    def _init_weights(self, module):
        """Initialize the weights."""
        if isinstance(module, (nn.Linear, Conv1D)):
            # Slightly different from the TF version which uses truncated_normal
            # for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(
                mean=0.0, std=self.config.initializer_range
            )
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(
                mean=0.0, std=self.config.initializer_range
            )
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper
        # Scheme:
        #   > A modified initialization which accounts for the accumulation on
        #     the residual path with model depth. Scale
        #   > the weights of residual layers at initialization by a factor of
        #     1/√N where N is the # of residual layers.
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        #
        # Reference (Megatron-LM):
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
        for name, p in module.named_parameters():
            if name == "c_proj.weight":
                # Special Scaled Initialization --> There are 2 Layer Norms
                # per Transformer Block
                p.data.normal_(
                    mean=0.0, std=(
                        self.config.initializer_range
                        / math.sqrt(2 * self.config.n_layer)
                    )
                )

#### **Conv1D**

In [16]:
class Conv1D(nn.Module):
    """
    1D-convolutional layer as defined by Radford et al. for OpenAI GPT
    (and also used in GPT-2).

    Basically works like a linear layer but the weights are transposed.

    Args:
        nf (`int`): The number of output features.
        nx (`int`): The number of input features.
    """

    def __init__(self, nf, nx):
        super().__init__()
        self.nf = nf
        self.nx = nx
        self.weight = nn.Parameter(torch.empty(nx, nf))
        self.bias = nn.Parameter(torch.zeros(nf))
        nn.init.normal_(self.weight, std=0.02)

    def __repr__(self) -> str:
        return "Conv1D(nf={nf}, nx={nx})".format(**self.__dict__)

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(size_out)
        return x

#### **GPT2MoEAttention**

In [17]:
class GPT2MoEAttention(nn.Module):
    def __init__(self, config, layer_idx=None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        max_positions = config.max_position_embeddings
        self.scale_attn_by_inverse_layer_idx = (
            config.scale_attn_by_inverse_layer_idx
        )
        self.register_buffer(
            "bias",
            torch.tril(
                torch.ones(
                    (max_positions, max_positions), dtype=torch.bool
                )
            ).view(
                1, 1, max_positions, max_positions
            ),
            persistent=False,
        )

        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.split_size = self.embed_dim
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                "`embed_dim` must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads})."
            )

        self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
        self.c_proj = Conv1D(self.embed_dim, self.embed_dim)

        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        self.is_causal = True

        self.scale_attn_weights = config.scale_attn_weights

    def attention_forward(
        self,
        query,
        key,
        value,
        attention_mask,
        head_mask=None,
        **kwargs
    ):
        attn_weights = torch.matmul(query, key.transpose(-1, -2))

        if self.scale_attn_weights:
            attn_weights = attn_weights / torch.full(
                [], value.size(-1) ** 0.5,
                dtype=attn_weights.dtype,
                device=attn_weights.device,
            )

        # Layer-wise attention scaling
        if self.scale_attn_by_inverse_layer_idx:
            attn_weights = attn_weights / float(self.layer_idx + 1)

        # implements causal mask
        query_length, key_length = query.size(-2), key.size(-2)
        causal_mask = self.bias[
            :, :, key_length - query_length : key_length, :key_length
        ]
        mask_value = torch.finfo(attn_weights.dtype).min
        # Need to be a tensor, otherwise we get error:
        #     `RuntimeError: expected scalar type float but found double`.
        # Need to be on the same device, otherwise
        #     `RuntimeError: ..., x and y to be on the same device`
        mask_value = torch.full(
            [],
            mask_value,
            dtype=attn_weights.dtype,
            device=attn_weights.device
        )
        attn_weights = torch.where(
            causal_mask, attn_weights.to(attn_weights.dtype), mask_value
        )

        if attention_mask is not None:
            # Apply the attention mask
            attn_weights = attn_weights + attention_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        # Downcast (if necessary) back to V's dtype (if in mixed-precision)
        # -- No-Op otherwise
        attn_weights = attn_weights.type(value.dtype)
        attn_weights = self.attn_dropout(attn_weights)

        # Mask heads if we want to
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_output = torch.matmul(attn_weights, value)
        attn_output = attn_output.transpose(1, 2)

        return attn_output

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        *args,
        **kwargs,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
        query_states, key_states, value_states = self.c_attn(
            hidden_states
        ).split(self.split_size, dim=2)

        shape_q = (*query_states.shape[:-1], -1, self.head_dim)
        shape_kv = (*key_states.shape[:-1], -1, self.head_dim)

        query_states = query_states.view(shape_q).transpose(1, 2)
        key_states = key_states.view(shape_kv).transpose(1, 2)
        value_states = value_states.view(shape_kv).transpose(1, 2)

        if layer_past is not None:
            past_key, past_value = layer_past
            key_states = torch.cat((past_key, key_states), dim=-2)
            value_states = torch.cat((past_value, value_states), dim=-2)

        if use_cache is True:
            present = (key_states, value_states)
        else:
            present = None

        is_causal = attention_mask is None and query_states.shape[-2] > 1

        attn_output = self.attention_forward(
            query_states,
            key_states,
            value_states,
            attention_mask,
            head_mask=head_mask,
            dropout=self.attn_dropout.p if self.training else 0.0,
            is_causal=is_causal,
            **kwargs,
        )

        attn_output = attn_output.reshape(
            *attn_output.shape[:-2], -1
        ).contiguous()
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output)

        outputs = (attn_output, present)

        return outputs

#### **GPT2MoEFeedForward**

In [18]:
class GPT2MoEFeedForward(nn.Module):
    """This class implements the feed-forward network derived from Llama2.
    """
    def __init__(self, intermediate_size, config):
        super().__init__()

        self.config = config

        self.w1 = nn.Linear(config.n_embd, intermediate_size, bias=False)
        self.w2 = nn.Linear(intermediate_size, config.n_embd, bias=False)
        self.w3 = nn.Linear(config.n_embd, intermediate_size, bias=False)
        self.activation = nn.SiLU()
        self.dropout = nn.Dropout(config.resid_pdrop)

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]]
    ) -> torch.FloatTensor:
        hidden_states = self.w2(
            self.activation(self.w1(hidden_states))
            * self.w3(hidden_states)
        )
        hidden_states = self.dropout(hidden_states)
        return hidden_states

#### **Mixture-Of-Experts**

In [19]:
class MixtureOfExperts(nn.Module):
    """This class implements the Mixture-Of-Experts derived from Mixtral.
    """
    def __init__(self, intermediate_size, config):
        super().__init__()
        self.config = config

        self.num_expert = config.n_expert
        self.k = config.top_k_expert

        self.experts = nn.ModuleList([
            GPT2MoEFeedForward(intermediate_size, config)
            for _ in range(self.num_expert)
        ])

        self.gating_network = nn.Linear(
            config.n_embd,
            self.num_expert,
            bias=False,
        )

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]]
    ) -> torch.FloatTensor:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)

        router_logits = self.gating_network(hidden_states)

        router_weights = F.softmax(router_logits, dim=-1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(
            router_weights, k=self.k, dim=-1
        )
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        routing_weights = routing_weights.to(hidden_states.dtype)

        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim),
            dtype=hidden_states.dtype,
            device=hidden_states.device,
        )

        # One hot encode the selected experts to create an expert mask
        # this will be used to easily index which expert is going to be
        # sollicitated
        expert_mask = torch.nn.functional.one_hot(
            selected_experts,
            num_classes=self.num_expert
        ).permute(2, 1, 0)

        expert_hitted = (
            expert_mask.sum(dim=(-1, -2)) > 0
        ).nonzero(as_tuple=True)[0].tolist()
        for expert_idx in expert_hitted:
            expert_layer = self.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])
            # Index the correct hidden states and compute the expert hidden
            # state for the current expert. We need to make sure to multiply
            # the output hidden states by `routing_weights` on the
            # corresponding tokens (top-1 and top-2)
            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
            current_hidden_states = (
                expert_layer(current_state)
                * routing_weights[top_x, idx, None]
            )

            # However `index_add_` only support torch tensors for indexing so
            # we'll use the `top_x` tensor here.
            final_hidden_states.index_add_(
                0, top_x, current_hidden_states.to(hidden_states.dtype)
            )

        final_hidden_states = final_hidden_states.reshape(
            batch_size, sequence_length, hidden_dim
        )
        return final_hidden_states, router_logits

#### **GPT2MoEBlock**

In [20]:
class GPT2MoEBlock(nn.Module):
    def __init__(self, config, layer_idx=None):
        super().__init__()
        hidden_size = config.hidden_size
        inner_dim = (
            config.n_inner
            if config.n_inner is not None
            else hidden_size // config.scale_down_ffn
        )

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.attn = GPT2MoEAttention(config=config, layer_idx=layer_idx)
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

        self.moe = MixtureOfExperts(inner_dim, config)

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_router_logits: Optional[bool] = None,
        use_cache: Optional[bool] = False,
        *args,
        **kwargs,
    ) -> Union[
        Tuple[torch.Tensor],
        Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]
    ]:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_outputs = self.attn(
            hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
        )
        attn_output = attn_outputs[0]
        outputs = attn_outputs[1:]
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states, router_logits = self.moe(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states

        if use_cache:
            outputs = (hidden_states,) + outputs
        else:
            outputs = (hidden_states,) + outputs[1:]

        if output_router_logits:
            outputs = outputs + (router_logits,)

        return outputs

#### **GPT2MoEModel**

In [21]:
class GPT2MoEModel(GPT2MoEPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.embed_dim = config.hidden_size

        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)

        self.drop = nn.Dropout(config.embd_pdrop)
        self.h = nn.ModuleList(
            [
                GPT2MoEBlock(config, layer_idx=i)
                for i in range(config.num_hidden_layers)
            ]
        )
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.wte

    def set_input_embeddings(self, new_embeddings):
        self.wte = new_embeddings

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        output_router_logits: Optional[bool] = None,
        use_cache: Optional[bool] = None,
        *args,
        **kwargs,
    ) -> Union[Tuple, MoeModelOutputWithPast]:
        use_cache = (
            use_cache
            if use_cache is not None
            else self.config.use_cache
        )

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the "
                "same time"
            )
        elif input_ids is not None:
            self.warn_if_padding_and_no_attention_mask(
                input_ids, attention_mask
            )
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
            batch_size = input_ids.shape[0]
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
            batch_size = inputs_embeds.shape[0]
        else:
            raise ValueError(
                "You have to specify either input_ids or inputs_embeds"
            )

        device = (
            input_ids.device
            if input_ids is not None
            else inputs_embeds.device
        )

        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, input_shape[-1])

        if past_key_values is None:
            past_length = 0
            past_key_values = tuple([None] * len(self.h))
        else:
            past_length = past_key_values[0][0].size(-2)

        position_ids = torch.arange(
            past_length,
            input_shape[-1] + past_length,
            dtype=torch.long,
            device=device,
        )
        position_ids = position_ids.unsqueeze(0)

        if inputs_embeds is None:
            inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)

        # Attention mask.
        if attention_mask is not None:
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to
            #     [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of
            # causal attention used in OpenAI GPT, we just need to prepare the
            # broadcast dimension here.
            attention_mask = attention_mask[:, None, None, :]

            # Since attention_mask is 1.0 for positions we want to attend
            # and 0.0 for masked positions, this operation will create a tensor
            # which is 0.0 for positions we want to attend and the dtype's
            # smallest value for masked positions.
            # Since we are adding it to the raw scores before the softmax,
            # this is effectively the same as removing these entirely.
            # fp16 compatibility
            attention_mask = attention_mask.to(dtype=self.dtype)
            attention_mask = (
                (1.0 - attention_mask)
                * torch.finfo(self.dtype).min
            )

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # head_mask has shape n_layer x batch x n_heads x N x N
        head_mask = self.get_head_mask(head_mask, self.config.n_layer)

        if token_type_ids is not None:
            token_type_embeds = self.wte(token_type_ids)
            hidden_states = hidden_states + token_type_embeds

        hidden_states = self.drop(hidden_states)

        output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)

        presents = () if use_cache else None
        all_router_logits = () if output_router_logits else None

        for i in range(len(self.h)):
            block, layer_past = self.h[i], past_key_values[i]

            outputs = block(
                hidden_states,
                layer_past=layer_past,
                attention_mask=attention_mask,
                head_mask=head_mask[i],
                output_router_logits=output_router_logits,
                use_cache=use_cache,
            )

            hidden_states = outputs[0]
            if use_cache is True:
                presents = presents + (outputs[1],)

            if output_router_logits:
                all_router_logits += (outputs[-1],)

        hidden_states = self.ln_f(hidden_states)

        hidden_states = hidden_states.view(output_shape)

        return MoeModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=presents,
            router_logits=all_router_logits,
        )

#### **GPT2MoEForCausalLM**

In [22]:
class GPT2MoEForCausalLM(GPT2MoEPreTrainedModel, GenerationMixin):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.transformer = GPT2MoEModel(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.num_expert = config.n_expert
        self.k = config.top_k_expert
        self.router_aux_loss_coef = config.router_aux_loss_coef

        # Initialize weights and apply final processing
        self.post_init()

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_router_logits: Optional[bool] = None,
        use_cache: Optional[bool] = None,
        *args,
        **kwargs,
    ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
        r"""
        labels (
          `torch.LongTensor` of shape `(batch_size, sequence_length)`,
          *optional*
        ):
            Labels for language modeling. Note that the labels **are shifted**
            inside the model, i.e. you can set `labels = input_ids` Indices are
            selected in `[-100, 0, ..., config.vocab_size]` All labels set
            to `-100` are ignored (masked), the loss is only computed for
            labels in `[0, ..., config.vocab_size]`
        """

        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_router_logits=output_router_logits,
            use_cache=use_cache,
        )
        hidden_states = transformer_outputs[0]

        lm_logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # Flatten the tokens
            loss = self.loss_function(
                lm_logits,
                labels,
                vocab_size=self.config.vocab_size,
                **kwargs,
            )

        aux_loss = None
        if output_router_logits:
            aux_loss = load_balancing_loss_func(
                transformer_outputs.router_logits,
                self.num_expert,
                self.k,
                attention_mask,
            )

            if labels is not None:
                loss += self.router_aux_loss_coef * aux_loss.to(loss.device)

        return MoeCausalLMOutputWithPast(
            loss=loss,
            aux_loss=aux_loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            router_logits=transformer_outputs.router_logits,
        )

### **Wrapper**

In [23]:
class GPT2MoEFineTuner(L.LightningModule):
    def __init__(
        self,
        model_name: str = BASE_MODEL,
        learning_rate: float = LEARNING_RATE,
        adam_epsilon: float = ADAM_EPSILON,
        weight_decay: float = WEIGHT_DECAY,
        warmup_steps: int = WARMUP_STEP,
        num_cycles: int = NUM_CYCLE,
        log_interval: int = LOG_INTERVAL,
    ):
        super().__init__()

        self.model_name = model_name
        self.learning_rate = learning_rate
        self.adam_epsilon = adam_epsilon
        self.warmup_steps = warmup_steps
        self.weight_decay = weight_decay
        self.num_cycles = num_cycles
        self.log_interval = log_interval

        print(f"Loading base model: {self.model_name}")

        source_model = AutoModelForCausalLM.from_pretrained(self.model_name)

        target_config = GPT2MoEConfig(
            block_size=source_model.config.n_positions,
            vocab_size=source_model.config.vocab_size,
            n_layer=source_model.config.n_layer,
            n_head=source_model.config.n_head,
            n_embd=source_model.config.n_embd,
        )
        target_model = GPT2MoEForCausalLM(target_config)
        self.model = model_weight_copy(
            source_model,
            target_model,
            self.model_name,
        )

        self.train_loss = list()
        self.train_loss_recorder = AvgMeter()

        self.steps = list()

    def forward(
        self,
        input_ids,
        attention_mask=None,
        labels=None,
        output_router_logits=None,
    ):
        return self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,  # Pass labels to compute the loss
            output_router_logits=output_router_logits,
        )

    def training_step(self, batch, batch_idx):
        # batch comes from the DataLoader, prepared by
        #     DataCollatorForLanguageModeling
        # It should contain 'input_ids', 'attention_mask', and 'labels'
        outputs = self(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
            output_router_logits=True,
        )

        loss = outputs.loss  # Extract the loss from the model outputs

        # Log the training loss
        self.log(
            "train_loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        self.train_loss_recorder.update(loss.data)

        aux_loss = outputs.aux_loss

        if aux_loss is not None:
            self.log(
                "aux_loss",
                aux_loss,
                on_step=True,
                on_epoch=True,
                prog_bar=True,
                logger=True,
            )

        if (batch_idx + 1) % self.log_interval == 0:
            self.train_loss.append(
                self.train_loss_recorder.show().data.cpu().numpy()
            )
            self.train_loss_recorder = AvgMeter()
            self.steps.append(
                self.current_epoch
                * (len(DATA_MODULE.train_dataset) // BATCH_SIZE)
                + batch_idx
            )

        return loss  # Return the loss tensor

    def on_train_epoch_end(self):
        self._save_plot_curves()

    def _save_plot_curves(self):
        # Loss
        img_file = os.path.join(
            TRAINING_DIR,
            f"{MODEL_NAME}_loss_plot.png",
        )
        plt.plot(self.steps, self.train_loss, color="b", label="loss")
        plt.title("Loss Curves")
        plt.xlabel("Step")
        plt.ylabel("Loss")
        plt.legend()
        plt.grid()
        plt.savefig(img_file)
        plt.clf()

    def configure_optimizers(self):
        # Apply weight decay to parameters, excluding bias and LayerNorm weights
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": self.weight_decay,
            },
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
            },
        ]

        # Create optimizer
        optimizer = optim.AdamW(
            optimizer_grouped_parameters,
            lr=self.learning_rate,
            eps=self.adam_epsilon,
        )

        # Create learning rate scheduler (optional)
        # Calculate total training steps if needed for scheduler
        # Available after trainer initialization
        num_training_steps = self.trainer.estimated_stepping_batches

        # LR scheduler with warmup
        scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.warmup_steps,
            num_training_steps=num_training_steps,
            num_cycles=self.num_cycles,
        )

        return [optimizer], [{"scheduler": scheduler, "interval": "step"}]

## **Training**

In [None]:
try:
    seed_everything(SEED, workers=True)


    MODEL_MODULE = GPT2MoEFineTuner()


    callback = list()

    checkpoint_callback = ModelCheckpoint(
        dirpath=MODEL_DIR,
        filename=MODEL_NAME,
        save_top_k=1,
        monitor=METRIC_TO_MONITOR,
        mode=METRIC_MODE,
        save_last=True,
        every_n_train_steps=SAVE_INTERVAL,
    )
    callback.append(checkpoint_callback)

    if os.path.exists(os.path.join(MODEL_DIR, "last.ckpt")):
        os.rename(
            os.path.join(MODEL_DIR, "last.ckpt"),
            os.path.join(MODEL_DIR, f"{MODEL_NAME}.ckpt"),
        )
        CKPT_PATH = os.path.join(MODEL_DIR, f"{MODEL_NAME}.ckpt")
    else:
        CKPT_PATH = None

    trainer = Trainer(
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=NUM_GPU if torch.cuda.is_available() else 1,
        max_epochs=NUM_EPOCH,
        accumulate_grad_batches=ACCUMULATE_GRAD_BATCH,
        precision=PRECISION if torch.cuda.is_available() else 32,
        callbacks=callback,
        log_every_n_steps=20,
        logger=False,
    )

    print("Starting training...")
    trainer.fit(
        model=MODEL_MODULE,
        datamodule=DATA_MODULE,
        ckpt_path=CKPT_PATH,
    )
    print("Training finished.")


    final_model_path = os.path.join(MODEL_DIR, MODEL_NAME)

    print(f"Saving final model to: {final_model_path}")

    MODEL_MODULE.model.save_pretrained(final_model_path)
    DATA_MODULE.tokenizer.save_pretrained(final_model_path)

    print("Model saving complete.")

    print("Registering GPT2MoE-Instruct...")

    AutoConfig.register(f"{BASE_MODEL}moe", GPT2MoEConfig)
    AutoModelForCausalLM.register(GPT2MoEConfig, GPT2MoEForCausalLM)

    print("GPT2MoE-Instruct registered.")

except Exception as e:

    print(f"Training has been stopped. Reason: {e}")


torch.cuda.empty_cache()

INFO: Seed set to 244610225
INFO:lightning.fabric.utilities.seed:Seed set to 244610225


Loading base model: gpt2


Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Copied 101 parameter tensors.

 - transformer.h.0.moe.experts.0.w1.weight
 - transformer.h.0.moe.experts.0.w2.weight
 - transformer.h.0.moe.experts.0.w3.weight
 - transformer.h.0.moe.experts.1.w1.weight
 - transformer.h.0.moe.experts.1.w2.weight
 - transformer.h.0.moe.experts.1.w3.weight
 - transformer.h.0.moe.experts.2.w1.weight
 - transformer.h.0.moe.experts.2.w2.weight
 - transformer.h.0.moe.experts.2.w3.weight
 - transformer.h.0.moe.experts.3.w1.weight
 - transformer.h.0.moe.experts.3.w2.weight
 - transformer.h.0.moe.experts.3.w3.weight
 - transformer.h.0.moe.experts.4.w1.weight
 - transformer.h.0.moe.experts.4.w2.weight
 - transformer.h.0.moe.experts.4.w3.weight
 - transformer.h.0.moe.experts.5.w1.weight
 - transformer.h.0.moe.experts.5.w2.weight
 - transformer.h.0.moe.experts.5.w3.weight
 - transformer.h.0.moe.experts.6.w1.weight
 - transformer.h.0.moe.experts.6.w2.weight
 - transformer.h.0.moe.experts.6.w3.weight
 - transformer.h.0.moe.experts.7.w1.weight
 - transformer.h.0.moe.

INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


Starting training...


README.md:   0%|          | 0.00/7.47k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


(…)-00000-of-00001-a09b74b3ef9c3b56.parquet:   0%|          | 0.00/24.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/52002 [00:00<?, ? examples/s]

Map:   0%|          | 0/52002 [00:00<?, ? examples/s]

Map:   0%|          | 0/51974 [00:00<?, ? examples/s]

INFO: Restoring states from the checkpoint path at /content/drive/MyDrive/GPT2MoEInstruct/model/gpt2moe-alpaca-finetuned-lightning.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at /content/drive/MyDrive/GPT2MoEInstruct/model/gpt2moe-alpaca-finetuned-lightning.ckpt


Dataset setup complete.Train dataset size: 51974


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: Loading `train_dataloader` to estimate number of stepping batches.
INFO:lightning.pytorch.utilities.rank_zero:Loading `train_dataloader` to estimate number of stepping batches.
INFO: 
  | Name  | Type               | Params | Mode 
-----------------------------------------------------
0 | model | GPT2MoEForCausalLM | 138 M  | train
-----------------------------------------------------
138 M     Trainable params
0         Non-trainable params
138 M     Total params
554.567   Total estimated model params size (MB)
860       Modules in train mode
0         Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name  | Type               | Params | Mode 
-----------------------------------------------------
0 | model | GPT2MoEForCausalLM | 138 M  | train
-----------------------------------------------------
138 M     Trainable params
0  

Training: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=12` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=12` reached.


Training finished.
Saving final model to: /content/drive/MyDrive/GPT2MoEInstruct/model/gpt2moe-alpaca-finetuned-lightning
Model saving complete.
Registering GPT2MoE-Instruct...
GPT2MoE-Instruct registered.


<Figure size 640x480 with 0 Axes>

## **Evaluation**

### **Utils**

In [24]:
def measure_tokens_per_second(model, tokenizer, dataset):
    """Measures the tokens per second of a model on a dataset."""

    tok_per_sec = list()

    for example in tqdm(dataset, desc="Measuring tokens/sec"):
        try:
            if not example['prompt']:
                raise ValueError("Prompt is empty or None.")
        except Exception as e:
            print(f"Error processing example: {e}")
            continue

        prompt = (
            f"### Instruction:\n{example['prompt']}\n\n"
            f"### Input:\n''\n\n"
            "### Response:\n"
        )

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        inputs = tokenizer(
            prompt,
            truncation=True,
            max_length=1024,
            return_tensors="pt",
            return_attention_mask=True,
        )
        inputs = {
            key: tensor.to(device) for key, tensor in inputs.items()
        }

        max_new_tokens = min(
            1024 - inputs['input_ids'].shape[-1],
            96,
        )

        start_time = time.time()
        with torch.no_grad(): # Disable gradient calculations for inference
            output_sequences = model.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs["attention_mask"], # Pass attention mask
                max_new_tokens=max_new_tokens,           # Max tokens to generate *after* the prompt
                do_sample=True,                          # Use sampling for more varied output
                temperature=0.69,                        # Controls randomness (lower -> more focused)
                top_k=42,                                # Consider only top K tokens for sampling
                top_p=0.96,                              # Nucleus sampling: consider tokens summing up to P probability
                num_return_sequences=1,                  # Number of responses to generate
                pad_token_id=tokenizer.eos_token_id      # Set pad token ID for generation
            )
        end_time = time.time()
        elapsed_time = end_time - start_time

        tokens_per_second = len(
            output_sequences[0][inputs['input_ids'].shape[-1]:]
        ) / elapsed_time
        tok_per_sec.append(tokens_per_second)

    return np.mean(np.array(tok_per_sec))

In [25]:
def calculate_perplexity(model, tokenizer, dataset):
    """Calculates the perplexity of a model on a dataset."""

    ppl = list()

    for example in tqdm(dataset, desc="Calculating perplexity"):
        try:
            if not example['prompt']:
                raise ValueError("Prompt is empty or None.")

            if not example['completion']:
                raise ValueError("Completion is empty or None.")
        except Exception as e:
            print(f"Error processing example: {e}")
            continue

        log_likelihood = 0
        num_tokens = 0

        prompt = (
            f"### Instruction:\n{example['prompt']}\n\n"
            f"### Input:\n''\n\n"
            f"### Response:\n{example['completion']}"
        )
        try:
            input_ids = tokenizer(
                prompt,
                truncation=True,
                max_length=1024,
                return_tensors="pt",
                return_attention_mask=True,
            )

            input_ids = {
                key: tensor.to(model.device) for key, tensor in input_ids.items()
            }

            outputs = model(
                input_ids=input_ids['input_ids'],
                labels=input_ids['input_ids'],
                attention_mask=input_ids['attention_mask'],
            )
            # Cross-entropy loss
            loss = outputs.loss
            # Accumulate log-likelihood
            log_likelihood += loss.item() * input_ids['input_ids'].shape[1]
            num_tokens += input_ids['input_ids'].shape[1]
        except Exception as e:
            print(f"Error processing example: {e}")
            continue

        perplexity = math.exp(log_likelihood / num_tokens)
        ppl.append(perplexity)

    return np.mean(np.array(ppl))

In [26]:
def measure_flops_per_token(model, tokenizer, dataset):
    """Measures the FLOPs per token of a model on a dataset."""

    flops_per_token = list()
    for example in tqdm(dataset, desc="Measuring FLOPs/token"):
        try:
            if not example['prompt']:
                raise ValueError("Prompt is empty or None.")
        except Exception as e:
            print(f"Error processing example: {e}")
            continue

        prompt = (
            f"### Instruction:\n{example['prompt']}\n\n"
            f"### Input:\n''\n\n"
            "### Response:\n"
        )

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        inputs = tokenizer(
            prompt,
            truncation=True,
            max_length=1024,
            return_tensors="pt",
            return_attention_mask=True,
        )
        inputs = {
            key: tensor.to(device) for key, tensor in inputs.items()
        }

        max_new_tokens = min(1024 - inputs['input_ids'].shape[-1], 96)

        # Use a profiler to measure FLOPs
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        start_event.record()
        with torch.no_grad():
            _ = model.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs["attention_mask"], # Pass attention mask
                max_new_tokens=max_new_tokens,           # Max tokens to generate *after* the prompt
                do_sample=True,                          # Use sampling for more varied output
                temperature=0.69,                        # Controls randomness (lower -> more focused)
                top_k=42,                                # Consider only top K tokens for sampling
                top_p=0.96,                              # Nucleus sampling: consider tokens summing up to P probability
                num_return_sequences=1,                  # Number of responses to generate
                pad_token_id=tokenizer.eos_token_id      # Set pad token ID for generation

            )
        end_event.record()
        torch.cuda.synchronize()  # Wait for GPU operations to finish

        elapsed_time_ms = start_event.elapsed_time(end_event)
        elapsed_time_s = elapsed_time_ms / 1000

        # Estimate FLOPs (This is a rough estimate)
        num_params = sum(p.numel() for p in model.parameters())
        # Simplified FLOP estimation
        estimated_flops = num_params * max_new_tokens

        flops_per_token.append(estimated_flops / max_new_tokens)

    return np.mean(np.array(flops_per_token))

### **Result**

In [29]:
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME)

hfh4_dataset = load_dataset("HuggingFaceH4/instruction-dataset", split="test")

eval_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {eval_device}")
eval_precision = (
    torch.float16
    if eval_device == "cuda"
    else torch.float32
)

try:
    gpt2moe_model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=eval_precision,
    )
    gpt2moe_tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
except Exception as e:
    print(f"Could not load the model due to {e}.")
    print("Registering GPT2MoE-Instruct...")
    AutoConfig.register("gpt2moe", GPT2MoEConfig)
    AutoModelForCausalLM.register(GPT2MoEConfig, GPT2MoEForCausalLM)
    print("GPT2MoE-Instruct registered.")

    gpt2moe_model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=eval_precision,
    )
    gpt2moe_tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

gpt2moe_model.to(eval_device)
gpt2moe_model.eval()

tokens_per_second = measure_tokens_per_second(
    gpt2moe_model, gpt2moe_tokenizer, hfh4_dataset
)
print(f"GPT2MoE Instruct Tokens/second: {tokens_per_second}")

perplexity = calculate_perplexity(
    gpt2moe_model, gpt2moe_tokenizer, hfh4_dataset
)
print(f"GPT2MoE Instruct Perplexity: {perplexity}")

flops_per_token = measure_flops_per_token(
    gpt2moe_model, gpt2moe_tokenizer, hfh4_dataset,
)
print(f"GPT2MoE Instruct FLOPs/token: {flops_per_token}")

Using device: cuda


Measuring tokens/sec:   0%|          | 0/327 [00:00<?, ?it/s]

GPT2MoE Instruct Tokens/second: 34.0699482741552


Calculating perplexity:   0%|          | 0/327 [00:00<?, ?it/s]

GPT2MoE Instruct Perplexity: 71.73608923131344


Measuring FLOPs/token:   0%|          | 0/327 [00:00<?, ?it/s]

GPT2MoE Instruct FLOPs/token: 138641664.0


## **Inference**

### **Initialize**

In [30]:
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME)

INFERENCE_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {INFERENCE_DEVICE}")

INFERENCE_PRECISION = (
    torch.float16
    if INFERENCE_DEVICE == "cuda"
    else torch.float32
)

print(f"Loading model from: {MODEL_PATH}")
INFERENCE_TOKENIZER = AutoTokenizer.from_pretrained(BASE_MODEL)
try:
    INFERENCE_MODEL = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=INFERENCE_PRECISION,
    )
except Exception as e:
    print(f"Could not load the model due to {e}.")

    print("Registering GPT2MoE-Instruct...")

    AutoConfig.register(f"{BASE_MODEL}moe", GPT2MoEConfig)
    AutoModelForCausalLM.register(GPT2MoEConfig, GPT2MoEForCausalLM)

    print("GPT2MoE-Instruct registered.")

    try:
        INFERENCE_MODEL = AutoModelForCausalLM.from_pretrained(
            MODEL_PATH,
            torch_dtype=INFERENCE_PRECISION,
        )
    except Exception as e:
        print(f"Could not load the model due to {e}.")
        INFERENCE_MODEL = AutoModelForCausalLM.from_pretrained(
            MODEL_PATH,
            torch_dtype=torch.float32,
        )

INFERENCE_MODEL.to(INFERENCE_DEVICE)
INFERENCE_MODEL.eval()
print("Model and tokenizer loaded successfully.")

Using device: cuda
Loading model from: /content/drive/MyDrive/GPT2MoEInstruct/model/gpt2moe-alpaca-finetuned-lightning


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Model and tokenizer loaded successfully.


### **Live Conversation**

In [None]:
# @markdown # **Let's Chat with GPT2MoE-Instruct!**
# @markdown To end the conversation, please say: "exit", "quit", or "bye".
# @markdown * * *
# @markdown ### **Conversation Setting**
# @markdown Apply before run!
top_k = 42 # @param {type:"integer"}
top_p = 0.96 # @param {type:"slider", min:0.00, max:1.0, step:0.01}
temperature = 0.69  #@param {type:"slider", min:0.00, max:2.0, step:0.01}
context = ""

console = Console()


try:
    while True:
        console.clear()
        user_input = Prompt.ask("[bold bright_magenta]User[/bold bright_magenta]")
        os.system('cls' if os.name == 'nt' else 'clear')

        print()
        print()

        if user_input.lower() in ["exit", "quit", "bye"]:
            console.print(Panel(
                "Ok, bye!",
                title="GPT2MoE-Instruct",
                border_style="bold bright_green",
                title_align="center",
                padding=(1, 2)
            ))
            print()
            print()
            console.clear()
            os.system('cls' if os.name == 'nt' else 'clear')
            break

        prompt = (
            f"### Instruction:\n{user_input}\n\n"
            f"### Input:\n{context}\n\n"
            "### Response:\n"
        )

        inputs = INFERENCE_TOKENIZER(
            prompt,
            truncation=True,
            max_length=MAX_SEQ_LENGTH,
            return_tensors="pt",
            return_attention_mask=True,
        )
        inputs = {
            key: tensor.to(INFERENCE_DEVICE) for key, tensor in inputs.items()
        }

        max_new_tokens = min(
            MAX_SEQ_LENGTH - inputs['input_ids'].shape[-1],
            96,
        )

        with torch.no_grad(): # Disable gradient calculations for inference
            output_sequences = INFERENCE_MODEL.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs["attention_mask"], # Pass attention mask
                max_new_tokens=max_new_tokens, # Max tokens to generate *after* the prompt
                do_sample=True,             # Use sampling for more varied output
                temperature=temperature,    # Controls randomness (lower -> more focused)
                top_k=top_k,                # Consider only top K tokens for sampling
                top_p=top_p,                 # Nucleus sampling: consider tokens summing up to P probability
                num_return_sequences=1,     # Number of responses to generate
                pad_token_id=INFERENCE_TOKENIZER.eos_token_id # Set pad token ID for generation
            )

        assistant_response = INFERENCE_TOKENIZER.decode(output_sequences[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)

        console.print(Panel(
            f"{assistant_response}",
            title="GPT2MoE-Instruct",
            border_style="bold bright_green",
            title_align="center",
            padding=(1, 2)
        ))
        print()
        print()

        context = (
            user_input[:69]
            + ' '
            + assistant_response[:69]
        )
        context = context[:42]

except (KeyboardInterrupt, Exception) as e:
    print(f"\nStopping conversation. Reason: {e}")
    console.clear()
    os.system('cls' if os.name == 'nt' else 'clear')


torch.cuda.empty_cache()

Who are you? Tell me everything about yourself!








Alright, explain AI to me then!








Can you elaborate on more?








I mean, what is AI by definition?








Regarding AI, what did you know about machine learning? What is their relationship?








Let me repeat: what is machine learning?








How about deep learning? Is it different?








Yes, I know. But, what is deep learning?








Write a poem about what we have talked about so far!








If you fancy AI, what is the best model to suit your liking?








Tell me a joke about that then!








I don't get it. Let me clarify: tell me a joke about NLP!








Forget that. Now, tell me, which one is better CV or NLP?








So, CV is not interesting to you, eh?








I mean CV as in computer vision.








Who is John? He has nothing to do with computer vision!








Forget it! Move on to this question: what is computer vision?








So he is a computer vision engineer then?








What is computer vision?








Can you do some computer vision tasks?








How can we integrate it with NLP?








So, NLP and computer vision can be combined, no?








Last one, what did you know about reinforcement learning?








Let me repeat: what is reinforcement learning?








Tell me a story about it then!








Now, for the final discussion, tell me about artificial neural networks!








What? Let me clarify: what are artificial neural networks?








What is the relationship of it to deep learning?








I mean: are artificial neural networks and deep learning related?








Tell me a funny poetic story about it!








bye








### **Chatbot**

#### **LangChain**

In [109]:
class GPT2MoEInstruct(LLM):
    max_length_response: int = 4096 # Telegram maximum text length
    max_new_tokens: int = 120

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """Run the LLM on the given input.

        Override this method to implement the LLM logic.

        Args:
            prompt: The prompt to generate from.
            stop: Stop words to use when generating. Model output is cut off at
                the first occurrence of any of the stop substrings.
                If stop tokens are not supported consider raising
                NotImplementedError.
            run_manager: Callback manager for the run.
            **kwargs: Arbitrary additional keyword arguments. These are usually
                passed to the model provider API call.

        Returns:
            The model output as a string. Actual completions SHOULD NOT include
            the prompt.
        """

        inputs = INFERENCE_TOKENIZER(
            prompt,
            truncation=True,
            max_length=MAX_SEQ_LENGTH,
            return_tensors="pt",
            return_attention_mask=True,
        )
        inputs = {
            key: tensor.to(INFERENCE_DEVICE) for key, tensor in inputs.items()
        }

        num_new_tokens = min(
            MAX_SEQ_LENGTH - inputs['input_ids'].shape[-1],
            self.max_new_tokens,
        )

        with torch.no_grad():
            output_sequences = INFERENCE_MODEL.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs["attention_mask"],
                max_new_tokens=num_new_tokens,
                do_sample=True,
                temperature=0.72,
                top_k=60,
                top_p=0.96,
                num_return_sequences=1,
                pad_token_id=INFERENCE_TOKENIZER.eos_token_id
            )

        response = INFERENCE_TOKENIZER.decode(
            output_sequences[0][inputs['input_ids'].shape[-1]:],
            skip_special_tokens=True,
        )

        return response[:self.max_length_response]

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """Return a dictionary of identifying parameters."""
        return {
            # The model name allows users to specify custom token counting
            # rules in LLM monitoring applications (e.g., in LangSmith users
            # can provide per token pricing for their model and monitor
            # costs for the given LLM.)
            "model_name": "GPT2MoEInstruct",
        }

    @property
    def _llm_type(self) -> str:
        """Get the type of language model used by this chat model.
        Used for logging purposes only."""
        return "custom"

In [110]:
# Define a new graph
workflow = StateGraph(state_schema=MessagesState)

llm = GPT2MoEInstruct()

# Define the function that calls the model
def call_model(state: MessagesState):
    template = (
        """### Instruction:\n{user_input}\n\n"""
        """### Input:\n{context}\n\n"""
        """### Response:\n"""
    )
    prompt = PromptTemplate.from_template(template)
    chain = prompt | llm

    messages = list()
    for msg in state['messages']:
        messages.append(msg.content)

    user_input = messages[-1]

    context = ""
    messages = messages[-3:-1]
    for idx, msg in enumerate(messages):
        # prefix = "User: " if idx % 2 == 0 else "Assistant: "
        # if idx < len(messages) - 1:
        #     context += prefix + msg[:100] + "\n"
        # else:
        #     context += prefix + msg[:100]
        if idx < len(messages) - 1:
            context += msg[:69] + " "
        else:
            context += msg[:69]
    context = context[:42]

    response = chain.invoke(
        {
            'user_input': user_input,
            'context': context,
        }
    )
    return {"messages": response}

# Define the (single) node in the graph
workflow.add_edge(START, "model")
workflow.add_node("model", call_model)

# Add memory
memory = MemorySaver()
app = workflow.compile(checkpointer=memory)

#### **Telegram**

In [111]:
""" Please provide your Telegram bot's API TOKEN in Colab's secret! """

TOKEN = userdata.get("TOKEN")
bot = telebot.TeleBot(TOKEN)

In [112]:
@bot.message_handler(commands=['start'])
def start(message):
    """Send a message when the command /start is issued."""

    bot.reply_to(
        message,
        "GPT2MoE-Instruct: GPT-2 with MoE Chatbot. "
        "Talk something random to me!",
    )

In [113]:
@bot.message_handler(commands=['help'])
def help(message):
    """Send a message when the command /help is issued."""

    bot.reply_to(
        message,
        "Just type and send texts, it will reply.",
    )

In [114]:
@bot.message_handler(func=lambda m: True)
def reply_text(message):
    """Reply text input from the user message."""

    config = {"configurable": {"thread_id": message.chat.id}}

    user_input = message.text

    output = app.invoke({"messages": user_input}, config)
    response = output["messages"][-1].content

    bot.reply_to(message, response)

In [115]:
try:
    bot.polling()
except Exception as e:
    print(f"Stopping bot. Reason: {e}")