In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [3]:
import torch
from torch.nn.attention import flex_attention
from torchtune.modules import attention_utils


score_mod: flex_attention._score_mod_signature | None = None


# We cannot do nested compile, but flex attention only has perf benefits
# when compiled. To insulate it from the compiler, we wrap it with
# compiler.disable so that it can be used regardless of whether the model
# is compiled or not, and flex attention always remains compiled.
@torch.compiler.disable(recursive=False)
def compile_friendly_flex_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    block_mask: flex_attention.BlockMask,
) -> torch.Tensor:
    print("Using score_mod:", score_mod)
    return attention_utils.flex_attention_compiled(q, k, v, score_mod=score_mod, block_mask=block_mask)  # type: ignore


attention_utils.compile_friendly_flex_attention = compile_friendly_flex_attention

In [4]:
import art
from art.torchtune.config import (
    ModelConfig,
    MetricLoggerConfig,
    OptimizerConfig,
    CheckpointerConfig,
    RecipeConfig,
)
from art.torchtune.recipe import FullFinetuneRecipeDistributed
import asyncio
import glob
import os

process = await asyncio.subprocess.create_subprocess_exec(
    "huggingface-cli",
    "download",
    "Qwen/Qwen3-14B",
    stdout=asyncio.subprocess.PIPE,
    stderr=asyncio.subprocess.PIPE,
)
stdout, _ = await process.communicate()
checkpoint_dir = stdout.decode("utf-8").splitlines()[-1].strip()
safetensor_files = glob.glob(f"{checkpoint_dir}/*.safetensors")
checkpoint_files = sorted(os.path.basename(f) for f in safetensor_files)
print("checkpoint_dir", checkpoint_dir)
print("checkpoint_files", checkpoint_files)

cfg = RecipeConfig(
    model=ModelConfig(_component_="torchtune.models.qwen3.qwen3_14b_instruct"),
    metric_logger=MetricLoggerConfig(
        _component_="torchtune.training.metric_logging.StdoutLogger"
    ),
    optimizer=OptimizerConfig(_component_="torch.optim.AdamW"),
    checkpointer=CheckpointerConfig(
        _component_="torchtune.training.FullModelHFCheckpointer",  # type: ignore
        model_type="QWEN3",
        checkpoint_dir=checkpoint_dir,  # type: ignore
        checkpoint_files=checkpoint_files,  # type: ignore
    ),
    output_dir=os.path.abspath("../.art/temporal-clue/models/002"),
    enable_activation_checkpointing=True,
    enable_activation_offloading=True
)

recipe = FullFinetuneRecipeDistributed(cfg=cfg)  # type: ignore
recipe.setup(cfg=cfg)

checkpoint_dir /home/ubuntu/.cache/huggingface/hub/models--Qwen--Qwen3-14B/snapshots/8268fe3026cb304910457689366670e803a6fd56
checkpoint_files ['model-00001-of-00008.safetensors', 'model-00002-of-00008.safetensors', 'model-00003-of-00008.safetensors', 'model-00004-of-00008.safetensors', 'model-00005-of-00008.safetensors', 'model-00006-of-00008.safetensors', 'model-00007-of-00008.safetensors', 'model-00008-of-00008.safetensors']


INFO:torchtune.utils._logging:Distributed training is enabled. Instantiating model and loading checkpoint on Rank 0 ...
INFO:torchtune.utils._logging:Instantiating model and loading checkpoint took 0.24 secs
INFO:torchtune.utils._logging:Memory stats after model init:
	GPU peak memory active: 27.66 GiB
	GPU peak memory alloc: 27.66 GiB
	GPU peak memory reserved: 27.66 GiB
INFO:torchtune.utils._logging:Optimizer is initialized.


In [5]:
micro_batches, batch = recipe._get_micro_batches(curr_epoch=0)
inputs = micro_batches[0]

In [6]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-14B")

In [7]:
print("".join(tokenizer.decode(token_id) for token_id in inputs["tokens"][0]))

<|im_start|>user
On a dark winter night, wealthy and enigmatic Mr. John Q. Boddy hosted a small, but lavish, dinner party for some of his closest associates. However, the night ended in tragedy when Mr. Boddy was found dead in one of the rooms of Tudor Mansion in the early hours of the morning. The following persons of interest have been identified as suspects:

