<a href="https://colab.research.google.com/github/srimallya/eop-pRL/blob/main/eop_pRL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# End-Only Penalty Progressive Reinforcement Learning: A Mathematical Analysis

**Abstract**

This paper presents a comprehensive mathematical analysis of End-Only Penalty Progressive Reinforcement Learning (EOP-PRL), a novel reinforcement learning algorithm designed for training language models on structured reasoning tasks. Unlike traditional token-level reinforcement learning approaches, EOP-PRL employs a unique reward mechanism that exclusively penalizes incomplete generations rather than token-level discrepancies. We formalize the mathematical foundations of this approach, analyze its convergence properties, and demonstrate how it creates an effective balance between exploration and exploitation. Our analysis reveals that EOP-PRL introduces a natural curriculum through progressive penalty scaling, directly addressing the exploration-exploitation dilemma inherent in language model fine-tuning. We prove that under reasonable assumptions, EOP-PRL converges to policies that prioritize the generation of complete reasoning paths while allowing for greater diversity in intermediate token selection, a crucial requirement for complex reasoning tasks.

**1. Introduction**

Language models have demonstrated remarkable capabilities in various natural language processing tasks, yet they often struggle with structured reasoning that requires step-by-step logical thinking. Traditional supervised learning approaches using maximum likelihood estimation (MLE) tend to produce models that imitate the surface patterns in the training data without developing robust reasoning abilities [1]. For instance, in tasks requiring multi-step deduction, an MLE-trained model might learn to predict the next word based on the immediate context without truly understanding the underlying logical structure.

Reinforcement Learning (RL) has emerged as a promising approach for fine-tuning language models beyond simple imitation. However, conventional RL methods for language models typically employ token-level reward structures that penalize any deviation from reference sequences [2, 3]. This approach can be counterproductive for reasoning tasks, where multiple valid reasoning paths may lead to the same conclusion. The strict token-level matching effectively restricts the model's ability to explore alternative reasoning approaches. Consider a mathematical proof where different intermediate steps can lead to the same final answer. A token-level penalty would discourage the model from exploring these alternative valid steps.

We introduce End-Only Penalty Progressive Reinforcement Learning (EOP-PRL), a mathematical framework that fundamentally redefines the reward structure for training language models on reasoning tasks. The key innovation of EOP-PRL is its novel reward function that:

* Rewards matching tokens with position-dependent scaling
* Applies zero penalties for non-matching tokens during generation
* Implements an end-only penalty exclusively for incomplete sequences
* Incorporates a progressive penalty scaling that increases with training progress

The main contributions of this paper are:

* The introduction of the End-Only Penalty Progressive Reinforcement Learning (EOP-PRL) algorithm.
* A rigorous mathematical formulation of the EOP-PRL reward function and policy gradient objective.
* A theoretical analysis demonstrating how EOP-PRL balances exploration and exploitation.
* A proof of convergence for EOP-PRL under standard policy gradient assumptions.
* The identification of a natural curriculum learning effect arising from the progressive penalty scaling.
* A theoretical comparison highlighting the advantages of EOP-PRL over traditional token-level penalty approaches in encouraging exploration.

This paper provides a rigorous mathematical analysis of EOP-PRL, demonstrating how it resolves the exploration-exploitation trade-off in language model fine-tuning while ensuring convergence to policies that produce complete, structured reasoning.

**2. Background and Related Work**

**2.1 Policy Gradient Methods**

Policy gradient methods are a family of reinforcement learning algorithms that directly optimize policy parameters by following the gradient of expected return [4]. For a policy $\pi_\theta$ parameterized by $\theta$, the standard policy gradient objective is:
$$J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[R(\tau)] \quad (1)$$
where $\tau$ represents a trajectory and $R(\tau)$ is the return. The REINFORCE algorithm [5] provides an unbiased estimate of the gradient:
$$\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[R(\tau) \nabla_\theta \log \pi_\theta(\tau)] \quad (2)$$
In the context of language models, a trajectory corresponds to a sequence of tokens, and the policy represents the probability distribution over tokens at each generation step.

**2.2 RL for Language Model Fine-tuning**

Recent work has applied RL to language model fine-tuning, with prominent examples including RLHF (Reinforcement Learning from Human Feedback) [6], PPO (Proximal Policy Optimization) for language models [7], and various approaches using KL-regularized policy optimization [8]. These methods often rely on reward functions that evaluate the quality of the entire generated sequence based on human preferences or external metrics. Token-level rewards based on exact matching with reference sequences are also common. For instance, in sequence generation tasks, a reward might be given for each token that matches the ground truth, and a penalty for each mismatch. While effective for tasks like summarization or translation where the output is often unique, these approaches can restrict the model's ability to discover alternative, yet valid, reasoning steps in more complex tasks.

**2.3 Curriculum Learning in RL**

Curriculum learning [9] involves training models on increasingly difficult examples or with progressively higher standards. In reinforcement learning, curriculum learning has been implemented through environment complexity [10], reward shaping [11], and adversarial approaches [12]. EOP-PRL introduces a novel form of curriculum learning through its progressive penalty scaling. Initially, the penalty for incomplete sequences is relatively mild, allowing the model to focus on generating any reasonable output. As training progresses, the penalty increases, gradually pushing the model to generate complete and structured reasoning paths. This mirrors the idea of learning a complex skill by starting with simpler aspects and gradually increasing the difficulty.

**3. Mathematical Formulation of EOP-PRL**

**3.1 Notation and Preliminaries**

We define the following notation:

* $\pi_\theta$: Policy network (language model with parameters $\theta$)
* $x$: Input prompt
* $y^*$: Reference output sequence
* $\hat{y}$: Generated output sequence
* $\hat{y}_t$: Token at position $t$ in the generated sequence
* $r_t$: Reward for token at position $t$
* $|\hat{y}|$: Length of generated sequence
* $|y^*|$: Length of reference sequence
* $e$: Current episode number
* $E$: Total number of episodes
* $\alpha$: Base reward value (default: 1.0)
* $\beta$: Base penalty value (default: -0.5)
* $\gamma$: Maximum penalty scaling factor (default: 2.5)

**3.2 Reward Function**

The core innovation of EOP-PRL lies in its novel reward function. For a token $\hat{y}_t$ at position $t$ in the generated sequence $\hat{y}$, the reward $r_t$ is defined as:

* **Position-scaled reward for matching tokens:**
    If $\hat{y}_t = y^*_t$, then:
    $$r_t = \alpha \cdot (0.1 + 0.9 \cdot \frac{t}{|\hat{y}|}) \quad (3)$$
    This formulation provides several key properties:
    * Rewards increase with position, prioritizing correct tokens later in the sequence. This encourages the model to maintain correctness as the reasoning path progresses. The scaling factor $(0.1 + 0.9 \cdot \frac{t}{|\hat{y}|})$ ensures that later correct tokens contribute more significantly to the overall reward.
    * The minimum reward (at $t = 0$) is $0.1\alpha$, ensuring even early matches receive some reward. This prevents the model from completely disregarding initial correct steps.
    * The maximum reward (at $t = |\hat{y}|$) is $\alpha$, normalizing the scale of positive rewards.

* **Zero penalty for non-matching tokens during generation:**
    If $\hat{y}_t \neq y^*_t$, then:
    $$r_t = 0 \quad (4)$$
    Unlike conventional approaches that would apply a negative reward, this formulation allows exploration of alternative reasoning paths without penalty. The model is free to try different intermediate tokens without immediate negative feedback, fostering exploration of the solution space.

* **End-only penalty for incomplete generations:**
    If $|\hat{y}| < |y^*|$ and $t = |\hat{y}| - 1$ (the last generated token), then:
    $$r_t += \beta \cdot S(e, E) \cdot \frac{|y^*| - |\hat{y}|}{|y^*|} \quad (5)$$
    The end-only penalty is applied at the end of the generation process if the generated sequence length $|\hat{y}|$ is less than the reference sequence length $|y^*|$. This component penalizes only incomplete sequences, with penalties proportional to:
    * The degree of incompleteness: $\frac{|y^*| - |\hat{y}|}{|y^*|}$. The more incomplete the sequence, the larger the penalty.
    * The training progress: $S(e, E)$ increases from 1.0 to $\gamma$ over training, where $S(e, E)$ is the episode-dependent penalty scaling function:
        $$S(e, E) = 1.0 + \frac{e}{E-1} \cdot (\gamma - 1.0) \quad (6)$$

**3.3 Policy Gradient Objective**

For training the policy $\pi_\theta$, we use the REINFORCE algorithm with the objective:
$$J(\theta) = -\mathbb{E}_{\hat{y} \sim \pi_\theta(\cdot|x)}[\sum_{t=0}^{|\hat{y}|-1} r_t \cdot \log \pi_\theta(\hat{y}_t | x, \hat{y}_{<t})] \quad (7)$$
The gradient is approximated using a single sampled sequence per step:
$$\nabla_\theta J(\theta) \approx -\sum_{t=0}^{|\hat{y}|-1} r_t \cdot \nabla_\theta \log \pi_\theta(\hat{y}_t | x, \hat{y}_{<t}) \quad (8)$$
For implementation purposes, we can express this using a loss function $L$:
$$L = -\sum_{t=0}^{|\hat{y}|-1} r_t \cdot \log \pi_\theta(\hat{y}_t | x, \hat{y}_{<t}) \quad (9)$$
When using gradient accumulation over $G$ steps, each step's loss is scaled by $\frac{1}{G}$:
$$L' = \frac{L}{G} \quad (10)$$

**3.4 Mathematical Properties of the Reward Function**

The reward function of EOP-PRL has several important mathematical properties:

* **Bounded rewards:** All rewards are bounded between $\beta \cdot \gamma$ and $\alpha$, providing numerical stability during training. The minimum reward occurs for the maximum penalty on an incomplete sequence at the end of training, and the maximum reward occurs for a complete and perfectly matching sequence.
* **Monotonicity with sequence length:** For matching tokens, rewards increase monotonically with position, incentivizing the model to maintain correctness throughout longer sequences. This encourages the model to complete the reasoning process correctly.
* **Incomplete sequence penalty:** The penalty term creates a sharp discontinuity in the reward between complete and incomplete sequences. This strong negative signal discourages the model from stopping prematurely. The magnitude of the penalty increases with the degree of incompleteness.
* **Temporal curriculum:** The penalty scaling function $S(e, E)$ implements a temporal curriculum, with:
    * $S(0, E) = 1.0$ (initial scaling, providing mild penalties initially)
    * $S(E-1, E) = \gamma$ (final scaling, providing stronger penalties at the end of training)
    * $\frac{dS(e, E)}{de} = \frac{\gamma - 1.0}{E-1} > 0$, ensuring a linear and monotonic increase in penalty scaling throughout training.

**4. Theoretical Analysis**

**4.1 Exploration-Exploitation Balance**

Traditional reinforcement learning for language models faces a fundamental tension between exploration (trying diverse token sequences) and exploitation (following known good paths). We analyze how EOP-PRL addresses this tension.

**Theorem 1:** Under the EOP-PRL reward structure, the gradient update for non-matching tokens approaches zero as training progresses, if and only if the generated sequence length approaches the reference length.

**Proof:** For a generated token $\hat{y}_t \neq y^*_t$, the reward $r_t = 0$ if $|\hat{y}| \geq |y^*|$. The gradient contribution to the loss function (Equation 9) for this token is:
$$-\nabla_\theta (0 \cdot \log \pi_\theta(\hat{y}_t | x, \hat{y}_{<t})) = 0$$
Therefore, no explicit gradient signal discourages exploration of alternative tokens, provided the sequence reaches the reference length. The model is free to explore different token choices without being penalized for deviating from the reference during intermediate steps.

However, if $|\hat{y}| < |y^*|$, and $t = |\hat{y}| - 1$ (the last generated token), the reward includes the penalty term:
$$r_t = \beta \cdot S(e, E) \cdot \frac{|y^*| - |\hat{y}|}{|y^*|}$$
This creates a strong negative reward and thus a significant gradient signal to discourage stopping generation prematurely. The model is penalized for not completing the reasoning path.

As $S(e, E)$ increases with the episode number $e$, the pressure to generate complete sequences monotonically increases with training progress. This progressive increase in penalty encourages the model to first learn to complete the sequences and then refine the intermediate steps to match the reference where possible to gain the positive position-scaled rewards.

**4.2 Convergence Analysis**

**Theorem 2:** Under standard assumptions for policy gradient methods (including finite state and action spaces, bounded rewards, and appropriate learning rate schedules), EOP-PRL converges to a policy that maximizes the expected reward, which favors complete sequences while allowing diversity in token choice.

**Proof Outline:**

The REINFORCE algorithm provides an unbiased estimate of the gradient of the expected return [5]. With appropriate learning rate schedules that satisfy the Robbins-Monro conditions [13] (i.e., $\sum_{k=1}^{\infty} \eta_k = \infty$ and $\sum_{k=1}^{\infty} \eta_k^2 < \infty$, where $\eta_k$ is the learning rate at iteration $k$), stochastic gradient ascent is guaranteed to converge to a local optimum of the expected return.

The reward structure of EOP-PRL is bounded (as established in Section 3.4). The positive rewards are scaled between $0.1\alpha$ and $\alpha$, and the negative reward (penalty) is scaled between $\beta$ and $\beta \cdot \gamma$. This boundedness ensures that the gradient estimates are also bounded, which is a requirement for convergence in many policy gradient theorems.

The EOP-PRL reward function is designed to prioritize the generation of complete sequences. The end-only penalty for incomplete sequences provides a strong negative signal, pushing the policy towards generating sequences of length $|y^*|$. Once the model learns to generate complete sequences, the position-scaled reward for matching tokens incentivizes it to align the generated tokens with the reference sequence. However, importantly, there is no penalty for non-matching tokens in complete sequences, allowing for flexibility and exploration of alternative reasoning paths that might also lead to a valid conclusion.

Therefore, the expected reward is maximized by policies that generate complete sequences that match the reference as closely as possible. While the proof relies on standard assumptions of policy gradient methods, the specific structure of the EOP-PRL reward function guides the convergence towards policies that achieve the desired behavior of complete and potentially diverse reasoning paths. A more rigorous proof would involve formally showing that the expected reward landscape under EOP-PRL has desirable properties (e.g., unimodality around optimal policies within the space of complete sequences), but this is beyond the scope of this initial analysis.

**4.3 Curriculum Learning Effects**

**Theorem 3:** The progressive penalty scaling in EOP-PRL implements an effective curriculum learning mechanism that gradually increases performance standards.

**Proof:** The penalty scaling function $S(e, E)$ has the following properties:

* $S(0, E) = 1.0$, providing mild penalties initially in training episodes.
* $S(E-1, E) = \gamma$, providing stronger penalties at the end of training.
* $\frac{dS(e, E)}{de} = \frac{\gamma - 1.0}{E-1} > 0$, ensuring a linear and monotonic increase in the penalty scaling factor as training progresses.

These properties ensure that the model initially learns under more lenient conditions, where the penalty for incomplete sequences is relatively small. This encourages the model to explore and learn the basic structure of the reasoning task without being overly penalized for premature termination. As training progresses, the increasing penalty forces the model to focus on generating complete sequences. This gradual increase in the strictness of the completeness requirement aligns perfectly with the definition of curriculum learning [9], where the model learns simpler aspects of the task first before tackling the more challenging requirement of generating complete and correct reasoning paths.

**4.4 Comparison with Token-Level Penalties**

To highlight the differences between EOP-PRL and traditional token-level penalty approaches, we define a standard token-level penalty function:
$$r^{std}_t = \begin{cases}
\alpha, & \text{if } \hat{y}_t = y^*_t \\
\beta, & \text{otherwise}
\end{cases} \quad (11)$$

**Theorem 4:** The magnitude of the expected gradient update for generating an incorrect token is significantly lower in EOP-PRL compared to standard token-level penalty approaches during the intermediate generation steps, leading to greater exploration.

**Proof:** For a token where $\hat{y}_t \neq y^*_t$:

* Under the standard approach: $r^{std}_t = \beta$, leading to a gradient contribution proportional to $\beta \cdot \nabla_\theta \log \pi_\theta(\hat{y}_t | x, \hat{y}_{<t})$. This directly discourages the generation of any token that does not match the reference at each step.

* Under EOP-PRL: $r_t = 0$ for all non-matching tokens in a generated sequence, as long as the sequence eventually reaches the reference length. The penalty is only applied at the very end if the sequence is incomplete. Therefore, for the vast majority of exploratory tokens generated in complete (or potentially complete) sequences, EOP-PRL applies no penalty, resulting in a gradient contribution of zero for these steps. This allows the model to explore different token choices without immediate negative feedback.

This lower gradient magnitude for exploratory tokens in EOP-PRL allows for greater diversity in the generated intermediate reasoning steps compared to the standard token-level penalty approach, which strongly discourages any deviation from the reference sequence at every step.

**5. Implementation Considerations**

**5.1 Gradient Accumulation**

For practical implementation, EOP-PRL can effectively utilize gradient accumulation over $G$ steps before updating the policy network parameters. This technique is mathematically equivalent to using a batch size of $G$ but requires less memory, making it suitable for training large language models. For a sequence of losses $L_1, L_2, ..., L_G$ calculated from $G$ independent trajectories, the accumulated gradient is:
$$\nabla_\theta L_{accumulated} = \frac{1}{G} \sum_{i=1}^{G} \nabla_\theta L_i \quad (12)$$
This averaged gradient is then used to update the model's parameters.

**5.2 Memory-Efficient Attention**

When implementing EOP-PRL for large language models, memory efficiency is crucial due to the computational cost of attention mechanisms, especially for long sequences. We can employ a chunked attention computation approach that maintains mathematical equivalence to standard attention while reducing peak memory requirements. For a query matrix $Q \in \mathbb{R}^{B \times T \times H \times D}$ and key matrix $K \in \mathbb{R}^{B \times T \times H \times D}$, where $B$ is batch size, $T$ is sequence length, $H$ is number of heads, and $D$ is head dimension, the attention scores are typically computed as:
$$A = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) \quad (13)$$
By chunking the sequence length $T$ into blocks of size $C$, the attention computation can be performed block by block. For example, for a given query chunk $Q_{i:i+C}$, the attention scores are computed with all key chunks $K_{j:j+C}$:
$$A_{i,j} = \text{softmax}\left(\frac{Q_{i:i+C} K_{j:j+C}^T}{\sqrt{d}}\right) \quad (14)$$
This reduces the peak memory requirement from $\mathcal{O}(T^2)$ for the full attention matrix to roughly $\mathcal{O}(C \cdot T)$, where $C$ is the chunk size, allowing for training with longer sequences on limited computational resources. The choice of chunk size $C$ involves a trade-off between memory reduction and potential computational overhead due to the iterative nature of the chunked computation.

**5.3 Hyperparameter Tuning and Stopping Criterion**

The performance of EOP-PRL, like other reinforcement learning algorithms, can be sensitive to the choice of hyperparameters. Key hyperparameters such as the base reward value ($\alpha$), the base penalty value ($\beta$), the maximum penalty scaling factor ($\gamma$), and the learning rate will likely require careful tuning based on the specific reasoning task and the language model being used. Additionally, a suitable stopping criterion for the training process needs to be defined, which could be based on the performance on a validation set or reaching a certain number of training episodes.

**6. Discussion**

**6.1 Advantages over Traditional RL Approaches**

The mathematical formulation of EOP-PRL offers several advantages over traditional RL approaches for language model fine-tuning:

* **Exploration-friendly reward structure:** By removing penalties for token-level mismatches during generation, EOP-PRL allows models to explore diverse reasoning paths without immediate punishment, as long as the sequences eventually reach the desired length. This is particularly beneficial for tasks where multiple valid reasoning paths exist.
* **Completion-focused learning:** The end-only penalty creates a mathematically precise incentive for models to complete their reasoning rather than stopping prematurely. The magnitude of the penalty scales with the degree of incompleteness, encouraging the model to generate longer and more complete sequences.
* **Automatic curriculum:** The progressive penalty scaling creates a mathematical curriculum where the standard for sequence completeness gradually increases as training progresses, without requiring manual curriculum design or pre-defined difficulty levels. The model starts by learning to generate any reasonable sequence and is progressively pushed to generate complete ones.
* **Memory efficiency:** The implementation considerations, such as gradient accumulation and chunked attention, allow for training with limited computational resources while maintaining mathematical rigor and enabling the application of EOP-PRL to large language models and long reasoning sequences.

**6.2 Limitations and Future Work**

While mathematically elegant, EOP-PRL has several limitations that warrant further investigation:

* **Reference length dependency:** The end-only penalty depends on the reference sequence length $|y^*|$, which might not always be the optimal target length for all prompts. For some prompts, a shorter or longer correct reasoning path might exist. Future work could explore methods for dynamically determining or adapting the target sequence length.
* **Linear scaling assumption:** The penalty scaling function $S(e, E)$ assumes a linear increase in standards is optimal. Non-linear scaling functions or adaptive scaling strategies based on the model's performance could potentially lead to faster or more stable training.
* **Gradient sparsity:** The zero rewards for non-matching tokens during generation can lead to sparse gradients, especially in the early stages of training before the model starts generating sequences close to the reference length. This sparsity could potentially slow down the learning process in some cases. Future research could explore techniques to mitigate gradient sparsity, such as incorporating small shaping rewards for certain intermediate steps without overly restricting exploration.
* **Task Specificity:** The current formulation is primarily designed for tasks where the goal is to generate a complete reasoning path with a defined length. Its applicability to tasks with more open-ended or variable-length outputs needs further investigation.

Future work should focus on:

* Empirical evaluation of EOP-PRL on a variety of structured reasoning tasks and comparison with existing state-of-the-art RL fine-tuning methods.
* Investigating the impact of different forms of the penalty scaling function (e.g., non-linear, adaptive).
* Exploring techniques to address the potential issue of gradient sparsity.
* Extending the EOP-PRL framework to handle reasoning tasks with variable or unknown optimal output lengths.
* Analyzing the diversity of reasoning paths learned by models trained with EOP-PRL.

