<a href="https://colab.research.google.com/github/srimallya/T-JEPA/blob/main/t_jepa.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **T-decoder-JEPA: Integrating Temporal Joint Embedding Prediction into Decoder-Only Language Models**

Abstract

Decoder-only transformer models, pre-trained with causal language modeling (LM) objectives, have demonstrated remarkable capabilities. However, their reliance solely on predicting the immediate next token might limit the depth of semantic and structural understanding required for complex reasoning tasks. To address this, we propose T-decoder-JEPA, a novel architecture that integrates principles from the Joint Embedding Predictive Architecture (JEPA) into a standard decoder-only framework. T-decoder-JEPA augments the causal LM objective with a JEPA-inspired self-supervised task: predicting the representations of masked future segments (targets) using representations from an unmasked causal past (context). Crucially, the target representations are generated by a non-causal, exponential moving average (EMA) copy of the decoder backbone, providing a rich supervisory signal representing the model's accumulated "experience" applied to the full context. The prediction occurs in embedding space via a dedicated predictor module that leverages both self-attention and cross-attention to the causal backbone's outputs. This multi-task learning setup encourages the backbone decoder to learn richer, more predictive representations that capture longer-range dependencies and abstract structural information, potentially enhancing performance on downstream tasks requiring deep reasoning, such as mathematical problem-solving.

1. Introduction

Large Language Models (LLMs) based on the decoder-only transformer architecture (e.g., GPT series [Radford et al., 2018, 2019; Brown et al., 2020], LLaMA [Touvron et al., 2023]) have achieved state-of-the-art performance across a wide range of natural language processing tasks. Their core training objective, causal language modeling, involves predicting the next token in a sequence given the preceding tokens. While effective, this objective primarily focuses on local dependencies and statistical co-occurrence, which may not be sufficient to foster the deep understanding of structure, causality, and long-range relationships needed for complex reasoning tasks [Mallen et al., 2023].

Self-supervised learning (SSL) beyond simple next-token prediction offers a promising avenue to improve the representational quality of language models. Masked language modeling (MLM) [Devlin et al., 2019] and other reconstruction-based objectives have proven successful, particularly for encoder-based models. Recently, the Joint Embedding Predictive Architecture (JEPA) [LeCun, 2022; Assran et al., 2023] has emerged as a powerful SSL paradigm, particularly in computer vision (I-JEPA [Assran et al., 2023]). JEPA aims to learn abstract representations by predicting the representations of masked portions of the input (targets) from unmasked portions (context), operating entirely within the embedding space. This avoids the computational burden and potential semantic limitations of predicting raw pixels or tokens.

Inspired by JEPA's success, we propose T-decoder-JEPA (Temporal Decoder Joint Embedding Predictive Architecture), a novel approach to integrate JEPA principles directly into a decoder-only transformer. Our key contributions are:

1.  Novel Architecture: We present a decoder-only transformer enhanced with a JEPA-based self-supervised objective, operating alongside the standard causal LM task.
2.  Causal/Non-Causal Mechanism: We introduce a specific mechanism where target representations for the JEPA task are derived from a non-causal pass through an EMA target encoder, while the prediction is made using information from the standard causal backbone pass.
3.  Predictor Design: We detail a predictor module employing causal self-attention and cross-attention to the backbone decoder's states, enabling effective integration of context information for target prediction.
4.  Multi-Task Formulation: We combine the JEPA prediction loss (MSE in embedding space) with the causal LM loss (Cross-Entropy), allowing the model to benefit from both objectives simultaneously.

We hypothesize that T-decoder-JEPA encourages the learning of more robust and predictive representations, better suited for tasks requiring multi-step reasoning, as exemplified by mathematical datasets like GSM8K [Cobbe et al., 2021].

2. Related Work

*   Decoder-Only Language Models: Our work builds upon the standard decoder-only architecture [Vaswani et al., 2017] popularized by the GPT models [Radford et al., 2018, 2019; Brown et al., 2020]. These models are typically trained autoregressively using the causal LM objective. We retain this objective but augment it with our JEPA task.
*   Self-Supervised Learning in NLP: SSL has been pivotal in NLP. BERT [Devlin et al., 2019] introduced Masked Language Modeling (MLM) for bidirectional encoders. Denoising autoencoders [Vincent et al., 2008] and contrastive methods [Logeswaran & Lee, 2018; Gao et al., 2021] are other prominent approaches. T-decoder-JEPA differs by predicting representations in embedding space rather than reconstructing tokens (like MLM) or using contrastive losses.
*   Joint Embedding Predictive Architectures (JEPA): JEPA [LeCun, 2022] proposes learning predictive world models. I-JEPA [Assran et al., 2023] successfully applied this to vision, demonstrating strong performance by predicting representations of masked image patches from context patches using an EMA target encoder. Our work adapts this core idea to the sequential, temporal nature of language within a decoder framework.
*   Multi-Task Learning (MTL) in NLP: Combining different objectives is common in NLP [Caruana, 1997; Raffel et al., 2020]. T-decoder-JEPA employs MTL by combining the JEPA loss and the LM loss, aiming for synergistic benefits where the JEPA task regularizes and enriches the representations learned primarily for the LM task.
*   Reasoning in Language Models: Improving the reasoning capabilities of LLMs is an active research area [Wei et al., 2022; Nye et al., 2021]. While some approaches focus on fine-tuning or prompting techniques (e.g., Chain-of-Thought [Wei et al., 2022]), T-decoder-JEPA aims to enhance the foundational representational capacity of the pre-trained model itself through its novel training objectives.

3. Methodology: T-decoder-JEPA

The T-decoder-JEPA architecture integrates a JEPA prediction task into a standard decoder-only transformer backbone trained with a causal LM objective. The key components are the Backbone Decoder, the Target Encoder, the Span Selection Strategy, the Predictor, and the combined Loss Function.

3.1. Backbone Decoder (Causal)

The core of the model is a standard multi-layer Transformer decoder, denoted `f_θ`. It processes an input sequence `X = (x_1, ..., x_T)` autoregressively.
*   Input: Token sequence `X`.
*   Processing: Standard decoder blocks with causal self-attention (e.g., using RoPE [Su et al., 2024] for positional information) and feed-forward layers. Layer Normalization (Pre-LN [Ba et al., 2016; Xiong et al., 2020]) is used.
*   Output: A sequence of hidden states `H_c = (h_{c,1}, ..., h_{c,T}) = f_θ(X)`, where each `h_{c,t}` depends only on `x_1, ..., x_t`.
*   Role:
    1.  Generates representations for the standard causal LM loss.
    2.  Provides context representations (via cross-attention keys/values) to the Predictor.

3.2. Target Encoder (Non-Causal EMA)

The Target Encoder, denoted `f_θ'`, is structurally identical to the Backbone Decoder but its parameters `θ'` are an Exponential Moving Average (EMA) of the backbone parameters `θ`: `θ' ← α θ' + (1 - α) θ`. Crucially, it processes the input sequence non-causally.
*   Input: Token sequence `X`.
*   Processing: Identical transformer blocks as the backbone, but the self-attention mechanism is configured to be non-causal (bi-directional), allowing each position to attend to all other positions in the sequence (respecting padding).
*   Output: A sequence of hidden states `H_{nc} = (h_{nc,1}, ..., h_{nc,T}) = f_θ'(X)`, where each `h_{nc,t}` depends on the entire sequence `x_1, ..., x_T`.
*   Role: Generates the target representations for the JEPA prediction task. Its parameters are not updated via backpropagation; only through EMA updates from `f_θ`. These target representations reflect the model's accumulated "memory" or "experience" (`θ'`) applied to understand the full context of the current sequence `X`.

Analogy: The Detective and the Full Reader: To build intuition for the causal/non-causal dynamic, consider the causal Backbone Decoder (`f_θ`) as a detective reading a novel one page at a time. When encountering missing pages (target spans), the detective can only guess their content based on what has been read so far (the causal context `H_c`). The non-causal Target Encoder (`f_θ'`) is like someone who has already read the entire book; they possess the complete understanding and "memory" of how those missing pages fit into the overall narrative. The output `H_{nc}` represents this perfect understanding, serving as the ground truth for what the detective should infer.

3.3. Span Selection and Masking

For the JEPA task, we sample multiple target spans within each sequence `X`.
*   A proportion of the sequence is designated as the JEPA context.
*   Several non-overlapping target spans `s_i = (start_i, end_i)` are randomly selected from the remaining portion.
*   Parameters control the number of target spans, minimum/maximum span length, and context/target ratios.
*   A `context_mask` indicates which positions belong to the context (1) and which belong to targets or padding (0).
*   An `attention_mask` indicates padding tokens (0) vs. real tokens (1).

3.4. Predictor

The Predictor, `g_φ`, is another multi-layer Transformer-based module responsible for predicting the target representations.
*   Input: A modified sequence where context positions retain their original embeddings (or potentially representations from `H_c`), while target positions are replaced with learnable `[MASK]` tokens/embeddings. Let this input be `X_masked`.
*   Processing: The predictor consists of blocks performing:
    1.  Causal Self-Attention: Operates on the `X_masked` sequence representation, using the `attention_mask` for padding but maintaining causality.
    2.  Cross-Attention: Attends to the output states `H_c` from the causal Backbone Decoder. Queries come from the predictor's self-attention output. Keys and Values come from `H_c`. The `context_mask` is used here to ensure the predictor only attends to keys/values corresponding to the context positions in the backbone output.
    3.  Feed-Forward Layers.
*   Output: Predicted representations `P = (p_1, ..., p_T) = g_φ(X_masked, H_c, context_mask)`. We are interested in the outputs corresponding to the target spans: `P[s_i] = (p_{start_i}, ..., p_{end_i})`.
*   Role: Learns to predict the non-causal target representations `H_{nc}[s_i]` using only the masked input and the causal context information from the backbone `H_c`. It acts like the detective's assistant, trying to reconstruct the meaning of the missing pages using only the clues gathered so far.

3.5. Loss Functions

The model is trained with a combined loss:
`L_total = L_JEPA + λ * L_LM`

*   JEPA Loss (`L_JEPA`): The Mean Squared Error (MSE) between the predictor's output embeddings and the target encoder's output embeddings for all target spans `s_i`:
    `L_JEPA = (1 / N_spans) * Σ_{i} || P[s_i] - H_{nc}[s_i] ||^2`
    (Normalization by span length or averaging per-span loss might also be considered). The target `H_{nc}[s_i]` is detached from the computation graph. This loss measures the difference between the assistant's guess and the full reader's knowledge.
*   LM Loss (`L_LM`): The standard causal language modeling loss (Cross-Entropy) calculated using the output of the causal Backbone Decoder `H_c` and its associated LM head (which typically shares weights with the input embedding layer):
    `L_LM = CrossEntropy(LM_Head(H_c), X_shifted)`
*   `λ` is a hyperparameter balancing the two loss terms.

3.6. Training

Training proceeds by minimizing `L_total`. Parameters `θ` (backbone) and `φ` (predictor) are updated via gradient descent (e.g., AdamW [Loshchilov & Hutter, 2019]). Target encoder parameters `θ'` are updated using EMA after each optimizer step on `θ`.

3.7. Mathematical Formulation

To formalize the process during training and inference:

Notation:
*   `X = (x_1, ..., x_T)`: Input token sequence.
*   `θ, θ', φ`: Parameters for backbone, target encoder (EMA of θ), and predictor.
*   `f_θ`: Causal backbone decoder function.
*   `f_θ'`: Non-causal target encoder function.
*   `g_φ`: Predictor function.
*   `LM_Head`: Language modeling head (linear layer).
*   `H_c`: Hidden states from causal backbone. `H_c[t]` depends on `X_{1:t}`.
*   `H_{nc}`: Hidden states from non-causal target encoder. `H_{nc}[t]` depends on `X_{1:T}`.
*   `s = (start, end)`: Indices defining a target span.
*   `X_masked`: Input to predictor with target spans masked.
*   `Context(s)`: Indices of context tokens relevant for predicting span `s`.
*   `P[s]`: Predicted representation for span `s`.
*   `CE`: Cross-Entropy loss. `MSE`: Mean Squared Error loss. `detach`: Stop gradient.
*   `α`: EMA decay rate. `λ`: LM loss weight. `∇`: Gradient. `LR`: Learning rate.
*   `temp`: Temperature for sampling. `TopP`: Top-p sampling function.

Training Step:
1.  Causal Pass (Backbone): `H_c = f_θ(X)`
2.  Non-Causal Pass (Target Encoder): `H_{nc} = f_θ'(X)` (with `torch.no_grad()`)
3.  Prediction (Predictor): `P = g_φ(X_masked, H_c[Context(s)])` for all target spans `s`.
4.  LM Loss: `L_LM = CE(LM_Head(H_c[:, :-1]), X[:, 1:])` (ignoring padding)
5.  JEPA Loss: `L_JEPA = mean_s( MSE(P[s], detach(H_{nc}[s])) )`
6.  Total Loss: `L_total = L_JEPA + λ * L_LM`
7.  Parameter Update:
    *   `θ ← θ - LR * ∇_θ L_total`
    *   `φ ← φ - LR * ∇_φ L_total`
    *   `θ' ← α * θ' + (1 - α) * θ`

Inference (Autoregressive Generation):
Given context `X_context = (x_1, ..., x_k)`:
For `t` from `k` to `max_length - 1`:
1.  Get Last Causal State: `h_{c,t} = f_θ(X_context)_{last_token_state}`
2.  Get Logits: `logits_t = LM_Head(h_{c,t})`
3.  Sample Next Token:
    *   `probs_t = Softmax(TopP(logits_t / temp))`
    *   `x_{t+1} ~ Multinomial(probs_t)`
4.  Append: `X_context ← Append(X_context, x_{t+1})`
Return `X_context`.
*(Note: Only the causal backbone `f_θ` and `LM_Head` are used during inference)*.

4. Experiments (Planned)

*   Dataset: We plan to evaluate T-decoder-JEPA primarily on the GSM8K dataset [Cobbe et al., 2021], a collection of grade-school math word problems requiring multi-step reasoning.
*   Baselines:
    1.  A standard decoder-only model of identical size and configuration, trained solely with the causal LM objective (`λ=0`).
    2.  (Optional) Other relevant SSL methods adapted to the decoder framework, if feasible.
*   Evaluation Metrics:
    1.  Accuracy on the GSM8K test set (final answer matching).
    2.  Perplexity on a standard language modeling benchmark (e.g., WikiText) to assess the impact on general language modeling capabilities.
    3.  JEPA prediction loss during training as a diagnostic metric.
*   Ablation Studies: We will investigate the contribution of key components:
    1.  Effect of the JEPA loss (varying `λ`).
    2.  Importance of the non-causal target encoder (vs. using causal targets).
    3.  Impact of the EMA update.
    4.  Sensitivity to predictor depth and architecture.
    5.  Influence of span selection hyperparameters.
*   Hypothesized Results: We expect T-decoder-JEPA to achieve higher accuracy on GSM8K compared to the LM-only baseline, demonstrating improved reasoning capabilities stemming from the enriched representations learned via the JEPA objective. We anticipate that general LM performance (perplexity) will remain competitive or potentially improve slightly due to the regularizing effect of the JEPA task.

(Table 1: Placeholder for GSM8K accuracy results for T-decoder-JEPA vs. baselines.)
(Table 2: Placeholder for Perplexity results.)

5. Discussion

The core hypothesis behind T-decoder-JEPA is that forcing the model to predict future representations (derived non-causally) from a causal context compels it to learn more abstract and robust features. The non-causal target `H_{nc}` provides a rich signal about the complete sequence structure, which the predictor `g_φ` must learn to anticipate based only on the causally-generated context `H_c`. This implicit "lookahead" in the embedding space may encourage the backbone `f_θ` to encode information about long-range dependencies and underlying semantic/logical structure more effectively than relying solely on next-token prediction.

Analogy: The Student and Teacher: The learning dynamic can be likened to training a student (the predictor `g_φ`) to predict the ending (target representation `H_{nc}[s]`) of a story by only reading the beginning (causal context representations `H_c`). The teacher (the non-causal target encoder `f_θ'`) knows the full story and provides the perfect target representation. The student makes a guess (`P[s]`), and the difference (JEPA loss) signals the error. This error signal not only improves the student (`g_φ`) but crucially puts pressure back on the causal decoder (`f_θ`)—the provider of the initial information—to generate richer, more predictive "hints" (the hidden states `H_c`) in the first place. Over time, the causal decoder learns to encode predictive signals about the future within its current state, even though it cannot see the future directly.

