In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [3]:
from collections import deque
import torch
from torch.nn.attention import flex_attention
from torchtune.modules import attention_utils

score_mod: flex_attention._score_mod_signature | None = None
# score_mods: deque[flex_attention._score_mod_signature | None] = deque()


# We cannot do nested compile, but flex attention only has perf benefits
# when compiled. To insulate it from the compiler, we wrap it with
# compiler.disable so that it can be used regardless of whether the model
# is compiled or not, and flex attention always remains compiled.
@torch.compiler.disable(recursive=False)
def compile_friendly_flex_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    block_mask: flex_attention.BlockMask,
) -> torch.Tensor:
    return attention_utils.flex_attention_compiled(
        q,
        k,
        v,
        score_mod=score_mod,
        block_mask=block_mask,
        kernel_options={
            # "BLOCK_M": 64,
            # "BLOCK_N": 64,
            # "BLOCK_M1": 64,
            # "BLOCK_N1": 64,
            "BLOCK_M2": 64,
            "BLOCK_N2": 64,
        },
    )  # type: ignore


attention_utils.compile_friendly_flex_attention = compile_friendly_flex_attention

In [4]:
import art
from art.torchtune.config import (
    ModelConfig,
    MetricLoggerConfig,
    OptimizerConfig,
    CheckpointerConfig,
    RecipeConfig,
)
from art.torchtune.recipe import FullFinetuneRecipeDistributed
import asyncio
import glob
import os

process = await asyncio.subprocess.create_subprocess_exec(
    "huggingface-cli",
    "download",
    "Qwen/Qwen3-14B",
    stdout=asyncio.subprocess.PIPE,
    stderr=asyncio.subprocess.PIPE,
)
stdout, _ = await process.communicate()
checkpoint_dir = stdout.decode("utf-8").splitlines()[-1].strip()
safetensor_files = glob.glob(f"{checkpoint_dir}/*.safetensors")
checkpoint_files = sorted(os.path.basename(f) for f in safetensor_files)
print("checkpoint_dir", checkpoint_dir)
print("checkpoint_files", checkpoint_files)

cfg = RecipeConfig(
    model=ModelConfig(_component_="torchtune.models.qwen3.qwen3_14b_instruct"),
    metric_logger=MetricLoggerConfig(
        _component_="torchtune.training.metric_logging.StdoutLogger"
    ),
    optimizer=OptimizerConfig(_component_="torch.optim.AdamW"),
    checkpointer=CheckpointerConfig(
        _component_="torchtune.training.FullModelHFCheckpointer",  # type: ignore
        model_type="QWEN3",
        checkpoint_dir=checkpoint_dir,  # type: ignore
        checkpoint_files=checkpoint_files,  # type: ignore
    ),
    output_dir=os.path.abspath("../.art/temporal-clue/models/002"),
    enable_activation_checkpointing=True,
    enable_activation_offloading=True
)

recipe = FullFinetuneRecipeDistributed(cfg=cfg)  # type: ignore
recipe.setup(cfg=cfg)

checkpoint_dir /home/ubuntu/.cache/huggingface/hub/models--Qwen--Qwen3-14B/snapshots/8268fe3026cb304910457689366670e803a6fd56
checkpoint_files ['model-00001-of-00008.safetensors', 'model-00002-of-00008.safetensors', 'model-00003-of-00008.safetensors', 'model-00004-of-00008.safetensors', 'model-00005-of-00008.safetensors', 'model-00006-of-00008.safetensors', 'model-00007-of-00008.safetensors', 'model-00008-of-00008.safetensors']


INFO:torchtune.utils._logging:Distributed training is enabled. Instantiating model and loading checkpoint on Rank 0 ...
INFO:torchtune.utils._logging:Instantiating model and loading checkpoint took 4.06 secs
INFO:torchtune.utils._logging:Memory stats after model init:
	GPU peak memory active: 27.66 GiB
	GPU peak memory alloc: 27.66 GiB
	GPU peak memory reserved: 27.66 GiB
INFO:torchtune.utils._logging:Optimizer is initialized.


In [5]:
micro_batches, batch = recipe._get_micro_batches(curr_epoch=0)
inputs = micro_batches[0]

In [6]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-14B")

In [None]:
print("".join(tokenizer.decode(token_id) for token_id in inputs["tokens"][0]))

In [8]:
end_index = inputs["tokens"][0].tolist().index(151645, 3190)
end_index

22470

In [9]:
for key in inputs:
    inputs[key] = inputs[key][:, :end_index]

In [10]:
dict(zip(range(-101, -1), [tokenizer.decode(token_id) for token_id in inputs["tokens"][0]][-100:]))