**7. Conclusion**

This paper has presented a comprehensive mathematical analysis of End-Only Penalty Progressive Reinforcement Learning (EOP-PRL), demonstrating its theoretical foundations and advantages for training language models on reasoning tasks. By reformulating the reward structure to penalize only incomplete sequences rather than token-level mismatches, EOP-PRL creates a mathematically sound framework that encourages exploration of diverse reasoning paths while ensuring complete responses.

The mathematical analysis reveals that EOP-PRL effectively addresses the exploration-exploitation trade-off inherent in language model fine-tuning, implements a natural curriculum through progressive penalty scaling, and converges (under standard assumptions) to policies that prioritize the generation of complete reasoning paths while allowing for greater diversity in token selection. These properties make EOP-PRL particularly well-suited for complex reasoning tasks where multiple valid reasoning paths may lead to correct conclusions.

Future work will focus on extending the mathematical framework to address the identified limitations and exploring applications to a broader range of language model training scenarios.

**References**

[1] Brown, T. B., et al. (2020). Language models are few-shot learners. *Advances in Neural Information Processing Systems*, *33*, 1877-1901.
[2] Stiennon, N., et al. (2020). Learning to summarize from human feedback. *Advances in Neural Information Processing Systems*, *33*, 3008-3021.
[3] Ouyang, L., et al. (2022). Training language models to follow instructions with human feedback. *Advances in Neural Information Processing Systems*, *35*.
[4] Sutton, R. S., et al. (2000). Policy gradient methods for reinforcement learning with function approximation. *Advances in Neural Information Processing Systems*, *12*.
[5] Williams, R. J. (1992). Simple statistical gradient-following algorithms for connectionist reinforcement learning. *Machine Learning*, *8*(3-4), 229-256.
[6] Christiano, P. F., et al. (2017). Deep reinforcement learning from human preferences. *Advances in Neural Information Processing Systems*, *30*.
[7] Schulman, J., et al. (2017). Proximal policy optimization algorithms. *arXiv preprint arXiv:1707.06347*.
[8] Jaques, N., et al. (2017). Sequence tutor: Conservative fine-tuning of sequence generation models with KL-control. *International Conference on Machine Learning*.
[9] Bengio, Y., et al. (2009). Curriculum learning. *International Conference on Machine Learning*.
[10] Graves, A., et al. (2017). Automated curriculum learning for neural networks. *International Conference on Machine Learning*.
[11] Ng, A. Y., et al. (1999). Policy invariance under reward transformations: Theory and application to reward shaping. *International Conference on Machine Learning*.
[12] Sukhbaatar, S., et al. (2018). Intrinsic motivation and automatic curricula via asymmetric self-play. *International Conference on Learning Representations*.
[13] Robbins, H., & Monro, S. (1951). A stochastic approximation method. *The Annals of Mathematical Statistics*, *22*(3), 400-407.

In [None]:
## eop pRL

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import pandas as pd
from sklearn.model_selection import train_test_split
import random
from torch.amp import autocast, GradScaler  # For mixed precision
import gc  # For garbage collection
import matplotlib.pyplot as plt

# ==========================================
# 1) Hyperparameters
# ==========================================
hyperparams = {
    # Model Architecture
    'block_size': 1024,               # Sequence length for context
    'batch_size': 1,                  # Batch size (reduced from 2 to save memory)
    'embed_dim': 1024,                # Transformer embedding dimension
    'n_heads': 16,                    # Number of attention heads
    'n_layers': 24,                   # Number of Transformer blocks
    'memory_n_layers': 8,             # Number of layers in the original MemoryModule
    'vocab_size': 256,                # Fixed vocabulary size for byte tokenization

    # Memory Efficiency Settings
    'use_gradient_checkpointing': True,  # Use gradient checkpointing
    'gradient_accumulation_steps': 4,    # Accumulate gradients for X steps
    'chunk_size': 64,                    # Size of chunks for attention calculation
    'use_dynamic_quantization': True,    # Use dynamic quantization
    'limit_attention_memory': True,      # Use memory-efficient attention implementation

    # RL Training Parameters
    'n_prompt_ans_pairs': 5,          # Number of prompt-answer pairs to use for RL training
    'number_of_practice': 100,        # Number of practice episodes for RL training
    'rl_log_interval': 5,             # Log metrics every X episodes during RL training
    'rl_save_interval': 20,           # Save checkpoint every X episodes during RL training
    'base_reward': 1.0,               # Base reward value for correct predictions
    'base_penalty': -0.5,             # Base penalty value for incorrect predictions
    'rl_learning_rate': 1e-6,         # Learning rate for RL fine-tuning
    'max_penalty_scale': 2.5,         # Maximum penalty scaling factor for episode progression

    # Mixed Precision Parameters
    'use_mixed_precision': True,      # Whether to use mixed precision training
    'grad_scale_init': 65536.0,       # Initial scale for gradient scaler
    'scale_growth_interval': 2000,    # Steps between gradient scaler growth

    # Generation Parameters
    'generate_num_tokens': 2048,      # Number of tokens to generate after each epoch
    'top_p': 0.8,                     # Top-p (nucleus) sampling parameter
    'start_prompt': "Explain why the statement 'I wore my lucky socks today, and I got an A on my test, so my socks must be lucky' is a logical fallacy.",

    # Special Tokens & Tags
    'thinking_tag': "<think>",        # Opening tag for thinking process
    'thinking_end_tag': "</think>",   # Closing tag for thinking process
    'answer_tag': "<answer>",         # Opening tag for final answer
    'answer_end_tag': "</answer>",    # Closing tag for final answer
    'bos_token': 254,                 # Beginning-of-sequence token (byte value)
    'eos_token': 255,                 # End-of-sequence token (byte value)

    # File Paths
    'pretrained_model_path': "threshold_transformer_checkpoint.pt",  # Path to load pretrained model
    'rl_checkpoint_path': "rl_transformer_checkpoint.pt",        # RL checkpoint path

    # System Prompt
    'system_prompt': """just think before answer."""
}

# ==========================================
# 1.1) Select device and optimize settings
# ==========================================
device = "mps" if torch.backends.mps.is_available() else \
         ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Enable tensor cores for better performance with mixed precision
if device == "cuda" and torch.cuda.is_available():
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    print("TF32 enabled for better performance")

    # Set up GPU for maximum memory efficiency
    torch.cuda.empty_cache()
    print(f"CUDA memory allocated before starting: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
    print(f"CUDA memory reserved before starting: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

# ==========================================
# 1.2) Memory Management Functions
# ==========================================
def clear_memory():
    """Force clear CUDA memory and run garbage collection."""
    if device == "cuda":
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    gc.collect()

def print_memory_stats():
    """Print current memory usage statistics."""
    if device == "cuda":
        print(f"CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
        print(f"CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
        print(f"CUDA max memory allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")

# ==========================================
# 1.3) Data Loading and Preprocessing for COT Logic Reasoning
# ==========================================
def load_cot_logic_data():
    print("Loading COT Logic Reasoning dataset...")

    try:
        # Try standard pandas read_parquet first
        df = pd.read_parquet("isaiahbjork/cot-logic-reasoning/cot-logic-reasoning.parquet")
        print("Dataset loaded using standard path")
    except Exception as e:
        print(f"Error loading dataset with standard path: {e}")
        try:
            # Try with datasets library if available
            try:
                from datasets import load_dataset
                dataset = load_dataset("isaiahbjork/cot-logic-reasoning")
                df = dataset["train"].to_pandas()
                print("Dataset loaded using datasets library")
            except:
                # If all else fails, use the original path format
                df = pd.read_parquet("hf://datasets/isaiahbjork/cot-logic-reasoning/cot-logic-reasoning.parquet")
                print("Dataset loaded using hf:// protocol")
        except Exception as e2:
            print(f"Failed to load dataset: {e2}")
            raise RuntimeError("Unable to load the COT Logic Reasoning dataset")

    print(f"Data size: {len(df)}")

    # Split into train/validation/test sets (80/10/10)
    train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)
    val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

    print(f"Training examples: {len(train_df)}")
    print(f"Validation examples: {len(val_df)}")
    print(f"Test examples: {len(test_df)}")

    return train_df, val_df, test_df

# ==========================================
# 2) Improved Emergent Threshold Layer with Numerical Stability
# ==========================================
class ImprovedEmergentThresholdLayer(nn.Module):
    def __init__(self, feature_dim):
        super().__init__()
        self.feature_dim = feature_dim
        self.norm = nn.LayerNorm(feature_dim)
        self.register_buffer('running_mean', torch.zeros(feature_dim))
        self.register_buffer('running_var', torch.ones(feature_dim))
        self.adaptive_threshold = nn.Parameter(torch.ones(1) * 0.5)
        self.momentum = 0.01

    def forward(self, x):
        x_norm = self.norm(x)
        if self.training:
            with torch.no_grad():
                batch_mean = x_norm.mean(dim=(0, 1))
                batch_var = x_norm.var(dim=(0, 1), unbiased=False)
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var

        # More robust threshold calculation with clamping to prevent extremely small values
        threshold = torch.sigmoid(self.adaptive_threshold) * torch.sqrt(torch.clamp(self.running_var, min=1e-6))

        # Increase denominator from 0.1 to 1.0 for stability
        gate = torch.sigmoid((torch.abs(x_norm) - threshold.view(1, 1, -1)) / 1.0)

        alpha = torch.sigmoid(self.adaptive_threshold)

        # Clip outputs to prevent extreme values
        return torch.clamp(alpha * (gate * x) + (1 - alpha) * x, min=-100, max=100)

# ==========================================
# 3) Memory-Efficient Attention Mechanism
# ==========================================
class MemoryEfficientAttention(nn.Module):
    def __init__(self, embed_dim, n_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads

        # Standard attention projections
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        # Attention score normalization
        self.attn_scale = nn.Parameter(torch.ones(1) * (1.0 / math.sqrt(self.head_dim)))

        # Threshold parameters for attention scores
        self.register_buffer('score_running_mean', torch.zeros(n_heads))
        self.register_buffer('score_running_var', torch.ones(n_heads))
        self.score_threshold = nn.Parameter(torch.ones(1) * 0.5)
        self.score_momentum = 0.01

    def forward(self, x, attn_mask=None):
        B, T, C = x.size()

        # Project to queries, keys, values
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # B, H, T, D
        k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # B, H, T, D
        v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # B, H, T, D

        # Super memory-efficient attention implementation
        # Process in small chunks for both query and key sequences
        chunk_size = hyperparams['chunk_size']
        attn_output = torch.zeros_like(q)

        for i in range(0, T, chunk_size):
            i_end = min(i + chunk_size, T)

            # Get current query chunk
            q_chunk = q[:, :, i:i_end]

            # Compute scores for this chunk against all keys, in smaller sub-chunks
            scores_for_chunk = []
            for j in range(0, T, chunk_size):
                j_end = min(j + chunk_size, T)

                # Get key chunk and compute scores
                k_chunk = k[:, :, j:j_end]
                chunk_scores = torch.matmul(q_chunk, k_chunk.transpose(-2, -1)) * self.attn_scale

                # Apply causal mask if needed - only allow attention to previous positions
                if attn_mask is not None and i >= j:
                    # Generate mask just for this chunk
                    mask_size = (i_end-i, j_end-j)
                    chunk_mask = torch.triu(torch.ones(mask_size, device=x.device), diagonal=j-i+1).bool()
                    chunk_mask = chunk_mask.unsqueeze(0).unsqueeze(0).expand(B, self.n_heads, -1, -1)
                    chunk_scores.masked_fill_(chunk_mask, float('-inf'))

                scores_for_chunk.append(chunk_scores)

            # Concatenate all key chunks for this query chunk
            all_scores_for_chunk = torch.cat(scores_for_chunk, dim=-1)

            # Apply softmax across the full key dimension
            attn_weights = F.softmax(all_scores_for_chunk, dim=-1)

            # Multiply with values in chunks
            chunk_output = torch.zeros_like(q_chunk)
            start_idx = 0
            for j in range(0, T, chunk_size):
                j_end = min(j + chunk_size, T)

                # Get weights for this chunk and the corresponding values
                weights_chunk = attn_weights[:, :, :, start_idx:start_idx + (j_end - j)]
                v_chunk = v[:, :, j:j_end]

                # Accumulate the output for this chunk
                chunk_output += torch.matmul(weights_chunk, v_chunk)
                start_idx += (j_end - j)

            # Place the output for this query chunk in the right position
            attn_output[:, :, i:i_end] = chunk_output

        # Reshape output back to original dimensions
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)

        return self.out_proj(attn_output)

    # Method to handle compatibility with original MultiheadAttention
    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
        # Map old MHA parameters to new ThresholdedAttention parameters
        if f"{prefix}in_proj_weight" in state_dict:
            # MultiheadAttention uses a single in_proj_weight that combines q,k,v
            in_proj_weight = state_dict.pop(f"{prefix}in_proj_weight")
            in_proj_bias = state_dict.pop(f"{prefix}in_proj_bias", None)

            # Split the in_proj_weight into q, k, v parts
            q_weight, k_weight, v_weight = in_proj_weight.chunk(3, dim=0)
            state_dict[f"{prefix}q_proj.weight"] = q_weight
            state_dict[f"{prefix}k_proj.weight"] = k_weight
            state_dict[f"{prefix}v_proj.weight"] = v_weight

            if in_proj_bias is not None:
                q_bias, k_bias, v_bias = in_proj_bias.chunk(3, dim=0)
                state_dict[f"{prefix}q_proj.bias"] = q_bias
                state_dict[f"{prefix}k_proj.bias"] = k_bias
                state_dict[f"{prefix}v_proj.bias"] = v_bias

        # Map out_proj parameters
        if f"{prefix}out_proj.weight" in state_dict:
            state_dict[f"{prefix}out_proj.weight"] = state_dict[f"{prefix}out_proj.weight"]
            if f"{prefix}out_proj.bias" in state_dict:
                state_dict[f"{prefix}out_proj.bias"] = state_dict[f"{prefix}out_proj.bias"]

        # Call parent class method to handle the rest
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

# ==========================================
# 4) Improved Transformer Block with Memory Efficiency
# ==========================================
class ImprovedTransformerBlock(nn.Module):
    def __init__(self, embed_dim, n_heads):
        super().__init__()
        self.attention = MemoryEfficientAttention(embed_dim, n_heads)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            ImprovedEmergentThresholdLayer(4 * embed_dim),
            nn.GELU(),
            nn.Linear(4 * embed_dim, embed_dim)
        )
        self.threshold1 = ImprovedEmergentThresholdLayer(embed_dim)
        self.threshold2 = ImprovedEmergentThresholdLayer(embed_dim)

    def forward(self, x):
        # Use sequential processing to reduce memory usage
        attn_out = self.attention(x)
        x = x + self.threshold1(attn_out)

        # Explicitly delete to free memory
        del attn_out

        ff_out = self.feed_forward(x)
        x = x + self.threshold2(ff_out)

        # Explicitly delete to free memory
        del ff_out

        return x

# ==========================================
# 5) Improved Byte Transformer with Gradient Checkpointing
# ==========================================
class ImprovedByteTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, n_heads=4, n_layers=4, block_size=128):
        super().__init__()
        self.block_size = block_size
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embedding = nn.Embedding(self.block_size, embed_dim)
        self.blocks = nn.ModuleList([
            ImprovedTransformerBlock(embed_dim, n_heads)
            for _ in range(n_layers)
        ])
        self.final_threshold = ImprovedEmergentThresholdLayer(embed_dim)
        self.ln_f = nn.Linear(embed_dim, vocab_size)
        # Learned gating parameter for combining memory outputs
        self.gate_param = nn.Parameter(torch.tensor(0.0))
        self.use_checkpointing = hyperparams['use_gradient_checkpointing']

    def forward_with_embeddings(self, x_emb):
        for i, block in enumerate(self.blocks):
            if self.use_checkpointing and self.training:
                # Ensure tensor requires gradients for checkpointing
                if not x_emb.requires_grad:
                    x_emb.requires_grad = True
                x_emb = torch.utils.checkpoint.checkpoint(block, x_emb, use_reentrant=False)
            else:
                x_emb = block(x_emb)
        x_emb = self.final_threshold(x_emb)
        logits = self.ln_f(x_emb)
        return logits

    def forward_with_two_memory(self, x_emb, memory_module2):
        """
        Extended forward pass with memory modules and gradient checkpointing
        """
        transformer_out = x_emb
        for i, block in enumerate(self.blocks):
            if self.use_checkpointing and self.training:
                # Ensure tensor requires gradients for checkpointing
                if not transformer_out.requires_grad:
                    transformer_out.requires_grad = True
                transformer_out = torch.utils.checkpoint.checkpoint(block, transformer_out, use_reentrant=False)
            else:
                transformer_out = block(transformer_out)

        transformer_out = self.final_threshold(transformer_out)

        if self.use_checkpointing and self.training:
            # Ensure tensor requires gradients for checkpointing
            if not transformer_out.requires_grad:
                transformer_out.requires_grad = True
            mem_out2 = torch.utils.checkpoint.checkpoint(memory_module2, transformer_out, use_reentrant=False)
        else:
            mem_out2 = memory_module2(transformer_out)

        # Gated combination
        alpha = torch.sigmoid(self.gate_param)
        combined = alpha * mem_out2 + (1 - alpha) * x_emb
        final_emb = self.final_threshold(combined)
        logits = self.ln_f(final_emb)
        return logits

    def forward(self, x):
        B, T = x.size()
        token_emb = self.token_embedding(x)
        positions = torch.arange(min(T, self.block_size), device=x.device).unsqueeze(0)
        pos_emb = self.pos_embedding(positions)
        x_emb = token_emb[:, :min(T, self.block_size)] + pos_emb
        return self.forward_with_embeddings(x_emb)

# ==========================================
# 6) Memory Module with Gradient Checkpointing
# ==========================================
class MemoryModule(nn.Module):
    def __init__(self, embed_dim, n_layers=8, expansion_factor=4):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(n_layers):
            layer = nn.Sequential(
                nn.LayerNorm(embed_dim),
                nn.Linear(embed_dim, embed_dim * expansion_factor),
                nn.GELU(),
                nn.Linear(embed_dim * expansion_factor, embed_dim),
                nn.Dropout(0.1)
            )
            self.layers.append(layer)
        self.final_norm = nn.LayerNorm(embed_dim)
        self.use_checkpointing = hyperparams['use_gradient_checkpointing']

    def forward(self, x):
        out = x
        for layer in self.layers:
            if self.use_checkpointing and self.training:
                # Ensure tensor requires gradients for checkpointing
                if not out.requires_grad:
                    out.requires_grad = True
                residual = torch.utils.checkpoint.checkpoint(layer, out, use_reentrant=False)
                out = out + residual
            else:
                out = out + layer(out)
        out = self.final_norm(out)
        return out

# ==========================================
# 8) Progressive Reward RL Training with Gradient Accumulation and End-Only Penalty
# ==========================================
class ProgressiveRewardTrainer:
    def __init__(self, main_model, memory1, memory2,
                 base_reward=1.0, base_penalty=-0.5,
                 learning_rate=5e-6, max_penalty_scale=2.5):
        self.main_model = main_model
        self.memory1 = memory1
        self.memory2 = memory2
        self.base_reward = base_reward
        self.base_penalty = base_penalty
        self.max_penalty_scale = max_penalty_scale
        self.optimizer = torch.optim.Adam(
            list(main_model.parameters()) +
            list(memory1.parameters()) +
            list(memory2.parameters()),
            lr=learning_rate
        )
        # Create gradient scaler for mixed precision training
        self.scaler = GradScaler(
            init_scale=hyperparams['grad_scale_init'],
            growth_interval=hyperparams['scale_growth_interval'],
            enabled=hyperparams['use_mixed_precision']
        )
        # Gradient accumulation steps
        self.gradient_accumulation_steps = hyperparams['gradient_accumulation_steps']
        # Track accumulated batches
        self.accumulated_batches = 0
        # Store metrics for visualization
        self.metrics_history = {
            'episodes': [],
            'penalty_scale': [],
            'policy_loss': [],
            'avg_reward': []
        }

    def calculate_episode_penalty_scaling(self, current_episode, total_episodes):
        """Calculate the penalty scaling factor based on episode progress."""
        # Ensure the scaling starts at 1.0 and linearly increases to max_scale
        scale = 1.0 + (current_episode / max(1, total_episodes - 1)) * (self.max_penalty_scale - 1.0)
        return scale

    def compute_progressive_rewards(self, generated_tokens, reference_tokens, penalty_scale=1.0):
        """
        Compute rewards that only penalize for incomplete completions at the end.
        Matching tokens still receive rewards, but non-matching tokens don't get penalties.
        """
        rewards = []
        gen_len = len(generated_tokens)
        ref_len = len(reference_tokens)
        compare_len = min(gen_len, ref_len)

        # For tokens that exist in both sequences, only give rewards for matches
        for i in range(compare_len):
            # Position-based scaling factor (increases from 0.1 to 1.0)
            position_scale = 0.1 + 0.9 * (i / max(gen_len, 1))

            # Only rewards for matching tokens, no penalties for mismatches
            if generated_tokens[i] == reference_tokens[i]:
                reward = self.base_reward * position_scale
            else:
                # No penalty for incorrect tokens during generation
                reward = 0.0

            rewards.append(reward)

        # Only penalize if the generation is incomplete (shorter than reference)
        if gen_len < ref_len:
            # Calculate how many tokens are missing
            missing_tokens = ref_len - gen_len

            # Calculate the severity of incompleteness (higher if more is missing)
            incompleteness_ratio = missing_tokens / ref_len

            # Add a single penalty at the end for the incomplete generation
            # Scale by both episode progress and degree of incompleteness
            end_penalty = self.base_penalty * penalty_scale * incompleteness_ratio

            # Add to the last token's reward (or append if empty)
            if rewards:
                rewards[-1] += end_penalty
            else:
                rewards.append(end_penalty)

        return torch.tensor(rewards, device=device)

    def train_step(self, prompt, reference_answer, current_episode=0, total_episodes=1):
        """Execute one REINFORCE training step with progressive rewards and gradient accumulation."""
        # Calculate penalty scaling based on episode progress
        penalty_scale = self.calculate_episode_penalty_scaling(current_episode, total_episodes)

        # Ensure models are in training mode
        self.main_model.train()
        self.memory1.train()
        self.memory2.train()

        # Only zero gradients at the start of accumulation
        if self.accumulated_batches == 0:
            self.optimizer.zero_grad()

        # Prepare input context
        system_prompt = hyperparams['system_prompt']
        full_prompt = f"{system_prompt}\n\nQuestion: {prompt}"
        prompt_bytes = full_prompt.encode('utf-8')

        # Add BOS token to start generation
        prompt_tokens = list(prompt_bytes) + [hyperparams['bos_token']]
        context = torch.tensor([prompt_tokens], dtype=torch.long, device=device)

        # Reference answer tokens
        ref_tokens = list(reference_answer.encode('utf-8'))

        # Storage for generation
        log_probs = []
        generated_tokens = []

        # Auto-regressive generation with gradient tracking and mixed precision
        # Using a smaller max_tokens to save memory
        max_tokens = min(4096, len(ref_tokens) * 2)  # Reduced from 512

        # Enable mixed precision for forward passes
        with autocast(device_type='cuda' if device == 'cuda' else 'cpu', enabled=hyperparams['use_mixed_precision']):
            for _ in range(max_tokens):
                # Clear CUDA cache periodically during generation
                if _ % 50 == 0 and device == "cuda":
                    torch.cuda.empty_cache()

                # Get context within block size limit
                x_cond = context[:, -hyperparams['block_size']:] if context.size(1) > hyperparams['block_size'] else context

                # Get embeddings
                B, T = x_cond.shape
                token_emb = self.main_model.token_embedding(x_cond)

                # Handle the case where T > block_size
                effective_T = min(T, self.main_model.block_size)
                pos_indices = torch.arange(effective_T, device=x_cond.device).unsqueeze(0)
                pos_emb = self.main_model.pos_embedding(pos_indices)

                combined_emb = token_emb[:, :effective_T] + pos_emb

                # Forward pass through model
                mem_out1 = self.memory1(combined_emb)
                logits = self.main_model.forward_with_two_memory(mem_out1, self.memory2)

                # Get probabilities for next token
                next_token_logits = logits[:, -1, :]
                probs = F.softmax(next_token_logits, dim=-1)
                log_prob_dist = F.log_softmax(next_token_logits, dim=-1)

                # Sample token
                next_token = torch.multinomial(probs, num_samples=1)
                token_value = next_token.item()

                # Record log probability for policy gradient
                token_log_prob = log_prob_dist.gather(1, next_token).squeeze()
                log_probs.append(token_log_prob)

                # Add token to generated sequence
                generated_tokens.append(token_value)
                context = torch.cat([context, next_token], dim=1)

                # Stop conditions
                if token_value == hyperparams['eos_token']:
                    break

                # Check for answer end tag
                try:
                    last_tokens = [t for t in context[0, -30:].tolist() if t != 0]
                    recent_text = bytes(last_tokens).decode('utf-8', errors='replace')

                    if hyperparams['answer_end_tag'] in recent_text:
                        full_text = bytes([t for t in context[0].tolist() if t != 0]).decode('utf-8', errors='replace')
                        if (hyperparams['thinking_end_tag'] in full_text and
                            hyperparams['answer_end_tag'] in full_text):
                            break
                except:
                    pass

        # Calculate progressive rewards with penalty scaling
        rewards = self.compute_progressive_rewards(generated_tokens, ref_tokens, penalty_scale)

        # Match rewards to log_probs length
        if len(rewards) > len(log_probs):
            rewards = rewards[:len(log_probs)]
        elif len(log_probs) > len(rewards):
            log_probs = log_probs[:len(rewards)]

        # REINFORCE policy gradient loss with mixed precision handling
        loss_metrics = {"policy_loss": 0.0, "avg_reward": 0.0}

        if len(log_probs) > 0 and len(rewards) > 0:
            # Use full precision for loss calculation
            policy_loss = -torch.sum(torch.stack(log_probs) * rewards)

            # Scale loss for gradient accumulation
            policy_loss = policy_loss / self.gradient_accumulation_steps

            # Use scaler for mixed precision backpropagation
            self.scaler.scale(policy_loss).backward()

            # Record metrics
            loss_metrics["policy_loss"] = policy_loss.item() * self.gradient_accumulation_steps  # Unscale for reporting
            loss_metrics["avg_reward"] = rewards.mean().item()

            # Increment accumulated batches
            self.accumulated_batches += 1

            # Update parameters if we've accumulated enough gradients
            if self.accumulated_batches >= self.gradient_accumulation_steps:
                # Gradient clipping
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.main_model.parameters(), 1.0)
                torch.nn.utils.clip_grad_norm_(self.memory1.parameters(), 1.0)
                torch.nn.utils.clip_grad_norm_(self.memory2.parameters(), 1.0)

                # Update with scaler
                self.scaler.step(self.optimizer)
                self.scaler.update()

                # Reset accumulation counter
                self.accumulated_batches = 0

        # Clear memory
        if device == "cuda":
            torch.cuda.empty_cache()

        return {
            'policy_loss': loss_metrics["policy_loss"],
            'avg_reward': loss_metrics["avg_reward"],
            'generated_length': len(generated_tokens),
            'reference_length': len(ref_tokens),
            'scaler_scale': self.scaler.get_scale(),
            'optimizer_step_taken': self.accumulated_batches == 0,  # True if we just took an optimizer step
            'penalty_scale': penalty_scale  # Track the penalty scaling for monitoring
        }

    def train(self, train_df, num_prompt_pairs=10, num_episodes=100, log_interval=5, save_interval=20):
        """Run full RL training procedure."""
        # Sample a fixed set of prompt-answer pairs for training
        if len(train_df) < num_prompt_pairs:
            print(f"Warning: Requested {num_prompt_pairs} pairs but dataset only has {len(train_df)} examples")
            selected_indices = list(range(len(train_df)))
        else:
            selected_indices = random.sample(range(len(train_df)), num_prompt_pairs)

        selected_pairs = train_df.iloc[selected_indices]
        print(f"Selected {len(selected_indices)} prompt-answer pairs for RL training")

        for episode in range(num_episodes):
            # Sample random prompt-answer pair from our selected pairs
            idx = random.randint(0, len(selected_pairs) - 1)
            prompt = selected_pairs.iloc[idx]['prompt']
            reference = selected_pairs.iloc[idx]['response']

            # Print memory stats before training step
            if (episode + 1) % log_interval == 0 and device == "cuda":
                print_memory_stats()

            # Execute training step with episode information
            metrics = self.train_step(prompt, reference,
                                     current_episode=episode,
                                     total_episodes=num_episodes)

            # Store metrics for visualization
            self.metrics_history['episodes'].append(episode + 1)
            self.metrics_history['penalty_scale'].append(metrics['penalty_scale'])
            self.metrics_history['policy_loss'].append(metrics['policy_loss'])
            self.metrics_history['avg_reward'].append(metrics['avg_reward'])

            # Logging
            if (episode + 1) % log_interval == 0:
                print(f"Episode {episode+1}/{num_episodes}, Metrics: {metrics}")

            # Save checkpoint
            if (episode + 1) % save_interval == 0:
                save_path = hyperparams['rl_checkpoint_path'].replace('.pt', f'_ep{episode+1}.pt')
                torch.save({
                    'main_model_state': self.main_model.state_dict(),
                    'memory1_state': self.memory1.state_dict(),
                    'memory2_state': self.memory2.state_dict(),
                    'episode': episode + 1,
                    'scaler': self.scaler.state_dict(),  # Save scaler state
                    'metrics_history': self.metrics_history  # Save metrics for visualization
                }, save_path)
                print(f"Checkpoint saved to {save_path}")

                # Visualize penalty effect after saving checkpoint
                if episode + 1 >= log_interval:
                    self.visualize_penalty_effect()

                # Force cleanup after checkpoint
                clear_memory()

    def visualize_penalty_effect(self):
        """Generate plots to visualize the effect of the progressive penalty scaling."""
        # Only create visualization if we have enough data points
        if len(self.metrics_history['episodes']) < 2:
            return

        # Create figure with multiple subplots
        fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 15), sharex=True)

        # Plot 1: Penalty Scale over episodes
        ax1.plot(self.metrics_history['episodes'], self.metrics_history['penalty_scale'],
                marker='o', linestyle='-', color='red')
        ax1.set_ylabel('Penalty Scale')
        ax1.set_title('Progressive Penalty Scaling over Episodes')
        ax1.grid(True)

        # Plot 2: Policy Loss over episodes
        ax2.plot(self.metrics_history['episodes'], self.metrics_history['policy_loss'],
                marker='x', linestyle='-', color='blue')
        ax2.set_ylabel('Policy Loss')
        ax2.set_title('Policy Loss over Episodes')
        ax2.grid(True)

        # Plot 3: Average Reward over episodes
        ax3.plot(self.metrics_history['episodes'], self.metrics_history['avg_reward'],
                marker='s', linestyle='-', color='green')
        ax3.set_xlabel('Episode')
        ax3.set_ylabel('Average Reward')
        ax3.set_title('Average Reward over Episodes')
        ax3.grid(True)

        # Adjust layout and save figure
        plt.tight_layout()
        plt.savefig(f"penalty_effect_visualization_ep{max(self.metrics_history['episodes'])}.png")
        plt.close(fig)
        print(f"Visualization saved to penalty_effect_visualization_ep{max(self.metrics_history['episodes'])}.png")