Implications:
*   Enhanced Representations: The JEPA task acts as a regularizer, potentially leading to representations less sensitive to superficial correlations and more attuned to deeper structure. The causal hidden states `H_c` become implicitly conditioned to support future prediction.
*   Improved Reasoning & Downstream Tasks: By learning to anticipate representations related to future steps or outcomes, the model might develop capabilities more aligned with planning and multi-step reasoning. This should improve performance on downstream tasks like generation or classification that rely on these enriched causal states.

Limitations:
*   Computational Cost: The architecture involves forward passes through the backbone, target encoder, and predictor, increasing computational requirements compared to standard LM training.
*   Hyperparameter Sensitivity: Performance might be sensitive to the choice of span selection strategy, the weighting factor `λ`, EMA decay `α`, and predictor architecture.
*   Complexity: The interplay between the causal backbone, non-causal target, and predictor introduces complexity in analysis and debugging.

Future Work:
*   Scaling T-decoder-JEPA to larger models and datasets.
*   Applying the architecture to other domains requiring long-range understanding or planning.
*   In-depth analysis of the learned representations to understand the effects of the JEPA objective.
*   Exploring alternative predictor designs or target generation strategies.

6. Conclusion

We introduced T-decoder-JEPA, a novel architecture that integrates the principles of Joint Embedding Predictive Architecture into a decoder-only transformer. By augmenting the standard causal language modeling objective with a task that predicts non-causal future representations from a causal context—using an EMA target encoder for ground truth and a dedicated predictor—T-decoder-JEPA aims to learn richer, more predictive internal representations. The training dynamic encourages the causal backbone to generate hidden states that are implicitly predictive of future sequence structure. We hypothesize this approach will enhance the capabilities of decoder-only models on complex reasoning tasks. Planned experiments on datasets like GSM8K will evaluate the effectiveness of this architecture compared to standard LM baselines. T-decoder-JEPA offers a promising direction for improving the representational power and reasoning abilities of large language models through principled self-supervised learning.

7. References

[Assran et al., 2023] Assran, M., et al. (2023). Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture. arXiv preprint arXiv:2301.08243.
[Ba et al., 2016] Ba, J. L., Kiros, J. R., & Hinton, G. E. (2016). Layer normalization. arXiv preprint arXiv:1607.06450.
[Brown et al., 2020] Brown, T. B., et al. (2020). Language models are few-shot learners. Advances in neural information processing systems, 33, 1877-1901.
[Caruana, 1997] Caruana, R. (1997). Multitask learning. Machine learning, 28(1), 41-75.
[Cobbe et al., 2021] Cobbe, K., et al. (2021). Training Verifiers to Solve Math Word Problems. arXiv preprint arXiv:2110.14168.
[Devlin et al., 2019] Devlin, J., Chang, M. W., Lee, K., & Toutanova, K. (2019). Bert: Pre-training of deep bidirectional transformers for language understanding. Proceedings of NAACL-HLT 2019.
[Gao et al., 2021] Gao, T., Yao, X., & Chen, D. (2021). SimCSE: Simple Contrastive Learning of Sentence Embeddings. arXiv preprint arXiv:2104.08821.
[LeCun, 2022] LeCun, Y. (2022). A Path Towards Autonomous Machine Intelligence. OpenReview.
[Logeswaran & Lee, 2018] Logeswaran, L., & Lee, H. (2018). An efficient framework for learning sentence representations. arXiv preprint arXiv:1803.02893.
[Loshchilov & Hutter, 2019] Loshchilov, I., & Hutter, F. (2019). Decoupled weight decay regularization. ICLR 2019.
[Mallen et al., 2023] Mallen, A., et al. (2023). When Do Pre-Training Objectives Help? An Empirical Study of Multi-Task Transfer Learning. arXiv preprint arXiv:2304.14748.
[Nye et al., 2021] Nye, M., et al. (2021). Show Your Work: Scratchpads for Intermediate Computation with Language Models. arXiv preprint arXiv:2112.00114.
[Radford et al., 2018] Radford, A., Narasimhan, K., Salimans, T., & Sutskever, I. (2018). Improving language understanding by generative pre-training. OpenAI Blog.
[Radford et al., 2019] Radford, A., et al. (2019). Language models are unsupervised multitask learners. OpenAI Blog, 1(8), 9.
[Raffel et al., 2020] Raffel, C., et al. (2020). Exploring the limits of transfer learning with a unified text-to-text transformer. Journal of Machine Learning Research, 21(140), 1-67.
[Su et al., 2024] Su, J., et al. (2024). Roformer: Enhanced transformer with rotary position embedding. Neurocomputing, 568, 127063. (Note: Original RoPE work might be cited from earlier arXiv versions if preferred).
[Touvron et al., 2023] Touvron, H., et al. (2023). Llama: Open and efficient foundation language models. arXiv preprint arXiv:2302.13971.
[Vaswani et al., 2017] Vaswani, A., et al. (2017). Attention is all you need. Advances in neural information processing systems, 30.
[Vincent et al., 2008] Vincent, P., Larochelle, H., Bengio, Y., & Manzagol, P. A. (2008). Extracting and composing robust features with denoising autoencoders. ICML 2008.
[Wei et al., 2022] Wei, J., et al. (2022). Chain-of-thought prompting elicits reasoning in large language models. Advances in Neural Information Processing Systems, 35, 24824-24837.
[Xiong et al., 2020] Xiong, R., et al. (2020). On layer normalization in the transformer architecture. ICML 2020.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import copy
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import time
from typing import Optional, Tuple

# ==========================================
# 1) Hyperparameters
# ==========================================
def get_hyperparams():
    return {
        # Model Parameters
        'batch_size': 2,
        'block_size': 1024,               # INCREASED - More space for spans
        'vocab_size': 256,
        'embed_dim': 512,
        'n_heads': 8,
        'n_layers': 12,                    # Number of Decoder Blocks

        # JEPA Parameters
        'context_span_ratio': 0.6,        # Ratio calculation might need tuning with larger block size
        'target_span_ratio': 0.2,         # Ratio calculation might need tuning with larger block size
        'num_target_spans': 8,            # DECREASED - More realistic number
        'min_span_length': 32,            # DECREASED - Easier to fit smaller spans

        # Training Parameters
        'num_epochs': 50,
        'steps_per_epoch': 1000,
        'eval_interval': 200,
        'eval_iters': 100,
        'ema_decay': 0.999,
        'accumulation_steps': 8,
        'lm_loss_weight': 0.92,

        # Special Tokens
        'bos_token': 254,
        'eos_token': 255,
        'pad_token': 0,

        # Generation Parameters
        'generate_num_tokens': 1024,      # Can match block_size or be different
        'top_p': 0.8,
        'start_prompt': "Problem: A bakery produces cakes for $10 each. It costs them $5 in ingredients per cake, and they have a fixed overhead of $200 per day. How many cakes do they need to sell each day to make a daily profit of $100?",

        # Special Tags
        'thinking_tag': "<think>",
        'thinking_end_tag': "</think>",
        'answer_tag': "<answer>",
        'answer_end_tag': "</answer>",

        # Paths & Modes
        'checkpoint_path': "t_jepa_mtl_decoder_rope_bs1024_checkpoint.pt", # Updated name for new block size
        'continue_training': True,
        'system_prompt': """Consider this math problem. Think step by step and provide your reasoning between <think> </think> tags, then give your final answer between <answer> </answer> tags."""
    }

# ==========================================
# 1.1) Select device
# ==========================================
def get_device():
    device = "mps" if torch.backends.mps.is_available() else \
             ("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    return device

# ==========================================
# 1.2) Data Loading and Preprocessing for GSM8K
# ==========================================
def load_gsm8k_data():
    print("Loading GSM8K dataset...")

    try:
        # Try using the 'datasets' library first
        from datasets import load_dataset
        # Specify cache_dir to avoid potential permission issues in default locations
        cache_dir = os.path.expanduser("~/.cache/huggingface/datasets")
        os.makedirs(cache_dir, exist_ok=True)
        dataset = load_dataset("openai/gsm8k", "main", cache_dir=cache_dir)
        train_df = dataset["train"].to_pandas()
        test_df = dataset["test"].to_pandas()
        print("Dataset loaded using datasets library")
    except Exception as e:
        print(f"Error loading dataset with datasets library: {e}")
        print("Attempting alternative loading methods...")
        try:
            # Alternative: Load directly from Hugging Face Hub parquet files
            print("Attempting to load from Hugging Face Hub parquet files...")
            # Ensure you have pyarrow and fsspec installed: pip install pyarrow fsspec aiohttp
            splits = {'train': 'main/train-00000-of-00001.parquet',
                      'test': 'main/test-00000-of-00001.parquet'}
            train_df = pd.read_parquet("hf://datasets/openai/gsm8k/" + splits["train"])
            test_df = pd.read_parquet("hf://datasets/openai/gsm8k/" + splits["test"])
            print("Dataset loaded using parquet files from Hugging Face Hub")
        except Exception as e2:
            print(f"Failed to load dataset using parquet from Hub: {e2}")
            # Fallback to local path if available (adjust path if needed)
            local_path_train = "./gsm8k_data/train.jsonl" # Example local path
            local_path_test = "./gsm8k_data/test.jsonl"   # Example local path
            if os.path.exists(local_path_train) and os.path.exists(local_path_test):
                 print("Attempting to load from local JSONL files...")
                 train_df = pd.read_json(local_path_train, lines=True)
                 test_df = pd.read_json(local_path_test, lines=True)
                 print("Dataset loaded from local JSONL files.")
            else:
                print(f"Local files not found at {local_path_train} and {local_path_test}")
                raise RuntimeError("Unable to load the GSM8K dataset via datasets, parquet, or local files.")


    print(f"Training examples: {len(train_df)}")
    print(f"Test examples: {len(test_df)}")

    # Split training data into train/validation
    train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)

    print(f"Final training examples: {len(train_df)}")
    print(f"Validation examples: {len(val_df)}")
    print(f"Test examples: {len(test_df)}")

    return train_df, val_df, test_df


# ==========================================
# 1.3) Prepare data for JEPA training
# ==========================================
def prepare_batches_from_gsm8k(data_df, hyperparams, device):
    """Create training batches from GSM8K dataset with context and target spans for JEPA."""
    batch_indices = torch.randint(0, len(data_df), (hyperparams['batch_size'],))
    batch_examples = data_df.iloc[batch_indices]

    block_size = hyperparams['block_size']
    bos_token = hyperparams['bos_token']
    eos_token = hyperparams['eos_token']
    pad_token = hyperparams['pad_token']

    # JEPA specific parameters
    num_target_spans = hyperparams['num_target_spans']
    min_span_length = hyperparams['min_span_length']

    # Create storage for batches
    full_sequences = []
    context_masks = [] # Mask for JEPA context (1=context, 0=target/padding)
    target_spans_indices = [] # List of (start, end) tuples

    for _, row in batch_examples.iterrows():
        # Get question and answer
        question = row['question']
        answer = row['answer']

        # Format with system prompt and tags
        system_prompt = hyperparams['system_prompt']
        full_text = f"{system_prompt}\n\nProblem: {question}\n\n<think>{answer}</think>\n\n<answer>"

        # Extract the final answer from the explanation
        answer_lines = answer.strip().split('\n')
        final_answer = answer_lines[-1] if answer_lines else ""
        if "answer is" in final_answer.lower():
            final_answer = final_answer.split("answer is")[-1].strip()
        elif "=" in final_answer:
            final_answer = final_answer.split("=")[-1].strip()
        # Simple extraction, might need refinement
        final_answer_numeric = ''.join(filter(lambda x: x.isdigit() or x == '.', final_answer.split('####')[-1].strip()))

        full_text += f"{final_answer_numeric}</answer>"

        # Convert to byte sequence and add BOS/EOS tokens
        full_bytes = [bos_token] + [b for b in full_text.encode('utf-8', errors='replace')] + [eos_token]

        # Truncate or pad sequence to block_size
        seq_length = len(full_bytes)
        if seq_length > block_size:
            full_bytes = full_bytes[:block_size]
            seq_length = block_size # Actual length after potential truncation
        elif seq_length < block_size:
            padding_needed = block_size - seq_length
            full_bytes = full_bytes + [pad_token] * padding_needed
            # seq_length remains the original length before padding

        # Create context mask (1 = keep as context, 0 = mask for JEPA prediction)
        # Initialize all non-padding positions as potential context
        context_mask = torch.zeros(block_size, dtype=torch.float) # Use float for easier masking later
        context_mask[:seq_length] = 1 # Only real tokens can be context initially

        # Select random target spans for this example
        current_target_spans = []
        # Indices of real tokens available for masking
        available_indices = torch.where(context_mask[:seq_length] == 1)[0].tolist()

        for _ in range(num_target_spans):
            # Check if enough *remaining* tokens are available to form a min_span_length span
            if len(available_indices) < min_span_length:
                break # Not enough remaining tokens

            # Randomly choose target span length
            # Max length: limited by available indices and a fraction of total real tokens
            max_possible_len = min(len(available_indices), int(seq_length * 0.2)) # e.g., Max 20% of real tokens
            if max_possible_len < min_span_length:
                 continue # Skip if max possible length is too small

            # Ensure span_length > min_span_length
            span_length = torch.randint(min_span_length, max(min_span_length + 1, max_possible_len + 1), (1,)).item()

            # Choose random starting position *from the list of available indices*
            if len(available_indices) - span_length < 0:
                # This shouldn't happen if max_possible_len logic is correct, but safety check
                continue
            start_idx_in_available = torch.randint(0, len(available_indices) - span_length + 1, (1,)).item()
            start_pos = available_indices[start_idx_in_available]

            # Calculate end position based on start_pos and span_length
            # Ensure span doesn't exceed sequence length (should be covered by available_indices logic)
            end_pos = min(start_pos + span_length, seq_length)
            actual_span_length = end_pos - start_pos

            # Skip very short spans that might result from hitting the seq_length boundary
            if actual_span_length < min_span_length // 2:
                continue

            # Mark positions in target span on the context mask (set to 0)
            context_mask[start_pos:end_pos] = 0

            # Store span positions (start, end)
            current_target_spans.append((start_pos, end_pos))

            # Update available indices: remove indices used by the target span
            new_available_indices = []
            span_indices_set = set(range(start_pos, end_pos))
            for idx in available_indices:
                if idx not in span_indices_set:
                    new_available_indices.append(idx)
            available_indices = new_available_indices
            # Check if we successfully removed indices
            # print(f" Span {start_pos}-{end_pos}, Remaining indices: {len(available_indices)}") # Debug


        # Add to batches
        full_sequences.append(full_bytes)
        context_masks.append(context_mask)
        target_spans_indices.append(current_target_spans)
        # if not current_target_spans: print(f"Warning: No spans generated for an example. SeqLen: {seq_length}") # Debug

    # Convert to tensors
    x = torch.tensor(full_sequences, dtype=torch.long).to(device)
    context_masks = torch.stack(context_masks).to(device) # [B, T], 1 for context, 0 for target/padding

    # Create attention mask (1 for real tokens including targets, 0 for padding)
    # This mask is used by all attention layers to ignore padding
    attention_mask = (x != pad_token).float().to(device) # [B, T], 1 for non-pad, 0 for pad

    return x, context_masks, target_spans_indices, attention_mask


# ==========================================
# 2) Rotary Positional Embedding (RoPE)
# ==========================================
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=2048, base=10000, device=None):
        super().__init__()
        self.dim = dim
        # Adjust max_seq_len for RoPE based on the actual block_size
        self.max_seq_len = max_seq_len
        self.base = base
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        t = torch.arange(self.max_seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos(), persistent=False)
        self.register_buffer("sin_cached", emb.sin(), persistent=False)

    def forward(self, seq_len: int):
        # x: [*, seq_len, *]
        # returns: cos, sin buffers of shape [seq_len, dim]
        # Handle cases where seq_len might exceed precomputed length during generation potentially
        if seq_len > self.max_seq_len:
             # Dynamically extend RoPE if needed (more complex, often avoided by setting max_seq_len large enough)
             # For now, we assume seq_len <= self.max_seq_len based on block_size
             # Or simply clamp:
             # print(f"Warning: RoPE seq_len {seq_len} > max_seq_len {self.max_seq_len}. Clamping.")
             # seq_len = self.max_seq_len
            raise ValueError(f"RoPE sequence length {seq_len} exceeds precomputed max {self.max_seq_len}")

        return (
            self.cos_cached[:seq_len, ...],
            self.sin_cached[:seq_len, ...],
        )

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    """Applies RoPE to query and key tensors."""
    # Add sequence length dimension if necessary
    cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, T, D_head]
    sin = sin.unsqueeze(0).unsqueeze(0) # [1, 1, T, D_head]

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