{-101: 'Studio',
 -100: ' (',
 -99: '0',
 -98: '3',
 -97: ')**',
 -96: ' at',
 -95: ' **',
 -94: '1',
 -93: '1',
 -92: ':',
 -91: '3',
 -90: '0',
 -89: ' PM',
 -88: '**',
 -87: ').\n',
 -86: '-',
 -85: ' So',
 -84: ',',
 -83: ' **',
 -82: 'Studio',
 -81: ' (',
 -80: '0',
 -79: '3',
 -78: ')**',
 -77: ' is',
 -76: ' the',
 -75: ' correct',
 -74: ' room',
 -73: ',',
 -72: ' assuming',
 -71: ' Colonel',
 -70: ' Must',
 -69: 'ard',
 -68: ' moved',
 -67: ' from',
 -66: ' the',
 -65: ' **',
 -64: 'G',
 -63: 'aze',
 -62: 'bo',
 -61: ' (',
 -60: '0',
 -59: '4',
 -58: ')**',
 -57: ' to',
 -56: ' the',
 -55: ' **',
 -54: 'Studio',
 -53: ' (',
 -52: '0',
 -51: '3',
 -50: ')**',
 -49: ' by',
 -48: ' **',
 -47: '1',
 -46: '1',
 -45: ':',
 -44: '0',
 -43: '0',
 -42: ' PM',
 -41: '**',
 -40: '.\n\n',
 -39: '---\n\n',
 -38: '###',
 -37: ' Final',
 -36: ' Answers',
 -35: ':\n',
 -34: 'A',
 -33: '.',
 -32: ' Miss',
 -31: ' Scarlet',
 -30: '  \n',
 -29: 'B',
 -28: '.',
 -27: ' Rev',
 -26: 'olver',
 -25: 

In [11]:
from torchtune.utils import batch_to_device

batch_to_device(inputs, device=recipe._device) # type: ignore

In [12]:
import torch
from torch.nn.attention.flex_attention import create_block_mask, BlockMask

def make_block_mask(
    group_ids: torch.Tensor,  # [B, S]  int32/64
    parent_ids: torch.Tensor,  # [B, S]  int32/64
    block_size: int = 256,  # Reduced from 128 to 64 to avoid OOM
) -> BlockMask:
    """
    FlexAttention equivalent of

        causal_mask & (group_ids[q]==group_ids[kv]  |  parent_ids[kv]==group_ids[q])

    * group_ids : id shared by all tokens of the same sampled trajectory
    * parent_ids: id identifying the prompt that produced each token
    """
    B, S = group_ids.shape  # batch, sequence length

    # the closure captures the two id tensors; that's fine for torch.compile
    def mask_mod(b, h, q_idx, kv_idx):
        # causal constraint
        causal = kv_idx <= q_idx

        same_group = group_ids[b, q_idx] == group_ids[b, kv_idx]
        prompt_link = parent_ids[b, q_idx] == group_ids[b, kv_idx]

        return causal & (same_group | prompt_link)

    return create_block_mask(
        mask_mod,
        B=B,
        H=None,
        Q_LEN=S,
        KV_LEN=S,
        BLOCK_SIZE=block_size,
    )

block_mask = make_block_mask(
    group_ids=inputs["group_ids"],
    parent_ids=inputs["parent_ids"],
)

In [13]:
def calculate_tensor_memory_gb(
    shape: tuple[int, ...],
    dtype: torch.dtype = torch.float32,
    binary_gb: bool = True,
) -> float:
    """
    Calculate the memory usage of a tensor in gigabytes without creating it.

    Args:
        shape: The shape/dimensions of the tensor (e.g., (1024, 1024, 512))
        dtype: The data type of the tensor (e.g., torch.float32, torch.float16)
        binary_gb: If True, use binary GB (1024^3), if False use decimal GB (10^9)

    Returns:
        Memory usage in gigabytes
    """
    # Calculate total number of elements
    total_elements = 1
    for dim in shape:
        total_elements *= dim

    # Get bytes per element based on dtype
    bytes_per_element = torch.tensor([], dtype=dtype).element_size()

    # Calculate total bytes
    total_bytes = total_elements * bytes_per_element

    # Convert to GB
    if binary_gb:
        gb_divisor = 1024**3  # 1 GiB = 1024^3 bytes
    else:
        gb_divisor = 10**9  # 1 GB = 10^9 bytes

    return total_bytes / gb_divisor

calculate_tensor_memory_gb((1, inputs["tokens"].shape[1], inputs["tokens"].shape[1]), dtype=torch.bfloat16)

0.9404512122273445

In [14]:
attn_bias = torch.zeros((1, inputs["tokens"].shape[1], inputs["tokens"].shape[1]), dtype=torch.bfloat16, device=recipe._device, requires_grad=True)
attn_bias

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0',
       dtype=torch.bfloat16, requires_grad=True)

In [15]:
score_mod = lambda score, b, h, q_idx, kv_idx: score + attn_bias[b, q_idx, kv_idx]

In [16]:
# score_mods_list = [score_mod if i % 2 == 0 else None for i in range(len(recipe._model.layers))]
# score_mods = deque(score_mods_list)

In [17]:
# score_mods = deque([score_mod] + [None] * (len(recipe._model.layers) - 1))
# score_mods = deque([None] * (len(recipe._model.layers) - 1) + [score_mod])

In [18]:
with recipe.activations_handling_ctx:
    hidden_states = recipe._model(
        tokens=inputs["tokens"],
        mask=block_mask,
        input_pos=inputs["input_pos"],
    )