# ==========================================
# 7) Generate Text from Trained Model with Dynamic Quantization
# ==========================================
@torch.no_grad()
def generate_from_prompt(main_model, memory1, memory2, prompt_text=None, max_new_tokens=200, top_p=None):
    if prompt_text is None:
        prompt_text = hyperparams['start_prompt']

    # Use hyperparameter value if top_p not specified
    if top_p is None:
        top_p = hyperparams['top_p']

    # Apply system prompt to user prompt
    system_prompt = hyperparams['system_prompt']
    full_prompt = f"{system_prompt}\n\nQuestion: {prompt_text}"

    # Convert prompt to bytes
    if isinstance(full_prompt, str):
        prompt_bytes = full_prompt.encode('utf-8')
    elif not isinstance(full_prompt, bytes):
        prompt_bytes = str(full_prompt).encode('utf-8')

    # Only quantize for CPU, not for CUDA (to fix the error)
    if hyperparams['use_dynamic_quantization'] and device != "cuda":
        print("Quantizing models for inference...")
        # Quantize main model
        quantized_main_model = torch.quantization.quantize_dynamic(
            main_model,
            {nn.Linear},
            dtype=torch.qint8
        )
        # Quantize memory modules
        quantized_memory1 = torch.quantization.quantize_dynamic(
            memory1,
            {nn.Linear},
            dtype=torch.qint8
        )
        quantized_memory2 = torch.quantization.quantize_dynamic(
            memory2,
            {nn.Linear},
            dtype=torch.qint8
        )
        # Use quantized models
        use_main_model = quantized_main_model
        use_memory1 = quantized_memory1
        use_memory2 = quantized_memory2
        print("Models quantized for inference")
    else:
        # If on CUDA, dynamic quantization is not supported, so use original models
        if hyperparams['use_dynamic_quantization'] and device == "cuda":
            print("Dynamic quantization not supported on CUDA, using original models")
        # Use original models
        use_main_model = main_model
        use_memory1 = memory1
        use_memory2 = memory2

    use_main_model.eval()
    use_memory1.eval()
    use_memory2.eval()

    # Create context from prompt
    context = torch.tensor([b for b in prompt_bytes], dtype=torch.long, device=device).unsqueeze(0)

    # Add BOS token to start the response generation
    bos_token = torch.tensor([[hyperparams['bos_token']]], dtype=torch.long, device=device)
    context = torch.cat([context, bos_token], dim=1)

    generated = []
    eos_found = False

    # Generate with reduced batch size and in smaller chunks for memory efficiency
    for _ in range(max_new_tokens):
        if eos_found:
            break

        # Only use the last block_size tokens for context to save memory
        x_cond = context[:, -hyperparams['block_size']:] if context.size(1) > hyperparams['block_size'] else context
        B, T = x_cond.shape
        token_emb = use_main_model.token_embedding(x_cond)

        # Handle the case where T > block_size
        effective_T = min(T, use_main_model.block_size)
        pos_indices = torch.arange(effective_T, device=x_cond.device).unsqueeze(0)
        pos_emb = use_main_model.pos_embedding(pos_indices)

        combined_emb = token_emb[:, :effective_T] + pos_emb

        # Forward pass with memory modules
        mem_out1 = use_memory1(combined_emb)
        logits = use_main_model.forward_with_two_memory(mem_out1, use_memory2)

        # Get next token distribution with top-p (nucleus) sampling
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)

        # Sort probabilities in descending order
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)

        # Compute cumulative probabilities
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        # Find indices where cumulative probability exceeds top_p
        sorted_indices_to_remove = cumulative_probs > top_p

        # Shift to create first index (0) as False to always keep at least one token
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # Create mask for indices to remove
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)

        # Filter logits
        filtered_logits = logits.clone()
        filtered_logits[indices_to_remove] = -float('inf')

        # Get probabilities from filtered logits
        filtered_probs = F.softmax(filtered_logits, dim=-1)

        # Sample from the filtered distribution
        next_token = torch.multinomial(filtered_probs, num_samples=1)
        next_token_value = next_token.item()

        # Check for EOS token
        if next_token_value == hyperparams['eos_token']:
            eos_found = True

        generated.append(next_token_value)
        context = torch.cat([context, next_token], dim=1)

        # Free some memory periodically
        if _ % 50 == 0 and device == "cuda":
            torch.cuda.empty_cache()

    # Combine context with generated bytes and return as bytes object
    result_bytes = bytes(context.view(-1).tolist())

    # Clean up special tokens when returning result
    try:
        # Convert to list for easier manipulation
        byte_list = list(result_bytes)

        # Find all BOS tokens and remove them
        while hyperparams['bos_token'] in byte_list:
            byte_list.remove(hyperparams['bos_token'])

        # Find all EOS tokens and remove everything after the first one
        if hyperparams['eos_token'] in byte_list:
            eos_index = byte_list.index(hyperparams['eos_token'])
            byte_list = byte_list[:eos_index]

        # Convert back to bytes
        cleaned_bytes = bytes(byte_list)
        return cleaned_bytes
    except:
        # If any error in cleaning, return the original bytes
        return result_bytes

# ==========================================
# 9) Main RL Training Function
# ==========================================
def train_with_progressive_rewards():
    """Main function to run RL training with progressive rewards."""
    # Create models with original architecture
    main_model = ImprovedByteTransformer(
        vocab_size=hyperparams['vocab_size'],
        embed_dim=hyperparams['embed_dim'],
        n_heads=hyperparams['n_heads'],
        n_layers=hyperparams['n_layers'],
        block_size=hyperparams['block_size']
    ).to(device)

    memory1 = MemoryModule(
        embed_dim=hyperparams['embed_dim'],
        n_layers=hyperparams['memory_n_layers'],
        expansion_factor=4
    ).to(device)

    memory2 = MemoryModule(
        embed_dim=hyperparams['embed_dim'],
        n_layers=hyperparams['memory_n_layers'],
        expansion_factor=4
    ).to(device)

    # Set all models to training mode explicitly
    main_model.train()
    memory1.train()
    memory2.train()

    # Calculate model size
    num_params = sum(p.numel() for p in main_model.parameters() if p.requires_grad)
    num_params += sum(p.numel() for p in memory1.parameters() if p.requires_grad)
    num_params += sum(p.numel() for p in memory2.parameters() if p.requires_grad)
    print(f"Total trainable parameters: {num_params:,}")

    # Load pretrained model if available
    if os.path.exists(hyperparams['pretrained_model_path']):
        print(f"Loading pretrained model from {hyperparams['pretrained_model_path']}...")
        try:
            checkpoint = torch.load(hyperparams['pretrained_model_path'], map_location=device)
            main_model.load_state_dict(checkpoint['main_model_state'], strict=False)
            memory1.load_state_dict(checkpoint['memory1_state'])
            if 'memory2_state' in checkpoint:
                memory2.load_state_dict(checkpoint['memory2_state'])
            print("Pretrained model loaded successfully.")
        except Exception as e:
            print(f"Error loading pretrained model: {e}")
            print("Starting with randomly initialized weights.")
    else:
        print(f"Warning: No pretrained model found at {hyperparams['pretrained_model_path']}.")
        print("Starting with randomly initialized weights.")

    # Enabled gradient checkpointing if requested (saves memory during training)
    if hyperparams['use_gradient_checkpointing']:
        print("Gradient checkpointing enabled for memory efficiency")

    # Load dataset
    train_df, val_df, _ = load_cot_logic_data()

    # Create RL trainer with the max penalty scale
    trainer = ProgressiveRewardTrainer(
        main_model=main_model,
        memory1=memory1,
        memory2=memory2,
        base_reward=hyperparams['base_reward'],
        base_penalty=hyperparams['base_penalty'],
        learning_rate=hyperparams['rl_learning_rate'],
        max_penalty_scale=hyperparams['max_penalty_scale']
    )

    # Run training
    trainer.train(
        train_df=train_df,
        num_prompt_pairs=hyperparams['n_prompt_ans_pairs'],
        num_episodes=hyperparams['number_of_practice'],
        log_interval=hyperparams['rl_log_interval'],
        save_interval=hyperparams['rl_save_interval']
    )

    # Save final model
    torch.save({
        'main_model_state': main_model.state_dict(),
        'memory1_state': memory1.state_dict(),
        'memory2_state': memory2.state_dict(),
        'metrics_history': trainer.metrics_history
    }, "rl_final_model.pt")

    # Generate final visualization
    trainer.visualize_penalty_effect()

    print("RL training complete!")

# ==========================================
# 10) Test RL-trained Model
# ==========================================
def test_rl_model(model_path="rl_final_model.pt"):
    """Test the RL-trained model on a few examples."""
    # Create models
    main_model = ImprovedByteTransformer(
        vocab_size=hyperparams['vocab_size'],
        embed_dim=hyperparams['embed_dim'],
        n_heads=hyperparams['n_heads'],
        n_layers=hyperparams['n_layers'],
        block_size=hyperparams['block_size']
    ).to(device)

    memory1 = MemoryModule(
        embed_dim=hyperparams['embed_dim'],
        n_layers=hyperparams['memory_n_layers'],
        expansion_factor=4
    ).to(device)

    memory2 = MemoryModule(
        embed_dim=hyperparams['embed_dim'],
        n_layers=hyperparams['memory_n_layers'],
        expansion_factor=4
    ).to(device)

    # Load trained model
    if os.path.exists(model_path):
        print(f"Loading RL-trained model from {model_path}...")
        checkpoint = torch.load(model_path, map_location=device)
        main_model.load_state_dict(checkpoint['main_model_state'], strict=False)
        memory1.load_state_dict(checkpoint['memory1_state'])
        memory2.load_state_dict(checkpoint['memory2_state'])

        # If metrics history is available, visualize it
        if 'metrics_history' in checkpoint:
            visualize_saved_metrics(checkpoint['metrics_history'])
    else:
        print(f"Model path {model_path} not found. Exiting test.")
        return

    # Load dataset
    _, _, test_df = load_cot_logic_data()

    # Select a few examples to test
    test_examples = test_df.sample(3)

    for i, (_, example) in enumerate(test_examples.iterrows()):
        prompt = example['prompt']
        reference = example['response']

        print(f"\n--- Test Example {i+1} ---")
        print(f"Prompt: {prompt[:100]}...")

        # Generate response
        generated_bytes = generate_from_prompt(
            main_model, memory1, memory2,
            prompt_text=prompt,
            max_new_tokens=512
        )

        try:
            generated_text = generated_bytes.decode('utf-8', errors='replace')
            print(f"\nGenerated Response: {generated_text[:2000]}...")

            # Find tags in response
            thinking_start = generated_text.find(hyperparams['thinking_tag'])
            thinking_end = generated_text.find(hyperparams['thinking_end_tag'])
            answer_start = generated_text.find(hyperparams['answer_tag'])
            answer_end = generated_text.find(hyperparams['answer_end_tag'])

            if thinking_start >= 0 and thinking_end > thinking_start:
                print("\nThinking Process:")
                print(generated_text[thinking_start:thinking_end + len(hyperparams['thinking_end_tag'])])

            if answer_start >= 0 and answer_end > answer_start:
                print("\nFinal Answer:")
                print(generated_text[answer_start:answer_end + len(hyperparams['answer_end_tag'])])

        except Exception as e:
            print(f"Error decoding response: {e}")

def visualize_saved_metrics(metrics_history):
    """Visualize metrics from a saved checkpoint."""
    if not metrics_history or len(metrics_history['episodes']) < 2:
        print("Not enough data to visualize metrics")
        return

    # Create figure with multiple subplots
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 15), sharex=True)

    # Plot 1: Penalty Scale over episodes
    ax1.plot(metrics_history['episodes'], metrics_history['penalty_scale'],
            marker='o', linestyle='-', color='red')
    ax1.set_ylabel('Penalty Scale')
    ax1.set_title('Progressive Penalty Scaling over Episodes')
    ax1.grid(True)

    # Plot 2: Policy Loss over episodes
    ax2.plot(metrics_history['episodes'], metrics_history['policy_loss'],
            marker='x', linestyle='-', color='blue')
    ax2.set_ylabel('Policy Loss')
    ax2.set_title('Policy Loss over Episodes')
    ax2.grid(True)

    # Plot 3: Average Reward over episodes
    ax3.plot(metrics_history['episodes'], metrics_history['avg_reward'],
            marker='s', linestyle='-', color='green')
    ax3.set_xlabel('Episode')
    ax3.set_ylabel('Average Reward')
    ax3.set_title('Average Reward over Episodes')
    ax3.grid(True)

    # Adjust layout and save figure
    plt.tight_layout()
    plt.savefig("loaded_model_metrics_visualization.png")
    plt.close(fig)
    print("Visualization of loaded model metrics saved to loaded_model_metrics_visualization.png")