# ==========================================
# 3) Improved Attention Mechanism (with RoPE and Causal Masking)
# ==========================================
class ImprovedAttention(nn.Module):
    def __init__(self, embed_dim, n_heads, is_self_attention=True, use_rope=True, max_seq_len=2048):
        super().__init__()
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads
        assert self.head_dim * n_heads == self.embed_dim, "embed_dim must be divisible by n_heads"
        self.is_self_attention = is_self_attention
        self.use_rope = use_rope

        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        # Instantiate RoPE only if used and needed
        if self.use_rope and self.is_self_attention:
            self.rotary_emb = RotaryEmbedding(self.head_dim, max_seq_len=max_seq_len)
        else:
            self.rotary_emb = None

        self.attn_dropout = nn.Dropout(0.1)
        self.out_dropout = nn.Dropout(0.1)

        # Buffer for causal mask (recreated if needed)
        self.register_buffer("causal_mask_cache", None, persistent=False)

    def _get_causal_mask(self, T, device):
        # Efficiently get or create causal mask
        if self.causal_mask_cache is None or self.causal_mask_cache.shape[-1] < T:
            # Create lower triangular mask (True for positions to be masked)
            mask = torch.triu(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=1)
            self.causal_mask_cache = mask
        # Return the sub-mask for the current sequence length T
        # Ensure it's on the correct device (might change between train/eval/inference)
        return self.causal_mask_cache[:T, :T].to(device=device)


    def forward(self, x, attn_mask=None, key_value_states=None, is_causal=False):
        """
        Args:
            x: Query input [B, T, C]
            attn_mask: Padding mask [B, T_k] or broadcastable. 1=Keep, 0=Mask.
            key_value_states: Optional key/value input for cross-attention [B, T_k, C].
            is_causal: If True, apply causal masking (for self-attention only).
        """
        B, T, C = x.size()
        is_cross_attn = key_value_states is not None
        # Determine if RoPE should be applied in this specific call
        use_rope_for_this_pass = self.use_rope and self.is_self_attention and not is_cross_attn and self.rotary_emb is not None

        # Project query
        q = self.q_proj(x)

        # Project keys and values
        if is_cross_attn:
            T_k = key_value_states.size(1)
            k = self.k_proj(key_value_states)
            v = self.v_proj(key_value_states)
            # Causal mask is ignored in cross-attention
            is_causal = False
        else:
            T_k = T # Self-attention
            k = self.k_proj(x)
            v = self.v_proj(x)

        # Reshape for multi-head attention
        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)    # B, H, T, D
        k = k.view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2)  # B, H, T_k, D
        v = v.view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2)  # B, H, T_k, D

        # Apply RoPE if applicable
        if use_rope_for_this_pass:
            cos, sin = self.rotary_emb(T) # Get embeddings for query length T
            q, k = apply_rotary_pos_emb(q, k, cos, sin)
            scaling_factor = 1.0 # RoPE often doesn't need explicit scaling
        else:
            scaling_factor = 1.0 / math.sqrt(self.head_dim) # Standard scaling

        # Compute scaled attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) * scaling_factor # [B, H, T, T_k]

        # Apply combined masking (padding AND causal)
        final_mask_bool = None # Boolean mask: True indicates position should be masked (-inf)

        # 1. Process padding mask (attn_mask) -> masks Keys/Values
        if attn_mask is not None:
            # Input mask: 1=keep, 0=mask. We need True where mask should be applied.
            if attn_mask.dim() == 2: # Common case: [B, T_k]
                # Expand to broadcast: [B, T_k] -> [B, 1, 1, T_k]
                padding_mask_bool = ~attn_mask.bool().unsqueeze(1).unsqueeze(2)
            elif attn_mask.dim() == 4: # E.g., [B, 1, 1, T_k]
                padding_mask_bool = ~attn_mask.bool()
            else:
                raise ValueError(f"Unsupported attn_mask dimension: {attn_mask.dim()}")
            final_mask_bool = padding_mask_bool # [B, 1, 1, T_k]

        # 2. Process causal mask (if self-attention and is_causal=True) -> masks future Query positions
        if self.is_self_attention and is_causal:
            causal_mask_bool = self._get_causal_mask(T, x.device) # [T, T]
            # Expand to broadcast: [T, T] -> [1, 1, T, T]
            causal_mask_bool = causal_mask_bool.unsqueeze(0).unsqueeze(0)

            if final_mask_bool is not None:
                # Combine: mask if *either* padding mask *or* causal mask applies
                # Broadcasting works: [B, 1, 1, T_k] | [1, 1, T, T] -> [B, 1, T, T] (since T=T_k)
                final_mask_bool = final_mask_bool | causal_mask_bool
            else:
                final_mask_bool = causal_mask_bool # [1, 1, T, T]

        # Apply the combined mask to scores
        if final_mask_bool is not None:
             # Ensure mask shape is compatible with scores [B, H, T, T_k]
             # final_mask_bool is typically [B, 1, T, T_k] or [B, 1, 1, T_k]
             scores = scores.masked_fill(final_mask_bool, torch.finfo(scores.dtype).min)

        # Apply softmax and dropout
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)

        # Apply attention to values
        attn_output = torch.matmul(attn_weights, v) # [B, H, T, D]

        # Reshape and project output
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_dropout(self.out_proj(attn_output))


# ==========================================
# 4) Transformer Decoder Block
# ==========================================
class DecoderBlock(nn.Module):
    def __init__(self, embed_dim, n_heads, dropout=0.1, max_seq_len=2048):
        super().__init__()
        # Using Pre-LN
        self.ln1 = nn.LayerNorm(embed_dim)
        self.self_attention = ImprovedAttention(embed_dim, n_heads, is_self_attention=True, use_rope=True, max_seq_len=max_seq_len)
        self.ln2 = nn.LayerNorm(embed_dim)
        hidden_dim = 4 * embed_dim
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(), # Consider SwiGLU later
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
        self.dropout = nn.Dropout(dropout) # Dropout for residual connections

    def forward(self, x, attention_mask=None, is_causal=True):
        """
        Args:
            x: Input sequence [B, T, C].
            attention_mask: Padding mask [B, T]. 1 for real tokens, 0 for padding.
            is_causal: Whether the self-attention should be causal.
        """
        # --- Self-Attention (Causal or Bidirectional based on is_causal) ---
        residual = x
        x_norm = self.ln1(x)
        attn_output = self.self_attention(x_norm, attn_mask=attention_mask, is_causal=is_causal)
        x = residual + self.dropout(attn_output)

        # --- Feed-Forward ---
        residual = x
        x_norm = self.ln2(x)
        ff_output = self.feed_forward(x_norm)
        x = residual + self.dropout(ff_output)

        return x

# ==========================================
# 5) JEPA Predictor Block (Causal Self-Attn, Cross-Attn to Decoder)
# ==========================================
class JEPAPredictorBlock(nn.Module):
    def __init__(self, embed_dim, n_heads, dropout=0.1, max_seq_len=2048):
        super().__init__()
        # Pre-LN structure
        self.ln1 = nn.LayerNorm(embed_dim)
        # Causal Self-attention within the predictor (RoPE enabled)
        self.self_attention = ImprovedAttention(embed_dim, n_heads, is_self_attention=True, use_rope=True, max_seq_len=max_seq_len)

        self.ln_cross_attn_query = nn.LayerNorm(embed_dim) # LN before cross-attn query input
        self.ln_cross_attn_kv = nn.LayerNorm(embed_dim)    # LN before cross-attn key/value input
        # Cross-attention to backbone decoder output (non-causal, no RoPE)
        self.cross_attention = ImprovedAttention(embed_dim, n_heads, is_self_attention=False, use_rope=False, max_seq_len=max_seq_len)

        self.ln3 = nn.LayerNorm(embed_dim)
        hidden_dim = 4 * embed_dim
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, decoder_output, self_attention_mask=None, cross_attention_mask=None):
        """
        Args:
            x: Predictor input sequence [B, T, C].
            decoder_output: Output from the main BackboneDecoder [B, T_kv, C].
            self_attention_mask: Padding mask for predictor input [B, T]. 1=keep, 0=mask.
            cross_attention_mask: Mask for decoder_output (keys/values in cross-attn) [B, T_kv].
                                  Should be JEPA context_mask (1=context, 0=target/pad).
        """
        # --- Causal Self-Attention within Predictor ---
        residual = x
        x_norm = self.ln1(x)
        attn_output = self.self_attention(
            x_norm,
            attn_mask=self_attention_mask, # Use overall padding mask for self-attn
            is_causal=True                 # Self-attention is CAUSAL
        )
        x = residual + self.dropout(attn_output)

        # --- Cross-Attention to Decoder Output ---
        residual = x
        query_norm = self.ln_cross_attn_query(x)           # Normalize query input (from predictor state)
        kv_norm = self.ln_cross_attn_kv(decoder_output)    # Normalize key/value input (from decoder)

        cross_attn_output = self.cross_attention(
            query_norm,                           # Query from predictor
            attn_mask=cross_attention_mask,       # Mask K/V based on JEPA context_mask
            key_value_states=kv_norm              # K/V from (normalized) decoder output
        )
        x = residual + self.dropout(cross_attn_output)

        # --- Feed-Forward ---
        residual = x
        x_norm = self.ln3(x)
        ff_output = self.feed_forward(x_norm)
        x = residual + self.dropout(ff_output)

        return x

# ==========================================
# 6) Backbone Decoder (Replaces ContextEncoder)
# ==========================================
class BackboneDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, n_heads, n_layers, block_size):
        super().__init__()
        self.block_size = block_size

        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.dropout = nn.Dropout(0.1) # Dropout after embedding

        self.blocks = nn.ModuleList([
            DecoderBlock(embed_dim, n_heads, dropout=0.1, max_seq_len=block_size)
            for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(embed_dim)
        self.apply(self._init_weights) # Initialize weights

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None: torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias); torch.nn.init.ones_(module.weight)

    def forward(self, x, attention_mask=None, is_causal=True):
        """
        Args:
            x: Input token indices [B, T].
            attention_mask: Padding mask [B, T]. 1=keep, 0=mask.
            is_causal: Controls self-attention masking in DecoderBlocks.
        """
        B, T = x.size()
        assert T <= self.block_size, f"Sequence length {T} exceeds block size {self.block_size}"

        token_emb = self.token_embedding(x) # [B, T, C]
        x = self.dropout(token_emb)

        for block in self.blocks:
            x = block(x, attention_mask=attention_mask, is_causal=is_causal)

        x = self.ln_f(x)
        return x

# ==========================================
# 7) JEPA Predictor (Using causal self-attn)
# ==========================================
class JEPAPredictor(nn.Module):
    def __init__(self, embed_dim, n_heads, n_layers, block_size):
        super().__init__()
        self.block_size = block_size
        # Consider using fewer layers for predictor, e.g., predictor_layers = n_layers // 2
        predictor_layers = n_layers # Keep same depth for now

        # Learnable mask token embedding
        self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        torch.nn.init.normal_(self.mask_token, mean=0.0, std=0.02)

        # Predictor blocks (using JEPAPredictorBlock)
        self.blocks = nn.ModuleList([
            JEPAPredictorBlock(embed_dim, n_heads, max_seq_len=block_size)
            for _ in range(predictor_layers)
        ])
        self.ln_f = nn.LayerNorm(embed_dim) # Final layer norm
        self.apply(self._init_weights) # Initialize weights

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None: torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            pass # Only mask_token is an embedding here, initialized separately
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias); torch.nn.init.ones_(module.weight)

    def forward(self, decoder_output_causal, target_spans_indices, context_mask, attention_mask):
        """
        Predict target span representations.

        Args:
            decoder_output_causal: [B, T, C] Embeddings from CAUSAL BackboneDecoder pass.
            target_spans_indices: List[List[Tuple[int, int]]] Target span indices.
            context_mask: [B, T] JEPA context mask (1=context, 0=target/padding).
            attention_mask: [B, T] Overall padding mask (1=real, 0=pad).

        Returns:
            List[List[Tensor]]: Predicted embeddings for target spans per batch item.
        """
        B, T, C = decoder_output_causal.size()

        # Initialize predictor input:
        # Use mask tokens for target positions, causal decoder output for context positions.
        predictor_input = torch.zeros_like(decoder_output_causal)
        mask_token_expanded = self.mask_token.expand(B, T, C)

        # Boolean masks for indexing
        is_context = context_mask.bool()            # Where JEPA context mask is 1
        is_target = (~is_context) & attention_mask.bool() # Where context is 0 AND not padding

        # Populate predictor input
        predictor_input[is_context] = decoder_output_causal[is_context]
        predictor_input[is_target] = mask_token_expanded[is_target]
        # Padding positions remain zero

        # Process through predictor blocks
        x = predictor_input
        for block in self.blocks:
            x = block(
                x,
                decoder_output=decoder_output_causal, # K/V for cross-attention comes from causal decoder
                self_attention_mask=attention_mask,   # Padding mask for predictor's CAUSAL self-attention
                cross_attention_mask=context_mask     # JEPA context mask to select K/V in cross-attention
            )

        x = self.ln_f(x) # Final normalization [B, T, C]

        # Extract predicted embeddings only for the target spans
        predicted_spans = []
        for b in range(B):
            batch_spans = []
            if not target_spans_indices[b]: # Handle cases where no spans were generated for this item
                predicted_spans.append(batch_spans)
                continue
            for start, end in target_spans_indices[b]:
                valid_end = min(end, T) # Ensure end index is within bounds
                if start < valid_end: # Ensure span has non-zero length
                    span_emb = x[b, start:valid_end] # Extract embeddings [SpanLen, C]
                    batch_spans.append(span_emb)
            predicted_spans.append(batch_spans)

        return predicted_spans

# ==========================================
# 8) Target Encoder (EMA copy of BackboneDecoder, runs NON-CAUSALLY)
# ==========================================
class TargetEncoder(nn.Module):
    def __init__(self, backbone_decoder, ema_decay=0.999):
        super().__init__()
        # Create a deep copy of the backbone decoder structure
        self.encoder = copy.deepcopy(backbone_decoder)
        self.ema_decay = ema_decay
        # Disable gradient computation for the target encoder
        for param in self.encoder.parameters():
            param.requires_grad = False

    @torch.no_grad()
    def update_ema(self, backbone_decoder, decay_rate=None):
        """Update target encoder weights using exponential moving average"""
        decay_rate = decay_rate if decay_rate is not None else self.ema_decay
        self.encoder.eval() # Ensure target is in eval mode
        backbone_decoder.eval() # Ensure source is also in eval mode for consistency

        source_params = dict(backbone_decoder.named_parameters())
        target_params = dict(self.encoder.named_parameters())
        assert source_params.keys() == target_params.keys(), "Parameter mismatch between backbone and target encoders!"

        for name, source_param in source_params.items():
            target_params[name].data.mul_(decay_rate).add_(source_param.data, alpha=1 - decay_rate)

    @torch.no_grad()
    def forward(self, x, attention_mask=None):
        """Forward pass for target encoder - runs NON-CAUSALLY"""
        self.encoder.eval() # Ensure target encoder is always in eval mode
        # Call the underlying decoder's forward pass, forcing is_causal=False
        return self.encoder(x, attention_mask=attention_mask, is_causal=False)

