In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [None]:
from collections import deque
import torch
from torch.nn.attention import flex_attention
from torchtune.modules import attention_utils

score_mod: flex_attention._score_mod_signature | None = None
# score_mods: deque[flex_attention._score_mod_signature | None] = deque()


# We cannot do nested compile, but flex attention only has perf benefits
# when compiled. To insulate it from the compiler, we wrap it with
# compiler.disable so that it can be used regardless of whether the model
# is compiled or not, and flex attention always remains compiled.
@torch.compiler.disable(recursive=False)
def compile_friendly_flex_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    block_mask: flex_attention.BlockMask,
) -> torch.Tensor:
    return attention_utils.flex_attention_compiled(
        q,
        k,
        v,
        score_mod=score_mod,
        block_mask=block_mask,
        kernel_options={
            # "BLOCK_M": 64,
            # "BLOCK_N": 64,
            # "BLOCK_M1": 64,
            # "BLOCK_N1": 64,
            "BLOCK_M2": 64,
            "BLOCK_N2": 64,
        },
    )  # type: ignore


attention_utils.compile_friendly_flex_attention = compile_friendly_flex_attention

In [None]:
import art
from art.torchtune.config import (
    ModelConfig,
    MetricLoggerConfig,
    OptimizerConfig,
    CheckpointerConfig,
    RecipeConfig,
)
from art.torchtune.recipe import FullFinetuneRecipeDistributed
import asyncio
import glob
import os

process = await asyncio.subprocess.create_subprocess_exec(
    "huggingface-cli",
    "download",
    "Qwen/Qwen3-14B",
    stdout=asyncio.subprocess.PIPE,
    stderr=asyncio.subprocess.PIPE,
)
stdout, _ = await process.communicate()
checkpoint_dir = stdout.decode("utf-8").splitlines()[-1].strip()
safetensor_files = glob.glob(f"{checkpoint_dir}/*.safetensors")
checkpoint_files = sorted(os.path.basename(f) for f in safetensor_files)
print("checkpoint_dir", checkpoint_dir)
print("checkpoint_files", checkpoint_files)

cfg = RecipeConfig(
    model=ModelConfig(_component_="torchtune.models.qwen3.qwen3_14b_instruct"),
    metric_logger=MetricLoggerConfig(
        _component_="torchtune.training.metric_logging.StdoutLogger"
    ),
    optimizer=OptimizerConfig(_component_="torch.optim.AdamW"),
    checkpointer=CheckpointerConfig(
        _component_="torchtune.training.FullModelHFCheckpointer",  # type: ignore
        model_type="QWEN3",
        checkpoint_dir=checkpoint_dir,  # type: ignore
        checkpoint_files=checkpoint_files,  # type: ignore
    ),
    output_dir=os.path.abspath("../.art/temporal-clue/models/002"),
    enable_activation_checkpointing=True,
    enable_activation_offloading=True
)

recipe = FullFinetuneRecipeDistributed(cfg=cfg)  # type: ignore
recipe.setup(cfg=cfg)

In [None]:
micro_batches, batch = recipe._get_micro_batches(curr_epoch=0)
inputs = micro_batches[0]

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-14B")

In [None]:
print("".join(tokenizer.decode(token_id) for token_id in inputs["tokens"][0]))

In [None]:
end_index = inputs["tokens"][0].tolist().index(151645, 3190)
end_index

In [None]:
for key in inputs:
    inputs[key] = inputs[key][:, :end_index]

In [None]:
dict(zip(range(-101, -1), [tokenizer.decode(token_id) for token_id in inputs["tokens"][0]][-100:]))

In [None]:
from torchtune.utils import batch_to_device

batch_to_device(inputs, device=recipe._device) # type: ignore

In [None]:
import torch
from torch.nn.attention.flex_attention import create_block_mask, BlockMask