# ==========================================
# 11) Main Entry Point
# ==========================================
if __name__ == "__main__":
    print("Starting Progressive Reward RL Training with Memory Optimizations...")
    print(f"Gradient checkpointing enabled: {hyperparams['use_gradient_checkpointing']}")
    print(f"Gradient accumulation steps: {hyperparams['gradient_accumulation_steps']}")
    print(f"Batch size: {hyperparams['batch_size']}")
    print(f"Using dynamic quantization: {hyperparams['use_dynamic_quantization']}")
    print(f"Maximum penalty scale: {hyperparams['max_penalty_scale']}")

    # Clear CUDA memory before starting
    if device == "cuda":
        torch.cuda.empty_cache()
        print_memory_stats()

    try:
        print("\nRunning RL training...")
        train_with_progressive_rewards()

        print("\nTesting RL-trained model...")
        test_rl_model()

        print("\nTraining complete!")
    except Exception as e:
        print(f"Error during training: {e}")
        import traceback
        traceback.print_exc()

In [None]:
## inference V2

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# ==========================================
# 1) Define Hyperparameters for Inference
# ==========================================
hyperparams = {
    'block_size': 1024,               # Sequence length for context
    'embed_dim': 1024,                # Transformer embedding dimension
    'n_heads': 16,                    # Number of attention heads
    'n_layers': 24,                   # Number of Transformer blocks
    'memory_n_layers': 8,             # Number of layers in the Memory modules
    'vocab_size': 256,                # Fixed vocabulary size for byte tokenization
    'bos_token': 254,                 # Beginning-of-sequence token (byte value)
    'eos_token': 255,                 # End-of-sequence token (byte value)
    'checkpoint_path': "threshold_transformer_checkpoint.pt", # Path to model checkpoint
    'top_p': 0.6,                     # Top-p sampling parameter (0-1)
    'system_prompt': """IMPORTANT: Your response format should have two parts:
    1. First, explain your thinking process in detail between <think> </think> tags.
    2. Then, provide your final answer between <answer> </answer> tags.
    For example: <think> Let me think about this problem carefully...
    [detailed reasoning process] </think> <answer> [concise answer] </answer> """
}

# ==========================================
# 2) Select device for inference
# ==========================================
device = "mps" if torch.backends.mps.is_available() else \
         ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ==========================================
# 3) NEW: Thresholded Attention Implementation
# ==========================================
class ThresholdedAttention(nn.Module):
    def __init__(self, embed_dim, n_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads

        # Standard attention projections
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        # Attention score normalization
        self.attn_scale = nn.Parameter(torch.ones(1) * (1.0 / math.sqrt(self.head_dim)))

        # Threshold parameters for attention scores
        self.register_buffer('score_running_mean', torch.zeros(n_heads))
        self.register_buffer('score_running_var', torch.ones(n_heads))
        self.score_threshold = nn.Parameter(torch.ones(1) * 0.5)
        self.score_momentum = 0.01
        self.temperature = 1.0

    def forward(self, x, attn_mask=None):
        B, T, C = x.size()

        # Project to queries, keys, values
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # B, H, T, D
        k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # B, H, T, D
        v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # B, H, T, D

        # Compute scaled attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.attn_scale  # B, H, T, T

        # Apply causal mask if provided
        if attn_mask is not None:
            scores = scores.masked_fill(attn_mask, float('-inf'))

        # Apply thresholding to attention scores (only in training)
        if self.training:
            with torch.no_grad():
                # Compute statistics of attention scores across batch and tokens
                # We remove the masked (very negative) values from statistics calculation
                valid_mask = ~torch.isinf(scores)
                if valid_mask.any():
                    # Get head-wise mean and variance
                    score_mean = torch.sum(scores * valid_mask, dim=(0, 2, 3)) / torch.sum(valid_mask, dim=(0, 2, 3))
                    score_var = torch.sum(((scores - score_mean.view(1, -1, 1, 1)) ** 2) * valid_mask, dim=(0, 2, 3)) / torch.sum(valid_mask, dim=(0, 2, 3))

                    # Update running statistics
                    self.score_running_mean = (1 - self.score_momentum) * self.score_running_mean + self.score_momentum * score_mean
                    self.score_running_var = (1 - self.score_momentum) * self.score_running_var + self.score_momentum * score_var

        # Calculate adaptive threshold for attention scores
        threshold_value = torch.sigmoid(self.score_threshold) * torch.sqrt(torch.clamp(self.score_running_var, min=1e-6))

        # Create soft mask for scores (0 for values below threshold, 1 for values above)
        # We can't use scores directly as they may have -inf values, so we'll make a mask
        # Exclude values that are already -inf (from causal mask)
        mask = (~torch.isinf(scores)) & (scores < threshold_value.view(1, -1, 1, 1))
        scores = scores.masked_fill(mask, -1e4)  # Not -inf to keep gradients

        # Apply softmax to get attention weights
        attn_weights = F.softmax(scores, dim=-1)

        # Apply attention to values
        attn_output = torch.matmul(attn_weights, v)  # B, H, T, D

        # Reshape and project output
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(attn_output)

    # Method to handle compatibility with original MultiheadAttention
    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
        # Map old MHA parameters to new ThresholdedAttention parameters
        if f"{prefix}in_proj_weight" in state_dict:
            # MultiheadAttention uses a single in_proj_weight that combines q,k,v
            in_proj_weight = state_dict.pop(f"{prefix}in_proj_weight")
            in_proj_bias = state_dict.pop(f"{prefix}in_proj_bias", None)

            # Split the in_proj_weight into q, k, v parts
            q_weight, k_weight, v_weight = in_proj_weight.chunk(3, dim=0)
            state_dict[f"{prefix}q_proj.weight"] = q_weight
            state_dict[f"{prefix}k_proj.weight"] = k_weight
            state_dict[f"{prefix}v_proj.weight"] = v_weight

            if in_proj_bias is not None:
                q_bias, k_bias, v_bias = in_proj_bias.chunk(3, dim=0)
                state_dict[f"{prefix}q_proj.bias"] = q_bias
                state_dict[f"{prefix}k_proj.bias"] = k_bias
                state_dict[f"{prefix}v_proj.bias"] = v_bias

        # Map out_proj parameters
        if f"{prefix}out_proj.weight" in state_dict:
            state_dict[f"{prefix}out_proj.weight"] = state_dict[f"{prefix}out_proj.weight"]
            if f"{prefix}out_proj.bias" in state_dict:
                state_dict[f"{prefix}out_proj.bias"] = state_dict[f"{prefix}out_proj.bias"]

        # Call parent class method to handle the rest
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

# ==========================================
# 4) Model Architecture with Thresholded Attention
# ==========================================
class ImprovedEmergentThresholdLayer(nn.Module):
    def __init__(self, feature_dim):
        super().__init__()
        self.feature_dim = feature_dim
        self.norm = nn.LayerNorm(feature_dim)
        self.register_buffer('running_mean', torch.zeros(feature_dim))
        self.register_buffer('running_var', torch.ones(feature_dim))
        self.adaptive_threshold = nn.Parameter(torch.ones(1) * 0.5)
        self.momentum = 0.01

    def forward(self, x):
        x_norm = self.norm(x)
        if self.training:
            with torch.no_grad():
                batch_mean = x_norm.mean(dim=(0, 1))
                batch_var = x_norm.var(dim=(0, 1), unbiased=False)
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var

        # More robust threshold calculation with clamping to prevent extremely small values
        threshold = torch.sigmoid(self.adaptive_threshold) * torch.sqrt(torch.clamp(self.running_var, min=1e-6))

        # Increase denominator from 0.1 to 1.0 for stability
        gate = torch.sigmoid((torch.abs(x_norm) - threshold.view(1, 1, -1)) / 1.0)

        alpha = torch.sigmoid(self.adaptive_threshold)

        # Clip outputs to prevent extreme values
        return torch.clamp(alpha * (gate * x) + (1 - alpha) * x, min=-100, max=100)

class ImprovedTransformerBlock(nn.Module):
    def __init__(self, embed_dim, n_heads):
        super().__init__()
        # Use ThresholdedAttention instead of nn.MultiheadAttention
        self.attention = ThresholdedAttention(embed_dim, n_heads)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            ImprovedEmergentThresholdLayer(4 * embed_dim),
            nn.GELU(),
            nn.Linear(4 * embed_dim, embed_dim)
        )
        self.threshold1 = ImprovedEmergentThresholdLayer(embed_dim)
        self.threshold2 = ImprovedEmergentThresholdLayer(embed_dim)

    def forward(self, x):
        B, T, E = x.size()
        causal_mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
        attn_out = self.attention(x, attn_mask=causal_mask)
        x = x + self.threshold1(attn_out)
        ff_out = self.feed_forward(x)
        x = x + self.threshold2(ff_out)
        return x

class ImprovedByteTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, n_heads=4, n_layers=4, block_size=128):
        super().__init__()
        self.block_size = block_size
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embedding = nn.Embedding(self.block_size, embed_dim)
        self.blocks = nn.ModuleList([
            ImprovedTransformerBlock(embed_dim, n_heads)
            for _ in range(n_layers)
        ])
        self.final_threshold = ImprovedEmergentThresholdLayer(embed_dim)
        self.ln_f = nn.Linear(embed_dim, vocab_size)
        # Learned gating parameter for combining memory outputs
        self.gate_param = nn.Parameter(torch.tensor(0.0))

    def forward_with_embeddings(self, x_emb):
        for block in self.blocks:
            x_emb = block(x_emb)
        x_emb = self.final_threshold(x_emb)
        logits = self.ln_f(x_emb)
        return logits

    def forward_with_two_memory(self, x_emb, memory_module2):
        """
        Extended forward pass:
          1. Run transformer blocks on x_emb.
          2. Apply the transformer's final threshold.
          3. Process the result with a second memory module.
          4. Combine the result of memory_module2 and the original x_emb using a gated combination.
          5. Apply the final threshold on the combined representation.
          6. Project to logits.
        """
        transformer_out = x_emb
        for block in self.blocks:
            transformer_out = block(transformer_out)
        transformer_out = self.final_threshold(transformer_out)
        mem_out2 = memory_module2(transformer_out)
        # Gated combination instead of simple addition:
        alpha = torch.sigmoid(self.gate_param)  # Learned gating weight in [0, 1]
        combined = alpha * mem_out2 + (1 - alpha) * x_emb
        final_emb = self.final_threshold(combined)
        logits = self.ln_f(final_emb)
        return logits

    def forward(self, x):
        B, T = x.size()
        token_emb = self.token_embedding(x)
        positions = torch.arange(T, device=x.device).unsqueeze(0)
        pos_emb = self.pos_embedding(positions)
        x_emb = token_emb + pos_emb
        return self.forward_with_embeddings(x_emb)

class MemoryModule(nn.Module):
    def __init__(self, embed_dim, n_layers=8, expansion_factor=4):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(n_layers):
            layer = nn.Sequential(
                nn.LayerNorm(embed_dim),
                nn.Linear(embed_dim, embed_dim * expansion_factor),
                nn.GELU(),
                nn.Linear(embed_dim * expansion_factor, embed_dim),
                nn.Dropout(0.1)
            )
            self.layers.append(layer)
        self.final_norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        out = x
        for layer in self.layers:
            out = out + layer(out)
        out = self.final_norm(out)
        return out

# ==========================================
# 5) Model Conversion Function (for compatibility)
# ==========================================
def convert_original_model_to_thresholded(original_checkpoint_path, device_type=device):
    """
    Function to convert a standard model checkpoint to use the thresholded attention.
    This is a fallback if loading directly fails.
    """
    print(f"Converting original model to thresholded version: {original_checkpoint_path}")

    # Load original checkpoint
    checkpoint = torch.load(original_checkpoint_path, map_location=device_type)

    # Create new models
    new_main_model = ImprovedByteTransformer(
        vocab_size=hyperparams['vocab_size'],
        embed_dim=hyperparams['embed_dim'],
        n_heads=hyperparams['n_heads'],
        n_layers=hyperparams['n_layers'],
        block_size=hyperparams['block_size']
    ).to(device_type)

    new_memory1 = MemoryModule(
        embed_dim=hyperparams['embed_dim'],
        n_layers=hyperparams['memory_n_layers'],
        expansion_factor=4
    ).to(device_type)

    new_memory2 = MemoryModule(
        embed_dim=hyperparams['embed_dim'],
        n_layers=hyperparams['memory_n_layers'],
        expansion_factor=4
    ).to(device_type)

    # Load non-attention parts directly
    # Embeddings, Layer Norms, Feed Forward, and Memory Modules should have identical keys
    state_dict = checkpoint['main_model_state']
    new_state_dict = {}

    # Copy all parts that can be directly copied
    for k, v in state_dict.items():
        if 'attention' not in k:
            new_state_dict[k] = v

    # Load memory modules directly
    new_memory1.load_state_dict(checkpoint['memory1_state'])
    if 'memory2_state' in checkpoint:
        new_memory2.load_state_dict(checkpoint['memory2_state'])

    # Process attention parts
    for layer_idx in range(hyperparams['n_layers']):
        prefix = f"blocks.{layer_idx}.attention."

        # Handle MultiheadAttention parameters specially
        if f"{prefix}in_proj_weight" in state_dict:
            in_proj_weight = state_dict[f"{prefix}in_proj_weight"]
            in_proj_bias = state_dict.get(f"{prefix}in_proj_bias", None)

            # Split the weights/biases for q, k, v
            embed_dim = hyperparams['embed_dim']
            q_weight, k_weight, v_weight = in_proj_weight.chunk(3, dim=0)

            # Set in new state dict
            new_state_dict[f"{prefix}q_proj.weight"] = q_weight
            new_state_dict[f"{prefix}k_proj.weight"] = k_weight
            new_state_dict[f"{prefix}v_proj.weight"] = v_weight

            if in_proj_bias is not None:
                q_bias, k_bias, v_bias = in_proj_bias.chunk(3, dim=0)
                new_state_dict[f"{prefix}q_proj.bias"] = q_bias
                new_state_dict[f"{prefix}k_proj.bias"] = k_bias
                new_state_dict[f"{prefix}v_proj.bias"] = v_bias

        # Copy output projection
        if f"{prefix}out_proj.weight" in state_dict:
            new_state_dict[f"{prefix}out_proj.weight"] = state_dict[f"{prefix}out_proj.weight"]
            if f"{prefix}out_proj.bias" in state_dict:
                new_state_dict[f"{prefix}out_proj.bias"] = state_dict[f"{prefix}out_proj.bias"]

    # Load the modified state dict
    new_main_model.load_state_dict(new_state_dict, strict=False)

    return new_main_model, new_memory1, new_memory2

# ==========================================
# 6) Inference Function with Top-p Sampling
# ==========================================
@torch.no_grad()
def generate_from_prompt(main_model, memory1, memory2, prompt_text=None, max_new_tokens=200, top_p=None):
    if prompt_text is None:
        prompt_text = "Explain why the statement 'I wore my lucky socks today, and I got an A on my test, so my socks must be lucky' is a logical fallacy."

    # Use default top_p if not specified
    if top_p is None:
        top_p = hyperparams['top_p']

    # Apply system prompt to user prompt
    system_prompt = hyperparams['system_prompt']
    full_prompt = f"{system_prompt}\n\nQuestion: {prompt_text}"

    # Convert prompt to bytes
    if isinstance(full_prompt, str):
        prompt_bytes = full_prompt.encode('utf-8')
    elif not isinstance(full_prompt, bytes):
        prompt_bytes = str(full_prompt).encode('utf-8')

    # Set models to evaluation mode
    main_model.eval()
    memory1.eval()
    memory2.eval()

    # Create context from prompt
    context = torch.tensor([b for b in prompt_bytes], dtype=torch.long, device=device).unsqueeze(0)

    # Add BOS token to start the response generation
    bos_token = torch.tensor([[hyperparams['bos_token']]], dtype=torch.long, device=device)
    context = torch.cat([context, bos_token], dim=1)

    generated = []
    eos_found = False

    for _ in range(max_new_tokens):
        if eos_found:
            break

        x_cond = context[:, -hyperparams['block_size']:] if context.size(1) > hyperparams['block_size'] else context
        B, T = x_cond.shape
        token_emb = main_model.token_embedding(x_cond)
        pos_emb = main_model.pos_embedding(torch.arange(T, device=x_cond.device).unsqueeze(0))
        combined_emb = token_emb + pos_emb

        mem_out1 = memory1(combined_emb)
        logits = main_model.forward_with_two_memory(mem_out1, memory2)

        # Get next token distribution with top-p (nucleus) sampling
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)

        # Sort probabilities in descending order
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)

        # Compute cumulative probabilities
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        # Find indices where cumulative probability exceeds top_p
        sorted_indices_to_remove = cumulative_probs > top_p

        # Shift to create first index (0) as False to always keep at least one token
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # Create mask for indices to remove
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)

        # Filter logits
        filtered_logits = logits.clone()
        filtered_logits[indices_to_remove] = -float('inf')

        # Get probabilities from filtered logits
        filtered_probs = F.softmax(filtered_logits, dim=-1)

        # Sample from the filtered distribution
        next_token = torch.multinomial(filtered_probs, num_samples=1)
        next_token_value = next_token.item()

        # Check for EOS token
        if next_token_value == hyperparams['eos_token']:
            eos_found = True
            break  # Immediately break out of the loop when EOS is found

        # Only append non-EOS tokens to the generated sequence
        generated.append(next_token_value)
        context = torch.cat([context, next_token], dim=1)

    # Convert generated tokens to bytes
    result_bytes = bytes(generated)

    # Clean up special tokens when returning result
    try:
        # Convert to string
        response_str = result_bytes.decode('utf-8')
        return response_str
    except:
        # If decoding fails, return a message
        return "Error: Unable to decode generated response."

# ==========================================
# 7) Main Inference Function
# ==========================================
def inference(prompt_text, max_tokens=512, top_p=None):
    """Generate a response for a given prompt using the pre-trained model with thresholded attention"""
    # Use default top_p if not specified
    if top_p is None:
        top_p = hyperparams['top_p']

    # Initialize models
    main_model = ImprovedByteTransformer(
        vocab_size=hyperparams['vocab_size'],
        embed_dim=hyperparams['embed_dim'],
        n_heads=hyperparams['n_heads'],
        n_layers=hyperparams['n_layers'],
        block_size=hyperparams['block_size']
    ).to(device)

    memory1 = MemoryModule(
        embed_dim=hyperparams['embed_dim'],
        n_layers=hyperparams['memory_n_layers'],
        expansion_factor=4
    ).to(device)

    memory2 = MemoryModule(
        embed_dim=hyperparams['embed_dim'],
        n_layers=hyperparams['memory_n_layers'],
        expansion_factor=4
    ).to(device)

    # Check if thresholded checkpoint exists
    thresholded_checkpoint_path = hyperparams['checkpoint_path'].replace('.pt', '_thresholded.pt')
    if os.path.exists(thresholded_checkpoint_path):
        checkpoint_path = thresholded_checkpoint_path
        print(f"Using thresholded model checkpoint: {checkpoint_path}")
    else:
        checkpoint_path = hyperparams['checkpoint_path']
        print(f"Thresholded checkpoint not found, using original: {checkpoint_path}")

    # Load pre-trained model weights
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device)

        # Try to load model states directly
        try:
            # Load main model state
            main_model.load_state_dict(checkpoint['main_model_state'], strict=False)
            print("Main model loaded with some parameters ignored (normal for model conversion)")

            # Load memory modules
            memory1.load_state_dict(checkpoint['memory1_state'])
            if 'memory2_state' in checkpoint:
                memory2.load_state_dict(checkpoint['memory2_state'])

            print("Models loaded successfully")
        except Exception as e:
            print(f"Error loading checkpoint directly: {e}")
            print("Trying model conversion approach...")

            if checkpoint_path == hyperparams['checkpoint_path']:
                # If direct loading fails with original checkpoint, try conversion
                main_model, memory1, memory2 = convert_original_model_to_thresholded(checkpoint_path)
                print("Model converted successfully")
            else:
                # If even the thresholded checkpoint fails, something is wrong
                raise Exception("Failed to load both regular and thresholded checkpoints")

    except FileNotFoundError:
        raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
    except Exception as e:
        raise Exception(f"Error loading checkpoint: {e}")

    # Generate response with top-p sampling
    response = generate_from_prompt(
        main_model, memory1, memory2,
        prompt_text=prompt_text,
        max_new_tokens=max_tokens,
        top_p=top_p
    )

    return response

# ==========================================
# 8) Example Usage
# ==========================================
if __name__ == "__main__":
    import os
    import argparse

    # Parse command line arguments
    parser = argparse.ArgumentParser(description='Generate responses using thresholded attention model')
    parser.add_argument('--prompt', type=str, default="Why sea is salty?",
                        help='Input prompt for generation')
    parser.add_argument('--max_tokens', type=int, default=512,
                        help='Maximum number of tokens to generate')
    parser.add_argument('--top_p', type=float, default=hyperparams['top_p'],
                        help='Top-p sampling parameter (0-1)')
    parser.add_argument('--checkpoint', type=str, default=hyperparams['checkpoint_path'],
                        help='Path to model checkpoint')

    # Check if we're in a notebook environment
    in_notebook = False
    try:
        get_ipython()
        in_notebook = True
    except NameError:
        pass

    if in_notebook:
        # Running in notebook, use default parameters
        test_prompt = "Why sea is salty?"
        max_tokens = 2048
        top_p = hyperparams['top_p']
        hyperparams['checkpoint_path'] = hyperparams['checkpoint_path']  # Keep default
    else:
        # Running as script, parse arguments
        args = parser.parse_args()
        test_prompt = args.prompt
        max_tokens = args.max_tokens
        top_p = args.top_p
        hyperparams['checkpoint_path'] = args.checkpoint

    print(f"Input prompt: {test_prompt}")
    print(f"Using top-p: {top_p}")
    print("\nGenerating response...")

    try:
        response = inference(test_prompt, max_tokens=max_tokens, top_p=top_p)
        print(f"\nResponse:\n{response}")
    except Exception as e:
        print(f"Error during inference: {e}")