# ==========================================
# 9) Complete T-JEPA Model (Decoder Backbone)
# ==========================================
class TJEPAModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, n_heads, n_layers, block_size, ema_decay=0.999, lm_loss_weight=0.1, pad_token_id=0):
        super().__init__()
        self.pad_token_id = pad_token_id
        self.lm_loss_weight = lm_loss_weight
        self.block_size = block_size # Store block_size

        # Main Backbone: Transformer Decoder
        self.decoder_backbone = BackboneDecoder(vocab_size, embed_dim, n_heads, n_layers, block_size)

        # JEPA Predictor
        self.predictor = JEPAPredictor(embed_dim, n_heads, n_layers, block_size)

        # Target Encoder (EMA copy, runs non-causally)
        self.target_encoder = TargetEncoder(self.decoder_backbone, ema_decay)
        # Perform initial weight copy after TargetEncoder is created
        self.target_encoder.update_ema(self.decoder_backbone, decay_rate=0.0)

        # LM Head (predicts next token)
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)

        # Weight tying (tie backbone embedding with LM head)
        self.decoder_backbone.token_embedding.weight = self.lm_head.weight

    def forward(self, x, context_mask, target_spans_indices, attention_mask):
        """
        Orchestrates the forward pass for training.

        Args:
            x: [B, T] Input token sequence.
            context_mask: [B, T] JEPA context mask (1=context, 0=target/pad).
            target_spans_indices: List[List[Tuple[int, int]]] Target span indices.
            attention_mask: [B, T] Padding mask (1=real, 0=pad).

        Returns:
            Dictionary containing outputs needed for loss calculation.
        """

        # 1. Causal pass through the main decoder backbone
        # Used for LM loss and as context for the predictor's cross-attention.
        decoder_output_causal = self.decoder_backbone(
            x,
            attention_mask=attention_mask,
            is_causal=True # Standard causal operation
        ) # [B, T, C]

        # 2. Non-causal pass through the target encoder (EMA copy, no gradients)
        # Used to get the target representations for the JEPA loss.
        with torch.no_grad():
            self.target_encoder.eval() # Ensure target is in eval mode
            target_embeddings_full = self.target_encoder(
                x,
                attention_mask=attention_mask
                # Internally calls backbone with is_causal=False
            ) # [B, T, C]

        # 3. Predictor pass
        # Predicts representations for target spans using the causal decoder output
        # as context in cross-attention.
        predicted_spans_embeddings = self.predictor(
            decoder_output_causal=decoder_output_causal, # Context for cross-attention
            target_spans_indices=target_spans_indices,   # Which spans to predict
            context_mask=context_mask,                   # Mask for cross-attention K/V
            attention_mask=attention_mask                # Padding mask for predictor's self-attention
        ) # List[List[Tensor]]

        # 4. Extract actual target embeddings from the NON-CAUSAL target encoder output
        target_spans_embeddings = []
        for b in range(x.size(0)):
            batch_spans = []
            if not target_spans_indices[b]: # Handle empty span list for this batch item
                target_spans_embeddings.append(batch_spans)
                continue
            for start, end in target_spans_indices[b]:
                valid_end = min(end, x.size(1))
                if start < valid_end:
                    # Extract from the full target embeddings (non-causal)
                    span_emb = target_embeddings_full[b, start:valid_end]
                    batch_spans.append(span_emb)
            target_spans_embeddings.append(batch_spans) # List[List[Tensor]]

        # 5. Calculate LM Logits
        # Based on the output of the CAUSAL decoder backbone pass.
        lm_logits = self.lm_head(decoder_output_causal) # [B, T, VocabSize]

        return {
            "predicted_spans_embeddings": predicted_spans_embeddings, # From Predictor
            "target_spans_embeddings": target_spans_embeddings,     # From Target Encoder (non-causal)
            "lm_logits": lm_logits,                                 # From Backbone Decoder (causal)
            "input_sequence": x,                                    # For LM loss calculation
            "attention_mask": attention_mask,                       # For LM loss masking (optional)
        }

    def update_target_encoder(self):
        """Update target encoder weights using EMA"""
        self.target_encoder.update_ema(self.decoder_backbone)

    def compute_loss(self, outputs):
        """Compute combined JEPA (MSE) and LM (CrossEntropy) loss."""
        # --- JEPA MSE Loss ---
        predicted_spans = outputs["predicted_spans_embeddings"]
        target_spans = outputs["target_spans_embeddings"]
        batch_size = len(predicted_spans)
        jepa_losses = []
        num_valid_comparisons = 0 # Track how many span comparisons actually happen

        for b in range(batch_size):
            num_spans_in_batch_item = len(predicted_spans[b])
            # Ensure target list has same length (should always be true if data prep is correct)
            if num_spans_in_batch_item != len(target_spans[b]):
                # print(f"Warning: Mismatch in number of predicted ({num_spans_in_batch_item}) vs target ({len(target_spans[b])}) spans for batch item {b}.")
                continue # Skip this item if lengths mismatch

            if num_spans_in_batch_item == 0:
                continue # Skip if no spans were generated/extracted for this item

            span_losses_for_batch_item = []
            for i in range(num_spans_in_batch_item):
                pred_span = predicted_spans[b][i] # [SpanLen_pred, C]
                target_span = target_spans[b][i]  # [SpanLen_target, C]

                # Ensure spans are not empty and shapes match exactly
                if pred_span.nelement() > 0 and target_span.nelement() > 0 and pred_span.shape == target_span.shape:
                    # Optional: Normalize embeddings before MSE loss
                    # pred_span_norm = F.normalize(pred_span, p=2, dim=-1)
                    # target_span_norm = F.normalize(target_span, p=2, dim=-1)
                    # loss = F.mse_loss(pred_span_norm, target_span_norm)

                    loss = F.mse_loss(pred_span, target_span)
                    span_losses_for_batch_item.append(loss)
                    num_valid_comparisons += 1
                # else:
                    # Optional: Log why a comparison was skipped
                    # if pred_span.nelement() == 0: print(f"Debug: Skipped empty pred span {b},{i}")
                    # elif target_span.nelement() == 0: print(f"Debug: Skipped empty target span {b},{i}")
                    # else: print(f"Debug: Skipped shape mismatch {pred_span.shape} vs {target_span.shape} for {b},{i}")


            # Average loss across valid spans for this batch item
            if span_losses_for_batch_item:
                 jepa_losses.append(torch.stack(span_losses_for_batch_item).mean())

        # Average JEPA loss over the batch items that had valid spans
        if jepa_losses:
            final_jepa_loss = torch.stack(jepa_losses).mean()
        else:
            # Return zero loss if NO valid spans were compared across the entire batch
            example_tensor = outputs["lm_logits"] # Get device/dtype hint
            final_jepa_loss = torch.tensor(0.0, device=example_tensor.device, dtype=example_tensor.dtype)
            # if num_valid_comparisons == 0: print("Warning: JEPA loss is 0.0 because no valid span comparisons occurred in this batch.")


        # --- LM Cross Entropy Loss ---
        lm_logits = outputs["lm_logits"] # [B, T, V]
        input_sequence = outputs["input_sequence"] # [B, T]

        # Shift logits and labels for next token prediction
        shift_logits = lm_logits[:, :-1, :].contiguous() # [B, T-1, V]
        shift_labels = input_sequence[:, 1:].contiguous() # [B, T-1]

        # Flatten the tokens for CrossEntropyLoss
        shift_logits = shift_logits.view(-1, shift_logits.size(-1)) # [B*(T-1), V]
        shift_labels = shift_labels.view(-1) # [B*(T-1)]

        # Calculate loss, ignoring padding tokens
        lm_loss = F.cross_entropy(shift_logits, shift_labels, ignore_index=self.pad_token_id)

        # --- Combine Losses ---
        total_loss = final_jepa_loss + self.lm_loss_weight * lm_loss

        return {
            "total_loss": total_loss,
            "jepa_loss": final_jepa_loss, # This should now be non-zero if spans are generated
            "lm_loss": lm_loss
        }

    @torch.no_grad()
    def generate(self, x, max_new_tokens, temperature=1.0, top_p=0.9):
        """Generate text autoregressively using the BackboneDecoder."""
        self.eval() # Ensure model is in evaluation mode
        B = x.size(0)
        pad_token_id = self.pad_token_id

        for _ in range(max_new_tokens):
            # Crop context if it exceeds block size
            x_cond = x if x.size(1) <= self.block_size else x[:, -self.block_size:]
            seq_len = x_cond.size(1)

            # Create attention mask for padding (1 for real tokens, 0 for padding)
            attention_mask = (x_cond != pad_token_id).float() # [B, T]

            # Get embeddings from the decoder backbone (CAUSALLY)
            decoder_output = self.decoder_backbone(
                x_cond,
                attention_mask=attention_mask,
                is_causal=True # Explicitly causal for generation
            ) # [B, T, C]

            # Get logits for the *next* token prediction (using the last token's embedding)
            # Apply LM head to the embedding of the very last token in the sequence
            logits = self.lm_head(decoder_output[:, -1, :])  # [B, C] -> [B, V]

            # Apply temperature scaling
            if temperature > 0 and temperature != 1.0:
                 logits = logits / temperature

            # Apply top-p (nucleus) sampling
            if top_p > 0.0 and top_p < 1.0:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

                # Remove tokens with cumulative probability above the threshold
                sorted_indices_to_remove = cumulative_probs > top_p
                # Shift right to keep the first token above the threshold
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0

                # Scatter mask back to original indices
                indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                # Apply mask by setting logits to -infinity
                logits[indices_to_remove] = float('-inf')

            # Sample from the potentially modified distribution
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1) # [B, 1]

            # Append sampled token to the running sequence
            x = torch.cat([x, next_token], dim=1)

            # Optional: Check if EOS token was generated in *all* sequences (for batch generation)
            # if hyperparams['eos_token'] is not None and (next_token == hyperparams['eos_token']).all():
            #     break

        return x


# ==========================================
# 10) Evaluation Function
# ==========================================
@torch.no_grad()
def estimate_loss(model, train_df, val_df, hyperparams, device):
    """Estimates loss on train and validation splits."""
    out = {}
    model.eval() # Set model to evaluation mode

    for split, df in [('train', train_df), ('val', val_df)]:
        total_losses = torch.zeros(hyperparams['eval_iters'])
        jepa_losses = torch.zeros(hyperparams['eval_iters'])
        lm_losses = torch.zeros(hyperparams['eval_iters'])

        # Use tqdm for eval iterations if desired, but can be removed
        # pbar_eval = tqdm(range(hyperparams['eval_iters']), desc=f"Eval {split}", leave=False)
        # for k in pbar_eval:
        for k in range(hyperparams['eval_iters']):
            # Get a batch of data
            x, context_mask, target_spans_indices, attention_mask = prepare_batches_from_gsm8k(
                df, hyperparams, device
            )

            # Forward pass through the model
            outputs = model(x, context_mask, target_spans_indices, attention_mask)

            # Compute loss using the model's loss function
            loss_dict = model.compute_loss(outputs)

            # Store losses
            total_losses[k] = loss_dict['total_loss'].item()
            jepa_losses[k] = loss_dict['jepa_loss'].item()
            lm_losses[k] = loss_dict['lm_loss'].item()

        # Calculate average losses for the split
        out[split + '_total'] = total_losses.mean()
        out[split + '_jepa'] = jepa_losses.mean()
        out[split + '_lm'] = lm_losses.mean()

    # model.train() # Caller should reset mode after evaluation
    return out


# ==========================================
# 11) Generate Text Function (Uses model.generate)
# ==========================================
@torch.no_grad()
def generate_from_prompt(model, hyperparams, prompt_text=None, max_new_tokens=200, top_p=None, device="cuda"):
    """Generates text from a prompt using the model's generate method."""
    model.eval() # Ensure evaluation mode
    prompt_text = prompt_text if prompt_text is not None else hyperparams['start_prompt']
    top_p = top_p if top_p is not None else hyperparams['top_p']
    system_prompt = hyperparams['system_prompt']
    full_prompt = f"{system_prompt}\n\nProblem: {prompt_text}\n\n<think>" # Start generation within think tags

    # Encode prompt
    bos_token = hyperparams['bos_token']
    prompt_bytes = [bos_token] + [b for b in full_prompt.encode('utf-8', errors='replace')]
    context = torch.tensor(prompt_bytes, dtype=torch.long, device=device).unsqueeze(0) # [1, T_prompt]

    # Use the model's generate method
    full_output_tokens = model.generate(
        context,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        temperature=0.8 # Example temperature, could be hyperparameter
    ) # [1, T_prompt + T_new]

    # Decode the full output sequence
    full_output_list = full_output_tokens[0].tolist()

    # Decode bytes to text, handling potential errors and special tokens
    try:
        # Find EOS token if present and truncate
        eos_token = hyperparams['eos_token']
        eos_pos = full_output_list.index(eos_token) if eos_token in full_output_list else -1
        if eos_pos != -1:
            full_output_list = full_output_list[:eos_pos] # Truncate at EOS

        # Remove padding tokens and decode
        pad_token = hyperparams['pad_token']
        decoded_bytes = bytes([tok for tok in full_output_list if tok != pad_token])
        generated_text = decoded_bytes.decode('utf-8', errors='replace')
        return generated_text
    except Exception as e:
        print(f"Decoding error during generation: {e}")
        # Fallback: return raw bytes representation
        return str(bytes(full_output_list))


# ==========================================
# 12) Token-by-Token Generation (Manual loop)
# ==========================================
@torch.no_grad()
def generate_token_by_token(model, hyperparams, prompt_text, max_new_tokens=200, device="cuda"):
    """Generates token by token, printing output, using the decoder model."""
    model.eval() # Ensure evaluation mode
    system_prompt = hyperparams['system_prompt']
    full_prompt = f"{system_prompt}\n\nProblem: {prompt_text}\n\n<think>"
    bos_token = hyperparams['bos_token']
    pad_token = hyperparams['pad_token']
    eos_token = hyperparams['eos_token']

    # Encode prompt
    prompt_bytes = [bos_token] + [b for b in full_prompt.encode('utf-8', errors='replace')]
    context = torch.tensor(prompt_bytes, dtype=torch.long, device=device).unsqueeze(0) # [1, T_prompt]

    print(f"\n--- Generating from prompt ---\n{full_prompt}", end="", flush=True)

    generated_tokens = []
    current_byte_fragment = b''

    # Manually loop for token-by-token generation
    for _ in range(max_new_tokens):
        # --- Prepare input for this step ---
        # Crop context if it exceeds block size
        context_cond = context if context.size(1) <= model.block_size else context[:, -model.block_size:]
        # Create attention mask for padding
        attention_mask = (context_cond != pad_token).float() # [1, T_cond]

        # --- Forward pass (Causal) ---
        decoder_output = model.decoder_backbone(
            context_cond,
            attention_mask=attention_mask,
            is_causal=True
        ) # [1, T_cond, C]
        # Get logits for the next token prediction (using the last token's output)
        logits = model.lm_head(decoder_output[:, -1, :]) # [1, V]

        # --- Sampling (Top-p) ---
        top_p = hyperparams['top_p']
        temperature = 0.8 # Example temperature
        if temperature > 0 and temperature != 1.0:
            logits = logits / temperature
        if top_p > 0.0 and top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
            logits[indices_to_remove] = float('-inf')

        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1) # [1, 1]

        # --- Update context and decode/print ---
        next_token_value = next_token.item()
        context = torch.cat([context, next_token], dim=1) # Append to context for next step
        generated_tokens.append(next_token_value)

        # Attempt to decode and print the new byte(s)
        current_byte_fragment += bytes([next_token_value])
        try:
            next_char = current_byte_fragment.decode('utf-8')
            print(next_char, end="", flush=True)
            current_byte_fragment = b'' # Reset fragment if decode succeeds
            time.sleep(0.01) # Small delay for visualization
        except UnicodeDecodeError:
            # If we can't decode (partial UTF-8 character), wait for more bytes
            if len(current_byte_fragment) > 3: # Avoid getting stuck on invalid sequences
                 print("<?>", end="", flush=True) # Print placeholder for invalid sequence
                 current_byte_fragment = b'' # Reset

        # Check for EOS token
        if next_token_value == eos_token:
            print("<EOS>", end="", flush=True)
            break

    print("\n\n--- Generation completed ---")

    # Return the full generated text (including prompt) after loop finishes
    full_generated_list = prompt_bytes + generated_tokens
    try:
        eos_pos = full_generated_list.index(eos_token) if eos_token in full_generated_list else -1
        if eos_pos != -1: full_generated_list = full_generated_list[:eos_pos]
        decoded_bytes = bytes([tok for tok in full_generated_list if tok != pad_token])
        return decoded_bytes.decode('utf-8', errors='replace')
    except Exception as e:
        print(f"Final decoding error after token-by-token generation: {e}")
        return str(bytes(full_generated_list))


