### Expert-Choice MoR Model Analysis

This notebook analyzes the number of recursions required per token in expert-choice Mixture-of-Recursion (MoR) models.    
This visualizes the number of recursion steps that each subword token undergoes to predict the *next token*.

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from pathlib import Path
PROJECT_DIR = Path.cwd().parent
os.chdir(PROJECT_DIR)
from paths import HF_CACHE_DIR; os.environ["HF_HOME"] = HF_CACHE_DIR
from typing import Callable, List, Optional, Tuple, Union
from dataclasses import dataclass
from omegaconf import DictConfig, OmegaConf
from copy import deepcopy
from IPython.display import HTML, display

import re
import json
import math
import html
import pandas as pd
import numpy as np
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn_utils
from transformers import AutoConfig, AutoTokenizer
from transformers.utils import ModelOutput
from transformers.processing_utils import Unpack
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs

from model.util import load_model_from_config
from model.sharing_strategy import SHARING_STRATEGY
from model.kv_caches.cache_utils import Cache, DynamicCache, RecursiveDynamicCache
from model.mor_model.modeling_llama import MoRLlamaForCausalLM, MoRLlamaModel, MoRBaseModelOutputWithPast, MoRBaseModelOutputWithPast
from model.mor_model.expert_choice_router import MoRLlamaDecoderLayer
from model.mor_model.util import MoRLayerOutputWithPast
from lm_dataset.load_dataset import load_dataset_from_config
from util.config import preprocess_config
from util.misc import get_torch_dtype

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
sampling_idx = 0 # Index of the sample to be processed

# exp_name = "250720_pretrain_smollm-360m_rec2_middle_cycle_random_lr3e-3_mor_expert_linear_alpha_0.1_sigmoid_aux_loss_0.001"
exp_name = "250720_pretrain_smollm-360m_rec3_middle_cycle_random_lr3e-3_mor_expert_linear_alpha_0.1_sigmoid_aux_loss_0.001"
# exp_name = "250720_pretrain_smollm-360m_rec4_middle_cycle_random_lr3e-3_mor_expert_linear_alpha_0.1_sigmoid_aux_loss_0.001"
# exp_name = "250720_pretrain_smollm-360m_kv-share_rec3_middle_cycle_random_lr3e-3_mor_expert_linear_alpha_0.1_sigmoid_aux_loss_0.001"

match = re.search(r"rec(\d+)", exp_name)
if match:
    num_recursion = int(match.group(1))
else:
    num_recursion = 3
if "expert" not in exp_name:
    raise ValueError("This script is for expert choice routing only.")