• Miss Peach
• Mr. Green
• Mrs. Peacock
• Mrs. White
• Monsieur Brunette
• Miss Scarlet
• Professor Plum
• Colonel Mustard

And the following weapons were found on the premises:

• Poison
• Wrench
• Knife
• Rope
• Horseshoe
• Revolver
• Lead Pipe
• Candlestick

The murder could only have occured in one of the following rooms:

01. Study
02. Carriage House
03. Conservatory
04. Lounge
05. Hall
06. Courtyard
07. Billiard Room
08. Cloak Room
09. Gazebo
10. Kitchen

The rooms are laid out as follows:

  NN NN NN  
W 01|02|03 E
W 04|05|06 E
W 07|08|09 E
W 10|-|- E
  SS SS SS  

The exact time of the murder is a bit uncertain, but it 

In [9]:
end_index = inputs["tokens"][0].tolist().index(151645, 1910)
end_index

24265

In [10]:
for key in inputs:
    inputs[key] = inputs[key][:, :end_index]

In [11]:
from torchtune.utils import batch_to_device

batch_to_device(inputs, device=recipe._device) # type: ignore

In [12]:
import torch
from torch.nn.attention.flex_attention import create_block_mask, BlockMask

def make_block_mask(
    group_ids: torch.Tensor,  # [B, S]  int32/64
    parent_ids: torch.Tensor,  # [B, S]  int32/64
    block_size: int = 128,  # Reduced from 128 to 64 to avoid OOM
) -> BlockMask:
    """
    FlexAttention equivalent of

        causal_mask & (group_ids[q]==group_ids[kv]  |  parent_ids[kv]==group_ids[q])

    * group_ids : id shared by all tokens of the same sampled trajectory
    * parent_ids: id identifying the prompt that produced each token
    """
    B, S = group_ids.shape  # batch, sequence length

    # the closure captures the two id tensors; that's fine for torch.compile
    def mask_mod(b, h, q_idx, kv_idx):
        # causal constraint
        causal = kv_idx <= q_idx

        same_group = group_ids[b, q_idx] == group_ids[b, kv_idx]
        prompt_link = parent_ids[b, q_idx] == group_ids[b, kv_idx]

        return causal & (same_group | prompt_link)

    return create_block_mask(
        mask_mod,
        B=B,
        H=None,
        Q_LEN=S,
        KV_LEN=S,
        BLOCK_SIZE=block_size,
    )

block_mask = make_block_mask(
    group_ids=inputs["group_ids"],
    parent_ids=inputs["parent_ids"],
)

In [13]:
inputs["tokens"].shape

torch.Size([1, 24265])

In [14]:
def calculate_tensor_memory_gb(
    shape: tuple[int, ...],
    dtype: torch.dtype = torch.float32,
    binary_gb: bool = True,
) -> float:
    """
    Calculate the memory usage of a tensor in gigabytes without creating it.

    Args:
        shape: The shape/dimensions of the tensor (e.g., (1024, 1024, 512))
        dtype: The data type of the tensor (e.g., torch.float32, torch.float16)
        binary_gb: If True, use binary GB (1024^3), if False use decimal GB (10^9)

    Returns:
        Memory usage in gigabytes
    """
    # Calculate total number of elements
    total_elements = 1
    for dim in shape:
        total_elements *= dim

    # Get bytes per element based on dtype
    bytes_per_element = torch.tensor([], dtype=dtype).element_size()

    # Calculate total bytes
    total_bytes = total_elements * bytes_per_element

    # Convert to GB
    if binary_gb:
        gb_divisor = 1024**3  # 1 GiB = 1024^3 bytes
    else:
        gb_divisor = 10**9  # 1 GB = 10^9 bytes

    return total_bytes / gb_divisor

calculate_tensor_memory_gb((1, 24265, 24265), dtype=torch.bfloat16)

1.0967072565108538

In [15]:
attn_bias = torch.zeros((1, 24265, 24265), dtype=torch.bfloat16, device=recipe._device, requires_grad=True)
attn_bias

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0',
       dtype=torch.bfloat16, requires_grad=True)

In [16]:
score_mod = lambda score, b, h, q_idx, kv_idx: score + attn_bias[b, q_idx, kv_idx]