# ==========================================
# 13) Training Implementation
# ==========================================
def train(continue_training=True):
    """Train the T-JEPA DECODER model on GSM8K."""
    # --- Setup ---
    hyperparams = get_hyperparams()
    device = get_device()
    train_df, val_df, test_df = load_gsm8k_data()

    # --- Model Initialization ---
    model = TJEPAModel(
        vocab_size=hyperparams['vocab_size'],
        embed_dim=hyperparams['embed_dim'],
        n_heads=hyperparams['n_heads'],
        n_layers=hyperparams['n_layers'],
        block_size=hyperparams['block_size'],
        ema_decay=hyperparams['ema_decay'],
        lm_loss_weight=hyperparams['lm_loss_weight'],
        pad_token_id=hyperparams['pad_token']
    ).to(device)

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model Block Size: {hyperparams['block_size']}")
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}") # Includes decoder, predictor, LM head

    # --- Optimizer ---
    # Filter out parameters that don't require gradients (target encoder)
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=3e-4, # Initial LR, scheduler will adjust
        betas=(0.9, 0.95),
        weight_decay=0.1
    )

    # --- Checkpoint Loading ---
    start_epoch = 0
    best_val_loss = float('inf')
    current_step = 0
    checkpoint_path = hyperparams['checkpoint_path']

    if continue_training and os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}...")
        try:
            checkpoint = torch.load(checkpoint_path, map_location=device)
            model_state = checkpoint['model_state']

            # --- Flexible State Dict Loading ---
            # Handles potential renames (e.g., encoder -> decoder) or minor architecture changes
            current_model_dict = model.state_dict()
            processed_state_dict = {}
            warned_keys = set()
            for k, v in model_state.items():
                new_k = k
                # Example rename: if checkpoint has 'context_encoder', map to 'decoder_backbone'
                if k.startswith("context_encoder."):
                    new_k = k.replace("context_encoder.", "decoder_backbone.", 1)

                if new_k in current_model_dict:
                    if v.shape == current_model_dict[new_k].shape:
                        processed_state_dict[new_k] = v
                    else:
                        if new_k not in warned_keys:
                            print(f"Warning: Shape mismatch for key '{new_k}'. Checkpoint: {v.shape}, Model: {current_model_dict[new_k].shape}. Skipping.")
                            warned_keys.add(new_k)
                # else:
                #     if k not in warned_keys and new_k not in warned_keys: # Avoid double warning if rename failed
                #          print(f"Warning: Key '{k}' (mapped to '{new_k}') not found in current model. Skipping.")
                #          warned_keys.add(k); warned_keys.add(new_k)

            missing_keys, unexpected_keys = model.load_state_dict(processed_state_dict, strict=False)
            if missing_keys: print(f"Warning: Missing keys in final state_dict load: {missing_keys}")
            if unexpected_keys: print(f"Warning: Unexpected keys in final state_dict load: {unexpected_keys}")
            # --- End Flexible State Dict Loading ---


            # Load optimizer state cautiously
            try:
                # Basic check: does the number of parameter groups match?
                if len(optimizer.param_groups) == len(checkpoint['optimizer_state']['param_groups']):
                    # More thorough check: do parameter IDs seem to align? (Heuristic)
                    # This is complex; often safer to reinitialize if model structure changed significantly.
                    # For simplicity, we'll try loading if group count matches.
                    optimizer.load_state_dict(checkpoint['optimizer_state'])
                    print("Optimizer state loaded.")
                else:
                     print("Warning: Optimizer parameter group mismatch. Reinitializing optimizer.")
                     optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4, betas=(0.9, 0.95), weight_decay=0.1)
            except Exception as e_optim:
                 print(f"Warning: Could not load optimizer state: {e_optim}. Reinitializing.")
                 optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4, betas=(0.9, 0.95), weight_decay=0.1)

            # Load training progress
            start_epoch = checkpoint.get('epoch', 0) + 1
            best_val_loss = checkpoint.get('val_loss', float('inf'))
            current_step = checkpoint.get('current_step', start_epoch * hyperparams['steps_per_epoch'])
            print(f"Resuming from epoch {start_epoch}, step {current_step}.")

            # IMPORTANT: Re-sync target encoder from the loaded backbone weights
            model.target_encoder.update_ema(model.decoder_backbone, decay_rate=0.0)
            print("Target encoder re-synced from loaded backbone weights.")

        except Exception as e:
            print(f"Error loading checkpoint comprehensively: {e}")
            print("Starting training from scratch or with partially loaded state.")
            start_epoch = 0; best_val_loss = float('inf'); current_step = 0
            # Ensure target encoder is initialized correctly if loading failed
            model.target_encoder.update_ema(model.decoder_backbone, decay_rate=0.0)

    else:
        print("Starting training from scratch.")
        # Initial sync of target encoder
        model.target_encoder.update_ema(model.decoder_backbone, decay_rate=0.0)

    # --- LR Scheduler ---
    grad_clip = 1.0
    total_steps = hyperparams['num_epochs'] * hyperparams['steps_per_epoch']
    warmup_steps = 2000 # Example warmup steps
    base_lr = 3e-4
    min_lr = 1e-5

    def get_lr(step):
        # Cosine decay with warmup
        if step < warmup_steps:
            return base_lr * step / warmup_steps
        decay_steps = total_steps - warmup_steps
        steps_after_warmup = step - warmup_steps
        if steps_after_warmup >= decay_steps: # Avoid going past total steps
            return min_lr
        cosine_decay = 0.5 * (1 + math.cos(math.pi * steps_after_warmup / decay_steps))
        decayed_lr = min_lr + (base_lr - min_lr) * cosine_decay
        return max(min_lr, decayed_lr) # Ensure LR doesn't drop below min_lr

    # --- Training Loop ---
    print(f"Starting training on GSM8K dataset with T-JEPA (DECODER bs={hyperparams['block_size']} + RoPE + MTL)...")
    accumulation_steps = hyperparams['accumulation_steps']
    # Optional: Setup Mixed Precision
    # scaler = torch.cuda.amp.GradScaler(enabled=(device=='cuda'))

    for epoch in range(start_epoch, hyperparams['num_epochs']):
        print(f"\n--- Epoch {epoch+1}/{hyperparams['num_epochs']} ---")
        model.train() # Set model to training mode
        epoch_total_loss, epoch_jepa_loss, epoch_lm_loss = 0.0, 0.0, 0.0
        steps_in_epoch = hyperparams['steps_per_epoch']
        optimizer.zero_grad() # Zero gradients at the start of epoch / after optimizer step

        pbar = tqdm(range(steps_in_epoch), desc=f"Epoch {epoch+1}")
        for step_in_epoch in pbar:
            global_step = current_step

            # --- Periodic Evaluation ---
            if global_step > 0 and global_step % hyperparams['eval_interval'] == 0:
                model.eval() # Switch to eval mode
                losses = estimate_loss(model, train_df, val_df, hyperparams, device)
                print(f"\nStep {global_step} Eval:")
                print(f"  Train Total: {losses['train_total']:.4f}, JEPA: {losses['train_jepa']:.4f}, LM: {losses['train_lm']:.4f}")
                print(f"  Val Total:   {losses['val_total']:.4f}, JEPA: {losses['val_jepa']:.4f}, LM: {losses['val_lm']:.4f}")
                model.train() # Switch back to train mode

                # Save best model based on validation total loss
                current_val_loss = losses['val_total']
                if current_val_loss < best_val_loss:
                    best_val_loss = current_val_loss
                    save_path = checkpoint_path.replace('.pt', '_best.pt')
                    # Save model state, optimizer, epoch, step, loss
                    torch.save({
                        'model_state': model.state_dict(),
                        'optimizer_state': optimizer.state_dict(),
                        'epoch': epoch,
                        'current_step': global_step,
                        'val_loss': best_val_loss,
                        'hyperparams': hyperparams # Save hyperparams used for this checkpoint
                    }, save_path)
                    print(f"  New best model saved to {save_path}! Val loss: {best_val_loss:.4f}")

            # --- Data Batch ---
            try:
                 x, context_mask, target_spans_indices, attention_mask = prepare_batches_from_gsm8k(
                    train_df, hyperparams, device)
            except Exception as data_err:
                print(f"\nError preparing batch at step {global_step}: {data_err}. Skipping step.")
                # Skip optimizer step if accumulation would be incomplete
                if (step_in_epoch + 1) % accumulation_steps == 0:
                     optimizer.zero_grad() # Reset grads if skipping step at accumulation boundary
                current_step += 1 # Still increment step counter
                continue

            # --- Forward and Loss Calculation ---
            # Optional: Use autocast for mixed precision
            # with torch.autocast(device_type=device if device != 'cpu' else 'cpu', dtype=torch.bfloat16 if device=='cuda' else torch.float32, enabled=(device=='cuda')):
            outputs = model(x, context_mask, target_spans_indices, attention_mask)
            loss_dict = model.compute_loss(outputs)
            total_loss = loss_dict['total_loss']
            jepa_loss = loss_dict['jepa_loss']
            lm_loss = loss_dict['lm_loss']

            # Scale loss for gradient accumulation
            scaled_loss = total_loss / accumulation_steps

            # --- Backward Pass ---
            # scaler.scale(scaled_loss).backward() # With AMP
            scaled_loss.backward() # Without AMP

            # Accumulate epoch losses for monitoring (average per step)
            epoch_total_loss += total_loss.item()
            epoch_jepa_loss += jepa_loss.item()
            epoch_lm_loss += lm_loss.item()

            # --- Optimizer Step (after accumulation) ---
            if (step_in_epoch + 1) % accumulation_steps == 0:
                # Unscale gradients before clipping (required for AMP)
                # scaler.unscale_(optimizer) # With AMP

                # Clip gradients to prevent explosion
                torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, model.parameters()), grad_clip)

                 # Check for NaN/Inf gradients *before* optimizer step
                found_nan_inf = False
                for p in filter(lambda p: p.requires_grad and p.grad is not None, model.parameters()):
                    if not torch.isfinite(p.grad).all():
                        print(f"\nWarning: NaN or Inf found in gradients at step {global_step}. Zeroing gradients for this step.")
                        found_nan_inf = True
                        break
                if found_nan_inf:
                    optimizer.zero_grad() # Skip update if grads are invalid
                else:
                    # Update learning rate based on global step
                    lr = get_lr(global_step)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr

                    # Perform optimizer step
                    # scaler.step(optimizer) # With AMP
                    optimizer.step() # Without AMP

                # Update target encoder using EMA after successful optimizer step
                if not found_nan_inf:
                     model.update_target_encoder()

                # Update scaler for next iteration (AMP)
                # scaler.update() # With AMP

                # Zero gradients for the next accumulation cycle
                optimizer.zero_grad()

                # --- Logging ---
                # Calculate average loss over steps completed so far in the epoch
                avg_total_loss = epoch_total_loss / (step_in_epoch + 1)
                avg_jepa_loss = epoch_jepa_loss / (step_in_epoch + 1)
                avg_lm_loss = epoch_lm_loss / (step_in_epoch + 1)
                # Update tqdm progress bar
                pbar.set_description(f"E{epoch+1}, S{global_step+1}/{total_steps}, LR: {lr:.6f}")
                pbar.set_postfix({
                    "AvgLoss": f"{avg_total_loss:.4f}",
                    "JEPA": f"{avg_jepa_loss:.4f}", # Should be non-zero now
                    "LM": f"{avg_lm_loss:.4f}",
                    "LastJEPA": f"{jepa_loss.item():.4f}", # Show last step's JEPA loss
                })

            current_step += 1 # Increment global step counter

        # --- End of Epoch ---
        # Generate sample text
        try:
            print("\nGenerating sample text at end of epoch...")
            model.eval() # Set to eval mode for generation
            sample_text = generate_from_prompt(
                model, hyperparams, hyperparams['start_prompt'],
                max_new_tokens=256, # Shorter sample for epoch end
                device=device
            )
            print(f"Sample: {sample_text}\n" + "-"*20)
            model.train() # Set back to train mode
        except Exception as e:
            print(f"Error generating sample: {e}")
            model.train() # Ensure model is back in train mode

        # Save end-of-epoch checkpoint
        torch.save({
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'epoch': epoch,
            'current_step': current_step,
            'val_loss': best_val_loss, # Save the best validation loss seen so far
            'hyperparams': hyperparams # Save hyperparams with checkpoint
        }, checkpoint_path)
        print(f"Checkpoint saved at end of epoch {epoch+1} to {checkpoint_path}.")

    print("Training complete!")


# ==========================================
# 14) Inference Implementation
# ==========================================
def inference(model_path, prompt_text, hyperparams_override=None):
    """Run inference with trained DECODER model."""
    device = get_device()

    # --- Load Checkpoint and Hyperparameters ---
    if not os.path.exists(model_path):
        print(f"Error: Model checkpoint not found at {model_path}")
        return None

    print(f"Loading model checkpoint from {model_path}...")
    try:
        checkpoint = torch.load(model_path, map_location=device)
        # Load hyperparams from checkpoint if available, otherwise use defaults/overrides
        hyperparams_loaded = checkpoint.get('hyperparams', None)
        if hyperparams_loaded:
            print("Using hyperparameters loaded from checkpoint.")
            hyperparams = hyperparams_loaded
        else:
            print("Warning: Hyperparameters not found in checkpoint, using default values.")
            hyperparams = get_hyperparams()

        # Allow overriding specific hyperparameters for inference
        if hyperparams_override:
            print(f"Applying hyperparameter overrides: {hyperparams_override}")
            hyperparams.update(hyperparams_override)

        print(f"Using hyperparameters for inference: {hyperparams}") # Log effective hyperparams

    except Exception as e:
        print(f"Error loading checkpoint structure: {e}")
        return None

    # --- Create Model Structure based on loaded/effective Hyperparameters ---
    try:
        model = TJEPAModel(
            vocab_size=hyperparams['vocab_size'], embed_dim=hyperparams['embed_dim'],
            n_heads=hyperparams['n_heads'], n_layers=hyperparams['n_layers'],
            block_size=hyperparams['block_size'], ema_decay=hyperparams['ema_decay'], # Needed for structure
            lm_loss_weight=hyperparams['lm_loss_weight'], pad_token_id=hyperparams['pad_token']
        ).to(device)
    except KeyError as e:
         print(f"Error: Missing hyperparameter '{e}' needed to build the model structure.")
         return None

    # --- Load Model State ---
    try:
        model_state = checkpoint['model_state']
        # Flexible loading (handle potential renames/missing/unexpected keys)
        current_model_dict = model.state_dict()
        processed_state_dict = {}
        for k, v in model_state.items():
            new_k = k
            if k.startswith("context_encoder."): new_k = k.replace("context_encoder.", "decoder_backbone.", 1)
            if new_k in current_model_dict and v.shape == current_model_dict[new_k].shape:
                processed_state_dict[new_k] = v
        missing, unexpected = model.load_state_dict(processed_state_dict, strict=False)
        if missing: print(f"  Info: Missing keys while loading state_dict: {missing}")
        if unexpected: print(f"  Info: Unexpected keys while loading state_dict: {unexpected}")
        print("Model state loaded successfully.")
        loaded_epoch = checkpoint.get('epoch', -1); loaded_step = checkpoint.get('current_step', -1)
        print(f"  Checkpoint details: Epoch {loaded_epoch}, Step {loaded_step}, Val Loss {checkpoint.get('val_loss', 'N/A'):.4f}")
    except Exception as e:
        print(f"Error loading model state weights: {e}")
        print("Attempting inference with initialized model weights (may perform poorly).")

    # --- Run Generation ---
    model.eval() # Set to evaluation mode
    print(f"\n--- Generating response for prompt ---")
    print(f"Prompt: {prompt_text}")

    # Use token-by-token generation for streaming output
    result = generate_token_by_token(
        model, hyperparams, prompt_text=prompt_text,
        max_new_tokens=hyperparams.get('generate_num_tokens', 1024), # Use hyperparam, default 512
        device=device
    )
    # Result is printed during generation

    return result

# ==========================================
# 15) Main Entry Point
# ==========================================
if __name__ == "__main__":
    # --- Configuration ---
    # Load default hyperparameters initially to get paths etc.
    default_hyperparams = get_hyperparams()

    # Choose mode: "train" or "inference"
    MODE = "train"
    # MODE = "inference"

    # Set prompt for inference mode
    INFERENCE_PROMPT = "A rectangle has a length of 15 cm and a width of 8 cm. What is its perimeter and area?"
    # Specify model path for inference (usually the best saved model)
    # Use path from default hyperparams, but it might be overridden if loaded from checkpoint in inference mode
    INFERENCE_MODEL_PATH = default_hyperparams['checkpoint_path'].replace('.pt', '_best.pt')
    # --- End Configuration ---


    if MODE == "train":
        print("Starting training...")
        # Pass continue_training flag from default hyperparams
        train(continue_training=default_hyperparams['continue_training'])

    elif MODE == "inference":
        print("Starting inference...")
        # Check if the specified best model path exists, otherwise try the regular checkpoint
        if not os.path.exists(INFERENCE_MODEL_PATH):
             print(f"Warning: Best model path '{INFERENCE_MODEL_PATH}' not found.")
             base_checkpoint_path = default_hyperparams['checkpoint_path']
             if os.path.exists(base_checkpoint_path):
                 print(f"Attempting to use the base checkpoint path: '{base_checkpoint_path}'")
                 INFERENCE_MODEL_PATH = base_checkpoint_path
             else:
                 print(f"Error: Neither best model nor base checkpoint path found ('{base_checkpoint_path}'). Cannot run inference.")
                 exit() # Exit if no model file found

        # Run inference function
        inference(INFERENCE_MODEL_PATH, INFERENCE_PROMPT) # Hyperparams will be loaded from checkpoint

    else:
        print(f"Unknown mode: {MODE}. Choose 'train' or 'inference'.")