In [3]:
class MoRMetricLlamaDecoderLayer(MoRLlamaDecoderLayer):
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        prev_selected_tokens: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs]
    ):          
        total_x = x # clone()
        bs, seq_len, hidden_dim = total_x.shape
        
        if self.training:
            self.training_step += 1
            if self.cap_warmup_step > 0:
                step_ratio = min(1.0, self.training_step / self.cap_warmup_step)
                decay_factor = 0.5 * (1.0 + math.cos(math.pi * step_ratio))
            else:
                decay_factor = 0.0
            capacity_factor = self.capacity_factor + (1.0 - self.capacity_factor) * decay_factor
        else:
            capacity_factor = self.capacity_factor
    
        # self.top_k = int(self.capacity_factor * config.max_position_embeddings) # max value
        # top_k = min(self.top_k, int(capacity_factor * seq_len))
        top_k = int(capacity_factor * seq_len)
        
        # gather the tokens that were processed in the previous layer
        if prev_selected_tokens is not None:
            x = torch.gather(x, 1, index=prev_selected_tokens.expand(-1, -1, hidden_dim))
        
        """STEP 1: get logits and top_k tokens"""
        if not self.cfg.mor.rand_router:
            _router_weights = self.mor_router(x / self.cfg.mor.temp) # [bs, seq_len, 1]
        
            if self.router_func is None:
                router_probs = router_weights = _router_weights
            elif self.router_func == "sigmoid":
                router_weights = F.sigmoid(_router_weights)
                router_probs = router_weights * self.cfg.mor.expert.alpha
            elif self.router_func == "tanh":
                router_weights = F.tanh(_router_weights)
                router_probs = router_weights * self.cfg.mor.expert.alpha
            else:
                raise NotImplementedError("Router function is not implemented")
            
        else:
            router_weights = _router_weights = torch.rand(bs, x.shape[1], 1, device=x.device, dtype=x.dtype)
            router_probs = router_weights * self.cfg.mor.expert.get("alpha", 0.1)
            
        weights, selected_tokens = torch.topk(router_probs, top_k, dim=1, sorted=False) # [bs, k, 1]
        # IMPORTANT: need to sort indices to keep causal order for those tokens that are processed in a block
        selected_tokens, index = torch.sort(selected_tokens, dim=1)
        weights = torch.gather(weights, dim=1, index=index)
        
        """STEP 2: expand indices to process batches with _reduced_ seqlen"""
        # We need to expand indices' dimensions from
        # [bs, k, 1] to [bs, k, hidden_size] for gathering
        indices_expanded = selected_tokens.expand(-1, -1, hidden_dim)
        top_k_tokens = torch.gather(x, dim=1, index=indices_expanded)
        
        sampling_loss = None
        sampling_acc = None
        topk_acc = None
        uniformity = None
        dead_token_seq = None
    
        targets = torch.zeros_like(router_probs, dtype=router_probs.dtype)
        src = torch.ones_like(selected_tokens, dtype=targets.dtype)
        targets.scatter_(1, selected_tokens, src)
        
        if self.sampling == "aux_router":
            logits = self.mlp_router(x.clone().detach())
            sampling_loss = self.bce_loss(logits.view(-1), targets.view(-1)) / (bs * logits.shape[1])
            prediction = (F.sigmoid(logits) >= 0.5)
            correct_predictions = (prediction == targets).view(-1)
            sampling_acc = correct_predictions.sum() / (bs * logits.shape[1])
            
            aux_router_topk = torch.topk(logits, top_k, dim=1, sorted=False)[1]
            topk_acc = torch.tensor(0.0, device=logits.device)
            for b in range(bs):
                topk_acc += torch.isin(selected_tokens[b].view(-1), aux_router_topk[b].view(-1)).sum()
            topk_acc = topk_acc / (bs * top_k)
            
        elif self.sampling == "aux_loss":
            if self.router_func is None or self.router_func == "sigmoid":
                sampling_loss = self.bce_loss(_router_weights.view(-1), targets.view(-1)) / (bs * router_weights.shape[1])
                prediction = (router_weights >= 0.5)
            elif self.router_func == "tanh":
                sampling_loss = self.bce_loss(_router_weights.view(-1), targets.view(-1)) / (bs * router_weights.shape[1])
                prediction = (router_weights >= 0.)
            correct_predictions = (prediction == targets).view(-1)
            sampling_acc = correct_predictions.sum() / (bs * router_weights.shape[1])
            topk_acc = None
            
        """STEP 3: based on total seqlen, prepare input for block forward"""
        # recompute selected_tokens based on total tokens        
        if prev_selected_tokens is not None:
            selected_tokens = torch.gather(prev_selected_tokens, dim=1, index=selected_tokens)
            indices_expanded = selected_tokens.expand(-1, -1, hidden_dim)

        with torch.no_grad():
            _targets = torch.zeros((bs, seq_len, 1), dtype=selected_tokens.dtype, device=selected_tokens.device)
            _targets.scatter_(1, selected_tokens, torch.ones_like(selected_tokens))
            dead_token_seq = _targets.squeeze(-1).sum(dim=0).squeeze(0)
        
        """STEP 4: forward block"""     
        if 'kv_sharing' in self.cfg and self.cfg.kv_sharing.enable:
            top_k_tokens = total_x.clone()
                
            for blk in self.block:
                outputs = blk(
                    top_k_tokens,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_value,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    **kwargs
                )
                top_k_tokens = outputs[0]
            top_k_tokens_processed = torch.gather(outputs[0], dim=1, index=indices_expanded)
            
        else:
            if attention_mask is not None: 
                if attention_mask.dim() == 4: # attn_implementation == "eager"
                    row_indices = selected_tokens.unsqueeze(1).expand(bs, 1, top_k, attention_mask.shape[-1])  
                    mask_rows_selected = torch.gather(attention_mask, 2, row_indices)
                    col_indices = selected_tokens.unsqueeze(1).transpose(2, 3).expand(bs, 1, top_k, top_k)
                    attention_mask = torch.gather(mask_rows_selected, 3, col_indices)
                elif attention_mask.dim() == 2: # TODO
                    raise NotImplementedError("Attention mask is not implemented for inference phase of MoR")
                else: 
                    raise NotImplementedError("Attention mask has unexpected dimensions")
            
            if position_ids is not None: 
                position_ids = position_ids[:, :top_k]            
            if position_embeddings is not None:
                head_dim = position_embeddings[0].shape[-1]
                position_embeddings = tuple([torch.gather(emb.expand(bs, -1, -1), dim=1, index=selected_tokens.expand(-1, -1, head_dim)) 
                                                for emb in position_embeddings])
            if cache_position is not None:
                cache_position = torch.gather(cache_position.expand(bs, -1), dim=1, index=selected_tokens.squeeze(-1)) 
                
            for blk in self.block:
                outputs = blk(
                    top_k_tokens,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_value,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    **kwargs
                )
                top_k_tokens = outputs[0]
            top_k_tokens_processed = outputs[0]
            
        """STEP 5: combine results"""
        _src = top_k_tokens_processed * weights if self.cfg.mor.expert.get("gating", "weighted") == "weighted" else top_k_tokens_processed
        total_x = torch.scatter_add(
            total_x,
            dim=1,
            index=indices_expanded,
            src=_src,
        )
        
        router_z_loss = torch.logsumexp(_router_weights, dim=-1)
        router_z_loss = torch.square(router_z_loss)
        router_z_loss = router_z_loss.mean()
                
        return MoRLayerOutputWithPast(
            hidden_state=total_x,
            attention_weights=outputs[1:],
            selected_tokens=selected_tokens,
            sampling_loss=sampling_loss,
            sampling_acc=sampling_acc,  
            sampling_topk_acc=topk_acc,
            uniformity=uniformity,
            dead_token_seq=dead_token_seq.detach().cpu(),
            balancing_loss=None,
            balancing_ratio=None,
            router_z_loss=router_z_loss,
        )