del block_mask

DEBUG:torchtune.utils._logging:Using flex attention for attention computation since a BlockMask was passed in.


In [19]:
from art.unsloth.train import shift_tensor
from typing import cast

assistant_mask = shift_tensor(inputs["assistant_mask"], False)
hidden_states = hidden_states[assistant_mask]
next_token_ids = shift_tensor(inputs["tokens"], 0)[assistant_mask]
chunk_size = batch.dev_config.get("logprob_calculation_chunk_size", 1024)
loss = torch.tensor(0.0, device=recipe._device)
all_new_logprobs = []
for i in range(0, hidden_states.size(0), chunk_size):
    chunk_end = min(i + chunk_size, hidden_states.size(0))
    # [chunk_size, hidden_size] @ [hidden_size, vocab_size]
    logits = cast(
        torch.Tensor, recipe._model.output(hidden_states[i:chunk_end,:])
    )  # [chunk_size, vocab_size]
    selected_logits = torch.gather(
        logits, dim=-1, index=next_token_ids[i:chunk_end].unsqueeze(-1)
    ).squeeze(
        -1
    )  # [chunk_size]
    logsumexp = torch.logsumexp(logits, dim=-1)  # [chunk_size]
    new_logprobs = selected_logits - logsumexp
    all_new_logprobs.append(new_logprobs)
    # loss += new_logprobs.sum()
    loss += logits.sum()
    # loss += torch.exp(new_logprobs).sum()
    del logits, selected_logits, logsumexp, new_logprobs

new_logprobs = torch.cat(all_new_logprobs, dim=0)
new_logprobs.shape

INFO 07-21 18:49:51 [__init__.py:244] Automatically detected platform cuda.


torch.Size([19276])

In [20]:
# add L1 loss for attn_bias
loss += torch.abs(attn_bias).sum() * 10

In [21]:
import gc
import torch

for _ in range(3):
    gc.collect()
    torch.cuda.empty_cache()

In [22]:
# score_mods = deque([None] * (len(recipe._model.layers) - 1) + [score_mod])
# score_mods = deque(reversed(score_mods_list))
result = torch.autograd.grad(loss, attn_bias)

In [25]:
from illustrate import illustrate

# -44: Professor Plum
# -39: Revolver
# -34: Gazebo
# -18: Ambition
# -13: Betrayal
# -7: Conservatory
# -2: Lounge

print(
    illustrate(
        list(
            zip(
                list(tokenizer.decode(token_id) for token_id in inputs["tokens"][0]),
                (result[0][0][-22]).tolist(),
                # result[0].sum(dim=1).squeeze(0).tolist(),
            )
        ),
        gradient="one-dark-simple",
    )
)

[90m╔═════════════════════════════════════════════════════════════╗[0m
[90m║ Stats: scale=20352.000, min=-20352.000, max=15744.000, mean=-0.004 ║[0m
[90m╚═════════════════════════════════════════════════════════════╝[0m

[38;2;198;71;66m<|im_start|>[0m[38;2;171;176;184muser[0m[38;2;164;174;188m
[0m[38;2;168;170;184mOn[0m[38;2;171;175;184m a[0m[38;2;171;176;185m dark[0m[38;2;171;176;185m winter[0m[38;2;172;176;184m night[0m[38;2;171;175;184m,[0m[38;2;171;175;184m wealthy[0m[38;2;171;175;184m and[0m[38;2;171;175;184m en[0m[38;2;172;176;184migmatic[0m[38;2;170;175;185m Mr[0m[38;2;170;173;184m.[0m[38;2;171;176;184m John[0m[38;2;170;174;184m Q[0m[38;2;172;176;184m.[0m[38;2;167;162;183m Bod[0m[38;2;171;176;184mdy[0m[38;2;172;176;184m hosted[0m[38;2;171;175;184m a[0m[38;2;171;176;185m small[0m[38;2;171;176;184m,[0m[38;2;172;176;184m but[0m[38;2;169;172;184m lavish[0m[38;2;171;176;184m,[0m[38;2;168;175;186m dinner[0m[38;2;171;176;1

In [None]:
from illustrate import illustrate

tokens = [tokenizer.decode(token_id) for token_id in inputs["tokens"][0]]
logprobs = torch.exp(new_logprobs).tolist()
print(
    illustrate(
        list(
            zip(
                tokens,
                [0] * (len(tokens) - len(logprobs)) + logprobs,
            )
        ),
        gradient="one-dark-simple",
    )
)

In [None]:
new_logprobs

In [None]:
([0] * (len(tokens) - len(logprobs)) + logprobs)[-3]

In [None]:
result[0].sum(dim=1).squeeze(0).tolist()

In [None]:
[tokenizer.decode(token_id) for token_id in inputs["tokens"][0]][-100:]

In [None]:
loss.backward()

In [None]:
recipe._optimizer.step()
recipe._optimizer.zero_grad()

In [None]:
hidden_states.shape

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-14B")

In [None]:
recipe._model.layers