In [17]:
with recipe.activations_handling_ctx:
    hidden_states = recipe._model(
        tokens=inputs["tokens"],
        # mask=mask,
        mask=block_mask,
        input_pos=inputs["input_pos"],
    )

DEBUG:torchtune.utils._logging:Using flex attention for attention computation since a BlockMask was passed in.


Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using scor

In [18]:
from art.unsloth.train import shift_tensor
from typing import cast

assistant_mask = shift_tensor(inputs["assistant_mask"], False)
hidden_states = hidden_states[assistant_mask]
next_token_ids = shift_tensor(inputs["tokens"], 0)[assistant_mask]
chunk_size = batch.dev_config.get("logprob_calculation_chunk_size", 1024)
loss = torch.tensor(0.0, device=recipe._device)
for i in range(0, hidden_states.size(0), chunk_size):
    chunk_end = min(i + chunk_size, hidden_states.size(0))
    # [chunk_size, hidden_size] @ [hidden_size, vocab_size]
    logits = cast(
        torch.Tensor, recipe._model.output(hidden_states[i:chunk_end])
    )  # [chunk_size, vocab_size]
    selected_logits = torch.gather(
        logits, dim=-1, index=next_token_ids[i:chunk_end].unsqueeze(-1)
    ).squeeze(
        -1
    )  # [chunk_size]
    logsumexp = torch.logsumexp(logits, dim=-1)  # [chunk_size]
    new_logprobs = selected_logits - logsumexp
    loss += new_logprobs.sum()
    del logits, selected_logits, logsumexp, new_logprobs

INFO 07-19 03:15:51 [__init__.py:244] Automatically detected platform cuda.


In [19]:
result = torch.autograd.grad(loss, attn_bias)

Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using score_mod: <function <lambda> at 0x72cfd8b7d120>
Using scor

In [26]:
result[0]

tensor([[[ 2.6822e-07,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-6.0312e+00,  6.1250e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-6.5625e+00, -1.2688e+01,  1.9125e+01,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...,
         [ 9.5367e-05,  1.1349e-04, -2.1458e-05,  ...,  2.5988e-05,
           0.0000e+00,  0.0000e+00],
         [ 1.0872e-04, -4.7922e-05, -4.3392e-05,  ...,  7.0095e-05,
          -9.2983e-05,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]]], device='cuda:0', dtype=torch.bfloat16)

In [42]:
result[0][0].std()

tensor(0.0023, device='cuda:0', dtype=torch.bfloat16)

In [28]:
[tokenizer.decode(token_id) for token_id in inputs["tokens"][0]][-100:]

[' **',
 'H',
 '.',
 ' Lounge',
 '**',
 ' is',
 ' where',
 ' the',
 ' Poison',
 ' was',
 ' at',
 ' ',
 '1',
 '2',
 ':',
 '3',
 '0',
 ' AM',
 ',',
 ' based',
 ' on',
 ' the',
 ' clue',
 ' about',
 ' it',
 ' being',
 ' north',
 ' of',
 ' Mrs',
 '.',
 ' Pe',
 'acock',
 '’s',
 ' room',
 ' (',
 'Bill',
 'iard',
 ' Room',
 ' at',
 ' ',
 '1',
 '2',
 ':',
 '0',
 '0',
 ' AM',
 ').',
 '  \n\n',
 '---',
 '  \n',
 '**',
 'Final',
 ' Answers',
 ':**',
 '  \n',
 'A',
 '.',
 ' Professor',
 ' Plum',
 '  \n',
 'B',
 '.',
 ' Rev',
 'olver',
 '  \n',
 'C',
 '.',
 ' G',
 'aze',
 'bo',
 '  \n',
 'D',
 '.',
 ' ',
 '1',
 '2',
 ':',
 '3',
 '0',
 ' AM',
 '  \n',
 'E',
 '.',
 ' Amb',
 'ition',
 '  \n',
 'F',
 '.',
 ' Bet',
 'ray',
 'al',
 '  \n',
 'G',
 '.',
 ' Conserv',
 'atory',
 '  \n',
 'H',
 '.',
 ' Lounge']

In [None]:
loss.backward()

In [None]:
recipe._optimizer.zero_grad()

In [None]:
import torch

torch.cuda.empty_cache()

In [None]:
hidden_states.shape

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-14B")

In [None]:
recipe._model.layers