In [4]:
class MoRMetricLlamaModel(MoRLlamaModel):
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
    ) -> Union[Tuple, MoRBaseModelOutputWithPast]:
        
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if use_cache and past_key_values is None:
            if "kv_sharing" in self.config and self.config.kv_sharing is not None:
                kwargs = self.config.kv_sharing
                past_key_values = RecursiveDynamicCache(kwargs["base_depth"], kwargs["num_recursion"], kwargs["sharing"])
            else:
                past_key_values = DynamicCache()

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        
        prev_selected_tokens = None
        sampling_loss = torch.tensor(0.0, device=hidden_states.device)
        sampling_acc_list = []
        sampling_topk_acc_list = []
        uniformity = None # torch.tensor(0.0, device=hidden_states.device)
        dead_token_seq_list = []
        balancing_loss = torch.tensor(0.0, device=hidden_states.device)
        balancing_ratio = torch.tensor(0.0, device=hidden_states.device)
        router_z_loss = torch.tensor(0.0, device=hidden_states.device)
        
        for decoder_layer in self.layers[: self.config.num_hidden_layers]:            
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                # TODO: support MoRLlamaDecoderLayer
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                    position_embeddings,
                )
            else:
                if hasattr(decoder_layer, "mor") and decoder_layer.mor:
                    layer_outputs = decoder_layer(
                        hidden_states,
                        attention_mask=causal_mask,
                        position_ids=position_ids,
                        past_key_value=past_key_values,
                        output_attentions=output_attentions,
                        use_cache=use_cache,
                        cache_position=cache_position,
                        position_embeddings=position_embeddings,
                        prev_selected_tokens=prev_selected_tokens,
                        **flash_attn_kwargs,
                    )
                    if decoder_layer.mor_type == "expert":
                        prev_selected_tokens = layer_outputs.selected_tokens
                        if layer_outputs.dead_token_seq is not None:
                            dead_token_seq_list.append(layer_outputs.dead_token_seq)
                else:
                    layer_outputs = decoder_layer(
                        hidden_states,
                        attention_mask=causal_mask,
                        position_ids=position_ids,
                        past_key_value=past_key_values,
                        output_attentions=output_attentions,
                        use_cache=use_cache,
                        cache_position=cache_position,
                        position_embeddings=position_embeddings,
                        **flash_attn_kwargs,
                    )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        output = MoRBaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            sampling_loss=sampling_loss,
            sampling_acc=sum(sampling_acc_list)/len(sampling_acc_list) if len(sampling_acc_list) > 0 else torch.tensor(0.0, device=hidden_states.device),
            sampling_topk_acc=sum(sampling_topk_acc_list)/len(sampling_topk_acc_list) if len(sampling_topk_acc_list) > 0 else torch.tensor(0.0, device=hidden_states.device),
            uniformity=uniformity,
            dead_token_seq=dead_token_seq_list,
            balancing_loss=balancing_loss,
            balancing_ratio=balancing_ratio,
            router_z_loss=router_z_loss,
        )
        return output if return_dict else output.to_tuple()