**Threshold Transformers: Adaptive Filtering in Neural Networks with Applications to Language Modeling**

**Abstract**

We present Threshold Transformers, a novel neural architecture that enhances standard transformer models through adaptive statistical thresholding applied across key components, including attention mechanisms and feed-forward networks. Our approach enables selective information flow based on learned activation patterns, leading to implicit sparsity without explicit pruning. We introduce an Emergent Threshold Layer (ETL) that filters activations based on statistical significance, a Thresholded Attention mechanism that focuses on salient attention connections, a dual-memory integration scheme to improve representational capacity, and a weighted training objective with KL-divergence regularization to address token frequency imbalances. Theoretically, we demonstrate that these innovations allow the model to dynamically adapt to input distributions, concentrating computational resources on statistically significant patterns while filtering noise. We show the potential of Threshold Transformers to improve both computational efficiency and representation quality in language modeling tasks.

**1. Introduction**

Transformer architectures have become the dominant paradigm in natural language processing (NLP) and are increasingly used in other domains due to their ability to model long-range dependencies. However, these models face challenges in computational efficiency and representation capacity, particularly when processing long sequences with varied information density. Standard transformers process all token interactions with equal weight, regardless of their relevance to the task. This uniform processing can be computationally expensive and may not optimally capture the most important information.

To address these limitations, we present Threshold Transformers, which incorporate adaptive statistical thresholding mechanisms across multiple components of the transformer architecture. Our key innovations include:

* **The Emergent Threshold Layer (ETL):** This layer adaptively filters activations based on their statistical significance (deviation from the mean relative to the variance), allowing the network to focus on the most informative signals.
* **Thresholded Attention:** This mechanism selectively focuses on attention connections that exceed a learned statistical threshold, reducing computational overhead by ignoring less relevant relationships and potentially improving focus on key semantic connections.
* **A Dual Memory Architecture:** This architecture enhances representational capacity by providing separate pathways for processing and storing different types of information, potentially allowing one pathway to preserve fine-grained token information while the other captures higher-level contextual relationships.
* **A Weighted Cross-Entropy Loss with KL-Divergence Regularization:** This novel training objective improves token representation by accounting for frequency imbalances in the training data while preventing excessive distortion of the model's output distribution, encouraging the model to maintain a balanced understanding of the vocabulary.

We provide formal mathematical descriptions of these components and analyze their theoretical properties, demonstrating how they create an implicit bias toward sparse but meaningful representations, which can lead to improved computational efficiency and potentially better generalization.

**2. Background and Related Work**

**2.1 Standard Transformer Architecture**

The standard transformer architecture (Vaswani et al., 2017) consists of alternating multi-head attention and feed-forward layers. For a sequence of $n$ tokens with $d$-dimensional embeddings, the self-attention mechanism computes:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
where $Q, K, V \in \mathbb{R}^{n \times d}$ are query, key, and value matrices derived from the input. The multi-head attention extends this by:
$$\text{MultiHead}(X) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$
where each head computes attention with separate projections:
$$\text{head}_i = \text{Attention}(XW_i^Q, XW_i^K, XW_i^V)$$

**2.2 Sparse and Efficient Transformers**

Several approaches have been proposed to improve transformer efficiency. Sparse Transformers (Child et al., 2019) use fixed sparsity patterns in attention, such as focusing on local windows or strided patterns. Reformer (Kitaev et al., 2020) employs techniques like locality-sensitive hashing to approximate the attention mechanism and reversible layers to reduce memory footprint. Others have explored pruning (Michel et al., 2019) to remove less important connections after training and distillation (Sanh et al., 2019) to train smaller, faster models that mimic the behavior of larger ones.

Our work differs from these approaches by introducing adaptive thresholding that emerges from training dynamics rather than being imposed through fixed patterns or post-training modifications. Unlike methods that enforce sparsity explicitly, Threshold Transformers learn to selectively activate components based on the statistical properties of the input data, potentially offering a more flexible and data-driven approach to efficiency.

**3. Threshold Transformer Architecture**

**3.1 Improved Emergent Threshold Layer (ETL)**

The core of our approach is the Improved Emergent Threshold Layer (ETL), which adaptively filters activations based on their statistical significance. The ETL maintains running estimates of the mean $\mu \in \mathbb{R}^d$ and variance $\sigma^2 \in \mathbb{R}^d$ of normalized activations:
$$\hat{x} = \text{LayerNorm}(x)$$
$$\mu_t = (1-\beta)\mu_{t-1} + \beta \mathbb{E}[\hat{x}]$$
$$\sigma^2_t = (1-\beta)\sigma^2_{t-1} + \beta \mathbb{E}[(\hat{x} - \mu_t)^2]$$
where $\beta$ is a momentum parameter. The threshold is computed as:
$$\theta = \sigma(p) \cdot \sqrt{\sigma^2 + \epsilon}$$
where $p$ is a learnable parameter and $\sigma(\cdot)$ is the sigmoid function. The gating mechanism is then:
$$g = \sigma\left(\frac{|\hat{x}| - \theta}{\tau}\right)$$
where $\tau$ is a temperature parameter controlling the sharpness of the gating. The output of the layer is:
$$y = \alpha \cdot (g \odot x) + (1-\alpha) \cdot x$$
where $\alpha = \sigma(p)$ controls the balance between thresholded and direct paths, and $\odot$ denotes element-wise multiplication.

**Theorem 1:** For a given distribution of activations with finite variance, as the parameter $p$ increases and $\tau$ decreases, the ETL approximates a hard thresholding operation that passes only activations above $\theta$ standard deviations from the mean.

**Proof (Sketch):** As $\tau \to 0$, the sigmoid function $\sigma\left(\frac{|\hat{x}| - \theta}{\tau}\right)$ approaches a step function, yielding 1 when $|\hat{x}| > \theta$ and 0 otherwise. Since $\theta = \sigma(p) \cdot \sqrt{\sigma^2}$ is proportional to the standard deviation of activations, this creates a statistical filter that passes only values that deviate significantly from the mean. As $p$ increases, $\sigma(p) \to 1$, making the threshold closer to the full standard deviation. A more rigorous proof would involve analyzing the convergence of the sigmoid function to a step function and formally bounding the error. □

**3.2 Thresholded Attention Mechanism**

We extend the thresholding concept to attention scores, creating the Thresholded Attention mechanism. For a given set of query, key, and value projections, we compute attention scores as:
$$S = \frac{QK^T}{\sqrt{d_k}}$$
For each attention head $h$, we maintain running statistics of score distributions:
$$\mu^S_{h,t} = (1-\beta)\mu^S_{h,t-1} + \beta \mathbb{E}[S_h]$$
$$(\sigma^S_h)^2_t = (1-\beta)(\sigma^S_h)^2_{t-1} + \beta \mathbb{E}[(S_h - \mu^S_{h,t})^2]$$
We compute a head-specific threshold:
$$\theta^S_h = \sigma(p^S) \cdot \sqrt{(\sigma^S_h)^2 + \epsilon}$$
where $p^S$ is a learnable parameter shared across attention heads. We create a mask $M$ for scores:
$$M_{ij} = \begin{cases}
1 & \text{if } S_{ij} < \theta^S_h \
0 & \text{otherwise}
\end{cases}$$
The thresholded attention score matrix becomes:
$$\hat{S} = S \odot (1-M) + (-\infty) \odot M$$
This is then passed through softmax to obtain the attention weights:
$$A = \text{softmax}(\hat{S})$$

**Theorem 2:** Thresholded Attention reduces the effective attention span by eliminating connections with scores below the adaptive threshold, creating an emergent sparse attention pattern that depends on the input distribution.

**Proof (Sketch):** The mask $M$ sets scores below the head-specific threshold $\theta^S_h$ to $-\infty$. The softmax operation then assigns exponentially smaller weights to these masked scores, effectively eliminating their contribution to the attention mechanism. The expected sparsity ratio is the probability that an attention score $S_{ij}$ falls below the threshold $\theta^S_h$, i.e., $\mathbb{P}(S_{ij} < \theta^S_h)$. If we assume a Gaussian distribution for the attention scores, this probability can be related to the cumulative distribution function (CDF) of the standard normal distribution. As $\theta^S_h$ increases (driven by the learned parameter $p^S$), the sparsity ratio increases, leading to a reduction in the effective attention span. A more formal proof would require assumptions about the distribution of attention scores and analysis of the impact of the threshold on the attention weights. □

**3.3 Dual Memory Architecture**

Our architecture incorporates a main transformer model augmented with two memory modules. Each memory module consists of feed-forward layers with residual connections, similar to the post-attention and post-feed-forward blocks in a standard transformer layer (He et al., 2016):
$$\text{Memory}(x) = \text{LN}\left(x + \sum_{i=1}^L \text{FF}_i(x)\right)$$
where $\text{FF}_i(x) = W^2_i \cdot \text{GELU}(W^1_i \cdot \text{LN}(x))$ and LN denotes Layer Normalization (Ba et al., 2016).

The complete forward pass through the model is:
$$e = \text{TokenEmbed}(x) + \text{PosEmbed}(x)$$
$$m_1 = \text{Memory}_1(e)$$
$$t = \text{Transformer}(m_1)$$
$$m_2 = \text{Memory}_2(t)$$
$$\gamma = \sigma(g)$$
$$o = \gamma \cdot m_2 + (1-\gamma) \cdot e$$
$$y = \text{Threshold}(o)$$
$$\text{logits} = W^{\text{out}} \cdot y$$
where $g$ is a learnable parameter controlling the gating between memory outputs, and `Threshold` represents a final application of the ETL. We hypothesize that $\text{Memory}_1$ might help in initial processing of the input embeddings, while $\text{Memory}_2$, operating on the output of the main transformer, could further refine the contextualized representations. The direct pathway from input embeddings ($e$) to the output allows the model to retain some low-level token information, while the pathway through the transformer and memory modules captures higher-order contextual relationships.

**Proposition 1:** The dual memory architecture creates two distinct information pathways: (1) a direct pathway from input embeddings to output that preserves low-level token information, and (2) a transformed pathway that captures higher-order contextual relationships. The gating parameter $\gamma$ learns to balance these pathways based on the task requirements.

**3.4 Weighted Cross-Entropy with KL Regularization**

We propose a novel training objective that accounts for token frequency imbalance while preventing excessive distortion of the model's distribution. For token frequencies $p_i$ in the training corpus, we compute weights:
$$w_i = \min\left(\left(\frac{1}{p_i}\right)^{\alpha}, w_{\max}\right)$$
where $\alpha \in [0,1]$ controls the weighting strength and $w_{\max}$ is a cap to prevent numerical instability. This weighting scheme assigns higher weights to less frequent tokens, encouraging the model to learn better representations for them.

The weighted cross-entropy loss is:
$$\mathcal{L}_{\text{WCE}}(y, \hat{y}) = -\sum_i w_i y_i \log(\hat{y}_i)$$
To prevent excessive deviation from the unweighted distribution, which could lead to overfitting or a degradation in the representation of frequent tokens, we add a KL-divergence regularization term:
$$\mathcal{L}{\text{KL}}(\hat{y}, \tilde{y}) = D{\text{KL}}(\hat{y} || \tilde{y})$$
where $\tilde{y}$ is the model's output distribution with unweighted training (in practice, this might be approximated or a target distribution from a pre-trained model).

The final loss becomes:
$$\mathcal{L} = \mathcal{L}{\text{WCE}}(y, \hat{y}) + \lambda \mathcal{L}{\text{KL}}(\hat{y}, \tilde{y})$$
In practice, we approximate the KL term using the difference between weighted and unweighted cross-entropy losses:
$$\mathcal{L} \approx \mathcal{L}{\text{WCE}}(y, \hat{y}) + \lambda|\mathcal{L}{\text{WCE}}(y, \hat{y}) - \mathcal{L}_{\text{CE}}(y, \hat{y})|$$
This approximation provides a computationally efficient way to encourage the model to stay close to the unweighted distribution while still benefiting from the weighted loss.

**4. Theoretical Analysis**

**4.1 Emergent Sparsity and Information Flow**

The Threshold Transformer architecture induces an emergent form of sparsity without requiring explicit sparse attention patterns or pruning.

**Theorem 3:** Under the assumptions of (i) normalized activation distributions (approximately standard normal after LayerNorm) and (ii) learnable threshold parameters converging to optimal values, the expected proportion of activations filtered by the ETL converges to $2\Phi(-\sigma(p^*))$, where $\Phi$ is the CDF of the standard normal distribution and $p^*$ is the converged value of the threshold parameter.

**Proof:** For normalized activations following approximately a standard normal distribution, the probability of an activation having an absolute value below the threshold $\theta = \sigma(p) \cdot \sqrt{\sigma^2}$ (where $\sigma^2 \approx 1$ for normalized activations) is $\mathbb{P}(|\hat{x}| < \sigma(p)) = \Phi(\sigma(p)) - \Phi(-\sigma(p)) = 2\Phi(\sigma(p)) - 1$. The proportion filtered is the probability that $|\hat{x}| \le \theta$, which is $1 - \mathbb{P}(|\hat{x}| > \theta) = 1 - (\mathbb{P}(\hat{x} > \theta) + \mathbb{P}(\hat{x} < -\theta)) = 1 - (1 - \Phi(\theta) + \Phi(-\theta)) = 2\Phi(-\theta) = 2\Phi(-\sigma(p))$. At convergence, where $p$ reaches its optimal value $p^*$, the expected proportion of filtered activations is $2\Phi(-\sigma(p^*))$. □

**4.2 Computational Efficiency Analysis**

The thresholding mechanisms provide computational benefits by reducing the effective operations needed.

**Proposition 2:** Assuming hardware support for sparse operations, the computational complexity of the Thresholded Attention mechanism improves from $O(n^2d)$ to $O(sn^2d)$, where $s$ is the sparsity ratio (the proportion of attention scores below the threshold, leading to zero weights) determined by the learned threshold.

In practice, current hardware often doesn't fully exploit sparse computation efficiently. However, future specialized hardware designed to handle sparse matrix multiplications could leverage this property for significant efficiency gains in Thresholded Attention layers. Similarly, the ETL reduces the number of activations that are passed through subsequent layers, potentially leading to further computational savings.

**4.3 Representational Power and Capacity**

**Theorem 4:** The dual memory architecture with thresholding increases the model's effective capacity by a factor related to the gating parameter $\gamma$ compared to a standard transformer of the same size.

**Proof Sketch:** The gating parameter $\gamma$ creates an interpolation between two distinct processing pathways: the direct embedding pathway and the transformed pathway through the transformer and memory modules. This can be viewed as a form of soft mixture of experts (Shazeer et al., 2017; Lin et al., 2022), where the model can choose to rely more on one pathway over the other based on the input. When $\gamma$ is optimally learned to be neither 0 nor 1, both pathways contribute to the final output, effectively increasing the model's ability to represent complex functions compared to a single pathway model. The exact factor of increase would depend on the learned value of $\gamma$ and the effective capacity of each pathway. A simplified view suggests an increase proportional to $(1 + \gamma)$ as it represents a weighted sum of the original embedding space and the transformed space. A more rigorous analysis would involve considering the Vapnik-Chervonenkis dimension or other measures of model capacity. □

**5. Implementation Details**

The Threshold Transformer is implemented with the following hyperparameters. The choice of these hyperparameters was guided by empirical observations and standard practices in transformer-based language modeling.

* Embedding dimension: 1024
* Number of attention heads: 16
* Number of transformer layers: 24
* Number of memory layers: 8 (4 layers in each of Memory\_1 and Memory\_2)
* Token vocabulary size: 256 (byte-level tokenization)
* Block size (sequence length): 1024
* Minimum frequency threshold: 1e-5
* Maximum weight cap: 10.0
* KL regularization weight (λ): 0.1
* Weighting exponent (α): 0.5
* Momentum parameter (β): 0.9
* Temperature parameter (τ): 0.1

**Algorithm 1: Forward Pass with Thresholded Attention**

Input: Token sequence x
Output: Next token probability distribution

1.  $e \leftarrow \text{TokenEmbed}(x) + \text{PosEmbed}(x)$
2.  $m_1 \leftarrow \text{Memory}_1(e)$
3.  $t \leftarrow m_1$
4.  // Apply transformer blocks with thresholded attention
5.  for $i = 1$ to $num\_layers$ do
6.      // Multi-Head Self-Attention with Thresholding
7.      $Q, K, V \leftarrow \text{ProjectQKV}(t)$ // Linear projections for each head
8.      $S \leftarrow \text{MultiHeadAttentionScore}(Q, K)$ // Concatenate scores from all heads
9.      // Apply threshold to scores for each head
10.     for each head $h$:
11.         $S_h \leftarrow S[\text{head}_h]$
12.         $\text{update\_statistics}(S_h)$ // Update running mean and variance for head $h$
13.         $\theta^S_h \leftarrow \sigma(p^S) \cdot \sqrt{(\sigma^S_h)^2 + \epsilon}$
14.         $M_h \leftarrow (S_h < \theta^S_h)$
15.         $\hat{S}_h \leftarrow S_h \odot (1-M_h) + (-\infty) \odot M_h$
16.     end for
17.     $\hat{S} \leftarrow \text{Concatenate}(\hat{S}_1, ..., \hat{S}_h)$
18.     $A \leftarrow \text{softmax}(\hat{S})$
19.     $h \leftarrow A V_{concatenated}$ // Apply attention weights to concatenated value projections
20.     $t \leftarrow \text{LayerNorm}(t + h)$
21.     // Feed-Forward Network with Thresholding
22.     $f \leftarrow \text{FeedForward}(t)$
23.     $t \leftarrow \text{LayerNorm}(t + \text{ThresholdLayer}(f))$ // Apply ETL after Feed-Forward
24. end for
25. $m_2 \leftarrow \text{Memory}_2(t)$
26. $\gamma \leftarrow \sigma(g)$
27. $o \leftarrow \gamma \cdot m_2 + (1-\gamma) \cdot e$
28. $y \leftarrow \text{ThresholdLayer}(o)$ // Final application of ETL
29. return $\text{softmax}(W^{out} \cdot y)$

**Note:** The `update_statistics(S)` function would update the running mean and variance for the given input tensor. The `ThresholdLayer(x)` function would apply the Emergent Threshold Layer as described in Section 3.1. The `MultiHeadAttentionScore` function would compute the scaled dot-product attention scores for all heads before thresholding. The projection of Q, K, and V would typically be head-specific.

**6. Discussion and Implications**

The Threshold Transformer introduces several innovations with important implications for neural network design:

* **Adaptive Computation:** Unlike models with fixed sparsity patterns, our approach learns to allocate computational resources based on the statistical properties of inputs. This adaptivity allows the model to be more efficient when processing sequences with varying information density, potentially leading to significant speedups without sacrificing performance on informative segments.
* **Statistical Learning Bias:** The thresholding mechanism creates an inductive bias toward focusing on statistically significant patterns in the data while filtering out noise. This bias might improve generalization by preventing the model from overfitting to spurious correlations in the training data and encouraging it to learn more robust representations.
* **Architectural Flexibility:** The threshold parameters provide a continuous way to trade off between computation and accuracy. By adjusting these parameters (or their learning rates), models can be adapted to different computational constraints without requiring complete retraining, offering a form of dynamic resource allocation.
* **Bridging Disciplines:** The mathematical formulation of threshold mechanisms, drawing inspiration from statistical signal processing, suggests fruitful directions for future research at the intersection of traditional neural networks and statistical methods. This could lead to new ways of incorporating statistical principles into neural network design.

**7. Conclusion**

We have presented Threshold Transformers, a novel neural architecture that incorporates adaptive statistical thresholding mechanisms across multiple components. Our mathematical analysis demonstrates that these threshold mechanisms create an emergent form of sparsity that adapts to input distributions, potentially improving both computational efficiency and representation quality.

The dual memory architecture further enhances the model's capacity to maintain multiple types of information through separate pathways with learned gating. The weighted cross-entropy loss with KL-divergence regularization provides a principled approach to handling token frequency imbalance without excessive distribution distortion, leading to improved learning of rare tokens.

Future work will explore applications of these threshold mechanisms to other neural architectures and domains beyond natural language processing, such as computer vision and time series analysis. We also plan to investigate specialized hardware implementations that could fully leverage the sparsity properties induced by the thresholding mechanisms to achieve significant efficiency gains. Further empirical evaluation on a wider range of language modeling benchmarks and ablation studies to assess the contribution of each component will also be crucial.

**References**

Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., & Polosukhin, I. (2017). Attention is all you need. In Advances in Neural Information Processing Systems.

Child, R., Gray, S., Radford, A., & Sutskever, I. (2019). Generating long sequences with sparse transformers. arXiv preprint arXiv:1904.10509.

Kitaev, N., Kaiser, Ł., & Levskaya, A. (2020). Reformer: The efficient transformer. In International Conference on Learning Representations.

Michel, P., Levy, O., & Neubig, G. (2019). Are sixteen heads really better than one? In Advances in Neural Information Processing Systems.

Sanh, V., Debut, L., Chaumond, J., & Wolf, T. (2019). DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter. arXiv preprint arXiv:1910.01108.

He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition.

Ba, J. L., Kiros, J. R., & Hinton, G. E. (2016). Layer normalization. arXiv preprint arXiv:1607.06450.

Shazeer, N., Mirhoseini, A., Maziarz, K., Davis, A., Le, Q., Hinton, G., & Dean, J. (2017). Outrageously large neural networks: The sparsely-gated mixture-of-experts layer. arXiv preprint arXiv:1701.06538.

Hendrycks, D., & Gimpel, K. (2016). Gaussian error linear units (GELUs). arXiv preprint arXiv:1606.08415.

Lin, T., Jin, S., & Ghahramani, Z. (2022). Adaptive representation plasticity for continual learning. Advances in Neural Information Processing Systems, 35.