Starting training...
Using device: cuda
Loading GSM8K dataset...
Error loading dataset with datasets library: No module named 'datasets'
Attempting alternative loading methods...
Attempting to load from Hugging Face Hub parquet files...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Dataset loaded using parquet files from Hugging Face Hub
Training examples: 7473
Test examples: 1319
Final training examples: 6725
Validation examples: 748
Test examples: 1319
Model Block Size: 1024
Total parameters: 126,309,888
Trainable parameters: 88,367,616
Loading checkpoint from t_jepa_mtl_decoder_rope_bs1024_checkpoint.pt...
Optimizer state loaded.
Resuming from epoch 41, step 41000.
Target encoder re-synced from loaded backbone weights.
Starting training on GSM8K dataset with T-JEPA (DECODER bs=1024 + RoPE + MTL)...

--- Epoch 42/50 ---


Epoch 42:   0%|          | 0/1000 [00:00<?, ?it/s]


Step 41000 Eval:
  Train Total: 1.0670, JEPA: 0.8443, LM: 0.2421
  Val Total:   1.3609, JEPA: 0.8521, LM: 0.5531


E42, S41200/50000, LR: 0.000033:  20%|██        | 200/1000 [04:01<11:30,  1.16it/s, AvgLoss=1.0823, JEPA=0.8462, LM=0.2566, LastJEPA=0.9090]


Step 41200 Eval:
  Train Total: 1.0567, JEPA: 0.8421, LM: 0.2333
  Val Total:   1.3673, JEPA: 0.8525, LM: 0.5595


E42, S41400/50000, LR: 0.000032:  40%|████      | 400/1000 [08:03<08:34,  1.17it/s, AvgLoss=1.0896, JEPA=0.8487, LM=0.2619, LastJEPA=0.8420]


Step 41400 Eval:
  Train Total: 1.0660, JEPA: 0.8538, LM: 0.2306
  Val Total:   1.3528, JEPA: 0.8424, LM: 0.5548


E42, S41600/50000, LR: 0.000031:  60%|██████    | 600/1000 [12:05<05:46,  1.15it/s, AvgLoss=1.0829, JEPA=0.8462, LM=0.2572, LastJEPA=0.8483]


Step 41600 Eval:
  Train Total: 1.0516, JEPA: 0.8496, LM: 0.2196
  Val Total:   1.3555, JEPA: 0.8484, LM: 0.5512


E42, S41800/50000, LR: 0.000030:  80%|████████  | 800/1000 [16:06<02:51,  1.17it/s, AvgLoss=1.0811, JEPA=0.8454, LM=0.2562, LastJEPA=0.7833]


Step 41800 Eval:
  Train Total: 1.0568, JEPA: 0.8455, LM: 0.2298
  Val Total:   1.3415, JEPA: 0.8422, LM: 0.5427


E42, S42000/50000, LR: 0.000029: 100%|██████████| 1000/1000 [20:08<00:00,  1.21s/it, AvgLoss=1.0789, JEPA=0.8446, LM=0.2546, LastJEPA=0.7697]



Generating sample text at end of epoch...
Sample: �Consider this math problem. Think step by step and provide your reasoning between <think> </think> tags, then give your final answer between <answer> </answer> tags.

Problem: Problem: A bakery produces cakes for $10 each. It costs them $5 in ingredients per cake, and they have a fixed overhead of $200 per day. How many cakes do they need to sell each day to make a daily profit of $100?

<think>The number of cakes that they produced is 10 * 10 = $<<10*10=100>>100.
The fixed overhead and sell cake then to sell 100 - 20 = $<<100-20=80>>80.
The number of cakes that they need to sell a day is 80 * $20 = $<<80*20=1600>>1600.
#### 1600</think>

<answer
--------------------
Checkpoint saved at end of epoch 42 to t_jepa_mtl_decoder_rope_bs1024_checkpoint.pt.

--- Epoch 43/50 ---


Epoch 43:   0%|          | 0/1000 [00:00<?, ?it/s]


Step 42000 Eval:
  Train Total: 1.0581, JEPA: 0.8399, LM: 0.2372
  Val Total:   1.3533, JEPA: 0.8435, LM: 0.5541


E43, S42200/50000, LR: 0.000028:  20%|██        | 200/1000 [04:02<11:34,  1.15it/s, AvgLoss=1.0731, JEPA=0.8428, LM=0.2504, LastJEPA=0.8019]


Step 42200 Eval:
  Train Total: 1.0472, JEPA: 0.8383, LM: 0.2271
  Val Total:   1.3633, JEPA: 0.8506, LM: 0.5572


E43, S42400/50000, LR: 0.000028:  40%|████      | 400/1000 [08:04<08:33,  1.17it/s, AvgLoss=1.0775, JEPA=0.8465, LM=0.2511, LastJEPA=0.8113]


Step 42400 Eval:
  Train Total: 1.0449, JEPA: 0.8430, LM: 0.2195
  Val Total:   1.3786, JEPA: 0.8553, LM: 0.5688


E43, S42600/50000, LR: 0.000027:  60%|██████    | 600/1000 [12:06<05:44,  1.16it/s, AvgLoss=1.0777, JEPA=0.8479, LM=0.2498, LastJEPA=0.8763]


Step 42600 Eval:
  Train Total: 1.0544, JEPA: 0.8490, LM: 0.2233
  Val Total:   1.3480, JEPA: 0.8413, LM: 0.5507


E43, S42800/50000, LR: 0.000026:  80%|████████  | 800/1000 [16:08<02:51,  1.16it/s, AvgLoss=1.0732, JEPA=0.8458, LM=0.2473, LastJEPA=0.7940]


Step 42800 Eval:
  Train Total: 1.0561, JEPA: 0.8468, LM: 0.2275
  Val Total:   1.3588, JEPA: 0.8426, LM: 0.5612


E43, S43000/50000, LR: 0.000025: 100%|██████████| 1000/1000 [20:10<00:00,  1.21s/it, AvgLoss=1.0747, JEPA=0.8464, LM=0.2481, LastJEPA=0.8829]



Generating sample text at end of epoch...
Sample: �Consider this math problem. Think step by step and provide your reasoning between <think> </think> tags, then give your final answer between <answer> </answer> tags.

Problem: Problem: A bakery produces cakes for $10 each. It costs them $5 in ingredients per cake, and they have a fixed overhead of $200 per day. How many cakes do they need to sell each day to make a daily profit of $100?

<think>The fixed overhead was $10 x 12 = $<<10*12=240>>240.
Therefore, the fixed overhead was $240 x 5 = $<<240*5=1200>>1200.
The fixed overhead was $50 x 1200 = $<<50*1200=60000>>60000.
Therefore, the fixed overhead was $10000 - $60000 = $<<10000-60000=10000>>10
--------------------
Checkpoint saved at end of epoch 43 to t_jepa_mtl_decoder_rope_bs1024_checkpoint.pt.

--- Epoch 44/50 ---


Epoch 44:   0%|          | 0/1000 [00:00<?, ?it/s]


Step 43000 Eval:
  Train Total: 1.0539, JEPA: 0.8460, LM: 0.2259
  Val Total:   1.3610, JEPA: 0.8474, LM: 0.5583


E44, S43200/50000, LR: 0.000024:  20%|██        | 200/1000 [04:02<11:23,  1.17it/s, AvgLoss=1.0725, JEPA=0.8494, LM=0.2424, LastJEPA=0.8188]


Step 43200 Eval:
  Train Total: 1.0636, JEPA: 0.8464, LM: 0.2361
  Val Total:   1.3980, JEPA: 0.8596, LM: 0.5853


E44, S43400/50000, LR: 0.000023:  40%|████      | 400/1000 [08:04<08:34,  1.17it/s, AvgLoss=1.0695, JEPA=0.8461, LM=0.2428, LastJEPA=0.7845]


Step 43400 Eval:
  Train Total: 1.0617, JEPA: 0.8553, LM: 0.2243
  Val Total:   1.3456, JEPA: 0.8415, LM: 0.5480


E44, S43600/50000, LR: 0.000023:  60%|██████    | 600/1000 [12:05<05:45,  1.16it/s, AvgLoss=1.0688, JEPA=0.8457, LM=0.2425, LastJEPA=0.8156]


Step 43600 Eval:
  Train Total: 1.0525, JEPA: 0.8509, LM: 0.2191
  Val Total:   1.3626, JEPA: 0.8487, LM: 0.5586


E44, S43800/50000, LR: 0.000022:  80%|████████  | 800/1000 [16:07<02:51,  1.17it/s, AvgLoss=1.0682, JEPA=0.8458, LM=0.2418, LastJEPA=0.9056]


Step 43800 Eval:
  Train Total: 1.0486, JEPA: 0.8467, LM: 0.2195
  Val Total:   1.3486, JEPA: 0.8477, LM: 0.5445


E44, S44000/50000, LR: 0.000021: 100%|██████████| 1000/1000 [20:09<00:00,  1.21s/it, AvgLoss=1.0681, JEPA=0.8456, LM=0.2418, LastJEPA=0.9313]



Generating sample text at end of epoch...
Sample: �Consider this math problem. Think step by step and provide your reasoning between <think> </think> tags, then give your final answer between <answer> </answer> tags.

Problem: Problem: A bakery produces cakes for $10 each. It costs them $5 in ingredients per cake, and they have a fixed overhead of $200 per day. How many cakes do they need to sell each day to make a daily profit of $100?

<think>The fixed ingredients produce is $10 x 12 = $<<10*12=120>>120.
The fixed ingredients produce is $10 x 120 = $<<10*120=1200>>1200.
Therefore, the fixed ingredients produced $1000 - $100= $<<1000-100=900>>900.
#### 900</think>

<answer>900</answer>
--------------------
Checkpoint saved at end of epoch 44 to t_jepa_mtl_decoder_rope_bs1024_checkpoint.pt.

--- Epoch 45/50 ---


Epoch 45:   0%|          | 0/1000 [00:00<?, ?it/s]


Step 44000 Eval:
  Train Total: 1.0630, JEPA: 0.8527, LM: 0.2286
  Val Total:   1.3441, JEPA: 0.8424, LM: 0.5452


E45, S44200/50000, LR: 0.000020:  20%|██        | 200/1000 [04:02<11:33,  1.15it/s, AvgLoss=1.0659, JEPA=0.8498, LM=0.2350, LastJEPA=0.8821]


Step 44200 Eval:
  Train Total: 1.0485, JEPA: 0.8507, LM: 0.2150
  Val Total:   1.3698, JEPA: 0.8446, LM: 0.5709


E45, S44400/50000, LR: 0.000020:  40%|████      | 400/1000 [08:04<08:36,  1.16it/s, AvgLoss=1.0668, JEPA=0.8512, LM=0.2343, LastJEPA=0.8749]


Step 44400 Eval:
  Train Total: 1.0566, JEPA: 0.8555, LM: 0.2185
  Val Total:   1.4002, JEPA: 0.8675, LM: 0.5790


E45, S44600/50000, LR: 0.000019:  60%|██████    | 600/1000 [12:06<05:41,  1.17it/s, AvgLoss=1.0668, JEPA=0.8499, LM=0.2358, LastJEPA=0.8400]


Step 44600 Eval:
  Train Total: 1.0381, JEPA: 0.8433, LM: 0.2118
  Val Total:   1.3929, JEPA: 0.8561, LM: 0.5835


E45, S44800/50000, LR: 0.000018:  80%|████████  | 800/1000 [16:08<02:52,  1.16it/s, AvgLoss=1.0661, JEPA=0.8493, LM=0.2357, LastJEPA=0.8827]


Step 44800 Eval:
  Train Total: 1.0545, JEPA: 0.8461, LM: 0.2265
  Val Total:   1.3863, JEPA: 0.8485, LM: 0.5846


E45, S45000/50000, LR: 0.000018: 100%|██████████| 1000/1000 [20:10<00:00,  1.21s/it, AvgLoss=1.0655, JEPA=0.8488, LM=0.2355, LastJEPA=0.9246]



Generating sample text at end of epoch...
Sample: �Consider this math problem. Think step by step and provide your reasoning between <think> </think> tags, then give your final answer between <answer> </answer> tags.

Problem: Problem: A bakery produces cakes for $10 each. It costs them $5 in ingredients per cake, and they have a fixed overhead of $200 per day. How many cakes do they need to sell each day to make a daily profit of $100?

<think>The fixed ingredients produced in ingredients for $10 each ingredient is 10*10 = $<<10*10=100>>100
The fixed amount to the day profit in ingredients is $5*100 = $<<5*100=500>>500
The new needs 100/200 = <<100/200=5>>5 days to sell the amount they need to s
--------------------
Checkpoint saved at end of epoch 45 to t_jepa_mtl_decoder_rope_bs1024_checkpoint.pt.

--- Epoch 46/50 ---


Epoch 46:   0%|          | 0/1000 [00:00<?, ?it/s]


Step 45000 Eval:
  Train Total: 1.0465, JEPA: 0.8427, LM: 0.2215
  Val Total:   1.3687, JEPA: 0.8516, LM: 0.5620


E46, S45200/50000, LR: 0.000017:  20%|██        | 200/1000 [04:02<11:30,  1.16it/s, AvgLoss=1.0702, JEPA=0.8531, LM=0.2360, LastJEPA=0.8117]


Step 45200 Eval:
  Train Total: 1.0406, JEPA: 0.8463, LM: 0.2112
  Val Total:   1.3861, JEPA: 0.8609, LM: 0.5710


E46, S45400/50000, LR: 0.000017:  40%|████      | 400/1000 [08:03<08:34,  1.17it/s, AvgLoss=1.0657, JEPA=0.8523, LM=0.2320, LastJEPA=0.7619]


Step 45400 Eval:
  Train Total: 1.0525, JEPA: 0.8465, LM: 0.2238
  Val Total:   1.3824, JEPA: 0.8543, LM: 0.5740


E46, S45600/50000, LR: 0.000016:  60%|██████    | 600/1000 [12:05<05:43,  1.17it/s, AvgLoss=1.0648, JEPA=0.8510, LM=0.2324, LastJEPA=0.8332]


Step 45600 Eval:
  Train Total: 1.0295, JEPA: 0.8391, LM: 0.2069
  Val Total:   1.3897, JEPA: 0.8470, LM: 0.5899


E46, S45800/50000, LR: 0.000015:  80%|████████  | 800/1000 [16:07<02:51,  1.16it/s, AvgLoss=1.0649, JEPA=0.8522, LM=0.2312, LastJEPA=0.8531]


Step 45800 Eval:
  Train Total: 1.0461, JEPA: 0.8491, LM: 0.2142
  Val Total:   1.3893, JEPA: 0.8475, LM: 0.5889


E46, S46000/50000, LR: 0.000015: 100%|██████████| 1000/1000 [20:09<00:00,  1.21s/it, AvgLoss=1.0639, JEPA=0.8517, LM=0.2307, LastJEPA=0.7758]



Generating sample text at end of epoch...
Sample: �Consider this math problem. Think step by step and provide your reasoning between <think> </think> tags, then give your final answer between <answer> </answer> tags.

Problem: Problem: A bakery produces cakes for $10 each. It costs them $5 in ingredients per cake, and they have a fixed overhead of $200 per day. How many cakes do they need to sell each day to make a daily profit of $100?

<think>The fixed and produces $200 x 10/100 = $<<200*10/100=20>>20 a day.
So, they need to sell $5 in ingredients per cake for $10 x 20 = $<<10*20=200>>200.
Therefore, they need to sell a total of $200 + $200 = $<<200+200=400>>400.
#### 400</think>

<answer>400</
--------------------
Checkpoint saved at end of epoch 46 to t_jepa_mtl_decoder_rope_bs1024_checkpoint.pt.

--- Epoch 47/50 ---


Epoch 47:   0%|          | 0/1000 [00:00<?, ?it/s]


Step 46000 Eval:
  Train Total: 1.0476, JEPA: 0.8606, LM: 0.2033
  Val Total:   1.3646, JEPA: 0.8525, LM: 0.5566


E47, S46200/50000, LR: 0.000014:  20%|██        | 200/1000 [04:02<11:29,  1.16it/s, AvgLoss=1.0589, JEPA=0.8521, LM=0.2247, LastJEPA=0.8357]


Step 46200 Eval:
  Train Total: 1.0635, JEPA: 0.8644, LM: 0.2164
  Val Total:   1.3697, JEPA: 0.8424, LM: 0.5732


E47, S46400/50000, LR: 0.000014:  40%|████      | 400/1000 [08:04<08:38,  1.16it/s, AvgLoss=1.0557, JEPA=0.8490, LM=0.2247, LastJEPA=0.9156]


Step 46400 Eval:
  Train Total: 1.0433, JEPA: 0.8557, LM: 0.2040
  Val Total:   1.3779, JEPA: 0.8518, LM: 0.5718


E47, S46600/50000, LR: 0.000014:  60%|██████    | 600/1000 [12:05<05:42,  1.17it/s, AvgLoss=1.0563, JEPA=0.8503, LM=0.2239, LastJEPA=0.7871]