In [5]:
class MoRMetricLlamaForCausalLM(MoRLlamaForCausalLM):    
    def __init__(self, config):
        super().__init__(config)
        self.model = MoRMetricLlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def transform_layer_to_mor_expert(self, cfg):
        capacity = [float(cap) for cap in cfg.mor.capacity.split(',')]
        # warmup_step for capacity_factor
        if "cap_warmup_step" in cfg.mor.expert and cfg.mor.expert.cap_warmup_step is not None:
            cap_warmup_step = cfg.mor.expert.cap_warmup_step
        else:
            cap_warmup_step = cfg.num_warmup_steps * cfg.gradient_accumulation_steps
        
        sharing = cfg.recursive.sharing
        num_recursion = cfg.recursive.num_recursion        
        num_hidden_layers = len(self.model.layers)
        
        # Cycle sharing is for early-exiting mechanism
        if sharing == "cycle":
            base_depth = num_hidden_layers // num_recursion
            self.model.layers = nn.ModuleList(
                [
                    MoRMetricLlamaDecoderLayer(self.config, nn.ModuleList([self.model.layers[layer_idx + recur_idx * base_depth] for layer_idx in range(base_depth)]), 
                                         cfg, capacity[recur_idx], cap_warmup_step,) 
                    for recur_idx in range(num_recursion)
                ]
            )
        elif sharing == "middle_cycle":
            base_depth = (num_hidden_layers - 2) // num_recursion
            self.model.layers = nn.ModuleList(
                [self.model.layers[0]] + \
                [
                    MoRMetricLlamaDecoderLayer(self.config, nn.ModuleList([self.model.layers[1 + layer_idx + recur_idx * base_depth] for layer_idx in range(base_depth)]), 
                                         cfg, capacity[recur_idx], cap_warmup_step,)
                    for recur_idx in range(num_recursion)
                ]
                + [self.model.layers[-1]]
            )

In [6]:
MOR_MODEL_CLS = {
    "smollm": MoRMetricLlamaForCausalLM,
}

def load_model_from_config(cfg: DictConfig):
    assert "mor" in cfg and cfg.mor.enable
    model_cls = MOR_MODEL_CLS[cfg.model]
        
    attn_implementation = cfg.get("attn_implementation", "flash_attention_2")
    torch_dtype = get_torch_dtype(cfg)
    
    print("Initializing model from scratch...")
    config = AutoConfig.from_pretrained(
        cfg.model_name_or_path,
        attn_implementation=attn_implementation, 
        torch_dtype=torch_dtype,
    )
    
    if cfg.get("model_config") is not None:
        print("Using custom config for vanilla model...")
        for k, v in cfg.model_config.items():
            if not hasattr(config, k):
                raise ValueError(f"Config key {k} not found in model config.")
            print(f" {k}: {v}")
            setattr(config, k, v)
    if cfg.get("max_length") and cfg.max_length != config.max_position_embeddings:
        warnings.warn(f"original max_position_embeddings of {config.max_position_embeddings} is changed to {cfg.max_length}")
        setattr(config, "max_position_embeddings", cfg.max_length)
    return model_cls._from_config(
        config, attn_implementation=attn_implementation, torch_dtype=torch_dtype,)