def make_block_mask(
    group_ids: torch.Tensor,  # [B, S]  int32/64
    parent_ids: torch.Tensor,  # [B, S]  int32/64
    block_size: int = 256,  # Reduced from 128 to 64 to avoid OOM
) -> BlockMask:
    """
    FlexAttention equivalent of

        causal_mask & (group_ids[q]==group_ids[kv]  |  parent_ids[kv]==group_ids[q])

    * group_ids : id shared by all tokens of the same sampled trajectory
    * parent_ids: id identifying the prompt that produced each token
    """
    B, S = group_ids.shape  # batch, sequence length

    # the closure captures the two id tensors; that's fine for torch.compile
    def mask_mod(b, h, q_idx, kv_idx):
        # causal constraint
        causal = kv_idx <= q_idx

        same_group = group_ids[b, q_idx] == group_ids[b, kv_idx]
        prompt_link = parent_ids[b, q_idx] == group_ids[b, kv_idx]

        return causal & (same_group | prompt_link)

    return create_block_mask(
        mask_mod,
        B=B,
        H=None,
        Q_LEN=S,
        KV_LEN=S,
        BLOCK_SIZE=block_size,
    )

block_mask = make_block_mask(
    group_ids=inputs["group_ids"],
    parent_ids=inputs["parent_ids"],
)

In [None]:
def calculate_tensor_memory_gb(
    shape: tuple[int, ...],
    dtype: torch.dtype = torch.float32,
    binary_gb: bool = True,
) -> float:
    """
    Calculate the memory usage of a tensor in gigabytes without creating it.

    Args:
        shape: The shape/dimensions of the tensor (e.g., (1024, 1024, 512))
        dtype: The data type of the tensor (e.g., torch.float32, torch.float16)
        binary_gb: If True, use binary GB (1024^3), if False use decimal GB (10^9)

    Returns:
        Memory usage in gigabytes
    """
    # Calculate total number of elements
    total_elements = 1
    for dim in shape:
        total_elements *= dim

    # Get bytes per element based on dtype
    bytes_per_element = torch.tensor([], dtype=dtype).element_size()

    # Calculate total bytes
    total_bytes = total_elements * bytes_per_element

    # Convert to GB
    if binary_gb:
        gb_divisor = 1024**3  # 1 GiB = 1024^3 bytes
    else:
        gb_divisor = 10**9  # 1 GB = 10^9 bytes

    return total_bytes / gb_divisor

calculate_tensor_memory_gb((1, inputs["tokens"].shape[1], inputs["tokens"].shape[1]), dtype=torch.bfloat16)

In [None]:
attn_bias = torch.zeros((1, inputs["tokens"].shape[1], inputs["tokens"].shape[1]), dtype=torch.bfloat16, device=recipe._device, requires_grad=True)
attn_bias

In [None]:
score_mod = lambda score, b, h, q_idx, kv_idx: score + attn_bias[b, q_idx, kv_idx]

In [None]:
# score_mods_list = [score_mod if i % 2 == 0 else None for i in range(len(recipe._model.layers))]
# score_mods = deque(score_mods_list)

In [None]:
# score_mods = deque([score_mod] + [None] * (len(recipe._model.layers) - 1))
# score_mods = deque([None] * (len(recipe._model.layers) - 1) + [score_mod])

In [None]:
with recipe.activations_handling_ctx:
    hidden_states = recipe._model(
        tokens=inputs["tokens"],
        mask=block_mask,
        input_pos=inputs["input_pos"],
    )

del block_mask

In [None]:
from art.unsloth.train import shift_tensor
from typing import cast

assistant_mask = shift_tensor(inputs["assistant_mask"], False)
hidden_states = hidden_states[assistant_mask]
next_token_ids = shift_tensor(inputs["tokens"], 0)[assistant_mask]
chunk_size = batch.dev_config.get("logprob_calculation_chunk_size", 1024)
loss = torch.tensor(0.0, device=recipe._device)
all_new_logprobs = []
for i in range(0, hidden_states.size(0), chunk_size):
    chunk_end = min(i + chunk_size, hidden_states.size(0))
    # [chunk_size, hidden_size] @ [hidden_size, vocab_size]
    logits = cast(
        torch.Tensor, recipe._model.output(hidden_states[i:chunk_end,:])
    )  # [chunk_size, vocab_size]
    selected_logits = torch.gather(
        logits, dim=-1, index=next_token_ids[i:chunk_end].unsqueeze(-1)
    ).squeeze(
        -1
    )  # [chunk_size]
    logsumexp = torch.logsumexp(logits, dim=-1)  # [chunk_size]
    new_logprobs = selected_logits - logsumexp
    all_new_logprobs.append(new_logprobs)
    # loss += new_logprobs.sum()
    loss += logits.sum()
    # loss += torch.exp(new_logprobs).sum()
    del logits, selected_logits, logsumexp, new_logprobs