Step 46600 Eval:
  Train Total: 1.0334, JEPA: 0.8474, LM: 0.2021
  Val Total:   1.3780, JEPA: 0.8538, LM: 0.5698


E47, S46800/50000, LR: 0.000013:  80%|████████  | 800/1000 [16:07<02:52,  1.16it/s, AvgLoss=1.0577, JEPA=0.8505, LM=0.2252, LastJEPA=0.8460]


Step 46800 Eval:
  Train Total: 1.0417, JEPA: 0.8498, LM: 0.2086
  Val Total:   1.3636, JEPA: 0.8388, LM: 0.5704


E47, S47000/50000, LR: 0.000013: 100%|██████████| 1000/1000 [20:09<00:00,  1.21s/it, AvgLoss=1.0578, JEPA=0.8506, LM=0.2252, LastJEPA=0.8780]



Generating sample text at end of epoch...
Sample: �Consider this math problem. Think step by step and provide your reasoning between <think> </think> tags, then give your final answer between <answer> </answer> tags.

Problem: Problem: A bakery produces cakes for $10 each. It costs them $5 in ingredients per cake, and they have a fixed overhead of $200 per day. How many cakes do they need to sell each day to make a daily profit of $100?

<think>The number of cakes they have earned is $10 x 100 = $<<10*100=1000>>1000.
To sell hear off cakes cost $1000 x 10/100 = $<<1000*10/100=100>>100.
Therefore, the total number of cakes they need to sell it cost $1000 + $100 = $<<1000+100=1300.
#### 1300</think
--------------------
Checkpoint saved at end of epoch 47 to t_jepa_mtl_decoder_rope_bs1024_checkpoint.pt.

--- Epoch 48/50 ---


Epoch 48:   0%|          | 0/1000 [00:00<?, ?it/s]


Step 47000 Eval:
  Train Total: 1.0312, JEPA: 0.8453, LM: 0.2021
  Val Total:   1.4143, JEPA: 0.8684, LM: 0.5934


E48, S47200/50000, LR: 0.000012:  20%|██        | 200/1000 [04:02<11:33,  1.15it/s, AvgLoss=1.0635, JEPA=0.8504, LM=0.2317, LastJEPA=0.9261]


Step 47200 Eval:
  Train Total: 1.0465, JEPA: 0.8570, LM: 0.2059
  Val Total:   1.4166, JEPA: 0.8627, LM: 0.6021


E48, S47400/50000, LR: 0.000012:  40%|████      | 400/1000 [08:03<08:36,  1.16it/s, AvgLoss=1.0602, JEPA=0.8502, LM=0.2282, LastJEPA=0.8476]


Step 47400 Eval:
  Train Total: 1.0448, JEPA: 0.8582, LM: 0.2028
  Val Total:   1.3985, JEPA: 0.8589, LM: 0.5865


E48, S47600/50000, LR: 0.000012:  60%|██████    | 600/1000 [12:05<05:42,  1.17it/s, AvgLoss=1.0579, JEPA=0.8508, LM=0.2250, LastJEPA=0.8606]


Step 47600 Eval:
  Train Total: 1.0471, JEPA: 0.8551, LM: 0.2086
  Val Total:   1.3789, JEPA: 0.8510, LM: 0.5739


E48, S47800/50000, LR: 0.000012:  80%|████████  | 800/1000 [16:07<02:51,  1.17it/s, AvgLoss=1.0577, JEPA=0.8514, LM=0.2243, LastJEPA=0.8613]


Step 47800 Eval:
  Train Total: 1.0444, JEPA: 0.8566, LM: 0.2041
  Val Total:   1.3934, JEPA: 0.8520, LM: 0.5885


E48, S48000/50000, LR: 0.000011: 100%|██████████| 1000/1000 [20:09<00:00,  1.21s/it, AvgLoss=1.0565, JEPA=0.8507, LM=0.2237, LastJEPA=0.9021]



Generating sample text at end of epoch...
Sample: �Consider this math problem. Think step by step and provide your reasoning between <think> </think> tags, then give your final answer between <answer> </answer> tags.

Problem: Problem: A bakery produces cakes for $10 each. It costs them $5 in ingredients per cake, and they have a fixed overhead of $200 per day. How many cakes do they need to sell each day to make a daily profit of $100?

<think>The number of cakes that they have and they cost $200 because 100 / 5 = <<100/5=20>>20
They need $5 left to sell because 5 x 10 = <<5*10=50>>50
They need $200 because 200 / 50 = <<200/50=20>>20
#### 20</think>

<answer>20</answer>
--------------------
Checkpoint saved at end of epoch 48 to t_jepa_mtl_decoder_rope_bs1024_checkpoint.pt.

--- Epoch 49/50 ---


Epoch 49:   0%|          | 0/1000 [00:00<?, ?it/s]


Step 48000 Eval:
  Train Total: 1.0375, JEPA: 0.8544, LM: 0.1990
  Val Total:   1.3531, JEPA: 0.8329, LM: 0.5655


E49, S48200/50000, LR: 0.000011:  20%|██        | 200/1000 [04:02<11:24,  1.17it/s, AvgLoss=1.0563, JEPA=0.8539, LM=0.2199, LastJEPA=0.7260]


Step 48200 Eval:
  Train Total: 1.0373, JEPA: 0.8556, LM: 0.1975
  Val Total:   1.3903, JEPA: 0.8517, LM: 0.5855


E49, S48400/50000, LR: 0.000011:  40%|████      | 400/1000 [08:04<08:35,  1.16it/s, AvgLoss=1.0552, JEPA=0.8511, LM=0.2219, LastJEPA=0.7851]


Step 48400 Eval:
  Train Total: 1.0318, JEPA: 0.8528, LM: 0.1946
  Val Total:   1.4218, JEPA: 0.8574, LM: 0.6135


E49, S48600/50000, LR: 0.000011:  60%|██████    | 600/1000 [12:06<05:47,  1.15it/s, AvgLoss=1.0565, JEPA=0.8521, LM=0.2222, LastJEPA=0.9391]


Step 48600 Eval:
  Train Total: 1.0352, JEPA: 0.8559, LM: 0.1950
  Val Total:   1.4152, JEPA: 0.8634, LM: 0.5998


E49, S48800/50000, LR: 0.000010:  80%|████████  | 800/1000 [16:07<02:51,  1.17it/s, AvgLoss=1.0553, JEPA=0.8532, LM=0.2196, LastJEPA=0.8347]


Step 48800 Eval:
  Train Total: 1.0403, JEPA: 0.8541, LM: 0.2024
  Val Total:   1.4002, JEPA: 0.8452, LM: 0.6033


E49, S49000/50000, LR: 0.000010: 100%|██████████| 1000/1000 [20:09<00:00,  1.21s/it, AvgLoss=1.0555, JEPA=0.8536, LM=0.2195, LastJEPA=0.8818]



Generating sample text at end of epoch...
Sample: �Consider this math problem. Think step by step and provide your reasoning between <think> </think> tags, then give your final answer between <answer> </answer> tags.

Problem: Problem: A bakery produces cakes for $10 each. It costs them $5 in ingredients per cake, and they have a fixed overhead of $200 per day. How many cakes do they need to sell each day to make a daily profit of $100?

<think>The number of cakes that the fixed overhead was $10/3 = $<<10/3=0.50>>0.50.
To sell the overhead and cost $200 for a total of $10 for a total of $0.50 + $0.50 for a certain cost $0.50 + $0.50 for a total of $100 + 0.50 = $<<100+0.50=2.00>>2.00.
Therefore, 
--------------------
Checkpoint saved at end of epoch 49 to t_jepa_mtl_decoder_rope_bs1024_checkpoint.pt.

--- Epoch 50/50 ---


Epoch 50:   0%|          | 0/1000 [00:00<?, ?it/s]


Step 49000 Eval:
  Train Total: 1.0404, JEPA: 0.8565, LM: 0.1999
  Val Total:   1.3893, JEPA: 0.8496, LM: 0.5866


E50, S49200/50000, LR: 0.000010:  20%|██        | 200/1000 [04:02<11:24,  1.17it/s, AvgLoss=1.0513, JEPA=0.8533, LM=0.2153, LastJEPA=0.6665]


Step 49200 Eval:
  Train Total: 1.0227, JEPA: 0.8475, LM: 0.1905
  Val Total:   1.3964, JEPA: 0.8557, LM: 0.5877


E50, S49400/50000, LR: 0.000010:  40%|████      | 400/1000 [08:03<08:34,  1.17it/s, AvgLoss=1.0481, JEPA=0.8514, LM=0.2137, LastJEPA=0.7557]


Step 49400 Eval:
  Train Total: 1.0399, JEPA: 0.8573, LM: 0.1985
  Val Total:   1.3951, JEPA: 0.8590, LM: 0.5827


E50, S49600/50000, LR: 0.000010:  60%|██████    | 600/1000 [12:05<05:46,  1.15it/s, AvgLoss=1.0513, JEPA=0.8541, LM=0.2144, LastJEPA=0.7783]


Step 49600 Eval:
  Train Total: 1.0471, JEPA: 0.8590, LM: 0.2045
  Val Total:   1.3947, JEPA: 0.8572, LM: 0.5843


E50, S49800/50000, LR: 0.000010:  80%|████████  | 800/1000 [16:07<02:51,  1.17it/s, AvgLoss=1.0518, JEPA=0.8540, LM=0.2150, LastJEPA=0.9536]


Step 49800 Eval:
  Train Total: 1.0269, JEPA: 0.8455, LM: 0.1971
  Val Total:   1.4198, JEPA: 0.8656, LM: 0.6025


E50, S50000/50000, LR: 0.000010: 100%|██████████| 1000/1000 [20:08<00:00,  1.21s/it, AvgLoss=1.0524, JEPA=0.8538, LM=0.2158, LastJEPA=0.7730]



Generating sample text at end of epoch...
Sample: �Consider this math problem. Think step by step and provide your reasoning between <think> </think> tags, then give your final answer between <answer> </answer> tags.

Problem: Problem: A bakery produces cakes for $10 each. It costs them $5 in ingredients per cake, and they have a fixed overhead of $200 per day. How many cakes do they need to sell each day to make a daily profit of $100?

<think>The fixed overhead will need $5 * 10 = $<<5*10=50>>50 per day.
The fixed overhead will need 50 * 10 = $<<50*10=500>>500 for the cakes.
Therefore, they need to sell a fixed overhead will need 500 / 20 = <<500/20=25>>25 cakes.
#### 25</think>

<answer>25</an
--------------------
Checkpoint saved at end of epoch 50 to t_jepa_mtl_decoder_rope_bs1024_checkpoint.pt.
Training complete!


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import copy
import os
import time
from collections import Counter
from typing import Optional, Tuple, List, Dict, Any

# ==========================================
# Configuration
# ==========================================
MODEL_PATH = "t_jepa_mtl_decoder_rope_bs1024_checkpoint.pt" # Or your specific checkpoint path
PROMPT_TEXT = "what is five plus two?"
NUM_VOTES = 8 # K: Number of samples per token
MAX_NEW_TOKENS = 2048 # L: Maximum generation length
TEMPERATURE = 0.7 # Sampling temperature
TOP_P = 0.9 # Top-p nucleus sampling
SYSTEM_PROMPT = """Consider this math problem. Think step by step and provide your reasoning between <think> </think> tags, then give your final answer between <answer> </answer> tags."""
THINK_TAG_START = "<think>" # Used to start generation after the prompt

# ==========================================
# Model Definitions (Copied from training script)
# ==========================================

# --- 1) Hyperparameters Default ---
# These will be OVERRIDDEN by checkpoint if available
def get_default_hyperparams():
    # Provide sensible defaults in case checkpoint doesn't contain hyperparams
    return {
        'vocab_size': 256, 'embed_dim': 512, 'n_heads': 8, 'n_layers': 12,
        'block_size': 1024, 'ema_decay': 0.999, 'lm_loss_weight': 0.1,
        'bos_token': 254, 'eos_token': 255, 'pad_token': 0,
        'top_p': 0.8, # Default for generation if not overridden
        # JEPA params (not strictly needed for inference, but part of model structure)
        'context_span_ratio': 0.6, 'target_span_ratio': 0.2,
        'num_target_spans': 8, 'min_span_length': 32,
        # Tags (used in helper functions)
        'thinking_tag': "<think>", 'thinking_end_tag': "</think>",
        'answer_tag': "<answer>", 'answer_end_tag': "</answer>",
        'system_prompt': """Consider this math problem. Think step by step and provide your reasoning between <think> </think> tags, then give your final answer between <answer> </answer> tags."""
    }

# --- 2) RoPE ---
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=2048, base=10000, device=None):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        t = torch.arange(self.max_seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos(), persistent=False)
        self.register_buffer("sin_cached", emb.sin(), persistent=False)

    def forward(self, seq_len: int):
        if seq_len > self.max_seq_len:
             print(f"Warning: RoPE seq_len {seq_len} > max_seq_len {self.max_seq_len}. Clamping.")
             seq_len = self.max_seq_len
            # raise ValueError(f"RoPE sequence length {seq_len} exceeds precomputed max {self.max_seq_len}") # Clamp instead

        return (
            self.cos_cached[:seq_len, ...],
            self.sin_cached[:seq_len, ...],
        )

def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    cos = cos.unsqueeze(0).unsqueeze(0)
    sin = sin.unsqueeze(0).unsqueeze(0)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

# --- 3) Attention ---
class ImprovedAttention(nn.Module):
    def __init__(self, embed_dim, n_heads, is_self_attention=True, use_rope=True, max_seq_len=2048):
        super().__init__()
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads
        assert self.head_dim * n_heads == self.embed_dim, "embed_dim must be divisible by n_heads"
        self.is_self_attention = is_self_attention
        self.use_rope = use_rope

        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        if self.use_rope and self.is_self_attention:
            self.rotary_emb = RotaryEmbedding(self.head_dim, max_seq_len=max_seq_len)
        else:
            self.rotary_emb = None

        self.attn_dropout = nn.Dropout(0.1)
        self.out_dropout = nn.Dropout(0.1)
        self.register_buffer("causal_mask_cache", None, persistent=False)

    def _get_causal_mask(self, T, device):
        if self.causal_mask_cache is None or self.causal_mask_cache.shape[-1] < T:
            mask = torch.triu(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=1)
            self.causal_mask_cache = mask
        return self.causal_mask_cache[:T, :T].to(device=device)

    def forward(self, x, attn_mask=None, key_value_states=None, is_causal=False):
        B, T, C = x.size()
        is_cross_attn = key_value_states is not None
        use_rope_for_this_pass = self.use_rope and self.is_self_attention and not is_cross_attn and self.rotary_emb is not None

        q = self.q_proj(x)
        if is_cross_attn:
            T_k = key_value_states.size(1)
            k = self.k_proj(key_value_states)
            v = self.v_proj(key_value_states)
            is_causal = False
        else:
            T_k = T
            k = self.k_proj(x)
            v = self.v_proj(x)

        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2)

        if use_rope_for_this_pass:
            cos, sin = self.rotary_emb(T)
            q, k = apply_rotary_pos_emb(q, k, cos, sin)
            scaling_factor = 1.0
        else:
            scaling_factor = 1.0 / math.sqrt(self.head_dim)

        scores = torch.matmul(q, k.transpose(-2, -1)) * scaling_factor

        final_mask_bool = None
        if attn_mask is not None:
            if attn_mask.dim() == 2: padding_mask_bool = ~attn_mask.bool().unsqueeze(1).unsqueeze(2)
            elif attn_mask.dim() == 4: padding_mask_bool = ~attn_mask.bool()
            else: raise ValueError(f"Unsupported attn_mask dimension: {attn_mask.dim()}")
            final_mask_bool = padding_mask_bool

        if self.is_self_attention and is_causal:
            causal_mask_bool = self._get_causal_mask(T, x.device).unsqueeze(0).unsqueeze(0)
            if final_mask_bool is not None: final_mask_bool = final_mask_bool | causal_mask_bool
            else: final_mask_bool = causal_mask_bool

        if final_mask_bool is not None:
             scores = scores.masked_fill(final_mask_bool, torch.finfo(scores.dtype).min)

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)
        attn_output = torch.matmul(attn_weights, v)
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_dropout(self.out_proj(attn_output))

# --- 4) Decoder Block ---
class DecoderBlock(nn.Module):
    def __init__(self, embed_dim, n_heads, dropout=0.1, max_seq_len=2048):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.self_attention = ImprovedAttention(embed_dim, n_heads, is_self_attention=True, use_rope=True, max_seq_len=max_seq_len)
        self.ln2 = nn.LayerNorm(embed_dim)
        hidden_dim = 4 * embed_dim
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim), nn.Dropout(dropout)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attention_mask=None, is_causal=True):
        residual = x
        x_norm = self.ln1(x)
        attn_output = self.self_attention(x_norm, attn_mask=attention_mask, is_causal=is_causal)
        x = residual + self.dropout(attn_output)
        residual = x
        x_norm = self.ln2(x)
        ff_output = self.feed_forward(x_norm)
        x = residual + self.dropout(ff_output)
        return x

