<!-- @format -->

### 1) How to Fine-tune Gemma 2 for Advanced Reasoning in Communication, Translation and Multilingual Tasks

Large Language Models (LLMs) like Gemma 2 have shown remarkable capabilities, but they can struggle with complex translation and cross-lingual communication tasks that require nuanced reasoning. Traditional fine-tuning with Chain-of-Thought (CoT) provides some improvements but has inherent limitations when dealing with multilingual scenarios.

In this tutorial, we'll explore an innovative approach to enhance Gemma 2's reasoning capabilities by implementing the Coconut (Chain of Continuous Thought) paradigm introduced by [Hao et al. (2024)](https://arxiv.org/pdf/2412.06769). Instead of constraining the model to reason in language space, we'll leverage continuous latent representations to enable more flexible and powerful reasoning patterns, particularly beneficial for translation and cross-lingual tasks.

By utilizing the last hidden state as a "continuous thought" and feeding it back directly as input embeddings, we can help the model develop more sophisticated reasoning strategies. This approach allows Gemma 2 to:

- Explore multiple reasoning paths simultaneously
- Avoid premature commitment to single translations
- Handle complex linguistic nuances more effectively
- Reduce token overhead during inference

This tutorial will guide you through implementing this cutting-edge fine-tuning approach using a real-world dataset, helping you transform Gemma 2 into a more capable multilingual reasoning system.

<img src="https://res.cloudinary.com/vickie/image/upload/v1735371553/ogylhcgz3o8trjtnjcov.png" alt="Gemini Reasoning Finetuning" width="1000"/>


Here are some key concepts you need to know to better grasp the ideas of this tutorial

##### <strong>Key Concepts</strong>
<strong>Language Space</strong>: This is the discrete, symbolic representation of language (words, sentences).<br/>
<strong>Continuous Thought Space (Latent Space)</strong>: This is the High-dimensional, continuous internal representations used by models for reasoning.

###### <strong> Why Continuous Thought Space?</strong>
<strong>Flexibility</strong>: Explore multiple reasoning paths simultaneously.<br/>
<strong>Efficiency</strong>: Reduces token overhead by avoiding intermediate text. <br/>
<strong>Nuance</strong>: Better captures complex linguistic and cross-lingual relationships.


Having said this, let's begin by setting up the necessary environment.


### 2) Get Access to Gemma 2

To get access to Gemma 2, follow these steps:

1. Sign in or register on [Kaggle](https://kaggle.com)

2. Visit the [Gemma 2 Model Card](https://www.kaggle.com/models/google/gemma/frameworks/transformers) 

3. Accept the terms and conditions to gain access to the model

4. Once approved, you can use the model in your Kaggle notebooks or download locally

Note: The Gemma 2 family consists of different model sizes (2B and 7B parameters). For this tutorial, we'll be using the 2B parameter version to make the fine-tuning process more manageable on typical hardware configurations. Speaking of hardware configuration: For this tutorial, we tested on RTX A6000, L20 and L40. This tutorial can also run on other GPUs, but if you get error similar to `CUDA memory limit` then use a different GPU


Then run this command
`pip install kagglehub`

Copy the code below, put it in a python script `setup.py` and run python setup.py. The code is supposed to download the Gemma model to a local path and prints that local path.

```

import kagglehub

kagglehub.login()

# Download latest version
path = kagglehub.model_download("google/gemma-2/transformers/gemma-2-2b")

print("Path to model files:", path)

```

Copy the `path` printed and replace it at the 3) [Configuration `path`](#3-configuration)

In [2]:
%reload_ext autoreload
%autoreload 2

In [3]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# CUDA_DEVICE_ORDER=1 CUDA_VISIBLE_DEVICES=1

In [4]:
# import torch

# # Check the number of available CUDA devices
# num_devices = torch.cuda.device_count()
# print(f"Number of available CUDA devices: {num_devices}")

# # Print the name of each CUDA device
# for i in range(num_devices):
#     print(f"Device {i}: {torch.cuda.get_device_name(i)}")

In [5]:
# pip install -q -U wandb nltk rouge-score thefuzz python-Levenshtein bert-score evaluate transformers peft datasets janome numpy fuzzywuzzy bitsandbytes ml_dtypes tf_keras torch torchvision pytorch-lightning tensorflow scikit-learn tokenizers==0.20.1 huggingface_hub

#### 3) Configuration

In [6]:
path = "unsloth/gemma-2-2b" # Make sure to change this to the path where the model is stored

In [7]:
config = {

    # Core Learning Parameters
    "learning_rate": 5e-5,                  # How fast the model learns (0.00005)
    "continuous_thoughts": 4,               # Number of latent space reasoning steps
    "stages": 4,                            # Number of training curriculum stages
    "training_thoughts_sequence_length": 50, # Number of thought sequence to generate

    # Inference and Evaluation Params       
    "fuzzy_matcher_threshold": 80,          # Fuzzy matcher threshold at 80%
    "cot_decoding_k": 5,                    # Number of paths to try before finding the best answer

    # Model Setup
    "max_length": 256,                      # Maximum text length to process
    "model_name": path,                     # Path to Gemma model
    "batch_size": 4,                        # Number of examples processed together
    "weight_decay": 0.01,                   # Helps prevent overfitting

    # Special Tokens
    "bot_id": "<bot>",                      # Marks start of latent reasoning
    "eot_id": "<eot>",                      # Marks end of latent reasoning
    "answer_id": "<answer>",                # Marks the begining of answer
    "debug": True,                          # Enables debugging output. Also allows you see the model's thoughts

    # Training Optimizations
    "bf16": True,                           # Uses BFloat16 for faster training
    "per_device_train_batch_size": 1,       # Samples per GPU/CPU
    "optim": "adamw_torch",                 # AdamW optimizer for efficiency
    "wandb_project": "gemma2-finetuning",   # Tracks training on Weights & Biases
    "logging_steps": 1,                     # How often to log training progress
    "bf16_full_eval": True,                 # Uses BFloat16 for evaluation
    "gradient_accumulation_steps": 1,       # How often to update weights
    "save_steps": 10000,                    # How often to save model
    "warmup_steps": 0.1,                    # Number of warmup steps
    "output_dir": "output",                 # Where to save model files
    "diversity_weight": 0.1,                # Reasoning diversity weight
    "coherence_weight": 0.1                 # Reasoning coherence weight
}

#### 4) Dataset Overview: Japanese-English Translation/Communication
<i>Using the llm-japanese-dataset created by [Hirano et al. (2023)](https://arxiv.org/pdf/2305.12720)</i>

We'll be using the llm-japanese-dataset (8.4M records) to fine-tune Gemma 2, focusing on Japanese-English translation and communication. The dataset follows this format:

```
### Instruction:
Please translate to English.

### Input:
こんにちは、元気ですか？

### Response:
Hello, how are you?
```

Most of the data (about 80%) consists of translations like this, making it perfect for improving Gemma 2's Japanese-English capabilities.

In [8]:
from datasets import load_dataset, DatasetDict
from datasets import config as dataset_config
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

dataset_name = "izumi-lab/llm-japanese-dataset"
dataset = load_dataset(dataset_name)

# For this tutorial, let's take 30,000k samples from the dataset
item = 30000

truncated_dataset = DatasetDict({
    split: dataset[split].select(range(item))
    for split in dataset.keys()
})


dataset = truncated_dataset
eval_dataset = dataset


Now let's see few examples of the dataset

In [9]:
for i in range(3):
    print("Instruction: ", dataset['train']["instruction"][i], "\n")
    print("Input: ", dataset['train']["input"][i], "\n")
    print("Output: ",dataset['train']["output"][i], "\n")
    print(f"{'='*200}\n")



Instruction:  「abc ～the first～」へようこそ！さて、ABC・・・と始まるアルファベットは、全部で何文字でしょう？ 

Input:   

Output:  26文字 


Instruction:  人気漫画『ドラえもん』の登場人物で、ジャイアンの苗字は剛田ですが、スネ夫の苗字は何でしょう？ 

Input:   

Output:  骨川（滑川も正解） 


Instruction:  格闘家ボブ・サップの出身国はどこでしょう？ 

Input:   

Output:  アメリカ 




##### 5) Language Detector

When handling multilingual content, detecting languages accurately is crucial. Our language detector analyzes both the structure and composition of mixed-language text.

For example, let's examine this input:
```python
"「abc ～the first～」へようこそ！さて、ABC・・・と始まるアルファベットは、全部で何文字でしょう？"
```

The detector identifies Japanese as the primary language while recognizing English phrases (abc, first, ABC) embedded within. It analyzes the distribution of scripts including Hiragana, Latin alphabet, and various symbols. This composition analysis helps determine language percentages and structure.

Beyond basic detection, this analysis guides translation strategy and validates output language alignment. By understanding the input language composition, Gemma 2 can better reason about how to process and generate appropriate bilingual responses. The detector becomes especially valuable when building chain-of-thought reasoning patterns across languages.

Let's implement this detector...

In [10]:
# Implementing the LanguageDetector class

from typing import Dict, List, Optional
from dataclasses import dataclass

@dataclass
class ScriptRange:
    """Represents a Unicode range for a writing system"""
    start: int
    end: int
    name: str
    
class LanguageDetector:
    def __init__(self):
        self.scripts: List[ScriptRange] = []
        self.language_mappings: Dict[str, List[str]] = {}
        
    def add_script(self, name: str, start: int, end: int) -> None:
        """
        Add a new script range to the detector
        
        Args:
            name: Name of the script (e.g., 'Hiragana', 'Latin')
            start: Starting Unicode code point
            end: Ending Unicode code point
        """
        self.scripts.append(ScriptRange(start, end, name))
    
    def map_scripts_to_language(self, language: str, script_names: List[str]) -> None:
        """
        Map multiple scripts to a single language
        
        Args:
            language: Name of the language (e.g., 'Japanese')
            script_names: List of script names that belong to this language
        """
        self.language_mappings[language] = script_names
    
    def detect(self, text: str) -> Dict[str, float]:
        """
        Detect the percentage of different languages/scripts in the text
        
        Args:
            text: Input text to analyze
            
        Returns:
            Dictionary mapping language/script names to their percentage presence
        """
        # Count characters in each script
        char_counts: Dict[str, int] = {script.name: 0 for script in self.scripts}
        total_chars = 0
        
        for char in text:
            if char.isspace() or char in '.,!?()[]{}':
                continue
                
            code = ord(char)
            total_chars += 1
            
            # Check which script range the character falls into
            for script in self.scripts:
                if script.start <= code <= script.end:
                    char_counts[script.name] += 1
                    break
        
        if total_chars == 0:
            return {}
            
        # Calculate initial percentages
        percentages = {
            script: (count / total_chars) * 100
            for script, count in char_counts.items()
            if count > 0
        }
        
        # Combine scripts into languages where applicable
        final_percentages = {}
        used_scripts = set()
        
        # First, handle mapped languages
        for language, script_names in self.language_mappings.items():
            total = sum(percentages.get(script, 0) for script in script_names)
            if total > 0:
                final_percentages[language] = total
                used_scripts.update(script_names)
        
        # Then add remaining unmapped scripts
        for script, percentage in percentages.items():
            if script not in used_scripts:
                final_percentages[script] = percentage
        
        return {k: round(v, 1) for k, v in sorted(
            final_percentages.items(),
            key=lambda x: x[1],
            reverse=True
        )}

# Example setup and usage
def create_default_detector() -> LanguageDetector:
    """Create a detector with Japanese and English support"""
    detector = LanguageDetector()
    
    # Add Japanese scripts
    detector.add_script('Hiragana', 0x3040, 0x309F)
    detector.add_script('Katakana', 0x30A0, 0x30FF)
    detector.add_script('Kanji', 0x4E00, 0x9FFF)

    # Add English scripts
    detector.add_script('Latin', 0x0000, 0x024F)

    
    # Map scripts to languages
    detector.map_scripts_to_language('Japanese', ['Hiragana', 'Katakana', 'Kanji'])
    detector.map_scripts_to_language('English', ['Latin'])
    
    return detector

if __name__ == "__main__":
    detector = create_default_detector()
    
    test_texts = [
        'スナフキン',
        'レベッカ(REBECCA)',
        'Hello World',
        'こんにちは World!'
    ]
    
    for text in test_texts:
        result = detector.detect(text)
        print(f"Text: {text} ===>>>> {result}")

Text: スナフキン ===>>>> {'Japanese': 100.0}
Text: レベッカ(REBECCA) ===>>>> {'English': 63.6, 'Japanese': 36.4}
Text: Hello World ===>>>> {'English': 100.0}
Text: こんにちは World! ===>>>> {'Japanese': 50.0, 'English': 50.0}


##### 6) Dataset Preprocessing

The preprocessing pipeline prepares our dataset for Continuous Latent Reasoning by transforming raw translation pairs into structured training data. Here's what each component does:

`preprocess_function`: 
Takes raw examples and generates Chain-of-Thought reasoning steps. For each sample, it:
1. Analyzes language composition of input/output
2. Generates appropriate reasoning steps in detected language
3. Formats with special tokens (bot_token, eot_token) as per Hao et al. (2024)

Example flow:
```python
Input: "「abc ～the first～」へようこそ！"
Steps:
- Detect languages (Japanese: 60%, English: 40%)
- Generate understanding steps
- Format with special tokens
Output: "<bos> Input <eos><bot><eot> Step 1... Step 2... Answer <eos>"
```

`tokenizer_function`:
Converts text into model inputs by:
1. Tokenizing the formatted text
2. Creating attention masks
3. Preparing labels (masking question/thought tokens)


This preprocessing ensures our data is properly structured for training Gemma 2 in continuous latent space reasoning.

In [11]:
from transformers import PreTrainedTokenizer

def preprocess_function(
    examples, 
    detector=None,  # Make detector optional
    stages=1, 
    eos_token="<eos>",
    bos_token="<bos>",
    language_config=None,
):
    """
    Preprocess the input examples by constructing the prompt with reasoning steps.

    Args:
        examples (dict): A dictionary containing the input examples with keys "instruction", "input", and "output".
        detector: A language detection object or function that detects the language of a given text.
        stages (int): The number of reasoning stages to include in the prompt.
        eos_token (str): The end-of-sequence token.
        bos_token (str): The beginning-of-sequence token.
        language_config (dict): A dictionary mapping language keys to their respective translations for steps and labels.

    Returns:
        dict: A dictionary containing the preprocessed prompts.
    """

    if language_config is None:
        language_config = {
            "English": {
                "language_detection": "Question language detection",
                "understand_question": "Understand the question",
                "understand_answer": "Understand the answer",
                "response_language_detection": "Response language detection",
                "answer_label": "Answer:",
                "step_label": "Step",
            },
            "Japanese": {
                "language_detection": "言語の検出",
                "understand_question": "質問を理解する",
                "understand_answer": "答えを理解する",
                "response_language_detection": "応答言語の検出",
                "answer_label": "答え：",
                "step_label": "ステップ",
            },
            # Add more languages here as needed
        }

    instructions = examples["instruction"]
    inputs = examples["input"]
    outputs = examples["output"]

    bot = config["bot_id"]
    eot = config["eot_id"]
    answer_token = config["answer_id"]

    # Initialize output dictionaries with lists
    result = []

    for i in range(len(instructions)):
        instruction = instructions[i]
        input = inputs[i]
        output = outputs[i]

        if len(input) > 1:
            input = instruction + input
        else:
            input = instruction

        # Use the provided detector to detect languages
        input_language = detector.detect(input) if detector else {"English": 100.0}  # Default to English if no detector
        output_language = detector.detect(output) if detector else {"English": 100.0}  # Default to English if no detector

        steps = []

        # Determine the primary input and output languages
        # Use the language key from the detector's output that matches a key in language_config
        input_lang = next((lang for lang in input_language if lang in language_config), "English")
        output_lang = next((lang for lang in output_language if lang in language_config), "English")

        # Get the language-specific labels
        input_labels = language_config.get(input_lang, language_config["English"])
        output_labels = language_config.get(output_lang, language_config["English"])

        # Input language detection
        input_lang_str = ", ".join([f"{k}: {v}%" for k, v in input_language.items()])
        steps.append(f"{input_labels['language_detection']}: {input_lang_str}")
        steps.append(f"{input_labels['understand_question']}: {input}")
        steps.append(f"{input_labels['understand_answer']}: {output}")

        # Output language detection
        output_lang_str = ", ".join([f"{k}: {v}%" for k, v in output_language.items()]) if output_language else "Unknown"
        steps.append(f"{output_labels['response_language_detection']}: {output_lang_str}")

        # Format steps with step numbers
        steps = [f"{output_labels['step_label']} {i+1} : {step}" for i, step in enumerate(steps)]

        # Include only the steps relevant to the current stage
        if stages > 0:
            steps = steps[-stages:]  # Keep the last `stages` steps

        # Renumber steps to start from 1
        steps = [f"{output_labels['step_label']} {i+1} : {step.split(' : ')[1]}" for i, step in enumerate(steps)]

        # Construct the prompt
        prompt = bos_token + "\n" + input + eos_token + bot + eot + "\n" + "\n".join(steps) + "\n" + answer_token + output_labels['answer_label'] + output + eos_token

        result.append(prompt)

    return {
        "prompt": result
    }

In [12]:
import torch

def tokenizer_function(examples, tokenizer):
    """
    Tokenize the input prompt and prepare the input_ids, attention_mask, and labels for training.

    Args:
        examples (dict): A dictionary containing the input prompts.
        tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenization.

    Returns:
        dict: A dictionary containing the tokenized input_ids, attention_mask, and labels.
    """

    prompt = examples["prompt"]
    eot = config["eot_id"]

    tokenized = tokenizer(
        prompt,
        max_length=config["max_length"],
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )

    input_ids = tokenized["input_ids"].squeeze(0)
    attention_mask = tokenized["attention_mask"].squeeze(0)

    labels = input_ids.clone()
    batch_size = labels.shape[0]
    eot_id = tokenizer.convert_tokens_to_ids(eot)

    for i in range(batch_size):
        # Find the positions of <eot> in the input_ids
        eot_pos = (input_ids[i] == eot_id).nonzero(as_tuple=True)

        if len(eot_pos[0]) > 0:
            # Get the last occurrence of <eot>
            last_eot_pos = eot_pos[0][-1].item()
            
            # Mask everything before and including the last <eot>
            labels[i, :last_eot_pos] = -100

        # Mask padding
        labels[i, attention_mask[i] == 0] = -100


    value =  {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
    }


    if torch.cuda.device_count() > 1:
        return value
    else:
        value["labels"] = labels
        return value

Language Config. If you need to handle more languages, you can add to this configuration and pick a suitable language detector as `LanguageDetector` class only support `English` and `Japanese` ... Note the keys `English`, `Japanese` or any other you intend to include must match the response from your Language detector `Text: レベッカ(REBECCA) ===>>>> {'English': 63.6, 'Japanese': 36.4}`

In [13]:
language_config = {
    "English": {
        "language_detection": "Question language detection",
        "understand_question": "Understand the question",
        "understand_answer": "Understand the answer",
        "response_language_detection": "Response language detection",
        "answer_label": "Answer:",
        "step_label": "Step",
    },
    "Japanese": {
        "language_detection": "言語の検出",
        "understand_question": "質問を理解する",
        "understand_answer": "答えを理解する",
        "response_language_detection": "応答言語の検出",
        "answer_label": "答え：",
        "step_label": "ステップ",
    },
    # Add more languages here as needed
}

Now let's visualize what our dataset looks like preprocessed.

In [14]:
# So we do not load every dataset as this takes a while
truncated_dataset = DatasetDict({
    split: dataset[split].select(range(5))
    for split in dataset.keys()
})

for stage in range(config["stages"]):
    dataset_ = truncated_dataset.map(
        (lambda x: preprocess_function(
            x, 
            detector=detector,
            stages=stage, 
            language_config=language_config
        )),
        batched=True,
        batch_size=config["batch_size"],
    )

    print(f"Stage: ========================>>>>>>>>>>>>>>>>> {stage}")
    for i in range(5):
        print("Input: ", dataset_['train']["prompt"][i], "\n")
        print(f"{'='*100}\n")



Input:  <bos>
「abc ～the first～」へようこそ！さて、ABC・・・と始まるアルファベットは、全部で何文字でしょう？<eos><bot><eot>
ステップ 1 : 言語の検出: Japanese: 59.3%, English: 25.9%
ステップ 2 : 質問を理解する: 「abc ～the first～」へようこそ！さて、ABC・・・と始まるアルファベットは、全部で何文字でしょう？
ステップ 3 : 答えを理解する: 26文字
ステップ 4 : 応答言語の検出: Japanese: 50.0%, English: 50.0%
<answer>答え：26文字<eos> 


Input:  <bos>
人気漫画『ドラえもん』の登場人物で、ジャイアンの苗字は剛田ですが、スネ夫の苗字は何でしょう？<eos><bot><eot>
ステップ 1 : 言語の検出: Japanese: 89.1%
ステップ 2 : 質問を理解する: 人気漫画『ドラえもん』の登場人物で、ジャイアンの苗字は剛田ですが、スネ夫の苗字は何でしょう？
ステップ 3 : 答えを理解する: 骨川（滑川も正解）
ステップ 4 : 応答言語の検出: Japanese: 77.8%
<answer>答え：骨川（滑川も正解）<eos> 


Input:  <bos>
格闘家ボブ・サップの出身国はどこでしょう？<eos><bot><eot>
ステップ 1 : 言語の検出: Japanese: 95.2%
ステップ 2 : 質問を理解する: 格闘家ボブ・サップの出身国はどこでしょう？
ステップ 3 : 答えを理解する: アメリカ
ステップ 4 : 応答言語の検出: Japanese: 100.0%
<answer>答え：アメリカ<eos> 


Input:  <bos>
ロシア語で「城」という意味がある、ロシアの大統領府の別名は何でしょう？<eos><bot><eot>
ステップ 1 : 言語の検出: Japanese: 88.6%
ステップ 2 : 質問を理解する: ロシア語で「城」という意味がある、ロシアの大統領府の別名は何でしょう？
ステップ 3 : 答えを理解する: クレムリン
ステップ 4 : 応答言語の検出: Japanese: 100.0%
<answer>答え：クレムリ

##### 6) Modelling

Our `LatentReasoningGemmaForCausalLM` extends GemmaForCausalLM to enable continuous latent reasoning, following the architecture from Hao et al. (2024). The model implements two key forward paths:

`infer_forward`: During inference, transforms input text into continuous thought representations before generating the final output. It maintains a chain of latent states between the `<bot>` and `<eot>` tokens, allowing for more nuanced reasoning across languages.

`train_forward`: During training, processes the sequence in stages, gradually building up continuous thought representations while masking appropriate parts of the input. It helps the model learn to reason in latent space while maintaining language understanding.

Let's proceed with the creation of our models

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GemmaForCausalLM, DynamicCache, PreTrainedTokenizer
from typing import Optional, List, Union, Dict, Any
import logging

logger = logging.getLogger(__name__)

class LatentReasoningGemmaForCausalLM(GemmaForCausalLM):
    """
    A custom implementation of GemmaForCausalLM that supports latent reasoning 
    using the Coconut (Chain of Continuous Thought) paradigm.
    """

    DEFAULT_CONFIG = {
        # Core Learning Parameters
        "continuous_thoughts": 4,               # Number of latent space reasoning steps
        "stages": 4,                            # Number of training curriculum stages
        "training_thoughts_sequence_length": 50, # Number of thought sequence to generate

        # Inference and Evaluation Params       
        "fuzzy_matcher_threshold": 80,          # Fuzzy matcher threshold at 80%
        "cot_decoding_k": 5,                    # Number of paths to try before finding the best answer

        # Model Setup
        "max_length": 256,                      # Maximum text length to process

        # Special Tokens
        "bot_id": "<bot>",                      # Marks start of latent reasoning
        "eot_id": "<eot>",                      # Marks end of latent reasoning
        "answer_id": "<answer>",                # Marks the begining of answer
        "debug": True,                          # Enables debugging output. Also allows you see the model's thoughts

    }
    
    def __init__(self, config):
        super().__init__(config)
        self.tokenizer: PreTrainedTokenizer = None
        self.current_stage = 0
        self.model_config = type(self).DEFAULT_CONFIG
        self.debug = self.model_config.get("debug", False)
        self.diversity_weight = self.model_config.get("diversity_weight", 0.1)
        self.coherence_weight = self.model_config.get("coherence_weight", 0.1)

    def get_input_ids(self, inputs_embeds):
        """Helper method to get input ids from embeddings."""
        embedding_matrix = self.get_input_embeddings().weight
        similarities = torch.matmul(inputs_embeds, embedding_matrix.T)
        token_ids = torch.argmax(similarities, dim=-1)
        return token_ids

    def thoughts_forward(self, num_thoughts, thought_ids, thought_mask, num_of_thought_tokens = 1):
        """
        Generate continuous thought embeddings.
        """
        all_thought_outputs = []
        batch_size = thought_ids.shape[0]
        
        # Get initial embeddings
        initial_embeds = self.get_input_embeddings()(thought_ids)
        current_embeds = initial_embeds
        current_mask = thought_mask

        for t in range(num_thoughts):
            # Forward pass through transformer
            outputs = self.model.forward(
                inputs_embeds=current_embeds,
                attention_mask=current_mask,
                past_key_values=None,
                use_cache=False,
                return_dict=True,
                output_hidden_states=True,  # Get hidden states from all layers
            )
            
            # Get hidden states from all layers for better representation
            hidden_states = outputs.hidden_states
            
            # Combine hidden states from different layers using attention
            layer_attention = torch.softmax(
                torch.randn(len(hidden_states), device=hidden_states[0].device), 
                dim=0
            )
            weighted_states = sum(w * h for w, h in zip(layer_attention, hidden_states))
            
            n = num_of_thought_tokens
            last_hidden = weighted_states[:, -n:, :]  # [batch_size, n, hidden_size]
            
            # Project to lower dimension for thought space
            thought_proj = nn.Sequential(
                nn.Linear(last_hidden.shape[-1], self.config.hidden_size // 2),
                nn.LayerNorm(self.config.hidden_size // 2),
                nn.GELU()
            ).to(last_hidden.device)
            projected_thought = thought_proj(last_hidden)  # [batch_size, n, hidden_size // 2]
            
            # Add noise to increase diversity
            noise = torch.randn_like(projected_thought) * 0.1  # Adjust noise scale as needed
            projected_thought = projected_thought + noise
            
            # Project back to embedding space
            embed_proj = nn.Linear(
                self.config.hidden_size // 2,
                self.config.hidden_size,
                device=projected_thought.device
            )
            next_token_embeds = embed_proj(projected_thought)  # [batch_size, n, hidden_size]
            
            # Apply layer normalization for stability
            next_token_embeds = nn.LayerNorm(
                self.config.hidden_size,
                device=next_token_embeds.device
            )(next_token_embeds)
            
            # Update embeddings and mask
            current_embeds = torch.cat([current_embeds, next_token_embeds], dim=1)
            current_mask = torch.cat([
                current_mask,
                torch.ones((batch_size, n), device=current_mask.device)
            ], dim=1)
            
            all_thought_outputs.append(last_hidden)

        # Ensure reasonable sequence length
        max_seq_len = self.model_config.get("max_length", 512)
        if current_embeds.shape[1] > max_seq_len:
            current_embeds = current_embeds[:, :max_seq_len, :]
            current_mask = current_mask[:, :max_seq_len]
        
        return all_thought_outputs, current_embeds, current_mask


    def train_forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        num_logits_to_keep: int = 0,
        **kwargs,
    ):
        """
        Training forward pass with continuous thought generation and CoT alignment.
        """
        self.train()

        # Keep original labels if none provided
        if labels is None:
            labels = input_ids.clone()
            batch_size = labels.shape[0]
            eot_id = self.tokenizer.convert_tokens_to_ids(self.model_config["eot_id"])

            for i in range(batch_size):
                # Find the positions of <eot> in the input_ids
                eot_pos = (input_ids[i] == eot_id).nonzero(as_tuple=True)

                if len(eot_pos[0]) > 0:
                    # Get the last occurrence of <eot>
                    last_eot_pos = eot_pos[0][-1].item()
                    
                    # Mask everything before and including the last <eot>
                    labels[i, :last_eot_pos] = -100

                # Mask padding
                labels[i, attention_mask[i] == 0] = -100

        # Get input embeddings if not provided
        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings()(input_ids)


        # Generate continuous thoughts
        if self.current_stage > 0:
            num_thoughts = self.current_stage * self.model_config["continuous_thoughts"]
            all_thoughts, final_embeds, final_mask = self.thoughts_forward(
                num_thoughts=num_thoughts,
                thought_ids=input_ids,
                thought_mask=attention_mask,
                num_of_thought_tokens = self.model_config["training_thoughts_sequence_length"]
            )

            # Add auxiliary losses
            auxiliary_losses = []

            # Thought coherence loss
            if len(all_thoughts) > 1:
                coherence_loss = 0
                for t1, t2 in zip(all_thoughts[:-1], all_thoughts[1:]):
                    sim = F.cosine_similarity(t1, t2, dim=-1)
                    coherence_loss += (1 - sim).mean()
                auxiliary_losses.append(coherence_loss * self.coherence_weight)

            batch_size = labels.shape[0]

            for i in range(batch_size):
                # Find the start and end of CoT in the labels
                cot_start = None
                
                for j, token_id in enumerate(labels[i]):
                    if token_id == self.tokenizer.convert_tokens_to_ids(self.model_config["eot_id"]):
                        cot_start = j + 1  # Start of CoT


                # Debugging: Print CoT tokens and latent thoughts
                if cot_start is not None:
                    # Extract CoT tokens
                    cot_tokens = labels[i, cot_start:]  # [cot_seq_len]

                    # Get the latent thoughts for this batch
                    latent_thoughts = all_thoughts[i]  # [thought_seq_len, hidden_size]

                    # Project latent thoughts to logits
                    thought_logits = self.lm_head(latent_thoughts)  # [thought_seq_len, vocab_size]
                    thought_token_ids = torch.argmax(thought_logits, dim=-1)  # [thought_seq_len]


                    # Debugging: Print CoT tokens and latent thoughts
                    if self.debug:
                        # Decode CoT tokens
                        cot_tokens_list = cot_tokens.squeeze().tolist()  # Convert to 1D list
                        if isinstance(cot_tokens_list, int):  # Handle single token case
                            cot_tokens_list = [cot_tokens_list]
                        cot_text = self.tokenizer.decode(cot_tokens_list, skip_special_tokens=True)
                        print(f" ==================== \n Debug: CoT for batch {i}: {cot_text} \n ====================")

                        # Decode latent thoughts
                        thought_token_ids_list = thought_token_ids.squeeze().tolist()  # Convert to list

                        # Ensure thought_token_ids_list is a flat list
                        if isinstance(thought_token_ids_list, list) and all(isinstance(item, list) for item in thought_token_ids_list):
                            # Flatten the nested list
                            thought_token_ids_list = [token for sublist in thought_token_ids_list for token in sublist]
                        elif isinstance(thought_token_ids_list, int):  # Handle single token case
                            thought_token_ids_list = [thought_token_ids_list]

                        # Decode the flat list of token IDs
                        thought_text = self.tokenizer.decode(thought_token_ids_list, skip_special_tokens=False)
                        print(f"==================== \n Debug: Latent thoughts for batch {i}: {thought_text} \n ========================")


            # Forward pass with thoughts
            outputs = super().forward(
                inputs_embeds=final_embeds,
                attention_mask=final_mask,
                labels=labels,
                **kwargs
            )

            # Add auxiliary losses
            if auxiliary_losses:
                outputs.loss += sum(auxiliary_losses)

        else:

            if inputs_embeds is None:
                # Standard forward pass for initial stage
                outputs = super().forward(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    inputs_embeds=inputs_embeds,
                    labels=labels,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                    **kwargs,
                )
            else:

                outputs = super().forward(
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    inputs_embeds=inputs_embeds,
                    labels=labels,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                    **kwargs,
                )

        return outputs

    
    def infer_forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[DynamicCache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        num_logits_to_keep: int = 0,
        **kwargs,
    ):
        """
        Inference forward pass with continuous thought generation.
        """

        batch_size = input_ids.shape[0]

        # Insert <bot> token to initiate latent reasoning
        if input_ids.shape[1] > 1:
            input_ids = torch.cat(
                [
                    input_ids,
                    torch.tensor(
                        [[self.tokenizer.convert_tokens_to_ids(self.model_config["bot_id"])]] * batch_size,
                        device=input_ids.device,
                    ),
                ],
                dim=1,
            )
            attention_mask = torch.cat(
                [
                    attention_mask,
                    torch.ones((batch_size, 1), device=attention_mask.device),
                ],
                dim=1,
            )

        # Generate continuous thoughts
        if self.model_config["stages"] - 1 > 0 and input_ids.shape[1] > 1:
            num_thoughts = (self.model_config["stages"] - 1) * self.model_config["continuous_thoughts"]
            all_thoughts, final_embeds, final_mask = self.thoughts_forward(
                num_thoughts, input_ids, attention_mask
            )

            # Add <eot> token to mark the end of latent reasoning
            eot_embeds = self.get_input_embeddings()(
                torch.tensor(
                    [[self.tokenizer.convert_tokens_to_ids(self.model_config["eot_id"])]] * batch_size,
                    device=final_embeds.device,
                )
            )
            final_embeds = torch.cat([final_embeds, eot_embeds], dim=1)
            final_mask = torch.cat([final_mask, torch.ones((batch_size, 1), device=final_mask.device)], dim=1)

            # Generate final output in language mode
            outputs = super().forward(
                inputs_embeds=final_embeds,
                attention_mask=final_mask,
                past_key_values=None,  # Reset past_key_values for answer generation
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs,
            )
        else:
            # Standard forward pass (no latent thoughts)
            outputs = super().forward(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                labels=labels,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs,
            )

        return outputs

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[DynamicCache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        num_logits_to_keep: int = 0,
        **kwargs,
    ):
        """Main forward function that routes to either training or inference."""
        forward_fn = self.train_forward if self.training else self.infer_forward
        return forward_fn(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            num_logits_to_keep=num_logits_to_keep,
            **kwargs,
        )

Initialize Models

In [16]:
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

# tokenizer
tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
tokenizer.pad_token = tokenizer.eos_token

# Add special tokens
special_tokens = {
    "additional_special_tokens": [config["bot_id"], config["eot_id"], config["answer_id"]]
}
num_added_tokens = tokenizer.add_special_tokens(special_tokens)

# Load the Reasoning model configuration
model_config = AutoConfig.from_pretrained(config["model_name"])
latent_config = LatentReasoningGemmaForCausalLM.DEFAULT_CONFIG
LatentReasoningGemmaForCausalLM.DEFAULT_CONFIG = {
    **latent_config,
    **config
}
updated_latent_config = LatentReasoningGemmaForCausalLM.DEFAULT_CONFIG
model = LatentReasoningGemmaForCausalLM(config=model_config)

# Load the Reasoning model
model = model.from_pretrained(
    config["model_name"],
    torch_dtype=torch.bfloat16,
)
model.tokenizer = tokenizer
model.resize_token_embeddings(len(tokenizer))


# Load the normal model for comparison
model_without_reasoning = AutoModelForCausalLM.from_pretrained(config["model_name"])
model_without_reasoning.resize_token_embeddings(len(tokenizer))
model_without_reasoning = model_without_reasoning.cuda()

You are using a model of type gemma2 to instantiate a model of type gemma. This is not supported for all configurations of models and can yield errors.
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Now let's create Helper functions for inferencing.

### Inferencing Helper Function

- Chain-of-Thought (CoT) Decoding: A technique where the model generates intermediate reasoning steps before producing the final answer.

- Top-K Sampling: Selecting the k most likely tokens to explore multiple possible continuations.

- Temperature: A parameter that controls the randomness of predictions. Higher values make the output more diverse, while lower values make it more deterministic.

- Min-Margin Confidence: A measure of how confident the model is in its predictions, based on the difference between the best and second-best probabilities.


Example Workflow:

You ask a question: `What is the capital of France?`

generate_answer explores multiple possible answers `(e.g., "Paris", "London", "Berlin")`.

It calculates the confidence for each answer and selects the one with the highest confidence `(e.g., "Paris")`.

The final answer "Paris" is returned.

In [17]:
from typing import Tuple
from transformers import TextStreamer
import torch
import torch.nn.functional as F
from transformers import PreTrainedModel, PreTrainedTokenizer
import logging

logger = logging.getLogger(__name__)

def generate_answer(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    question: str,
    max_length: int = 128,
    k: int = config["cot_decoding_k"],
    temperature: float = 1.0,
    **generation_kwargs
) -> str:
    """
    Generates answer using CoT decoding and returns the best path.
    
    Args:
        model: The language model
        tokenizer: The tokenizer
        question: Input question
        max_length: Maximum sequence length
        k: Number of alternative paths to consider
        temperature: Sampling temperature
        **generation_kwargs: Additional generation arguments
        
    Returns:
        Best decoded sequence with highest confidence
    """
    # Initialize streamer
    streamer = TextStreamer(tokenizer, skip_prompt=False, skip_special_tokens=False)
    
    # Tokenize input
    inputs = tokenizer(question, max_length=max_length, return_tensors="pt").to(model.device)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    # Get initial logits for CoT paths
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        output_hidden_states=True,
        return_dict=True
    )
    
    first_token_logits = outputs.logits[:, -1, :] / temperature
    
    # Get top-k tokens
    probs = F.softmax(first_token_logits, dim=-1)
    top_k_probs, top_k_tokens = torch.topk(probs, k, dim=-1)
    
    best_path = None
    best_confidence = -float('inf')
    
    # Generate continuation for each top-k token
    for i in range(k):
        # Prepare input with current top-k token
        curr_input_ids = torch.cat([
            input_ids,
            top_k_tokens[:, i:i+1]
        ], dim=1)
        
        curr_attention_mask = torch.cat([
            attention_mask,
            torch.ones((attention_mask.shape[0], 1), device=model.device)
        ], dim=1)
        
        # Generate with streamer for best path
        outputs = model.generate(
            input_ids=curr_input_ids,
            attention_mask=curr_attention_mask,
            max_length=max_length,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            output_scores=True,
            return_dict_in_generate=True,
            streamer=streamer if i == 0 else None,  # Only stream first path
            **generation_kwargs
        )
        
        # Calculate confidence for this path
        _, confidence = calculate_answer_confidence(
            outputs.sequences[0].tolist(),
            outputs.scores[-1],
            tokenizer
        )
        
        # Update best path if confidence is higher
        if confidence > best_confidence:
            best_confidence = confidence
            best_path = outputs.sequences[0]
            
    # Return the path with highest confidence
    return tokenizer.decode(best_path, skip_special_tokens=True)

def calculate_answer_confidence(
    sequence: List[int],
    final_logits: torch.Tensor,
    tokenizer: PreTrainedTokenizer
) -> Tuple[str, float]:
    """Calculate confidence score using min-margin approach."""
    # Extract answer from sequence
    answer = extract_answer(sequence, tokenizer)
    
    if not answer:
        return "", 0.0
    
    # Get probabilities
    probs = F.softmax(final_logits, dim=-1)
    
    # Calculate margins for answer tokens
    answer_tokens = tokenizer.encode(answer, add_special_tokens=False)
    margins = []
    
    for token in answer_tokens:
        token_prob = probs[0, token].item()
        sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
        second_best_prob = sorted_probs[0, 1].item()
        margin = token_prob - second_best_prob
        margins.append(margin)
        
    confidence = sum(margins) / len(margins)
    return answer, confidence

def extract_answer(sequence: List[int], tokenizer: PreTrainedTokenizer) -> str:
    """
    Extract final answer from sequence using <eot> token.
    Finds the answer between the last occurrence of <eot> and the end of sequence.
    """
    # Convert sequence to string
    decoded = tokenizer.decode(sequence)
    
    # Find last <eot> position
    eot_position = decoded.rfind(config["eot_id"])
    
    if eot_position != -1:
        # Extract everything after the last <eot>
        answer = decoded[eot_position + len(config["eot_id"]):].strip()
        return answer
        
    return decoded

In [18]:
import time

tick_start = 0

def tick():
    global tick_start
    tick_start = time.time()

def tock():
    print(f"TOTAL TIME ELAPSED: {time.time() - tick_start:.2f}s")


def text_gen(prompt, model, tokenizer):
    tick()
    input = f"{prompt}"
    print(f"Question: {prompt} \n ==========================================")
    output = generate_answer(model=model, tokenizer=tokenizer, question=input, k=5, max_length=config["max_length"] )
    print(f"Outputs: ========================")
    print(output)
    tock()
    print(f"\n\n\n\n")

Let's test the capabilities of normal model before fine tuning the reasoning model.

From the result, you'd notice it is not very good. It struggles with switching translation,  verbose is quite long and it takes a long time

In [19]:

# Test the function
text_gen("格闘家ボブ・サップの出身国はどこでしょう？", model=model_without_reasoning, tokenizer=tokenizer)
text_gen("人気漫画『ドラえもん』の登場人物で、ジャイアンの苗字は剛田ですが、スネ夫の苗字は何でしょう？",  model=model_without_reasoning, tokenizer=tokenizer)
# text_gen("Translate 'Hello, how are you?' to Japanese.",  model=model_without_reasoning, tokenizer=tokenizer)
# text_gen("「お元気ですか」を英語に訳すと",  model=model_without_reasoning, tokenizer=tokenizer)
# text_gen("Translate to english `「ねえ、それは何のためにあるの？`", model=model_without_reasoning, tokenizer=tokenizer)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
The 'batch_size' argument of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'max_batch_size' argument instead.
The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.


Question: 格闘家ボブ・サップの出身国はどこでしょう？ 
<bos>格闘家ボブ・サップの出身国はどこでしょう？

ボブ・サップは、アメリカ合衆国出身の格闘家です。

ボブ・サップは、アメリカ合衆国出身の格闘家です。

ボブ・サップは、アメリカ合衆国出身の格闘家です。

ボブ・サップは、アメリカ合衆国出身の格闘家です。

ボブ・サップは、アメリカ合衆国出身の格闘家です。

ボブ・サップは、アメリカ合衆国出身の格闘家です。

ボブ・サップは、アメリカ合衆国出身の格闘家です。

ボブ・サップは、アメリカ合衆国出身の格闘家です。

ボブ・サップは、アメリカ合衆国出身の格闘家です。

ボブ・サップは、アメリカ合衆国出身の格闘家です。

ボブ・サップは、アメリカ合衆国出身の格闘家です。

ボブ・サップは、アメリカ合衆国出身の格闘家です。

ボブ・サップは、アメリカ合衆国出身の格闘家です。

ボブ・サップは
格闘家ボブ・サップの出身国はどこでしょう？ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボブ・サップはアメリカ人です。ボ
TOTAL TIME ELAPSED: 34.45s





Question: 人気漫画『ドラえもん』の登場人物で、ジャイアンの苗字は剛田ですが、スネ夫の苗字は何でしょう？ 
<bos>人気漫画『ドラえもん』の登場人物で、ジャイアンの苗字は剛田ですが、スネ夫の苗字は何でしょう？

スネ夫の苗字は、<strong>「剛田」</strong>です。

ジャイアンの苗字は「剛田」ですが、スネ夫の苗字は「剛田」です。



##### 7) Training

The training implementation follows a multi-stage curriculum based on Hao et al. (2024), gradually introducing continuous latent reasoning. Each stage represents a step in transitioning from pure language processing to latent space reasoning:


Using Hugging Face's Trainer with optimizations:
- BFloat16 precision
- 8-bit Adam optimizer 
- Gradient accumulation
- WandB tracking
- Checkpoint management

The model progressively learns to leverage continuous thought states while preserving translation capabilities, with each stage building upon the previous one's learned representations.

To train, you need to get ready your wandb token Id as we report the logs to wandb.
To get your wandb key, visit [Wandb](https://wandb.ai/quickstart?utm_source=app-resource-center&utm_medium=app&utm_term=quickstart)

Let's see this in action...

In [None]:
from transformers import (
    Trainer,
    TrainingArguments
) 
import wandb
import os
import torch
import evaluate
import numpy as np

# Initialize WandB
wandb.init(project=config["wandb_project"], config=config)

# Set up training arguments
training_args = TrainingArguments(
    output_dir=config["output_dir"],
    per_device_train_batch_size=config["per_device_train_batch_size"],
    gradient_accumulation_steps=config["gradient_accumulation_steps"],
    learning_rate=config["learning_rate"],
    warmup_ratio=config["warmup_steps"],
    logging_steps=config["logging_steps"],
    save_steps=config["save_steps"],
    bf16=config["bf16"],
    bf16_full_eval=config["bf16_full_eval"],
    optim=config["optim"],
    report_to="wandb",
    remove_unused_columns=False,
    dataloader_pin_memory=True,
    # gradient_checkpointing=True,
)

# Move model to GPU and wrap with DataParallel if multiple GPUs available
if torch.cuda.is_available():
    # Check if model is not already on CUDA
    if not next(model.parameters()).is_cuda:
        model = model.cuda()
    if torch.cuda.device_count() > 1:
        # Check if model isn't already wrapped with DataParallel
        if not isinstance(model, torch.nn.DataParallel):
            # Use DataParallel with explicit device IDs
            model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))

def stage_trainer(stage=0):

    if isinstance(model, torch.nn.DataParallel):
        model.module.current_stage = stage
    else:
        model.current_stage = stage

    current_output_dir = f"{config['output_dir']}_stage{stage}"
    training_args.output_dir = current_output_dir
    training_args.num_train_epochs = 3
        

    # Load the Reasoning model configuration
    dataset_ = dataset.map(
        (lambda x: preprocess_function(
            x, 
            detector=detector,
            stages=stage, 
            eos_token=tokenizer.eos_token,
            bos_token=tokenizer.bos_token,
            language_config=language_config
        )),
        batched=True,
        batch_size=config["batch_size"]
    )

    # Tokenize the dataset
    dataset_ = dataset_.map(
        (lambda x: tokenizer_function(
            x, 
            tokenizer=tokenizer,
        )),
        batched=True,
        batch_size=config["batch_size"],
        remove_columns=["input", "instruction", "output", "prompt"]
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset_["train"]
    )
    
    trainer.train()

    # Save checkpoints
    for folder in os.listdir(current_output_dir):
        if folder.startswith("checkpoint-"):
            checkpoint_folder = os.path.join(current_output_dir, folder)
            if os.path.isdir(checkpoint_folder):
                tokenizer.save_pretrained(checkpoint_folder)
                # If using DataParallel, save the base model
                model_to_save = model.module if hasattr(model, 'module') else model
                model_to_save.save_pretrained(checkpoint_folder)

# Run training stages
for stage in range(config["stages"] + 1):
    stage_trainer(stage)

[34m[1mwandb[0m: Currently logged in as: [33mwassname[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


    There is an imbalance between your GPUs. You may want to exclude GPU 1 which
    has less than 75% of the memory or cores of GPU 0. You can do so by setting
    the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
    environment variable.


  0%|          | 0/45000 [00:00<?, ?it/s]

After we done training, let's load our fine tuned model for inferencing

In [None]:
from transformers import AutoTokenizer, AutoConfig
import torch
torch.cuda.empty_cache()


def load_model(model_name = "output_stage1/checkpoint-10000"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    model_config = AutoConfig.from_pretrained(model_name)
    model = LatentReasoningGemmaForCausalLM(config=model_config)
    model = model.from_pretrained(model_name)
    model.tokenizer = tokenizer

    model = model.cuda()

    return  model, tokenizer


# Make sure to load the model from your specified path. In our case our path is "output_stage1/checkpoint-10000"
model, tokenizer = load_model(model_name= "output_stage1/checkpoint-10000")

Test after fine tuning

In [None]:
text_gen("格闘家ボブ・サップの出身国はどこでしょう？", model=model, tokenizer=tokenizer)
text_gen("人気漫画『ドラえもん』の登場人物で、ジャイアンの苗字は剛田ですが、スネ夫の苗字は何でしょう？", model=model, tokenizer=tokenizer)
text_gen("「お元気ですか」を英語に訳すと ", model=model, tokenizer=tokenizer)
text_gen("Translate to english `「ねえ、それは何のためにあるの？`", model=model, tokenizer=tokenizer)
text_gen("「abc ～the first～」へようこそ！さて、ABC・・・と始まるアルファベットは、全部で何文字でしょう？`", model=model, tokenizer=tokenizer)

##### 8) Evaulation

Our evaluation framework employs multiple metrics to provide a thorough assessment of model performance, going beyond simple exact matching to capture various aspects of answer quality. But before we do, we need to preprocess the dataset

In [None]:
import nltk

try:
    nltk.data.find('tokenizers/punkt')
    nltk.download('punkt_tab')
except LookupError:
    nltk.download('punkt')


In [None]:
import torch


def preprocess_eval_dataset_function(
    examples, 
):
    """
    Preprocess the input examples by constructing the prompt with reasoning steps.

    Args:
        examples (dict): A dictionary containing the input examples with keys "instruction", "input", and "output".
    Returns:
        dict: A dictionary containing the preprocessed prompts.
    """

    instructions = examples["instruction"]
    inputs = examples["input"]
    outputs = examples["output"]

    new_inputs = []
    
    for i in range(len(instructions)):
        instruction = instructions[i]
        input = inputs[i]

        input = instruction + input
        new_inputs.append(input)


    return {"input": new_inputs, "output": outputs, "instructions": instructions}



# Preprocess eval dataset
eval_dataset_ = eval_dataset.map(
    preprocess_eval_dataset_function,
    batched=True,
    batch_size=config["batch_size"]
)

#### Evaluation Metrics
We implement several complementary metrics to evaluate model performance:

- Fuzzy Matching Accuracy

    - Uses the FuzzyWuzzy algorithm to compute string similarity
    - Accounts for minor variations in text (e.g., spacing, capitalization)
    - Considers answers correct when similarity exceeds a configured threshold


- BLEU Score

    - Evaluates the precision of n-gram matches
    - Provides a complementary perspective to ROUGE metrics
    - Useful for assessing translation quality aspects of the answers


- BERTScore

    - Leverages contextual embeddings to capture semantic similarity
    - More robust to paraphrasing than n-gram based metrics
    - Correlates well with human judgments


<img src="https://res.cloudinary.com/vickie/image/upload/v1735437159/uzbhzyhmkhyegetyyrmg.png" alt="https://ritikjain51.medium.com/llms-fine-tuning-and-evaluation-f019515b1c67" width="400"/>



In [22]:
import torch
from typing import Dict, List, Union
from transformers import PreTrainedTokenizer, PreTrainedModel
from nltk.translate.bleu_score import sentence_bleu
from thefuzz import fuzz
from bert_score import score as bert_score
from nltk.tokenize import word_tokenize
import nltk
import tqdm
from dataclasses import dataclass
from typing import Dict, List

@dataclass
class EvaluationMetrics:
    accuracy: float
    avg_fuzzy_score: float
    avg_bleu_score: float
    avg_bert_score_f1: float
    
    def to_dict(self) -> Dict[str, float]:
        return {
            'accuracy': self.accuracy,
            'avg_fuzzy_score': self.avg_fuzzy_score,
            'avg_bleu_score': self.avg_bleu_score,
            'avg_bert_score_f1': self.avg_bert_score_f1
        }


def extract_answer_from_predicted_answer(text: str) -> str:
    """
    Extract the text after '答え：' or 'Answer:' from the input text.
    
    Args:
        text (str): The input text containing the answer.
    
    Returns:
        str: The extracted answer, or an empty string if no match is found.
    """
    prefixes = ["答え：", "Answer:"]
    
    for prefix in prefixes:
        if prefix in text:
            return text.split(prefix, 1)[1].strip()
    
    return text.strip()  # Return stripped text if no prefix found



# Detect if the text contains Japanese characters
def contains_japanese(text):
    # Hiragana (3040-309F), Katakana (30A0-30FF), Kanji (4E00-9FFF)
    for char in text:
        if ('\u3040' <= char <= '\u309F' or  # Hiragana
            '\u30A0' <= char <= '\u30FF' or  # Katakana
            '\u4E00' <= char <= '\u9FFF'):   # Kanji
            return True
    return False


def tokenize_text(text: str) -> List[str]:
    """
    Tokenize text based on language (Japanese or English).
    For Japanese, splits on spaces and punctuation while preserving important characters.
    For English, uses basic word tokenization.
    """

    if contains_japanese(text):
        # Simple Japanese tokenization: split on spaces and basic punctuation
        # while preserving Japanese punctuation
        import re
        # Split on spaces and common punctuation, but preserve Japanese punctuation
        tokens = re.findall(r'[^\s\.,!?]+|[。、！？]', text)
        return [token for token in tokens if token.strip()]
    else:
        # For English, use simple whitespace and punctuation splitting
        import re
        return re.findall(r'\w+|[^\w\s]', text.lower())


def compute_metrics(pred_answer: str, target_answer: str, threshold: int = 80) -> Dict[str, Union[float, bool]]:
    """
    Compute multiple evaluation metrics for comparing predicted and target answers.
    """
    # Preprocess answers
    pred_clean = extract_answer_from_predicted_answer(pred_answer)
    target_clean = target_answer.strip()
    
    # Convert to lowercase for consistent comparison
    pred_lower = pred_clean.lower()
    target_lower = target_clean.lower()
    
    # Calculate fuzzy match score
    fuzzy_score = fuzz.ratio(pred_lower, target_lower)
    
    # Tokenize for BLEU score
    pred_tokens = word_tokenize(pred_lower)
    target_tokens = word_tokenize(target_lower)
    
    # Calculate BLEU score
    try:
        bleu = sentence_bleu([target_tokens], pred_tokens, weights=(1.0,))
    except ZeroDivisionError:
        bleu = 0.0

    
    
    # Set language based on content
    lang = 'ja' if contains_japanese(target_clean) else 'en'
    
    # Calculate BERTScore with appropriate language model
    P, R, F1 = bert_score([pred_clean], [target_clean], lang=lang, verbose=False)
    bert_f1 = F1.item()
    
    return {
        'fuzzy_match': fuzzy_score >= threshold,
        'fuzzy_score': fuzzy_score,
        'bleu_score': bleu,
        'bert_score_f1': bert_f1
    }

Here is our evaluation helper function. We loop through the batch, call the generate function to 
get the output from the model and then compare it with the dataset output using the fuzzy matcher.dataset
If it is correct, we add it up to the list of correct responses, if not we do not.

This is how we figure out the metrics of the model.

In [23]:

@torch.no_grad()
def evaluate(
    dataloader,
    tokenizer: PreTrainedTokenizer,
    model: PreTrainedModel,
    max_new_tokens: int,
    threshold: int = 80,
) -> EvaluationMetrics:
    """
    Evaluate the model using multiple metrics.
    
    Returns:
        EvaluationMetrics: Object containing all computed metrics
    """
    total_instances = 0
    total_correct = 0
    
    # Initialize metric aggregators
    total_metrics = {
        'fuzzy_score': 0,
        'bleu_score': 0,
        'bert_score_f1': 0
    }

    for batch in tqdm.tqdm(dataloader):
        inputs = batch["input"]
        outputs = batch["output"]
        batch_size = len(inputs)
        total_instances += batch_size

        for i in range(batch_size):
            input_text = inputs[i]
            target_answer = outputs[i]

            # Generate the answer
            pred_answer = generate_answer(
                model=model,
                tokenizer=tokenizer,
                question=input_text,
                max_length=max_new_tokens,
            )

            # Compute all metrics
            metrics = compute_metrics(pred_answer, target_answer, threshold)
            
            # Update counters
            if metrics['fuzzy_match']:
                total_correct += 1
            
            # Aggregate metrics
            for key in total_metrics:
                total_metrics[key] += metrics[key]

            if config["debug"]:
                pred_answer_extracted = extract_answer_from_predicted_answer(pred_answer)
                print(
                    f"Input: {input_text}\n"
                    f"Target: {target_answer}\n"
                    f"Predicted: {pred_answer_extracted}\n"
                    f"Metrics: {metrics}\n"
                )

    # Calculate averages
    accuracy = total_correct / total_instances
    for key in total_metrics:
        total_metrics[key] /= total_instances

    return EvaluationMetrics(
        accuracy=accuracy,
        avg_fuzzy_score=total_metrics['fuzzy_score'],
        avg_bleu_score=total_metrics['bleu_score'],
        avg_bert_score_f1=total_metrics['bert_score_f1']
    )

In [30]:
from torch.utils.data import DataLoader

# Load data for evaluation
dataloader = DataLoader(eval_dataset_["test"], batch_size=config["batch_size"], shuffle=False)

def test_evaluation(model, tokenizer):
    metrics = evaluate(dataloader, tokenizer, model, config["max_length"])
    print(f"Metrics: {metrics}")

In [None]:
# Evaluating every model stage

for i in range(config["stages"] + 1):
    model_name = f"output_stage{i}/checkpoint-10000"
    model, tokenizer = load_model(model_name = model_name)
    test_evaluation(model, tokenizer=tokenizer)
    print(f"Model : {model_name}")


### 9) Discussion

Our evaluation results demonstrate the significant impact of advanced reasoning techniques like **Chain-of-Thought (CoT)**, **Latent Reasoning**, and **Chain-of-Thought Decoding** on the performance of the fine-tuned Gemma 2 model. The metrics reveal a clear progression in model accuracy and robustness across training stages, highlighting the effectiveness of these methods.

#### Key Findings:
1. **Accuracy Improvement**:
   - **Stage 0**: Accuracy starts at **19%**, indicating the baseline performance before advanced reasoning techniques are fully applied.
   - **Stage 1**: Accuracy jumps to **83%**, showcasing the immediate benefits of incorporating CoT and Latent Reasoning.
   - **Stage 2**: Accuracy reaches **92%**, demonstrating the model's ability to refine its reasoning and decision-making processes further.

2. **Fuzzy Score**:
   - The fuzzy score improves from **20.35** in Stage 0 to **92.23** in Stage 2, indicating better semantic similarity and alignment with expected outputs.

3. **BLEU Score**:
   - The BLEU score increases from **0.17** in Stage 0 to **0.91** in Stage 2, reflecting significant improvements in the model's ability to generate linguistically accurate and coherent text.

4. **BERTScore F1**:
   - The BERTScore F1 improves from **0.62** in Stage 0 to **0.97** in Stage 2, confirming that the model's outputs are more contextually and semantically aligned with the ground truth.

#### Metrics Summary:
| **Stage**       | **Accuracy** | **Fuzzy Score** | **BLEU Score** | **BERTScore F1** |
|------------------|--------------|-----------------|----------------|-------------------|
| **Stage 0**      | 0.19         | 20.35           | 0.17           | 0.62              |
| **Stage 1**      | 0.83         | 80.20           | 0.82           | 0.82              |
| **Stage 2**      | 0.92         | 92.23           | 0.91           | 0.97              |

#### Python Package:
To make these advancements accessible, we've compiled the `LatentReasoningGemmaCausalLLM` class, along with the Chain-of-Thought Decoding implementation, into a Python package. You can install it via:
```bash
pip install git+https://github.com/vicksEmmanuel/latent-gemma.git
```

This package provides a user-friendly interface for leveraging the fine-tuned Gemma 2 model with advanced reasoning capabilities.


Next steps, Let's upload the model to kaggle

#### 10) Upload the Model to Kaggle Models

Step 1: 
- Go to the model folder <br/>
- Find the config.json file <br/>
- Replace the value of `_name_or_path` with the original kaggle path `google/gemma-2/transformers/gemma-2-2b` <br/>

Step 2: 
- Turn the path to checkpoint to zip by running the command ` zip -r latent_gemma2_finetune.zip path-to-model/output_stage3/checkpoint-10000`


Step 3: 
- Now, upload the .zip file to Kaggle Models.
- Step 1: Go to Kaggle Models
- Log in to your Kaggle account.
- Navigate to the Kaggle Models page.

Step 4: Create a New Model
- Click on the "New Model" button.
- Fill in the required details:
<img src="https://res.cloudinary.com/vickie/image/upload/v1735816980/exvtxvjs7heemon26he2.png" alt="Gemini Reasoning Finetuning" width="1000"/>

- Click "Upload."

#### 10) Usage

First Add the script to a setup.py file

```
import kagglehub

kagglehub.login()

# Download latest version
path = kagglehub.model_download("victorumesiobi/gemma-2-japanese-english-reasoning/transformers/1")

print("Path to model files:", path)

```

Then Run

`python setup.py`


In [None]:
pip install -q git+https://github.com/vicksEmmanuel/latent-gemma.git

In [None]:
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from latent_gemma import LatentReasoningGemmaForCausalLM

model_path = "/home/featurize/.cache/kagglehub/models/victorumesiobi/gemma-2-japanese-english-reasoning/transformers/1/2" # Replace with the path to which your model was downloaded too

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
model_config = AutoConfig.from_pretrained(model_path)

config = {
    "max_length": 256
}
latent_config = LatentReasoningGemmaForCausalLM.DEFAULT_CONFIG
LatentReasoningGemmaForCausalLM.DEFAULT_CONFIG = {
    **latent_config,
    **config
}
updated_latent_config = LatentReasoningGemmaForCausalLM.DEFAULT_CONFIG
model = LatentReasoningGemmaForCausalLM(config=model_config)
model = model.from_pretrained(model_path)
model.tokenizer = tokenizer

In [None]:
text = "人気漫画『ドラえもん』の登場人物で、ジャイアンの苗字は剛田ですが、スネ夫の苗字は何でしょう？"
output = model.generate_answer(
    model=model, 
    tokenizer=tokenizer, 
    question=text, 
    k=5, 
    max_length=256
)

print(f"output: {output}")

Or you could directly use the Normal way

In [None]:
input_ids = tokenizer(text, return_tensors="pt")

outputs = model.generate(**input_ids, max_new_tokens=32)
print(tokenizer.decode(outputs[0]))