In [7]:
cfg = OmegaConf.load(os.path.join(PROJECT_DIR, "conf/pretrain", f"{exp_name}.yaml"))
cfg.per_device_train_batch_size = 1
cfg = preprocess_config(cfg)

model = load_model_from_config(cfg)

if cfg.recursive.get("enable"):        
    # KV cache sharing strategy
    model, lora_init_dict = SHARING_STRATEGY[cfg.model](cfg, model)

if "kv_sharing" in cfg and cfg.kv_sharing.get("enable"):
    model.set_kv_sharing_config(cfg)

if "mor" in cfg and cfg.mor.get("enable"):
    if cfg.mor.type == "expert":
        model.transform_layer_to_mor_expert(cfg)
    elif cfg.mor.type == "token":
        model.transform_layer_to_mor_token(cfg)
    else:
        raise ValueError(f"Unknown MoR type {cfg.mor.type}.")

print(model)


-------------------------------Preprocess Config -------------------------------


Automatically determining batch size based on `total_batch_size`
total_batch_size              : 1024 (given)
torch.cuda.device_count()     : 1
per_device_train_batch_size   : 1 (given)
gradient_accumulation_steps   : 1024 (computed)
actual total batch size       : 1024
Setting wandb_run_name: 250720_pretrain_smollm-360m_rec3_middle_cycle_random_lr3e-3_mor_expert_linear_alpha_0.1_sigmoid_aux_loss_0.001
Setting output_dir  : 250720_pretrain_smollm-360m_rec3_middle_cycle_random_lr3e-3_mor_expert_linear_alpha_0.1_sigmoid_aux_loss_0.001
No checkpoint directories found matching /home/sangmin/mixture_of_recursions/results/pretrain/250720_pretrain_smollm-360m_rec3_middle_cycle_random_lr3e-3_mor_expert_linear_alpha_0.1_sigmoid_aux_loss_0.001/checkpoint-*
Using deepspeed config = /home/sangmin/mixture_of_recursions/ds_configs/stage2.config
--------------------------------------------------------------------------------
Initializing model from scratch...