# --- 5) JEPA Predictor Block (Needed for model structure, not used in generation logic) ---
class JEPAPredictorBlock(nn.Module):
    def __init__(self, embed_dim, n_heads, dropout=0.1, max_seq_len=2048):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.self_attention = ImprovedAttention(embed_dim, n_heads, is_self_attention=True, use_rope=True, max_seq_len=max_seq_len)
        self.ln_cross_attn_query = nn.LayerNorm(embed_dim)
        self.ln_cross_attn_kv = nn.LayerNorm(embed_dim)
        self.cross_attention = ImprovedAttention(embed_dim, n_heads, is_self_attention=False, use_rope=False, max_seq_len=max_seq_len)
        self.ln3 = nn.LayerNorm(embed_dim)
        hidden_dim = 4 * embed_dim
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim), nn.Dropout(dropout)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, decoder_output, self_attention_mask=None, cross_attention_mask=None):
        # --- Self-Attention ---
        residual = x
        attn_output = self.self_attention(self.ln1(x), attn_mask=self_attention_mask, is_causal=True)
        x = residual + self.dropout(attn_output)
        # --- Cross-Attention ---
        residual = x
        cross_attn_output = self.cross_attention(
            self.ln_cross_attn_query(x),
            attn_mask=cross_attention_mask,
            key_value_states=self.ln_cross_attn_kv(decoder_output)
        )
        x = residual + self.dropout(cross_attn_output)
        # --- Feed-Forward ---
        residual = x
        ff_output = self.feed_forward(self.ln3(x))
        x = residual + self.dropout(ff_output)
        return x

# --- 6) Backbone Decoder ---
class BackboneDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, n_heads, n_layers, block_size):
        super().__init__()
        self.block_size = block_size
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.dropout = nn.Dropout(0.1)
        self.blocks = nn.ModuleList([
            DecoderBlock(embed_dim, n_heads, dropout=0.1, max_seq_len=block_size)
            for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(embed_dim)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02); torch.nn.init.zeros_(module.bias) if module.bias is not None else None
        elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm): torch.nn.init.zeros_(module.bias); torch.nn.init.ones_(module.weight)

    def forward(self, x, attention_mask=None, is_causal=True):
        B, T = x.size()
        # assert T <= self.block_size, f"Sequence length {T} exceeds block size {self.block_size}" # Allow longer during generation cropping
        token_emb = self.token_embedding(x)
        x = self.dropout(token_emb)
        for block in self.blocks: x = block(x, attention_mask=attention_mask, is_causal=is_causal)
        x = self.ln_f(x)
        return x

# --- 7) JEPA Predictor (Needed for model structure, not used in generation logic) ---
class JEPAPredictor(nn.Module):
    def __init__(self, embed_dim, n_heads, n_layers, block_size):
        super().__init__()
        self.block_size = block_size
        predictor_layers = n_layers
        self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        torch.nn.init.normal_(self.mask_token, mean=0.0, std=0.02)
        self.blocks = nn.ModuleList([
            JEPAPredictorBlock(embed_dim, n_heads, max_seq_len=block_size)
            for _ in range(predictor_layers)
        ])
        self.ln_f = nn.LayerNorm(embed_dim)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02); torch.nn.init.zeros_(module.bias) if module.bias is not None else None
        elif isinstance(module, nn.LayerNorm): torch.nn.init.zeros_(module.bias); torch.nn.init.ones_(module.weight)

    def forward(self, decoder_output_causal, target_spans_indices, context_mask, attention_mask):
        # This forward function is complex and primarily for training JEPA loss.
        # It's not directly called during standard autoregressive generation.
        # We include the structure for model loading compatibility.
        pass # Not needed for SR-ABI inference logic

# --- 8) Target Encoder (Needed for model structure, not used in generation logic) ---
class TargetEncoder(nn.Module):
    def __init__(self, backbone_decoder, ema_decay=0.999):
        super().__init__()
        self.encoder = copy.deepcopy(backbone_decoder)
        self.ema_decay = ema_decay
        for param in self.encoder.parameters(): param.requires_grad = False

    @torch.no_grad()
    def update_ema(self, backbone_decoder, decay_rate=None): pass # Not needed for inference

    @torch.no_grad()
    def forward(self, x, attention_mask=None): pass # Not needed for inference

# --- 9) Complete T-JEPA Model ---
class TJEPAModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, n_heads, n_layers, block_size, ema_decay=0.999, lm_loss_weight=0.1, pad_token_id=0):
        super().__init__()
        self.pad_token_id = pad_token_id
        self.lm_loss_weight = lm_loss_weight
        self.block_size = block_size

        self.decoder_backbone = BackboneDecoder(vocab_size, embed_dim, n_heads, n_layers, block_size)
        # Predictor and TargetEncoder needed for state_dict loading compatibility
        self.predictor = JEPAPredictor(embed_dim, n_heads, n_layers, block_size)
        self.target_encoder = TargetEncoder(self.decoder_backbone, ema_decay)

        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
        self.decoder_backbone.token_embedding.weight = self.lm_head.weight # Weight tying

    # Forward methods related to training (JEPA loss) are omitted for inference clarity
    # We only need the backbone and lm_head for generation.

# ==========================================
# Helper Functions for Tokenization
# ==========================================
def _encode(text: str, bos_token: int) -> List[int]:
    """Encodes text to byte tokens, adding BOS."""
    return [bos_token] + [b for b in text.encode('utf-8', errors='replace')]

def _decode(tokens: List[int], bos_token: int, eos_token: int, pad_token: int) -> str:
    """Decodes byte tokens to text, removing special tokens."""
    try:
        # Find EOS if present and truncate
        eos_pos = tokens.index(eos_token) if eos_token in tokens else -1
        if eos_pos != -1:
            tokens = tokens[:eos_pos]

        # Filter out BOS and PAD, then decode
        filtered_bytes = bytes([tok for tok in tokens if tok != bos_token and tok != pad_token])
        return filtered_bytes.decode('utf-8', errors='replace')
    except Exception as e:
        print(f"Decoding error: {e}")
        return f"[Decoding Error] Raw bytes: {bytes(tokens)}"

# ==========================================
# SR-ABI Inference Function
# ==========================================
@torch.no_grad()
def generate_sr_abi(
    model: TJEPAModel,
    prompt_text: str,
    num_votes: int,           # K
    max_new_tokens: int,      # L
    temperature: float,       # Part of Theta
    top_p: float,             # Part of Theta
    hyperparams: Dict[str, Any], # Contains BOS, EOS, PAD IDs etc.
    device: str
) -> str:
    """
    Generates text using State-Resetting Agreement-Based Inference (SR-ABI).
    """
    model.eval()
    bos_token = hyperparams['bos_token']
    eos_token = hyperparams['eos_token']
    pad_token = hyperparams['pad_token']
    block_size = hyperparams['block_size']

    # --- Initialization ---
    # a. Tokenize prompt (Including system prompt and starting tag)
    full_prompt = f"{hyperparams.get('system_prompt', '')}\n\nProblem: {prompt_text}\n\n{THINK_TAG_START}"
    prompt_tokens = _encode(full_prompt, bos_token)
    # b. Initialize current full sequence S
    S_list = prompt_tokens[:] # List of token IDs
    # c. Initialize generated sequence G
    G_list = [] # List of token IDs (only the generated part)

    print(f"\n--- Starting SR-ABI Generation (K={num_votes}) ---")
    print(f"Prompt:\n{full_prompt}", end="", flush=True)

    # --- Token Generation Loop ---
    for i in range(max_new_tokens):
        # --- a. Vote Collection ---
        votes = Counter()
        S_tensor = torch.tensor([S_list], dtype=torch.long, device=device) # Add batch dim [1, T]

        # Crop context if it exceeds block size for the forward pass
        S_cond = S_tensor if S_tensor.size(1) <= block_size else S_tensor[:, -block_size:]
        seq_len = S_cond.size(1)

        # Create attention mask for the current sequence
        attention_mask = (S_cond != pad_token).float().to(device) # [1, T_cond]

        for j in range(num_votes):
            # --- i.1 & i.2: Reset State & Re-evaluate Context ---
            # This is achieved by running the forward pass on the *full* current
            # sequence S_cond. The model calculates state (KV cache) from scratch.
            decoder_output = model.decoder_backbone(
                S_cond,
                attention_mask=attention_mask,
                is_causal=True # Standard causal generation
            ) # [1, T_cond, C]

            # Get logits for the *next* token prediction (using the last token's embedding)
            logits = model.lm_head(decoder_output[:, -1, :])  # [1, C] -> [1, V]

            # --- i.3: Sample Candidate ---
            # Apply temperature
            if temperature > 0 and temperature != 1.0:
                 logits = logits / temperature

            # Apply top-p (nucleus) sampling
            if top_p > 0.0 and top_p < 1.0:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                logits[indices_to_remove] = float('-inf')

            # Sample from the filtered distribution
            probs = F.softmax(logits, dim=-1)
            candidate_token_tensor = torch.multinomial(probs, num_samples=1) # [1, 1]
            candidate_token = candidate_token_tensor.item()

            # --- i.4: Record Vote ---
            votes[candidate_token] += 1

        # --- b. Agreement (Majority Vote) ---
        if not votes:
            print("\nWarning: No votes collected, stopping generation.")
            break # Should not happen if num_votes >= 1

        # Get the token with the most votes. Tie-breaking: implicitly handled by most_common
        # (returns items in order of first appearance among ties if counts are equal)
        winning_token, vote_count = votes.most_common(1)[0]
        # Optional: print vote distribution for debugging
        # print(f"\nVotes (Step {i+1}): {votes}")
        # print(f"Winner: {winning_token} ({vote_count}/{num_votes})")

        # --- c. Check for Termination ---
        if winning_token == eos_token:
            print("<EOS>", flush=True)
            break

        # --- d. Append Token ---
        G_list.append(winning_token)
        S_list.append(winning_token)

        # --- Print the winning token ---
        # Attempt to decode the single winning token for streaming output
        try:
            print(bytes([winning_token]).decode('utf-8', errors='replace'), end="", flush=True)
        except UnicodeDecodeError:
            print("<?>", end="", flush=True) # Placeholder for partial UTF-8 chars

        # Optional: Small delay for readability
        # time.sleep(0.01)

    print("\n--- Generation Complete ---")

    # --- Finalization ---
    # a. Detokenize the generated sequence G
    output_text = _decode(G_list, bos_token, eos_token, pad_token)

    return full_prompt + output_text # Return prompt + generated text


# ==========================================
# Model Loading Function
# ==========================================
def load_model_for_inference(model_path: str, device: str) -> Tuple[Optional[TJEPAModel], Optional[Dict[str, Any]]]:
    """Loads the TJEPAModel from a checkpoint for inference."""
    if not os.path.exists(model_path):
        print(f"Error: Model checkpoint not found at {model_path}")
        return None, None

    print(f"Loading model checkpoint from {model_path}...")
    try:
        checkpoint = torch.load(model_path, map_location=device)
    except Exception as e:
        print(f"Error loading checkpoint file: {e}")
        return None, None

    # Load hyperparams from checkpoint or use defaults
    hyperparams_loaded = checkpoint.get('hyperparams', None)
    if hyperparams_loaded:
        print("Using hyperparameters loaded from checkpoint.")
        # Merge with defaults to ensure all needed keys exist
        hyperparams = get_default_hyperparams()
        hyperparams.update(hyperparams_loaded) # Loaded values override defaults
    else:
        print("Warning: Hyperparameters not found in checkpoint, using default values.")
        hyperparams = get_default_hyperparams()

    print(f"Effective Hyperparameters: {hyperparams}")

    # Create model instance based on loaded/effective hyperparams
    try:
        model = TJEPAModel(
            vocab_size=hyperparams['vocab_size'], embed_dim=hyperparams['embed_dim'],
            n_heads=hyperparams['n_heads'], n_layers=hyperparams['n_layers'],
            block_size=hyperparams['block_size'], ema_decay=hyperparams['ema_decay'],
            lm_loss_weight=hyperparams['lm_loss_weight'], pad_token_id=hyperparams['pad_token']
        ).to(device)
    except KeyError as e:
         print(f"Error: Missing hyperparameter '{e}' needed to build the model structure.")
         return None, None
    except Exception as e:
        print(f"Error creating model instance: {e}")
        return None, None

    # Load model state dictionary
    try:
        model_state = checkpoint['model_state']
        # Flexible loading
        current_model_dict = model.state_dict()
        processed_state_dict = {}
        warned_keys = set()
        loaded_keys_count = 0
        for k, v in model_state.items():
            new_k = k # Handle potential renames if needed in the future
            # Example rename: if k.startswith("old_prefix."): new_k = k.replace("old_prefix.", "new_prefix.", 1)
            if new_k in current_model_dict:
                if v.shape == current_model_dict[new_k].shape:
                    processed_state_dict[new_k] = v
                    loaded_keys_count += 1
                else:
                    if new_k not in warned_keys:
                        print(f"Warning: Shape mismatch for key '{new_k}'. Checkpoint: {v.shape}, Model: {current_model_dict[new_k].shape}. Skipping.")
                        warned_keys.add(new_k)
            # else:
            #     if k not in warned_keys and new_k not in warned_keys:
            #         print(f"Warning: Key '{k}' (mapped to '{new_k}') not found in current model. Skipping.")
            #         warned_keys.add(k); warned_keys.add(new_k)

        missing_keys, unexpected_keys = model.load_state_dict(processed_state_dict, strict=False)
        if missing_keys: print(f"  Info: Missing keys in final state_dict load: {missing_keys}")
        if unexpected_keys: print(f"  Info: Unexpected keys found in checkpoint but not used: {unexpected_keys}")
        print(f"Model state loaded successfully ({loaded_keys_count} tensors loaded).")
        loaded_epoch = checkpoint.get('epoch', -1); loaded_step = checkpoint.get('current_step', -1)
        val_loss = checkpoint.get('val_loss', 'N/A')
        val_loss_str = f"{val_loss:.4f}" if isinstance(val_loss, float) else str(val_loss)
        print(f"  Checkpoint details: Epoch {loaded_epoch}, Step {loaded_step}, Val Loss {val_loss_str}")

    except Exception as e:
        print(f"Error loading model state weights: {e}")
        print("Attempting inference with potentially uninitialized weights.")

    return model, hyperparams

# ==========================================
# Main Execution Block
# ==========================================
if __name__ == "__main__":
    # --- Setup Device ---
    device = "mps" if torch.backends.mps.is_available() else \
             ("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- Load Model ---
    model, hyperparams = load_model_for_inference(MODEL_PATH, device)

    if model and hyperparams:
        # --- Run SR-ABI Generation ---
        start_time = time.time()
        generated_text = generate_sr_abi(
            model=model,
            prompt_text=PROMPT_TEXT,
            num_votes=NUM_VOTES,
            max_new_tokens=MAX_NEW_TOKENS,
            temperature=TEMPERATURE,
            top_p=TOP_P,
            hyperparams=hyperparams,
            device=device
        )
        end_time = time.time()

        print("\n\n--- Final Output ---")
        print(generated_text)
        print(f"\nGeneration took {end_time - start_time:.2f} seconds.")
    else:
        print("Failed to load model. Exiting.")

Using device: cuda
Loading model checkpoint from t_jepa_mtl_decoder_rope_bs1024_checkpoint.pt...
Using hyperparameters loaded from checkpoint.
Effective Hyperparameters: {'vocab_size': 256, 'embed_dim': 512, 'n_heads': 8, 'n_layers': 12, 'block_size': 1024, 'ema_decay': 0.999, 'lm_loss_weight': 0.92, 'bos_token': 254, 'eos_token': 255, 'pad_token': 0, 'top_p': 0.8, 'context_span_ratio': 0.6, 'target_span_ratio': 0.2, 'num_target_spans': 8, 'min_span_length': 32, 'thinking_tag': '<think>', 'thinking_end_tag': '</think>', 'answer_tag': '<answer>', 'answer_end_tag': '</answer>', 'system_prompt': 'Consider this math problem. Think step by step and provide your reasoning between <think> </think> tags, then give your final answer between <answer> </answer> tags.', 'batch_size': 2, 'num_epochs': 50, 'steps_per_epoch': 1000, 'eval_interval': 200, 'eval_iters': 100, 'accumulation_steps': 8, 'generate_num_tokens': 1024, 'start_prompt': 'Problem: A bakery produces cakes for $10 each. It costs the