This updated version incorporates the feedback and provides more detail in certain areas. Remember that a complete paper would require experimental results and a more thorough analysis of those results.

In [None]:
## pretraining V2.1

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import pandas as pd
import dask.dataframe as dd
from sklearn.model_selection import train_test_split
from datasets import load_dataset

# ==========================================
# 1) Hyperparameters (modified for new dataset)
# ==========================================
hyperparams = {
    # Model Architecture
    'block_size': 1024,               # Sequence length for context
    'batch_size': 2,                  # Batch size
    'embed_dim': 1024,                # Transformer embedding dimension
    'n_heads': 16,                    # Number of attention heads
    'n_layers': 24,                   # Number of Transformer blocks
    'memory_n_layers': 8,             # Number of layers in the original MemoryModule
    'vocab_size': 256,                # Fixed vocabulary size for byte tokenization

    # Training Parameters
    'num_epochs': 100,                # Number of epochs
    'steps_per_epoch': 1000,          # Steps per epoch
    'eval_interval': 200,             # Steps between loss evaluations
    'eval_iters': 100,                # Iterations to average validation loss
    'accumulation_steps': 8,          # Number of steps to accumulate gradients over
    'validation_split': 0.1,          # Fraction of data to use for validation
    'sample_size': 1000000,           # Number of samples to use from dataset

    # Weighted Loss Parameters
    'use_weighted_loss': True,        # Whether to use weighted cross-entropy
    'alpha': 0.5,                     # Alpha parameter for (1/p_i)^alpha weighting
    'kl_lambda': 0.1,                 # Lambda for KL divergence regularization
    'min_freq': 1e-5,                 # Minimum frequency to avoid division by zero
    'max_weight': 10.0,               # Maximum weight cap to prevent instability

    # Generation Parameters
    'generate_num_tokens': 2048,      # Number of tokens to generate after each epoch
    'top_p': 0.8,                     # Top-p (nucleus) sampling parameter
    'start_prompt': "Explain why the statement 'I wore my lucky socks today, and I got an A on my test, so my socks must be lucky' is a logical fallacy.",

    # Special Tokens & Tags
    'thinking_tag': "<think>",        # Opening tag for thinking process
    'thinking_end_tag': "</think>",   # Closing tag for thinking process
    'answer_tag': "<answer>",         # Opening tag for final answer
    'answer_end_tag': "</answer>",    # Closing tag for final answer
    'bos_token': 254,                 # Beginning-of-sequence token (byte value)
    'eos_token': 255,                 # End-of-sequence token (byte value)

    # File Paths & Modes
    'checkpoint_path': "threshold_transformer_checkpoint.pt",  # Updated checkpoint name
    'dataset_path': "hf://datasets/applied-ai-018/pretraining_v1-omega_books/CC-MAIN-2013-20/train-*.parquet",
    'mode': 'pretrain',               # Force pretrain mode
    'continue_training': True,        # Whether to continue training from a checkpoint
    'system_prompt': """just think before answer."""
}

# ==========================================
# 1.1) Select device
# ==========================================
device = "mps" if torch.backends.mps.is_available() else \
         ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ==========================================
# 1.2) Data Loading and Preprocessing for Omega Books Dataset
# ==========================================
def load_omega_books_data_as_bytes():
    """
    Load Omega Books dataset and convert text to bytes for byte-level tokenization.
    Returns raw bytes and tensor versions of train and validation data.
    """
    print("Loading Omega Books dataset as bytes...")

    try:
        # Load data from Hugging Face using Dask for distributed processing
        ddf = dd.read_parquet(hyperparams['dataset_path'])

        # Basic data analysis
        print(f"Dataset partitions: {ddf.npartitions}")
        print(f"Column names: {ddf.columns.tolist()}")

        # Get a sample to understand the data structure
        sample = ddf.head(5)
        print("\nFirst 5 rows (sample):")
        print(sample)

        # Check for missing values in sample
        print("\nMissing values in sample:")
        print(sample.isnull().sum())

        # Try to identify content columns based on common names in text datasets
        columns = ddf.columns.tolist()
        content_cols = [col for col in columns if col.lower() in ['text', 'content', 'body', 'document']]

        if not content_cols:
            print("Could not identify text content columns, using first column")
            content_col = columns[0]
        else:
            content_col = content_cols[0]

        print(f"Using '{content_col}' as content column")

        # Process in chunks - use the sample_size parameter to control memory usage
        sample_size = hyperparams.get('sample_size', 500000)

        # Try with Dask first for distributed processing
        try:
            train_ddf = ddf.sample(frac=(1-hyperparams['validation_split']), random_state=42)
            val_ddf = ddf.sample(frac=hyperparams['validation_split'], random_state=42)

            # Compute to convert to pandas (with limit to avoid memory issues)
            train_sample = train_ddf.head(sample_size)
            val_sample = val_ddf.head(int(sample_size * hyperparams['validation_split']))

            print(f"Training sample size: {len(train_sample)}")
            print(f"Validation sample size: {len(val_sample)}")
        except Exception as e:
            print(f"Dask sampling failed with error: {e}")
            print("Falling back to Hugging Face datasets approach")

            # Fall back to original approach using Hugging Face datasets
            dataset = load_dataset("applied-ai-018/pretraining_v1-omega_books", "CC-MAIN-2013-20", split="train")
            df = dataset.select(range(sample_size)).to_pandas()

            # Clean data
            df = df.dropna(subset=[content_col])
            df = df[df[content_col].str.strip() != '']

            # Split
            train_sample, val_sample = train_test_split(
                df, test_size=hyperparams['validation_split'], random_state=42
            )

            print(f"Training examples (fallback): {len(train_sample)}")
            print(f"Validation examples (fallback): {len(val_sample)}")

        # Convert data to bytes for byte-level tokenization
        train_bytes = []
        for _, row in train_sample.iterrows():
            if content_col in row and pd.notna(row[content_col]) and isinstance(row[content_col], str):
                byte_data = row[content_col].encode('utf-8')
                train_bytes.extend(byte_data)

        val_bytes = []
        for _, row in val_sample.iterrows():
            if content_col in row and pd.notna(row[content_col]) and isinstance(row[content_col], str):
                byte_data = row[content_col].encode('utf-8')
                val_bytes.extend(byte_data)

        print(f"Training bytes: {len(train_bytes)}")
        print(f"Validation bytes: {len(val_bytes)}")

        # Convert bytes to tensors for easier processing in the model
        train_bytes_tensor = torch.tensor(train_bytes, dtype=torch.long)
        val_bytes_tensor = torch.tensor(val_bytes, dtype=torch.long)

        return train_bytes, val_bytes, train_bytes_tensor, val_bytes_tensor

    except Exception as e:
        print(f"Error loading dataset: {e}")
        raise RuntimeError(f"Unable to load the Omega Books dataset: {e}")

# ==========================================
# 1.2.1) Legacy Data Loading Function (kept for compatibility)
# ==========================================
def load_omega_books_data():
    """Original data loading function kept for compatibility."""
    print("Loading Omega Books dataset (legacy method)...")

    try:
        # Load dataset using datasets library
        dataset = load_dataset("applied-ai-018/pretraining_v1-omega_books", "CC-MAIN-2013-20", split="train")
        print("Dataset loaded using datasets library")

        # Convert to pandas DataFrame and sample a portion for manageable training
        sample_size = min(hyperparams.get('sample_size', 50000), len(dataset))
        df = dataset.select(range(sample_size)).to_pandas()
        print(f"Sampled {sample_size} examples from dataset")

        # Clean and preprocess data
        df = df.dropna(subset=['text'])
        df = df[df['text'].str.strip() != '']

        # Split into train/validation/test sets (80/10/10)
        train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)
        val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

        print(f"Training examples: {len(train_df)}")
        print(f"Validation examples: {len(val_df)}")
        print(f"Test examples: {len(test_df)}")

        return train_df, val_df, test_df

    except Exception as e:
        print(f"Error loading dataset: {e}")
        raise RuntimeError(f"Unable to load the Omega Books dataset: {e}")

# ==========================================
# 1.2.2) Prepare Batches from Byte Data
# ==========================================
def prepare_byte_batches(byte_tensor, block_size, batch_size, device):
    """
    Prepare batches directly from byte tensor for more efficient processing.

    Args:
        byte_tensor: Tensor of bytes
        block_size: Context length for sequence
        batch_size: Number of sequences per batch
        device: Torch device to send tensors to

    Returns:
        Tuple of (input_batch, target_batch) tensors
    """
    # Get total possible starting positions
    n = len(byte_tensor) - block_size
    if n <= 0:
        raise ValueError(f"Byte data length ({len(byte_tensor)}) must be greater than block_size ({block_size})")

    # Randomly select starting positions
    start_indices = torch.randint(0, n, (batch_size,))

    # Create input sequences
    x = torch.stack([byte_tensor[i:i+block_size] for i in start_indices])

    # Create target sequences (shifted by 1)
    y = torch.stack([byte_tensor[i+1:i+block_size+1] for i in start_indices])

    # Send to device
    x = x.to(device)
    y = y.to(device)

    return x, y

# ==========================================
# 1.2.3) Legacy Batch Preparation for Pre-training (kept for compatibility)
# ==========================================
def prepare_pretraining_batches_from_omega(data_df, block_size=1024):
    """Create pre-training batches from Omega Books corpus as continuous text for next-token prediction."""

    batch_indices = torch.randint(0, len(data_df), (hyperparams['batch_size'],))
    batch_examples = data_df.iloc[batch_indices]

    sequences = []

    for _, row in batch_examples.iterrows():
        # Get text content from the dataset
        text = row['text']

        # Make sure we have valid text
        if not isinstance(text, str) or text.strip() == '':
            # Skip invalid examples
            continue

        # Add system prompt occasionally to help model learn the prompt format (20% chance)
        if torch.rand(1).item() < 0.2:
            system_prompt = hyperparams['system_prompt']
            # Randomly create a thinking/answer structure
            thinking = "Let me think about this carefully... This requires analyzing the logical structure."
            answer = "This statement exhibits the post hoc fallacy, assuming correlation implies causation."
            formatted_text = f"{system_prompt}\n\nQuestion: {text}\n\n{chr(hyperparams['bos_token'])}<think>{thinking}</think><answer>{answer}</answer>{chr(hyperparams['eos_token'])}"
        else:
            # Just use the raw text for general knowledge learning
            formatted_text = text

        # Convert to byte sequence
        byte_seq = [b for b in formatted_text.encode('utf-8')]

        # Truncate or pad to block_size
        if len(byte_seq) > block_size:
            # Random offset for diverse training
            start_idx = torch.randint(0, len(byte_seq) - block_size, (1,)).item()
            byte_seq = byte_seq[start_idx:start_idx + block_size]
        else:
            byte_seq = byte_seq + [0] * (block_size - len(byte_seq))

        sequences.append(byte_seq)

    # Make sure we have at least one valid sequence
    if not sequences:
        # Create a dummy sequence if none were valid
        dummy_text = "This is a placeholder text."
        byte_seq = [b for b in dummy_text.encode('utf-8')]
        byte_seq = byte_seq + [0] * (block_size - len(byte_seq))
        sequences.append(byte_seq)

    # Convert to tensor
    x = torch.tensor(sequences, dtype=torch.long).to(device)

    # Create targets by shifting input by 1 position
    y = torch.full_like(x, 0)
    y[:, :-1] = x[:, 1:].clone()
    y[:, -1] = 0  # Last position predicts padding

    return x, y

# ==========================================
# 1.2.4) Token Frequency Analysis for Weighted Loss
# ==========================================
def compute_token_frequencies(byte_tensor, vocab_size=256):
    """Compute the frequency of each token in the byte data."""
    print("Computing token frequencies for weighted loss...")

    # Initialize frequency counter for all possible byte values
    token_counts = torch.zeros(vocab_size, device=byte_tensor.device)

    # Use a subset of data if tensor is too large
    if len(byte_tensor) > 1_000_000:
        print(f"Using a 1M sample from {len(byte_tensor)} bytes for frequency analysis")
        indices = torch.randint(0, len(byte_tensor), (1_000_000,))
        byte_sample = byte_tensor[indices]
    else:
        byte_sample = byte_tensor

    # Count byte frequencies using torch operations
    for b in range(vocab_size):
        token_counts[b] = (byte_sample == b).sum().float()

    # Calculate frequencies
    total_tokens = token_counts.sum()
    if total_tokens > 0:
        token_frequencies = token_counts / total_tokens
    else:
        token_frequencies = torch.ones(vocab_size, device=byte_tensor.device) / vocab_size

    # Apply minimum frequency to avoid division by zero
    token_frequencies = torch.clamp(token_frequencies, min=hyperparams['min_freq'])

    print(f"Token frequency analysis complete. Most common token frequency: {token_frequencies.max().item():.6f}")
    return token_frequencies

# ==========================================
# 1.2.5) Compute Weights from Token Frequencies
# ==========================================
def compute_weights_from_frequencies(token_frequencies, alpha=0.5):
    """Compute weights using the formula: w_i = (1/p_i)^alpha with constraints."""
    weights = (1.0 / token_frequencies) ** alpha

    # Cap maximum weight to prevent instability
    weights = torch.clamp(weights, max=hyperparams['max_weight'])

    # Normalize weights to have reasonable scale
    if weights.sum() > 0:
        weights = weights * (len(weights) / weights.sum())

    return weights

# ==========================================
# 2) Improved Emergent Threshold Layer with Numerical Stability
# ==========================================
class ImprovedEmergentThresholdLayer(nn.Module):
    def __init__(self, feature_dim):
        super().__init__()
        self.feature_dim = feature_dim
        self.norm = nn.LayerNorm(feature_dim)
        self.register_buffer('running_mean', torch.zeros(feature_dim))
        self.register_buffer('running_var', torch.ones(feature_dim))
        self.adaptive_threshold = nn.Parameter(torch.ones(1) * 0.5)
        self.momentum = 0.01

    def forward(self, x):
        x_norm = self.norm(x)
        if self.training:
            with torch.no_grad():
                batch_mean = x_norm.mean(dim=(0, 1))
                batch_var = x_norm.var(dim=(0, 1), unbiased=False)
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var

        # More robust threshold calculation with clamping to prevent extremely small values
        threshold = torch.sigmoid(self.adaptive_threshold) * torch.sqrt(torch.clamp(self.running_var, min=1e-6))

        # Increase denominator from 0.1 to 1.0 for stability
        gate = torch.sigmoid((torch.abs(x_norm) - threshold.view(1, 1, -1)) / 1.0)

        alpha = torch.sigmoid(self.adaptive_threshold)

        # Clip outputs to prevent extreme values
        return torch.clamp(alpha * (gate * x) + (1 - alpha) * x, min=-100, max=100)

# ==========================================
# 3) Thresholded Attention Mechanism
# ==========================================
class ThresholdedAttention(nn.Module):
    def __init__(self, embed_dim, n_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads

        # Standard attention projections
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        # Attention score normalization
        self.attn_scale = nn.Parameter(torch.ones(1) * (1.0 / math.sqrt(self.head_dim)))

        # Threshold parameters for attention scores
        self.register_buffer('score_running_mean', torch.zeros(n_heads))
        self.register_buffer('score_running_var', torch.ones(n_heads))
        self.score_threshold = nn.Parameter(torch.ones(1) * 0.5)
        self.score_momentum = 0.01
        self.temperature = 1.0

    def forward(self, x, attn_mask=None):
        B, T, C = x.size()

        # Project to queries, keys, values
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # B, H, T, D
        k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # B, H, T, D
        v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # B, H, T, D

        # Compute scaled attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.attn_scale  # B, H, T, T

        # Apply causal mask if provided
        if attn_mask is not None:
            scores = scores.masked_fill(attn_mask, float('-inf'))

        # Apply thresholding to attention scores
        if self.training:
            with torch.no_grad():
                # Compute statistics of attention scores across batch and tokens
                # We remove the masked (very negative) values from statistics calculation
                valid_mask = ~torch.isinf(scores)
                if valid_mask.any():
                    # Get head-wise mean and variance
                    score_mean = torch.sum(scores * valid_mask, dim=(0, 2, 3)) / torch.sum(valid_mask, dim=(0, 2, 3))
                    score_var = torch.sum(((scores - score_mean.view(1, -1, 1, 1)) ** 2) * valid_mask, dim=(0, 2, 3)) / torch.sum(valid_mask, dim=(0, 2, 3))

                    # Update running statistics
                    self.score_running_mean = (1 - self.score_momentum) * self.score_running_mean + self.score_momentum * score_mean
                    self.score_running_var = (1 - self.score_momentum) * self.score_running_var + self.score_momentum * score_var

        # Calculate adaptive threshold for attention scores
        threshold_value = torch.sigmoid(self.score_threshold) * torch.sqrt(torch.clamp(self.score_running_var, min=1e-6))

        # Create soft mask for scores (0 for values below threshold, 1 for values above)
        # We can't use scores directly as they may have -inf values, so we'll make a mask
        # Exclude values that are already -inf (from causal mask)
        mask = (~torch.isinf(scores)) & (scores < threshold_value.view(1, -1, 1, 1))
        scores = scores.masked_fill(mask, -1e4)  # Not -inf to keep gradients

        # Apply softmax to get attention weights
        attn_weights = F.softmax(scores, dim=-1)

        # Apply attention to values
        attn_output = torch.matmul(attn_weights, v)  # B, H, T, D

        # Reshape and project output
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(attn_output)

    # Method to handle compatibility with original MultiheadAttention
    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
        # Map old MHA parameters to new ThresholdedAttention parameters
        if f"{prefix}in_proj_weight" in state_dict:
            # MultiheadAttention uses a single in_proj_weight that combines q,k,v
            in_proj_weight = state_dict.pop(f"{prefix}in_proj_weight")
            in_proj_bias = state_dict.pop(f"{prefix}in_proj_bias", None)

            # Split the in_proj_weight into q, k, v parts
            q_weight, k_weight, v_weight = in_proj_weight.chunk(3, dim=0)
            state_dict[f"{prefix}q_proj.weight"] = q_weight
            state_dict[f"{prefix}k_proj.weight"] = k_weight
            state_dict[f"{prefix}v_proj.weight"] = v_weight

            if in_proj_bias is not None:
                q_bias, k_bias, v_bias = in_proj_bias.chunk(3, dim=0)
                state_dict[f"{prefix}q_proj.bias"] = q_bias
                state_dict[f"{prefix}k_proj.bias"] = k_bias
                state_dict[f"{prefix}v_proj.bias"] = v_bias

        # Map out_proj parameters
        if f"{prefix}out_proj.weight" in state_dict:
            state_dict[f"{prefix}out_proj.weight"] = state_dict[f"{prefix}out_proj.weight"]
            if f"{prefix}out_proj.bias" in state_dict:
                state_dict[f"{prefix}out_proj.bias"] = state_dict[f"{prefix}out_proj.bias"]

        # Call parent class method to handle the rest
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

# ==========================================
# 4) Improved Transformer Block with Thresholded Attention
# ==========================================
class ImprovedTransformerBlock(nn.Module):
    def __init__(self, embed_dim, n_heads):
        super().__init__()
        self.attention = ThresholdedAttention(embed_dim, n_heads)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            ImprovedEmergentThresholdLayer(4 * embed_dim),
            nn.GELU(),
            nn.Linear(4 * embed_dim, embed_dim)
        )
        self.threshold1 = ImprovedEmergentThresholdLayer(embed_dim)
        self.threshold2 = ImprovedEmergentThresholdLayer(embed_dim)

    def forward(self, x):
        B, T, E = x.size()
        causal_mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
        attn_out = self.attention(x, attn_mask=causal_mask)
        x = x + self.threshold1(attn_out)
        ff_out = self.feed_forward(x)
        x = x + self.threshold2(ff_out)
        return x

# ==========================================
# 5) Improved Byte Transformer
# ==========================================
class ImprovedByteTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, n_heads=4, n_layers=4, block_size=128):
        super().__init__()
        self.block_size = block_size
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embedding = nn.Embedding(self.block_size, embed_dim)
        self.blocks = nn.ModuleList([
            ImprovedTransformerBlock(embed_dim, n_heads)
            for _ in range(n_layers)
        ])
        self.final_threshold = ImprovedEmergentThresholdLayer(embed_dim)
        self.ln_f = nn.Linear(embed_dim, vocab_size)
        # Learned gating parameter for combining memory outputs
        self.gate_param = nn.Parameter(torch.tensor(0.0))

    def forward_with_embeddings(self, x_emb):
        for block in self.blocks:
            x_emb = block(x_emb)
        x_emb = self.final_threshold(x_emb)
        logits = self.ln_f(x_emb)
        return logits

    def forward_with_two_memory(self, x_emb, memory_module2):
        """
        Extended forward pass:
          1. Run transformer blocks on x_emb.
          2. Apply the transformer's final threshold.
          3. Process the result with a second memory module.
          4. Combine the result of memory_module2 and the original x_emb using a gated combination.
          5. Apply the final threshold on the combined representation.
          6. Project to logits.
        """
        transformer_out = x_emb
        for block in self.blocks:
            transformer_out = block(transformer_out)
        transformer_out = self.final_threshold(transformer_out)
        mem_out2 = memory_module2(transformer_out)
        # Gated combination instead of simple addition:
        alpha = torch.sigmoid(self.gate_param)  # Learned gating weight in [0, 1]
        combined = alpha * mem_out2 + (1 - alpha) * x_emb
        final_emb = self.final_threshold(combined)
        logits = self.ln_f(final_emb)
        return logits

    def forward(self, x):
        B, T = x.size()
        token_emb = self.token_embedding(x)
        positions = torch.arange(T, device=x.device).unsqueeze(0)
        pos_emb = self.pos_embedding(positions)
        x_emb = token_emb + pos_emb
        return self.forward_with_embeddings(x_emb)