MoRMetricLlamaForCausalLM(
  (model): MoRMetricLlamaModel(
    (embed_tokens): Embedding(49152, 960)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=960, out_features=960, bias=False)
          (k_proj): Linear(in_features=960, out_features=320, bias=False)
          (v_proj): Linear(in_features=960, out_features=320, bias=False)
          (o_proj): Linear(in_features=960, out_features=960, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=960, out_features=2560, bias=False)
          (up_proj): Linear(in_features=960, out_features=2560, bias=False)
          (down_proj): Linear(in_features=2560, out_features=960, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((960,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((960,), eps=1e-05)
      )
      (1-3): 3 x MoRMetricLlamaDecoderLayer(
        (block): ModuleLis

In [8]:
SAVE_DIR = os.path.join(PROJECT_DIR, "checkpoints")

try:
    state_dict = torch.load(os.path.join(SAVE_DIR, exp_name, "pytorch_model.bin"))
except FileNotFoundError:
    # safetensors
    from safetensors.torch import load_file
    state_dict = load_file(os.path.join(SAVE_DIR, exp_name, "model.safetensors"))
model.load_state_dict(state_dict)

<All keys matched successfully>

In [9]:
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
if tokenizer.pad_token is None:
    tokenizer.pad_token_id = 0 # '<|endoftext|>'

cfg.dataset = "fineweb_test"
dataset = load_dataset_from_config(cfg, tokenizer)

In [10]:
model.to("cuda")
model.eval()

dtype = get_torch_dtype(cfg)
routing_list = [None for _ in range(cfg.recursive.num_recursion)]

for i, sample in enumerate(dataset):    
    if i != sampling_idx: continue
    
    output = model(
        input_ids=sample["input_ids"].unsqueeze(0).to(model.device),
        attention_mask=sample["attention_mask"].unsqueeze(0).to(model.device),
        labels=sample["labels"].unsqueeze(0).to(model.device),
        use_cache=True,
        output_attentions=True,
        output_hidden_states=True,
        return_dict=True,
    )
    
    for j in range(cfg.recursive.num_recursion):
        if routing_list[j] is None:
            routing_list[j] = output.dead_token_seq[j]
        else:
            routing_list[j] += output.dead_token_seq[j]
        
    torch.cuda.empty_cache()    
    if i == sampling_idx: break

In [11]:
left_len = 120    # Show only left_len tokens at the start
right_len = 60    # Show only right_len tokens at the end

data = []
for i in range(sample["input_ids"].shape[0]):
    if i < left_len:
        data.append(tokenizer.decode(sample["input_ids"][i].tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True))
    if i >= sample["input_ids"].shape[0] - right_len:
        if i == sample["input_ids"].shape[0] - right_len:
            data.append("... ...")
        data.append(tokenizer.decode(sample["input_ids"][i].tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True))

In [12]:
HIGHLIGHT_COLORS = {
    "purple": "#dabfda",
    "red": "#f7b0a9",
    "blue": "#aeddfd",
    "yellow": "#f2e4a5",
    "no_highlight": "transparent"
}

def get_layer_value_safe(dts_list, layer_idx, token_idx, default_val=0):
    # This function safely retrieves a value from a nested list.
    if layer_idx < len(dts_list):
        layer = dts_list[layer_idx]
        if hasattr(layer, '__len__') and token_idx < len(layer):
            val = layer[token_idx]
            if hasattr(val, 'item') and callable(val.item):
                return val.item()
            elif isinstance(val, (int, float)):
                 return val
            elif not hasattr(val, '__len__'): # if it looks like a scalar
                return val
    return default_val

def generate_highlighted_text_html(tokenizer, sample_data, current_left_len, current_right_len, current_dead_token_seq_list, current_colors, tokens_per_line=15):
    """
    Analyzes a token sequence, applies background colors based on conditions,
    and generates an HTML string with line breaks (<br>) after a specified number of tokens.
    """
    total_tokens = sample_data["input_ids"].shape[0]
    token_highlight_colors = [current_colors["no_highlight"]] * total_tokens
    recursion_counts = [0] * total_tokens  # Initialize recursion counts for each token

    # 1. Determine the highlight color and recursion count for each token in advance
    for j in range(total_tokens):
        l0_val = get_layer_value_safe(current_dead_token_seq_list, 0, j)
        l1_val = get_layer_value_safe(current_dead_token_seq_list, 1, j) if len(current_dead_token_seq_list) > 1 else 0
        l2_val = get_layer_value_safe(current_dead_token_seq_list, 2, j) if len(current_dead_token_seq_list) > 2 else 0
        l3_val = get_layer_value_safe(current_dead_token_seq_list, 3, j) if len(current_dead_token_seq_list) > 3 else 0

        chosen_color_key = None
        recursion_count = 0

        if l0_val == 1 and l1_val == 1 and l2_val == 1 and l3_val == 1:
            chosen_color_key = "yellow"
            recursion_count = 4
        elif l0_val == 1 and l1_val == 1 and l2_val == 1 and l3_val == 0:
            chosen_color_key = "red"
            recursion_count = 3
        elif l0_val == 1 and l1_val == 1 and l2_val == 0 and l3_val == 0:
            chosen_color_key = "blue"
            recursion_count = 2
        elif l0_val == 1 and l1_val == 0 and l2_val == 0 and l3_val == 0:
            chosen_color_key = "purple"
            recursion_count = 1
        
        if chosen_color_key:
            token_highlight_colors[j] = current_colors[chosen_color_key]
        recursion_counts[j] = recursion_count

    # 2. Prepare HTML fragment list and line break logic
    html_parts = []
    token_count_on_current_line = 0

    def append_html_and_conditionally_break(html_element, is_actual_token=True):
        nonlocal token_count_on_current_line
        html_parts.append(html_element)
        if is_actual_token:
            token_count_on_current_line += 1
            if token_count_on_current_line >= tokens_per_line:
                html_parts.append("<br>")
                token_count_on_current_line = 0

    def get_processed_token_string(token_id_tensor_obj):
        token_ids_list_for_decode = token_id_tensor_obj.tolist() if hasattr(token_id_tensor_obj, 'tolist') else [token_id_tensor_obj]
        token_string = tokenizer.decode(
            token_ids_list_for_decode,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        token_string = token_string.replace('\n', ' ').replace('\r', '') # Remove internal line breaks
        return html.escape(token_string)

    # # Process left part tokens
    last_idx_of_left_part = -1
    for i in range(min(current_left_len, total_tokens)):
        escaped_token_string = get_processed_token_string(sample_data["input_ids"][i])
        current_token_color = token_highlight_colors[i]
        recursion_count = recursion_counts[i]
        
        part_to_add = ""
        if current_token_color != current_colors["no_highlight"] and current_token_color is not None:
            part_to_add = f'<span style="background-color: {current_token_color};">{escaped_token_string}</span>'
        else:
            part_to_add = f'{escaped_token_string}'
        append_html_and_conditionally_break(part_to_add, is_actual_token=True)
        last_idx_of_left_part = i

    # Check and insert "..." condition
    first_idx_of_right_part = total_tokens - current_right_len
    if first_idx_of_right_part > last_idx_of_left_part + 1:
        # "..." is not included in the token count and is processed so that it does not cause a line break.
        # Add <br> here if you want a forced line break before/after "..."
        append_html_and_conditionally_break('<span style="background-color: #FFFFFF;">... ...</span>', is_actual_token=False)


    # Process right part tokens
    for i in range(max(0, first_idx_of_right_part), total_tokens):
        if i > last_idx_of_left_part:
            escaped_token_string = get_processed_token_string(sample_data["input_ids"][i])
            current_token_color = token_highlight_colors[i]
            recursion_count = recursion_counts[i]

            part_to_add = ""
            if current_token_color != current_colors["no_highlight"] and current_token_color is not None:
                part_to_add = f'<span style="background-color: {current_token_color};">{escaped_token_string}</span>'
            else:
                part_to_add = f'{escaped_token_string}'
            append_html_and_conditionally_break(part_to_add, is_actual_token=True)
            
    inner_html = "".join(html_parts)
    # Change to white-space: pre-wrap; so that the <br> tag works and other spaces are also (partially) maintained.
    return f'<div style="color: black; white-space: pre-wrap;">{inner_html}</div>', recursion_counts

highlighted_html_content, recursion_counts = generate_highlighted_text_html(
    tokenizer,
    sample,
    left_len,
    right_len,
    routing_list,
    HIGHLIGHT_COLORS,
    tokens_per_line=60 # Number of tokens to display per line
)

# 2. Display HTML in Jupyter Notebook, etc.:

display(HTML(highlighted_html_content))

# --- Code to save as HTML file ---
html_file_to_save = "highlighted_text_multiline.html" 

try:
    with open(html_file_to_save, "w", encoding="utf-8") as f:
        f.write("<!DOCTYPE html>\n")
        f.write("<html lang=\"ko\">\n")
        f.write("<head>\n")
        f.write("    <meta charset=\"UTF-8\">\n")
        f.write("    <title>Highlighted Text Output (Multiline)</title>\n")
        f.write("    <style>\n")
        f.write("        body { \n")
        f.write("            font-family: sans-serif; \n")
        f.write("            line-height: 1.6; \n")
        f.write("            margin: 20px; \n")
        #  white-space: pre-wrap; is applied to this div, so no special white-space setting is required for the body
        f.write("        }\n")
        f.write("    </style>\n")
        f.write("</head>\n")
        f.write("<body>\n")
        f.write(highlighted_html_content) # Write the generated HTML body content (including div)
        f.write("\n</body>\n")
        f.write("</html>\n")
    print(f"HTML file saved successfully as '{html_file_to_save}'.")
    print(f"Open this file directly in a web browser to see if it appears on multiple lines.")
except Exception as e:
    print(f"Error saving file: {e}")

HTML file saved successfully as 'highlighted_text_multiline.html'.
Open this file directly in a web browser to see if it appears on multiple lines.