new_logprobs = torch.cat(all_new_logprobs, dim=0)
new_logprobs.shape

In [None]:
# # add L1 loss for attn_bias
# loss += torch.abs(attn_bias).sum() * 10

In [None]:
# import gc
# import torch

# for _ in range(3):
#     gc.collect()
#     torch.cuda.empty_cache()

In [None]:
# score_mods = deque([None] * (len(recipe._model.layers) - 1) + [score_mod])
# score_mods = deque(reversed(score_mods_list))
result = torch.autograd.grad(loss, attn_bias)

In [None]:
result[0][0].sum(dim=1)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Get the gradient matrix - sample it down to speed up visualization
grad_matrix = result[0][0].to(torch.float16).cpu().numpy()

# Subsample the matrix for faster visualization
max_size = 1000  # Limit to 1000x1000 for performance
if grad_matrix.shape[0] > max_size or grad_matrix.shape[1] > max_size:
    # Calculate step sizes for subsampling
    step_y = max(1, grad_matrix.shape[0] // max_size)
    step_x = max(1, grad_matrix.shape[1] // max_size)
    grad_matrix_viz = grad_matrix[::step_y, ::step_x]
    print(f"Subsampled from {grad_matrix.shape} to {grad_matrix_viz.shape} for visualization")
else:
    grad_matrix_viz = grad_matrix

# Create a heatmap with appropriate sizing
plt.figure(figsize=(12, 8))

# Use rasterized rendering for better performance with large matrices
im = plt.imshow(grad_matrix_viz, cmap='RdBu_r', aspect='auto', interpolation='nearest', rasterized=True)

# Add colorbar
plt.colorbar(im, label='Gradient Value')

# Add labels
plt.xlabel('Attention Head Dimension')
plt.ylabel('Token Position')
plt.title(f'Attention Bias Gradients Heatmap\nOriginal Shape: {grad_matrix.shape}')

# Make it more readable
plt.tight_layout()
plt.show()

In [None]:
import torch

torch.where(
    torch.logical_and(torch.isinf(result[0][0]), result[0][0] > 0),
    result[0][0].max(),
    result[0][0],
)

In [None]:
import polars as pl

pl.Series(result[0][0].reshape(-1).to(torch.float32).cpu().numpy()).to_frame(
    "value"
).filter(pl.col("value") != 0)

In [None]:
import pandas as pd

pd.Series(torch.abs(result[0][0].sum(dim=1)).to(torch.float16).cpu().numpy()).plot()

In [None]:
from illustrate import illustrate

# -44: Professor Plum
# -39: Revolver
# -34: Gazebo
# -18: Ambition
# -13: Betrayal
# -7: Conservatory
# -2: Lounge

print(
    illustrate(
        list(
            zip(
                list(tokenizer.decode(token_id) for token_id in inputs["tokens"][0]),
                result[0][0][-22].to(torch.float32).tolist(),
                # result[0][0].sum(dim=1).squeeze(0).tolist(),
            )
        ),
        gradient="one-dark-simple",
    )
)

In [None]:
from illustrate import illustrate

tokens = [tokenizer.decode(token_id) for token_id in inputs["tokens"][0]]
logprobs = torch.exp(new_logprobs).tolist()
print(
    illustrate(
        list(
            zip(
                tokens,
                [0] * (len(tokens) - len(logprobs)) + logprobs,
            )
        ),
        gradient="one-dark-simple",
    )
)

In [None]:
new_logprobs

In [None]:
([0] * (len(tokens) - len(logprobs)) + logprobs)[-3]

In [None]:
result[0].sum(dim=1).squeeze(0).tolist()

In [None]:
[tokenizer.decode(token_id) for token_id in inputs["tokens"][0]][-100:]

In [None]:
loss.backward()

In [None]:
recipe._optimizer.step()
recipe._optimizer.zero_grad()

In [None]:
hidden_states.shape

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-14B")

In [None]:
recipe._model.layers