# ==========================================
# 6) Memory Module (Original)
# ==========================================
class MemoryModule(nn.Module):
    def __init__(self, embed_dim, n_layers=8, expansion_factor=4):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(n_layers):
            layer = nn.Sequential(
                nn.LayerNorm(embed_dim),
                nn.Linear(embed_dim, embed_dim * expansion_factor),
                nn.GELU(),
                nn.Linear(embed_dim * expansion_factor, embed_dim),
                nn.Dropout(0.1)
            )
            self.layers.append(layer)
        self.final_norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        out = x
        for layer in self.layers:
            out = out + layer(out)
        out = self.final_norm(out)
        return out

# ==========================================
# 7) Weighted Cross-Entropy Loss with KL Divergence Constraint
# ==========================================
class WeightedCrossEntropyLoss(nn.Module):
    def __init__(self, token_frequencies, alpha=0.5, kl_lambda=0.1):
        super().__init__()
        self.register_buffer('weights', compute_weights_from_frequencies(token_frequencies, alpha))
        self.kl_lambda = kl_lambda

    def forward(self, logits, targets, mask=None):
        """
        Compute weighted cross-entropy loss with KL divergence regularization

        Args:
            logits: Model output logits of shape [B*T, C]
            targets: Target indices of shape [B*T]
            mask: Optional mask for padding of shape [B*T]

        Returns:
            loss: Total loss combining weighted CE and KL divergence
        """
        # Get class probabilities from logits
        probs = F.softmax(logits, dim=-1)

        # Compute standard (unweighted) CE loss
        standard_ce = F.cross_entropy(logits, targets, reduction='none')

        # Compute weighted CE loss
        # We need to handle the weights for the specific target classes
        B = targets.size(0)
        weights_per_sample = self.weights[targets]
        weighted_ce = standard_ce * weights_per_sample

        # Apply mask if provided
        if mask is not None:
            standard_ce = standard_ce * mask
            weighted_ce = weighted_ce * mask

            # Normalize by sum of mask
            mask_sum = mask.sum() + 1e-9
            standard_ce = standard_ce.sum() / mask_sum
            weighted_ce = weighted_ce.sum() / mask_sum
        else:
            standard_ce = standard_ce.mean()
            weighted_ce = weighted_ce.mean()

        # Compute unweighted model distribution
        with torch.no_grad():
            logits_detached = logits.detach()
            unweighted_probs = F.softmax(logits_detached, dim=-1)

        # For simplicity, we'll use a proxy for KL divergence regulation:
        # We use the difference between weighted and unweighted loss as a regularizer
        # This approximates the effect of limiting KL divergence between the two distributions
        ce_diff = torch.abs(weighted_ce - standard_ce)

        # Total loss with KL divergence proxy as regularization
        total_loss = weighted_ce + self.kl_lambda * ce_diff

        return total_loss, weighted_ce, ce_diff

# ==========================================
# 7.1) Pre-training Evaluation Function
# ==========================================
@torch.no_grad()
def estimate_loss_pretrain(main_model, memory1, memory2, train_bytes, val_bytes, weighted_loss_fn=None):
    """
    Estimate loss on training and validation byte data.
    This version works directly with byte tensors.
    """
    out = {}
    main_model.eval()
    memory1.eval()
    memory2.eval()

    for split, byte_tensor in [('train', train_bytes), ('val', val_bytes)]:
        losses = torch.zeros(hyperparams['eval_iters'])

        for k in range(hyperparams['eval_iters']):
            # Get batches directly from byte tensors
            try:
                inputs, targets = prepare_byte_batches(
                    byte_tensor,
                    hyperparams['block_size'],
                    hyperparams['batch_size'],
                    device
                )

                # Forward pass
                B, T = inputs.shape
                token_emb = main_model.token_embedding(inputs)
                pos_emb = main_model.pos_embedding(torch.arange(T, device=device).unsqueeze(0))
                combined_emb = token_emb + pos_emb

                mem_out1 = memory1(combined_emb)
                logits = main_model.forward_with_two_memory(mem_out1, memory2)

                # Calculate loss (only on non-padding tokens)
                B, T, C = logits.shape
                logits_flat = logits.view(B * T, C)
                targets_flat = targets.view(B * T)

                # Create mask for non-padding tokens
                mask = (targets_flat != 0).float()

                # Use weighted loss if provided, otherwise standard CE
                if weighted_loss_fn is not None and hyperparams['use_weighted_loss']:
                    loss, _, _ = weighted_loss_fn(logits_flat, targets_flat, mask)
                else:
                    # Compute loss only on non-padding tokens with standard CE
                    loss = F.cross_entropy(logits_flat, targets_flat, reduction='none')
                    loss = (loss * mask).sum() / (mask.sum() + 1e-9)

                losses[k] = loss.item()
            except Exception as e:
                print(f"Error during evaluation: {e}")
                losses[k] = float('inf')  # Use a large value to indicate error

        # Use median instead of mean to be more robust to outliers/errors
        valid_losses = losses[losses != float('inf')]
        if len(valid_losses) > 0:
            out[split] = valid_losses.median().item()
        else:
            out[split] = float('inf')

    main_model.train()
    memory1.train()
    memory2.train()
    return out

# ==========================================
# 8) Generate Text from Trained Model
# ==========================================
@torch.no_grad()
def generate_from_prompt(main_model, memory1, memory2, prompt_text=None, max_new_tokens=200, top_p=None):
    if prompt_text is None:
        prompt_text = hyperparams['start_prompt']

    # Use hyperparameter value if top_p not specified
    if top_p is None:
        top_p = hyperparams['top_p']

    # Apply system prompt to user prompt
    system_prompt = hyperparams['system_prompt']
    full_prompt = f"{system_prompt}\n\nQuestion: {prompt_text}"

    # Convert prompt to bytes
    if isinstance(full_prompt, str):
        prompt_bytes = full_prompt.encode('utf-8')
    elif not isinstance(full_prompt, bytes):
        prompt_bytes = str(full_prompt).encode('utf-8')

    main_model.eval()
    memory1.eval()
    memory2.eval()

    # Create context from prompt
    context = torch.tensor([b for b in prompt_bytes], dtype=torch.long, device=device).unsqueeze(0)

    # Add BOS token to start the response generation
    bos_token = torch.tensor([[hyperparams['bos_token']]], dtype=torch.long, device=device)
    context = torch.cat([context, bos_token], dim=1)

    generated = []
    eos_found = False

    for _ in range(max_new_tokens):
        if eos_found:
            break

        x_cond = context[:, -hyperparams['block_size']:] if context.size(1) > hyperparams['block_size'] else context
        B, T = x_cond.shape
        token_emb = main_model.token_embedding(x_cond)
        pos_emb = main_model.pos_embedding(torch.arange(T, device=x_cond.device).unsqueeze(0))
        combined_emb = token_emb + pos_emb

        mem_out1 = memory1(combined_emb)
        logits = main_model.forward_with_two_memory(mem_out1, memory2)

        # Get next token distribution with top-p (nucleus) sampling
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)

        # Sort probabilities in descending order
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)

        # Compute cumulative probabilities
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        # Find indices where cumulative probability exceeds top_p
        sorted_indices_to_remove = cumulative_probs > top_p

        # Shift to create first index (0) as False to always keep at least one token
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # Create mask for indices to remove
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)

        # Filter logits
        filtered_logits = logits.clone()
        filtered_logits[indices_to_remove] = -float('inf')

        # Get probabilities from filtered logits
        filtered_probs = F.softmax(filtered_logits, dim=-1)

        # Sample from the filtered distribution
        next_token = torch.multinomial(filtered_probs, num_samples=1)
        next_token_value = next_token.item()

        # Check for EOS token
        if next_token_value == hyperparams['eos_token']:
            eos_found = True

        generated.append(next_token_value)
        context = torch.cat([context, next_token], dim=1)

    # Combine context with generated bytes and return as bytes object
    result_bytes = bytes(context.view(-1).tolist())

    # Clean up special tokens when returning result
    try:
        # Convert to list for easier manipulation
        byte_list = list(result_bytes)

        # Find all BOS tokens and remove them
        while hyperparams['bos_token'] in byte_list:
            byte_list.remove(hyperparams['bos_token'])

        # Find all EOS tokens and remove everything after the first one
        if hyperparams['eos_token'] in byte_list:
            eos_index = byte_list.index(hyperparams['eos_token'])
            byte_list = byte_list[:eos_index]

        # Convert back to bytes
        cleaned_bytes = bytes(byte_list)
        return cleaned_bytes
    except:
        # If any error in cleaning, return the original bytes
        return result_bytes

# ==========================================
# 9) Pre-training Implementation
# ==========================================
def pretrain(continue_training=True):
    """Pre-train the model on Omega Books corpus with causal language modeling."""
    # Load Omega Books data as bytes
    _, _, train_bytes_tensor, val_bytes_tensor = load_omega_books_data_as_bytes()

    # Create models
    main_model = ImprovedByteTransformer(
        vocab_size=hyperparams['vocab_size'],
        embed_dim=hyperparams['embed_dim'],
        n_heads=hyperparams['n_heads'],
        n_layers=hyperparams['n_layers'],
        block_size=hyperparams['block_size']
    ).to(device)

    memory1 = MemoryModule(
        embed_dim=hyperparams['embed_dim'],
        n_layers=hyperparams['memory_n_layers'],
        expansion_factor=4
    ).to(device)

    memory2 = MemoryModule(
        embed_dim=hyperparams['embed_dim'],
        n_layers=hyperparams['memory_n_layers'],
        expansion_factor=4
    ).to(device)

    # Calculate model size
    num_params = sum(p.numel() for p in main_model.parameters() if p.requires_grad)
    num_params += sum(p.numel() for p in memory1.parameters() if p.requires_grad)
    num_params += sum(p.numel() for p in memory2.parameters() if p.requires_grad)
    print(f"Total trainable parameters: {num_params:,}")

    # Compute token frequencies and initialize weighted loss if enabled
    weighted_loss_fn = None
    if hyperparams['use_weighted_loss']:
        # Compute token frequencies from training data
        token_frequencies = compute_token_frequencies(
            train_bytes_tensor,
            vocab_size=hyperparams['vocab_size']
        )

        # Initialize weighted loss
        weighted_loss_fn = WeightedCrossEntropyLoss(
            token_frequencies=token_frequencies,
            alpha=hyperparams['alpha'],
            kl_lambda=hyperparams['kl_lambda']
        ).to(device)

        print(f"Using weighted cross-entropy loss with alpha={hyperparams['alpha']}, kl_lambda={hyperparams['kl_lambda']}")
    else:
        print("Using standard cross-entropy loss")

    # Optimizer setup
    group1_params = list(main_model.parameters()) + list(memory1.parameters())
    group2_params = list(memory2.parameters())
    base_lr = 3e-4
    optimizer = torch.optim.AdamW([
        {'params': group1_params, 'lr': base_lr},
        {'params': group2_params, 'lr': base_lr}
    ], betas=(0.9, 0.95), weight_decay=0.1)

    start_epoch = 0
    best_val_loss = float('inf')

    # Load checkpoint if continuing training
    if continue_training and os.path.exists(hyperparams['checkpoint_path']):
        try:
            print(f"Loading checkpoint from {hyperparams['checkpoint_path']}...")
            checkpoint = torch.load(hyperparams['checkpoint_path'], map_location=device)

            try:
                # Try to load model states directly
                main_model.load_state_dict(checkpoint['main_model_state'], strict=False)
                memory1.load_state_dict(checkpoint['memory1_state'])
                if 'memory2_state' in checkpoint:
                    memory2.load_state_dict(checkpoint['memory2_state'])
                if 'optimizer_state' in checkpoint:
                    optimizer.load_state_dict(checkpoint['optimizer_state'])
                start_epoch = checkpoint.get('epoch', 0)
                best_val_loss = checkpoint.get('val_loss', float('inf'))
                print(f"Checkpoint loaded. Resuming from epoch {start_epoch}.")
            except Exception as e:
                print(f"Error loading checkpoint directly: {e}")
                print("Starting pre-training from scratch.")
        except Exception as e:
            print(f"Failed to load checkpoint: {e}")
            print("Starting pre-training from scratch.")
    else:
        print("Starting pre-training from scratch.")

    # Training setup
    grad_clip = 1.0
    total_steps = hyperparams['num_epochs'] * hyperparams['steps_per_epoch']
    current_step = start_epoch * hyperparams['steps_per_epoch']

    # Learning rate scheduler
    def get_lr(step, warmup_steps=2000, base_lr=base_lr, min_lr=1e-5):
        # Learning rate schedule with warmup and cosine decay
        if step < warmup_steps:
            return base_lr * step / warmup_steps
        decay_steps = total_steps - warmup_steps
        step_ = step - warmup_steps
        cosine_decay = 0.5 * (1 + math.cos(math.pi * step_ / decay_steps))
        return min_lr + (base_lr - min_lr) * cosine_decay

    print("Starting pre-training on Omega Books corpus...")
    for epoch in range(start_epoch, hyperparams['num_epochs']):
        print(f"\n--- Epoch {epoch+1}/{hyperparams['num_epochs']} ---")

        for step in range(hyperparams['steps_per_epoch']):
            # Periodic evaluation
            if step % hyperparams['eval_interval'] == 0:
                losses = estimate_loss_pretrain(main_model, memory1, memory2, train_bytes_tensor, val_bytes_tensor, weighted_loss_fn)
                print(f"Step {step}, train loss: {losses['train']:.4f}, val loss: {losses['val']:.4f}")

                # Save best model
                if losses['val'] < best_val_loss:
                    best_val_loss = losses['val']
                    torch.save({
                        'main_model_state': main_model.state_dict(),
                        'memory1_state': memory1.state_dict(),
                        'memory2_state': memory2.state_dict(),
                        'optimizer_state': optimizer.state_dict(),
                        'epoch': epoch,
                        'val_loss': best_val_loss
                    }, hyperparams['checkpoint_path'].replace('.pt', '_best.pt'))
                    print(f"New best model saved! Val loss: {best_val_loss:.4f}")

            # Get batches for this step directly from byte tensors
            inputs, targets = prepare_byte_batches(
                train_bytes_tensor,
                hyperparams['block_size'],
                hyperparams['batch_size'],
                device
            )

            # Zero gradients
            if step % hyperparams['accumulation_steps'] == 0:
                optimizer.zero_grad()

            # Forward pass
            B, T = inputs.shape
            token_emb = main_model.token_embedding(inputs)
            pos_emb = main_model.pos_embedding(torch.arange(T, device=device).unsqueeze(0))
            combined_emb = token_emb + pos_emb

            mem_out1 = memory1(combined_emb)
            logits = main_model.forward_with_two_memory(mem_out1, memory2)

            # Calculate loss
            B, T, C = logits.shape
            logits_flat = logits.view(B * T, C)
            targets_flat = targets.view(B * T)
            mask = (targets_flat != 0).float()

            # Use weighted loss if enabled
            if weighted_loss_fn is not None and hyperparams['use_weighted_loss']:
                loss, weighted_ce, ce_diff = weighted_loss_fn(logits_flat, targets_flat, mask)
            else:
                # Standard CE loss
                loss = F.cross_entropy(logits_flat, targets_flat, reduction='none')
                loss = (loss * mask).sum() / (mask.sum() + 1e-9)

            # Scale loss for gradient accumulation
            scaled_loss = loss / hyperparams['accumulation_steps']
            scaled_loss.backward()

            # Check for NaN or Inf gradients
            has_nan_inf = False
            for param in main_model.parameters():
                if param.grad is not None and (torch.isnan(param.grad).any() or torch.isinf(param.grad).any()):
                    has_nan_inf = True
                    param.grad = torch.zeros_like(param.grad)

            if has_nan_inf:
                print(f"NaN or Inf gradients detected and zeroed at step {step}")

            # Apply optimizer step
            if (step + 1) % hyperparams['accumulation_steps'] == 0:
                # Update learning rate
                lr = get_lr(current_step)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

                # Clip gradients
                torch.nn.utils.clip_grad_norm_(main_model.parameters(), grad_clip)
                torch.nn.utils.clip_grad_norm_(memory1.parameters(), grad_clip)
                torch.nn.utils.clip_grad_norm_(memory2.parameters(), grad_clip)

                optimizer.step()
                current_step += 1

        # Generate sample at end of epoch
        try:
            print("\nGenerating sample text...")
            sample_text = generate_from_prompt(
                main_model, memory1, memory2,
                prompt_text=hyperparams['start_prompt'],
                max_new_tokens=256
            )
            # Try to decode the bytes to show readable text
            try:
                decoded_text = sample_text.decode('utf-8', errors='replace')
                print(f"Sample: {decoded_text[:500]}")
            except:
                print(f"Sample (raw bytes, could not decode): {sample_text[:200]}")
        except Exception as e:
            print(f"Error generating sample: {e}")

        # End of epoch checkpoint
        torch.save({
            'main_model_state': main_model.state_dict(),
            'memory1_state': memory1.state_dict(),
            'memory2_state': memory2.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'epoch': epoch + 1,
            'val_loss': best_val_loss
        }, hyperparams['checkpoint_path'])
        print(f"Checkpoint saved at epoch {epoch+1} to {hyperparams['checkpoint_path']}.")

    print("Pre-training complete!")

# ==========================================
# 10) Script Main Entry Point
# ==========================================
if __name__ == "__main__":
    print(f"Starting pre-training on Omega Books corpus...")
    pretrain(continue_training=hyperparams['continue_training'])

In [None]:
## pretraining V2.2 (sft dataset) fixed ce loss

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import pandas as pd
from sklearn.model_selection import train_test_split

# ==========================================
# 1) Hyperparameters (same as original with additions)
# ==========================================
hyperparams = {
    # Model Architecture
    'block_size': 1024,               # Sequence length for context
    'batch_size': 2,                  # Batch size
    'embed_dim': 1024,                # Transformer embedding dimension
    'n_heads': 16,                    # Number of attention heads
    'n_layers': 24,                   # Number of Transformer blocks
    'memory_n_layers': 8,             # Number of layers in the original MemoryModule
    'vocab_size': 256,                # Fixed vocabulary size for byte tokenization

    # Training Parameters
    'num_epochs': 120,                 # Number of epochs
    'steps_per_epoch': 1000,          # Steps per epoch
    'eval_interval': 200,             # Steps between loss evaluations
    'eval_iters': 100,                # Iterations to average validation loss
    'accumulation_steps': 8,          # Number of steps to accumulate gradients over

    # Generation Parameters
    'generate_num_tokens': 2048,      # Number of tokens to generate after each epoch
    'top_p': 0.8,                     # Top-p (nucleus) sampling parameter
    'start_prompt': "Explain why the statement 'I wore my lucky socks today, and I got an A on my test, so my socks must be lucky' is a logical fallacy.",

    # Special Tokens & Tags
    'thinking_tag': "<think>",        # Opening tag for thinking process
    'thinking_end_tag': "</think>",   # Closing tag for thinking process
    'answer_tag': "<answer>",         # Opening tag for final answer
    'answer_end_tag': "</answer>",    # Closing tag for final answer
    'bos_token': 254,                 # Beginning-of-sequence token (byte value)
    'eos_token': 255,                 # End-of-sequence token (byte value)

    # File Paths & Modes
    'checkpoint_path': "threshold_transformer_checkpoint.pt",  # Single unified checkpoint
    'mode': 'pretrain',               # Force pretrain mode
    'continue_training': True,        # Whether to continue training from a checkpoint
    'system_prompt': """just think before answer."""
}

# ==========================================
# 1.1) Select device
# ==========================================
device = "mps" if torch.backends.mps.is_available() else \
         ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ==========================================
# 1.2) Data Loading and Preprocessing for COT Logic Reasoning
# ==========================================
def load_cot_logic_data():
    print("Loading COT Logic Reasoning dataset...")

    try:
        # Try standard pandas read_parquet first
        df = pd.read_parquet("isaiahbjork/cot-logic-reasoning/cot-logic-reasoning.parquet")
        print("Dataset loaded using standard path")
    except Exception as e:
        print(f"Error loading dataset with standard path: {e}")
        try:
            # Try with datasets library if available
            try:
                from datasets import load_dataset
                dataset = load_dataset("isaiahbjork/cot-logic-reasoning")
                df = dataset["train"].to_pandas()
                print("Dataset loaded using datasets library")
            except:
                # If all else fails, use the original path format
                df = pd.read_parquet("hf://datasets/isaiahbjork/cot-logic-reasoning/cot-logic-reasoning.parquet")
                print("Dataset loaded using hf:// protocol")
        except Exception as e2:
            print(f"Failed to load dataset: {e2}")
            raise RuntimeError("Unable to load the COT Logic Reasoning dataset")

    print(f"Data size: {len(df)}")

    # Split into train/validation/test sets (80/10/10)
    train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)
    val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

    print(f"Training examples: {len(train_df)}")
    print(f"Validation examples: {len(val_df)}")
    print(f"Test examples: {len(test_df)}")

    return train_df, val_df, test_df

# ==========================================
# 1.3) Prepare data for pre-training (continuous text with BOS/EOS tokens)
# ==========================================
def prepare_pretraining_batches_from_cot(data_df, block_size=1024):
    """Create pre-training batches from COT corpus as continuous text for next-token prediction,
    but with BOS/EOS tokens around answers."""

    batch_indices = torch.randint(0, len(data_df), (hyperparams['batch_size'],))
    batch_examples = data_df.iloc[batch_indices]

    sequences = []

    for _, row in batch_examples.iterrows():
        # Get prompt and response
        prompt = row['prompt']
        response = row['response']

        # Add system prompt to the user prompt
        system_prompt = hyperparams['system_prompt']
        full_prompt = f"{system_prompt}\n\nQuestion: {prompt}"

        # Insert BOS before response and EOS after response
        # But no formatting tags - treat as continuous text
        full_text = full_prompt + chr(hyperparams['bos_token']) + response + chr(hyperparams['eos_token'])

        # Convert to byte sequence
        byte_seq = [b for b in full_text.encode('utf-8')]

        # Truncate or pad to block_size
        if len(byte_seq) > block_size:
            # Random offset for diverse training
            start_idx = torch.randint(0, len(byte_seq) - block_size, (1,)).item()
            byte_seq = byte_seq[start_idx:start_idx + block_size]
        else:
            byte_seq = byte_seq + [0] * (block_size - len(byte_seq))

        sequences.append(byte_seq)

    # Convert to tensor
    x = torch.tensor(sequences, dtype=torch.long).to(device)

    # Create targets by shifting input by 1 position
    y = torch.full_like(x, 0)
    y[:, :-1] = x[:, 1:].clone()
    y[:, -1] = 0  # Last position predicts padding

    return x, y

# ==========================================
# 2) Improved Emergent Threshold Layer with Numerical Stability
# ==========================================
class ImprovedEmergentThresholdLayer(nn.Module):
    def __init__(self, feature_dim):
        super().__init__()
        self.feature_dim = feature_dim
        self.norm = nn.LayerNorm(feature_dim)
        self.register_buffer('running_mean', torch.zeros(feature_dim))
        self.register_buffer('running_var', torch.ones(feature_dim))
        self.adaptive_threshold = nn.Parameter(torch.ones(1) * 0.5)
        self.momentum = 0.01

    def forward(self, x):
        x_norm = self.norm(x)
        if self.training:
            with torch.no_grad():
                batch_mean = x_norm.mean(dim=(0, 1))
                batch_var = x_norm.var(dim=(0, 1), unbiased=False)
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var

        # More robust threshold calculation with clamping to prevent extremely small values
        threshold = torch.sigmoid(self.adaptive_threshold) * torch.sqrt(torch.clamp(self.running_var, min=1e-6))

        # Increase denominator from 0.1 to 1.0 for stability
        gate = torch.sigmoid((torch.abs(x_norm) - threshold.view(1, 1, -1)) / 1.0)

        alpha = torch.sigmoid(self.adaptive_threshold)

        # Clip outputs to prevent extreme values
        return torch.clamp(alpha * (gate * x) + (1 - alpha) * x, min=-100, max=100)

# ==========================================
# 3) Thresholded Attention Mechanism
# ==========================================
class ThresholdedAttention(nn.Module):
    def __init__(self, embed_dim, n_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads

        # Standard attention projections
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        # Attention score normalization
        self.attn_scale = nn.Parameter(torch.ones(1) * (1.0 / math.sqrt(self.head_dim)))

        # Threshold parameters for attention scores
        self.register_buffer('score_running_mean', torch.zeros(n_heads))
        self.register_buffer('score_running_var', torch.ones(n_heads))
        self.score_threshold = nn.Parameter(torch.ones(1) * 0.5)
        self.score_momentum = 0.01
        self.temperature = 1.0

    def forward(self, x, attn_mask=None):
        B, T, C = x.size()

        # Project to queries, keys, values
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # B, H, T, D
        k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # B, H, T, D
        v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # B, H, T, D

        # Compute scaled attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.attn_scale  # B, H, T, T

        # Apply causal mask if provided
        if attn_mask is not None:
            scores = scores.masked_fill(attn_mask, float('-inf'))

        # Apply thresholding to attention scores
        if self.training:
            with torch.no_grad():
                # Compute statistics of attention scores across batch and tokens
                # We remove the masked (very negative) values from statistics calculation
                valid_mask = ~torch.isinf(scores)
                if valid_mask.any():
                    # Get head-wise mean and variance
                    score_mean = torch.sum(scores * valid_mask, dim=(0, 2, 3)) / torch.sum(valid_mask, dim=(0, 2, 3))
                    score_var = torch.sum(((scores - score_mean.view(1, -1, 1, 1)) ** 2) * valid_mask, dim=(0, 2, 3)) / torch.sum(valid_mask, dim=(0, 2, 3))

                    # Update running statistics
                    self.score_running_mean = (1 - self.score_momentum) * self.score_running_mean + self.score_momentum * score_mean
                    self.score_running_var = (1 - self.score_momentum) * self.score_running_var + self.score_momentum * score_var

        # Calculate adaptive threshold for attention scores
        threshold_value = torch.sigmoid(self.score_threshold) * torch.sqrt(torch.clamp(self.score_running_var, min=1e-6))

        # Create soft mask for scores (0 for values below threshold, 1 for values above)
        # We can't use scores directly as they may have -inf values, so we'll make a mask
        # Exclude values that are already -inf (from causal mask)
        mask = (~torch.isinf(scores)) & (scores < threshold_value.view(1, -1, 1, 1))
        scores = scores.masked_fill(mask, -1e4)  # Not -inf to keep gradients

        # Apply softmax to get attention weights
        attn_weights = F.softmax(scores, dim=-1)

        # Apply attention to values
        attn_output = torch.matmul(attn_weights, v)  # B, H, T, D

        # Reshape and project output
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(attn_output)

    # Method to handle compatibility with original MultiheadAttention
    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
        # Map old MHA parameters to new ThresholdedAttention parameters
        if f"{prefix}in_proj_weight" in state_dict:
            # MultiheadAttention uses a single in_proj_weight that combines q,k,v
            in_proj_weight = state_dict.pop(f"{prefix}in_proj_weight")
            in_proj_bias = state_dict.pop(f"{prefix}in_proj_bias", None)

            # Split the in_proj_weight into q, k, v parts
            q_weight, k_weight, v_weight = in_proj_weight.chunk(3, dim=0)
            state_dict[f"{prefix}q_proj.weight"] = q_weight
            state_dict[f"{prefix}k_proj.weight"] = k_weight
            state_dict[f"{prefix}v_proj.weight"] = v_weight

            if in_proj_bias is not None:
                q_bias, k_bias, v_bias = in_proj_bias.chunk(3, dim=0)
                state_dict[f"{prefix}q_proj.bias"] = q_bias
                state_dict[f"{prefix}k_proj.bias"] = k_bias
                state_dict[f"{prefix}v_proj.bias"] = v_bias

        # Map out_proj parameters
        if f"{prefix}out_proj.weight" in state_dict:
            state_dict[f"{prefix}out_proj.weight"] = state_dict[f"{prefix}out_proj.weight"]
            if f"{prefix}out_proj.bias" in state_dict:
                state_dict[f"{prefix}out_proj.bias"] = state_dict[f"{prefix}out_proj.bias"]

        # Call parent class method to handle the rest
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

# ==========================================
# 4) Improved Transformer Block with Thresholded Attention
# ==========================================
class ImprovedTransformerBlock(nn.Module):
    def __init__(self, embed_dim, n_heads):
        super().__init__()
        self.attention = ThresholdedAttention(embed_dim, n_heads)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            ImprovedEmergentThresholdLayer(4 * embed_dim),
            nn.GELU(),
            nn.Linear(4 * embed_dim, embed_dim)
        )
        self.threshold1 = ImprovedEmergentThresholdLayer(embed_dim)
        self.threshold2 = ImprovedEmergentThresholdLayer(embed_dim)

    def forward(self, x):
        B, T, E = x.size()
        causal_mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
        attn_out = self.attention(x, attn_mask=causal_mask)
        x = x + self.threshold1(attn_out)
        ff_out = self.feed_forward(x)
        x = x + self.threshold2(ff_out)
        return x

# ==========================================
# 5) Improved Byte Transformer
# ==========================================
class ImprovedByteTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, n_heads=4, n_layers=4, block_size=128):
        super().__init__()
        self.block_size = block_size
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embedding = nn.Embedding(self.block_size, embed_dim)
        self.blocks = nn.ModuleList([
            ImprovedTransformerBlock(embed_dim, n_heads)
            for _ in range(n_layers)
        ])
        self.final_threshold = ImprovedEmergentThresholdLayer(embed_dim)
        self.ln_f = nn.Linear(embed_dim, vocab_size)
        # Learned gating parameter for combining memory outputs
        self.gate_param = nn.Parameter(torch.tensor(0.0))

    def forward_with_embeddings(self, x_emb):
        for block in self.blocks:
            x_emb = block(x_emb)
        x_emb = self.final_threshold(x_emb)
        logits = self.ln_f(x_emb)
        return logits

    def forward_with_two_memory(self, x_emb, memory_module2):
        """
        Extended forward pass:
          1. Run transformer blocks on x_emb.
          2. Apply the transformer's final threshold.
          3. Process the result with a second memory module.
          4. Combine the result of memory_module2 and the original x_emb using a gated combination.
          5. Apply the final threshold on the combined representation.
          6. Project to logits.
        """
        transformer_out = x_emb
        for block in self.blocks:
            transformer_out = block(transformer_out)
        transformer_out = self.final_threshold(transformer_out)
        mem_out2 = memory_module2(transformer_out)
        # Gated combination instead of simple addition:
        alpha = torch.sigmoid(self.gate_param)  # Learned gating weight in [0, 1]
        combined = alpha * mem_out2 + (1 - alpha) * x_emb
        final_emb = self.final_threshold(combined)
        logits = self.ln_f(final_emb)
        return logits

    def forward(self, x):
        B, T = x.size()
        token_emb = self.token_embedding(x)
        positions = torch.arange(T, device=x.device).unsqueeze(0)
        pos_emb = self.pos_embedding(positions)
        x_emb = token_emb + pos_emb
        return self.forward_with_embeddings(x_emb)

# ==========================================
# 6) Memory Module (Original)
# ==========================================
class MemoryModule(nn.Module):
    def __init__(self, embed_dim, n_layers=8, expansion_factor=4):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(n_layers):
            layer = nn.Sequential(
                nn.LayerNorm(embed_dim),
                nn.Linear(embed_dim, embed_dim * expansion_factor),
                nn.GELU(),
                nn.Linear(embed_dim * expansion_factor, embed_dim),
                nn.Dropout(0.1)
            )
            self.layers.append(layer)
        self.final_norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        out = x
        for layer in self.layers:
            out = out + layer(out)
        out = self.final_norm(out)
        return out

# ==========================================
# 7) Pre-training Evaluation Function
# ==========================================
@torch.no_grad()
def estimate_loss_pretrain(main_model, memory1, memory2, train_df, val_df):
    out = {}
    main_model.eval()
    memory1.eval()
    memory2.eval()

    for split, df in [('train', train_df), ('val', val_df)]:
        losses = torch.zeros(hyperparams['eval_iters'])
        for k in range(hyperparams['eval_iters']):
            # Get pre-training batches
            inputs, targets = prepare_pretraining_batches_from_cot(df, hyperparams['block_size'])

            # Forward pass
            B, T = inputs.shape
            token_emb = main_model.token_embedding(inputs)
            pos_emb = main_model.pos_embedding(torch.arange(T, device=device).unsqueeze(0))
            combined_emb = token_emb + pos_emb

            mem_out1 = memory1(combined_emb)
            logits = main_model.forward_with_two_memory(mem_out1, memory2)

            # Calculate loss (only on non-padding tokens)
            B, T, C = logits.shape
            logits_flat = logits.view(B * T, C)
            targets_flat = targets.view(B * T)

            # Create mask for non-padding tokens
            mask = (targets_flat != 0).float()

            # Compute loss only on non-padding tokens
            loss = F.cross_entropy(logits_flat, targets_flat, reduction='none')
            masked_loss = (loss * mask).sum() / (mask.sum() + 1e-9)

            losses[k] = masked_loss.item()

        out[split] = losses.mean()

    main_model.train()
    memory1.train()
    memory2.train()
    return out

# ==========================================
# 8) Generate Text from Trained Model
# ==========================================
@torch.no_grad()
def generate_from_prompt(main_model, memory1, memory2, prompt_text=None, max_new_tokens=200, top_p=None):
    if prompt_text is None:
        prompt_text = hyperparams['start_prompt']

    # Use hyperparameter value if top_p not specified
    if top_p is None:
        top_p = hyperparams['top_p']

    # Apply system prompt to user prompt
    system_prompt = hyperparams['system_prompt']
    full_prompt = f"{system_prompt}\n\nQuestion: {prompt_text}"

    # Convert prompt to bytes
    if isinstance(full_prompt, str):
        prompt_bytes = full_prompt.encode('utf-8')
    elif not isinstance(full_prompt, bytes):
        prompt_bytes = str(full_prompt).encode('utf-8')

    main_model.eval()
    memory1.eval()
    memory2.eval()

    # Create context from prompt
    context = torch.tensor([b for b in prompt_bytes], dtype=torch.long, device=device).unsqueeze(0)

    # Add BOS token to start the response generation
    bos_token = torch.tensor([[hyperparams['bos_token']]], dtype=torch.long, device=device)
    context = torch.cat([context, bos_token], dim=1)

    generated = []
    eos_found = False

    for _ in range(max_new_tokens):
        if eos_found:
            break

        x_cond = context[:, -hyperparams['block_size']:] if context.size(1) > hyperparams['block_size'] else context
        B, T = x_cond.shape
        token_emb = main_model.token_embedding(x_cond)
        pos_emb = main_model.pos_embedding(torch.arange(T, device=x_cond.device).unsqueeze(0))
        combined_emb = token_emb + pos_emb

        mem_out1 = memory1(combined_emb)
        logits = main_model.forward_with_two_memory(mem_out1, memory2)

        # Get next token distribution with top-p (nucleus) sampling
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)

        # Sort probabilities in descending order
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)

        # Compute cumulative probabilities
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        # Find indices where cumulative probability exceeds top_p
        sorted_indices_to_remove = cumulative_probs > top_p

        # Shift to create first index (0) as False to always keep at least one token
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # Create mask for indices to remove
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)

        # Filter logits
        filtered_logits = logits.clone()
        filtered_logits[indices_to_remove] = -float('inf')

        # Get probabilities from filtered logits
        filtered_probs = F.softmax(filtered_logits, dim=-1)

        # Sample from the filtered distribution
        next_token = torch.multinomial(filtered_probs, num_samples=1)
        next_token_value = next_token.item()

        # Check for EOS token
        if next_token_value == hyperparams['eos_token']:
            eos_found = True

        generated.append(next_token_value)
        context = torch.cat([context, next_token], dim=1)

    # Combine context with generated bytes and return as bytes object
    result_bytes = bytes(context.view(-1).tolist())

    # Clean up special tokens when returning result
    try:
        # Convert to list for easier manipulation
        byte_list = list(result_bytes)

        # Find all BOS tokens and remove them
        while hyperparams['bos_token'] in byte_list:
            byte_list.remove(hyperparams['bos_token'])

        # Find all EOS tokens and remove everything after the first one
        if hyperparams['eos_token'] in byte_list:
            eos_index = byte_list.index(hyperparams['eos_token'])
            byte_list = byte_list[:eos_index]

        # Convert back to bytes
        cleaned_bytes = bytes(byte_list)
        return cleaned_bytes
    except:
        # If any error in cleaning, return the original bytes
        return result_bytes

# ==========================================
# 9) Pre-training Implementation
# ==========================================
def pretrain(continue_training=True):
    """Pre-train the model on COT corpus with causal language modeling."""
    # Load COT data
    train_df, val_df, test_df = load_cot_logic_data()

    # Create models
    main_model = ImprovedByteTransformer(
        vocab_size=hyperparams['vocab_size'],
        embed_dim=hyperparams['embed_dim'],
        n_heads=hyperparams['n_heads'],
        n_layers=hyperparams['n_layers'],
        block_size=hyperparams['block_size']
    ).to(device)

    memory1 = MemoryModule(
        embed_dim=hyperparams['embed_dim'],
        n_layers=hyperparams['memory_n_layers'],
        expansion_factor=4
    ).to(device)

    memory2 = MemoryModule(
        embed_dim=hyperparams['embed_dim'],
        n_layers=hyperparams['memory_n_layers'],
        expansion_factor=4
    ).to(device)

    # Calculate model size
    num_params = sum(p.numel() for p in main_model.parameters() if p.requires_grad)
    num_params += sum(p.numel() for p in memory1.parameters() if p.requires_grad)
    num_params += sum(p.numel() for p in memory2.parameters() if p.requires_grad)
    print(f"Total trainable parameters: {num_params:,}")

    # Optimizer setup
    group1_params = list(main_model.parameters()) + list(memory1.parameters())
    group2_params = list(memory2.parameters())
    base_lr = 3e-4
    optimizer = torch.optim.AdamW([
        {'params': group1_params, 'lr': base_lr},
        {'params': group2_params, 'lr': base_lr}
    ], betas=(0.9, 0.95), weight_decay=0.1)

    start_epoch = 0
    best_val_loss = float('inf')

    # Load checkpoint if continuing training
    if continue_training and os.path.exists(hyperparams['checkpoint_path']):
        try:
            print(f"Loading checkpoint from {hyperparams['checkpoint_path']}...")
            checkpoint = torch.load(hyperparams['checkpoint_path'], map_location=device)

            try:
                # Try to load model states directly
                main_model.load_state_dict(checkpoint['main_model_state'], strict=False)
                memory1.load_state_dict(checkpoint['memory1_state'])
                if 'memory2_state' in checkpoint:
                    memory2.load_state_dict(checkpoint['memory2_state'])
                if 'optimizer_state' in checkpoint:
                    optimizer.load_state_dict(checkpoint['optimizer_state'])
                start_epoch = checkpoint.get('epoch', 0)
                best_val_loss = checkpoint.get('val_loss', float('inf'))
                print(f"Checkpoint loaded. Resuming from epoch {start_epoch}.")
            except Exception as e:
                print(f"Error loading checkpoint directly: {e}")
                print("Starting pre-training from scratch.")
        except Exception as e:
            print(f"Failed to load checkpoint: {e}")
            print("Starting pre-training from scratch.")
    else:
        print("Starting pre-training from scratch.")

    # Training setup
    grad_clip = 1.0
    total_steps = hyperparams['num_epochs'] * hyperparams['steps_per_epoch']
    current_step = start_epoch * hyperparams['steps_per_epoch']

    # Learning rate scheduler
    def get_lr(step, warmup_steps=2000, base_lr=base_lr, min_lr=1e-5):
        # Learning rate schedule with warmup and cosine decay
        if step < warmup_steps:
            return base_lr * step / warmup_steps
        decay_steps = total_steps - warmup_steps
        step_ = step - warmup_steps
        cosine_decay = 0.5 * (1 + math.cos(math.pi * step_ / decay_steps))
        return min_lr + (base_lr - min_lr) * cosine_decay

    print("Starting pre-training on COT Logic Reasoning corpus...")
    for epoch in range(start_epoch, hyperparams['num_epochs']):
        print(f"\n--- Epoch {epoch+1}/{hyperparams['num_epochs']} ---")

        for step in range(hyperparams['steps_per_epoch']):
            # Periodic evaluation
            if step % hyperparams['eval_interval'] == 0:
                losses = estimate_loss_pretrain(main_model, memory1, memory2, train_df, val_df)
                print(f"Step {step}, train loss: {losses['train']:.4f}, val loss: {losses['val']:.4f}")

                # Save best model
                if losses['val'] < best_val_loss:
                    best_val_loss = losses['val']
                    torch.save({
                        'main_model_state': main_model.state_dict(),
                        'memory1_state': memory1.state_dict(),
                        'memory2_state': memory2.state_dict(),
                        'optimizer_state': optimizer.state_dict(),
                        'epoch': epoch,
                        'val_loss': best_val_loss
                    }, hyperparams['checkpoint_path'].replace('.pt', '_best.pt'))
                    print(f"New best model saved! Val loss: {best_val_loss:.4f}")

            # Get batches for this step
            inputs, targets = prepare_pretraining_batches_from_cot(train_df, hyperparams['block_size'])

            # Zero gradients
            if step % hyperparams['accumulation_steps'] == 0:
                optimizer.zero_grad()

            # Forward pass
            B, T = inputs.shape
            token_emb = main_model.token_embedding(inputs)
            pos_emb = main_model.pos_embedding(torch.arange(T, device=device).unsqueeze(0))
            combined_emb = token_emb + pos_emb

            mem_out1 = memory1(combined_emb)
            logits = main_model.forward_with_two_memory(mem_out1, memory2)

            # Calculate loss
            B, T, C = logits.shape
            logits_flat = logits.view(B * T, C)
            targets_flat = targets.view(B * T)
            mask = (targets_flat != 0).float()
            loss = F.cross_entropy(logits_flat, targets_flat, reduction='none')
            masked_loss = (loss * mask).sum() / (mask.sum() + 1e-9)

            # Scale loss for gradient accumulation
            scaled_loss = masked_loss / hyperparams['accumulation_steps']
            scaled_loss.backward()

            # Check for NaN or Inf gradients
            has_nan_inf = False
            for param in main_model.parameters():
                if param.grad is not None and (torch.isnan(param.grad).any() or torch.isinf(param.grad).any()):
                    has_nan_inf = True
                    param.grad = torch.zeros_like(param.grad)

            if has_nan_inf:
                print(f"NaN or Inf gradients detected and zeroed at step {step}")

            # Apply optimizer step
            if (step + 1) % hyperparams['accumulation_steps'] == 0:
                # Update learning rate
                lr = get_lr(current_step)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

                # Clip gradients
                torch.nn.utils.clip_grad_norm_(main_model.parameters(), grad_clip)
                torch.nn.utils.clip_grad_norm_(memory1.parameters(), grad_clip)
                torch.nn.utils.clip_grad_norm_(memory2.parameters(), grad_clip)

                optimizer.step()
                current_step += 1

        # Generate sample at end of epoch
        try:
            print("\nGenerating sample text...")
            sample_text = generate_from_prompt(
                main_model, memory1, memory2,
                prompt_text=hyperparams['start_prompt'],
                max_new_tokens=256
            )
            print(f"Sample (first 200 bytes): {sample_text[:200]}")
        except Exception as e:
            print(f"Error generating sample: {e}")

        # End of epoch checkpoint
        torch.save({
            'main_model_state': main_model.state_dict(),
            'memory1_state': memory1.state_dict(),
            'memory2_state': memory2.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'epoch': epoch + 1,
            'val_loss': best_val_loss
        }, hyperparams['checkpoint_path'])
        print(f"Checkpoint saved at epoch {epoch+1} to {hyperparams['checkpoint_path']}.")

    print("Pre-training complete!")

# ==========================================
# 10) Script Main Entry Point
# ==========================================
if __name__ == "__main__":
    print(f"Starting pre-training on COT corpus...")
    pretrain(continue_training=hyperparams['continue_training'])