From 1eb7aba5c4a24449e37b30923a31b887c5152a37 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 31 Jan 2025 17:10:52 -0800 Subject: [PATCH 01/14] add repro script --- debugging/repros/repro.py | 209 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 debugging/repros/repro.py diff --git a/debugging/repros/repro.py b/debugging/repros/repro.py new file mode 100644 index 0000000000..dcd8941069 --- /dev/null +++ b/debugging/repros/repro.py @@ -0,0 +1,209 @@ +import logging +import os +import socket +from argparse import ArgumentParser, Namespace +from collections import defaultdict +from datetime import datetime, timedelta + +import torch +import torch.distributed as dist +from torch import nn +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper as ptd_checkpoint_wrapper, +) +from torch.utils.checkpoint import ( + CheckpointPolicy, + create_selective_checkpoint_contexts, +) +from torchao.float8.config import Float8LinearRecipeName, recipe_name_to_linear_config +from torchao.float8.float8_linear_utils import convert_to_float8_training + +TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S" + +# Keep a max of 100,000 alloc/free events in the recorded history +# leading up to the snapshot. +MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT: int = 100000 + +# logging +logging.basicConfig( + format="%(levelname)s:%(asctime)s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", +) +logger: logging.Logger = logging.getLogger(__name__) +logger.setLevel(level=logging.INFO) + + +class LinearModel(nn.Module): + def __init__(self, num_layers: int = 1): + super(LinearModel, self).__init__() + self.layers = nn.Sequential(*[nn.Linear(4096, 4096) for _ in range(num_layers)]) + + def forward(self, x): + return self.layers(x) + + +def main(args: Namespace): + try: + model = LinearModel(num_layers=args.num_layers).to(torch.bfloat16).cuda() + x = torch.randn(16, 4096, dtype=torch.bfloat16).cuda() + + # fp8 rowwise quant + if args.float8: + apply_fp8_rowwise_quant(model) + + # selective per op AC + if args.per_op_ac: + model = apply_ac(model) + + # compile + if args.compile: + model = apply_compile(model) + + # FSDP2 (2 GPUs or more required to avoid _scaled_mm error: + # "RuntimeError: Only bf16 high precsion output types are supported for row-wise scaling." + if args.fsdp: + setup_distributed() + apply_fsdp(model) + + # memory profile one fwd+bwd + start_record_memory_history() + + out = model(x) + out.sum().backward() + + # only 1 process should snapshot memory + if not (args.fsdp and dist.get_rank() != 0): + export_memory_snapshot(args.snapshot_file) + + stop_record_memory_history() + finally: + if args.fsdp: + clean_up_distributed() + + +def apply_compile(model: nn.Module): + model = torch.compile(model, fullgraph=True) + logger.info("Compiled model") + return model + + +def apply_ac(model: nn.Module): + """Apply activation checkpointing to the model.""" + model = _apply_ac_to_transformer_block(model) + logger.info(f"Applied selective per op activation checkpointing to the model") + return model + + +def apply_fp8_rowwise_quant(model: nn.Module): + recipe = Float8LinearRecipeName("all_axiswise") + config = recipe_name_to_linear_config(recipe) + convert_to_float8_training(model, config=config) + logger.info("Applied fp8 rowwise quantization to model") + + +def apply_fsdp(model: nn.Module): + fully_shard(model) + logger.info("Applied FSDP2 to model") + + +def _apply_ac_to_transformer_block(module: nn.Module): + _save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + # for low precision training, it's useful to always save + # the result of max(abs(tensor)) + torch.ops.aten.abs.default, + torch.ops.aten.max.default, + } + + def _get_custom_policy(meta): + def _custom_policy(ctx, func, *args, **kwargs): + mode = "recompute" if ctx.is_recompute else "forward" + mm_count_key = f"{mode}_mm_count" + if func == torch.ops.aten.mm.default: + meta[mm_count_key] += 1 + # Saves output of all compute ops, except every second mm + to_save = func in _save_list and not ( + func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 + ) + return ( + CheckpointPolicy.MUST_SAVE + if to_save + else CheckpointPolicy.PREFER_RECOMPUTE + ) + + return _custom_policy + + def selective_checkpointing_context_fn(): + meta = defaultdict(int) + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) + + return ptd_checkpoint_wrapper( + module, + context_fn=selective_checkpointing_context_fn, + preserve_rng_state=False, + ) + + +# memory snapshotting functions from: +# https://pytorch.org/blog/understanding-gpu-memory-1/#appendix-a---resnet50-memory-snapshot-code-example +def start_record_memory_history() -> None: + if not torch.cuda.is_available(): + logger.info("CUDA unavailable. Not recording memory history") + return + + logger.info("Starting snapshot record_memory_history") + torch.cuda.memory._record_memory_history( + max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT + ) + + +def stop_record_memory_history() -> None: + if not torch.cuda.is_available(): + logger.info("CUDA unavailable. Not recording memory history") + return + + logger.info("Stopping snapshot record_memory_history") + torch.cuda.memory._record_memory_history(enabled=None) + + +def export_memory_snapshot(filepath: str) -> None: + if not torch.cuda.is_available(): + logger.info("CUDA unavailable. Not exporting memory snapshot") + return + + try: + logger.info(f"Saving snapshot to local file: {filepath}") + torch.cuda.memory._dump_snapshot(f"{filepath}") + except Exception as e: + logger.error(f"Failed to capture memory snapshot {e}") + return + + +def setup_distributed(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + logger.info("Set up process group") + + +def clean_up_distributed(): + dist.destroy_process_group() + logger.info("Destroyed process group") + + +if __name__ == "__main__": + argparser = ArgumentParser() + argparser.add_argument("--float8", action="store_true") + argparser.add_argument("--fsdp", action="store_true") + argparser.add_argument("--compile", action="store_true") + argparser.add_argument("--per-op-ac", action="store_true") + argparser.add_argument("--num-layers", type=int, default=1) + argparser.add_argument("--snapshot-file", type=str, required=True) + args = argparser.parse_args() + main(args) From 95c0a213f2f96d2551278e7480045d0efcc239d2 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 31 Jan 2025 17:14:32 -0800 Subject: [PATCH 02/14] add sources --- debugging/repros/repro.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/debugging/repros/repro.py b/debugging/repros/repro.py index dcd8941069..2ffdba760f 100644 --- a/debugging/repros/repro.py +++ b/debugging/repros/repro.py @@ -89,9 +89,13 @@ def apply_compile(model: nn.Module): return model +# modified version of per op AC implementation from torchtitan. +# this applies per op selective AC to a model, without assuming it is a transformer model, +# and supports no other AC settings. +# source: https://github.com/pytorch/torchtitan/blob/cca07028e440de6a13189d251c28337bd34256ef/torchtitan/parallelisms/parallelize_llama.py#L288 def apply_ac(model: nn.Module): """Apply activation checkpointing to the model.""" - model = _apply_ac_to_transformer_block(model) + model = _apply_per_op_ac_to_model(model) logger.info(f"Applied selective per op activation checkpointing to the model") return model @@ -108,7 +112,7 @@ def apply_fsdp(model: nn.Module): logger.info("Applied FSDP2 to model") -def _apply_ac_to_transformer_block(module: nn.Module): +def _apply_per_op_ac_to_model(module: nn.Module): _save_list = { torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, From 507e35dfdb6098bdb5d5ef2591603fe34c308de1 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 31 Jan 2025 18:11:20 -0800 Subject: [PATCH 03/14] add ffn and peak memory usage --- debugging/repros/repro.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/debugging/repros/repro.py b/debugging/repros/repro.py index 2ffdba760f..672a6985fd 100644 --- a/debugging/repros/repro.py +++ b/debugging/repros/repro.py @@ -7,6 +7,7 @@ import torch import torch.distributed as dist +import torch.nn.functional as F from torch import nn from torch.distributed._composable.fsdp import fully_shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -44,10 +45,31 @@ def forward(self, x): return self.layers(x) +# Simplified FFN from Llama3 https://github.com/pytorch/torchtitan/blob/cca07028e440de6a13189d251c28337bd34256ef/torchtitan/models/llama/model.py#L217 +class FFN(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + super(FFN, self).__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + def main(args: Namespace): + assert torch.cuda.is_available() try: - model = LinearModel(num_layers=args.num_layers).to(torch.bfloat16).cuda() - x = torch.randn(16, 4096, dtype=torch.bfloat16).cuda() + device = torch.device("cuda") + torch.cuda.reset_peak_memory_stats(device) + + model = FFN(4096, 4 * 4096).to(torch.bfloat16).to(device) + x = torch.randn(16, 4096, dtype=torch.bfloat16).to(device) # fp8 rowwise quant if args.float8: @@ -78,6 +100,9 @@ def main(args: Namespace): export_memory_snapshot(args.snapshot_file) stop_record_memory_history() + + peak_memory = torch.cuda.max_memory_allocated(device) + print(f"Peak GPU memory usage: {peak_memory / (1024 ** 2):.2f} MB") finally: if args.fsdp: clean_up_distributed() From 17af5ec801190367e6414416a59e27e37080f23b Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 31 Jan 2025 18:16:41 -0800 Subject: [PATCH 04/14] record model allocation --- debugging/repros/repro.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/debugging/repros/repro.py b/debugging/repros/repro.py index 672a6985fd..2953085a9e 100644 --- a/debugging/repros/repro.py +++ b/debugging/repros/repro.py @@ -20,8 +20,6 @@ from torchao.float8.config import Float8LinearRecipeName, recipe_name_to_linear_config from torchao.float8.float8_linear_utils import convert_to_float8_training -TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S" - # Keep a max of 100,000 alloc/free events in the recorded history # leading up to the snapshot. MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT: int = 100000 @@ -68,6 +66,10 @@ def main(args: Namespace): device = torch.device("cuda") torch.cuda.reset_peak_memory_stats(device) + # start memory profile + start_record_memory_history() + + # allocate model and inputs model = FFN(4096, 4 * 4096).to(torch.bfloat16).to(device) x = torch.randn(16, 4096, dtype=torch.bfloat16).to(device) @@ -89,9 +91,7 @@ def main(args: Namespace): setup_distributed() apply_fsdp(model) - # memory profile one fwd+bwd - start_record_memory_history() - + # one fwd + backward out = model(x) out.sum().backward() From dbd679140ee35391fcb4ae46acdd2f662fdc3965 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 31 Jan 2025 18:38:53 -0800 Subject: [PATCH 05/14] configurable model type --- debugging/repros/repro.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/debugging/repros/repro.py b/debugging/repros/repro.py index 2953085a9e..2365714100 100644 --- a/debugging/repros/repro.py +++ b/debugging/repros/repro.py @@ -70,7 +70,12 @@ def main(args: Namespace): start_record_memory_history() # allocate model and inputs - model = FFN(4096, 4 * 4096).to(torch.bfloat16).to(device) + if args.model_type == "linear": + model = LinearModel(args.num_layers).to(torch.bfloat16).to(device) + elif args.model_type == "ffn": + model = FFN(4096, 4 * 4096).to(torch.bfloat16).to(device) + else: + raise ValueError(f"invalid model type: {args.model_type}") x = torch.randn(16, 4096, dtype=torch.bfloat16).to(device) # fp8 rowwise quant @@ -233,6 +238,9 @@ def clean_up_distributed(): argparser.add_argument("--compile", action="store_true") argparser.add_argument("--per-op-ac", action="store_true") argparser.add_argument("--num-layers", type=int, default=1) - argparser.add_argument("--snapshot-file", type=str, required=True) + argparser.add_argument("--model-type", type=str, required=True) + argparser.add_argument( + "--snapshot-file", type=str, required=True, help="[linear,ffn]" + ) args = argparser.parse_args() main(args) From e9f372e52e0a000ec271d4afbbead1ab27f8b556 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 31 Jan 2025 19:30:39 -0800 Subject: [PATCH 06/14] add attention repro --- debugging/repros/attention.py | 205 ++++++++++++++++++++++++++++++++++ debugging/repros/repro.py | 34 ++++-- 2 files changed, 232 insertions(+), 7 deletions(-) create mode 100644 debugging/repros/attention.py diff --git a/debugging/repros/attention.py b/debugging/repros/attention.py new file mode 100644 index 0000000000..44aec798e3 --- /dev/null +++ b/debugging/repros/attention.py @@ -0,0 +1,205 @@ +import torch +import torch.nn.functional as F +from torch import nn + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + seqlen = x.shape[1] + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +# Simplified FFN from Llama3 https://github.com/pytorch/torchtitan/blob/cca07028e440de6a13189d251c28337bd34256ef/torchtitan/models/llama/model.py#L128 +class Attention(nn.Module): + """ + Multi-head attention module. + + Attributes: + n_kv_heads (int): Number of key and value heads. + n_heads (int): Number of query heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + + """ + + def __init__( + self, + head_dim: int, + num_heads: int, + num_kv_heads: int, + max_seq_len: int = 1024, + rope_theta: int = 10000, + ): + super().__init__() + self.n_heads = num_heads + self.n_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = head_dim // num_heads + + self.wq = nn.Linear(head_dim, num_heads * self.head_dim, bias=False) + self.wk = nn.Linear(head_dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(head_dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(num_heads * self.head_dim, head_dim, bias=False) + self.max_seq_len = max_seq_len + self.rope_theta = rope_theta + # TODO persistent should be set to false, since this buffer can be recomputed. + # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, + # compile or pipeline-tracer will not correctly handle non-persistent buffers, + # so we need to fix that. (2) if we initialize pipeline-parallel models from + # a seed checkpoint rather than calling init_weights, we need freqs_cis to be + # initialized by the checkpoint, or we need to add a separate initializer for + # just the non-persistent buffers that is called after loading checkpoints. + self.register_buffer( + "freqs_cis", + self._precompute_freqs_cis(head_dim, num_heads, max_seq_len, rope_theta), + persistent=True, + ) + + def _precompute_freqs_cis( + self, + head_dim: int, + num_heads: int, + max_seq_len: int, + rope_theta: int, + ): + return precompute_freqs_cis( + head_dim // num_heads, + # Need to compute until at least the max token limit for generation + # TODO: explain in docs/composability.md why we removed the 2x + # relaxing in our CP enablement PR + max_seq_len, + rope_theta, + ) + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward(self, x: torch.Tensor): + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis) + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + # we use casual mask for training + output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) diff --git a/debugging/repros/repro.py b/debugging/repros/repro.py index 2365714100..7976f44b10 100644 --- a/debugging/repros/repro.py +++ b/debugging/repros/repro.py @@ -8,6 +8,8 @@ import torch import torch.distributed as dist import torch.nn.functional as F + +from attention import Attention, precompute_freqs_cis from torch import nn from torch.distributed._composable.fsdp import fully_shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -73,10 +75,19 @@ def main(args: Namespace): if args.model_type == "linear": model = LinearModel(args.num_layers).to(torch.bfloat16).to(device) elif args.model_type == "ffn": - model = FFN(4096, 4 * 4096).to(torch.bfloat16).to(device) + dim = 4096 + hidden_dim = 4 * dim + model = FFN(dim, hidden_dim).to(torch.bfloat16).to(device) + elif args.model_type == "attn": + head_dim = 4096 + heads = 4 + kv_heads = 4 + model = Attention(head_dim, heads, kv_heads).to(torch.bfloat16).to(device) else: - raise ValueError(f"invalid model type: {args.model_type}") - x = torch.randn(16, 4096, dtype=torch.bfloat16).to(device) + raise ValueError( + f"invalid model type: {args.model_type} (must be one of: linear,ffn,attn)" + ) + x = torch.randn(1, 16, 4096, dtype=torch.bfloat16).to(device) # fp8 rowwise quant if args.float8: @@ -124,9 +135,18 @@ def apply_compile(model: nn.Module): # and supports no other AC settings. # source: https://github.com/pytorch/torchtitan/blob/cca07028e440de6a13189d251c28337bd34256ef/torchtitan/parallelisms/parallelize_llama.py#L288 def apply_ac(model: nn.Module): - """Apply activation checkpointing to the model.""" - model = _apply_per_op_ac_to_model(model) - logger.info(f"Applied selective per op activation checkpointing to the model") + if hasattr(model, "layers"): + for layer_id, layer in model.layers.named_children(): + layer = _apply_per_op_ac_to_model(transformer_block, ac_config) + model.layers.register_module(layer_id, layer) + logger.info( + f"Applied selective per op activation checkpoitning to multi-layer model" + ) + else: + model = _apply_per_op_ac_to_model(model) + logger.info( + f"Applied selective per op activation checkpointing to single layer model" + ) return model @@ -240,7 +260,7 @@ def clean_up_distributed(): argparser.add_argument("--num-layers", type=int, default=1) argparser.add_argument("--model-type", type=str, required=True) argparser.add_argument( - "--snapshot-file", type=str, required=True, help="[linear,ffn]" + "--snapshot-file", type=str, required=True, help="[linear,ffn,attn]" ) args = argparser.parse_args() main(args) From 96cee89aac35d7ec85d47d048d0351cf4f6023ec Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Sat, 1 Feb 2025 15:31:10 -0800 Subject: [PATCH 07/14] add training mode --- debugging/repros/attention.py | 205 ---------------------- debugging/repros/repro.py | 319 ++++++++++++++++++++++++++++++---- 2 files changed, 283 insertions(+), 241 deletions(-) delete mode 100644 debugging/repros/attention.py diff --git a/debugging/repros/attention.py b/debugging/repros/attention.py deleted file mode 100644 index 44aec798e3..0000000000 --- a/debugging/repros/attention.py +++ /dev/null @@ -1,205 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import nn - - -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: - """ - Precompute the frequency tensor for complex exponentials (cis) with given dimensions. - - This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' - and the end index 'end'. The 'theta' parameter scales the frequencies. - The returned tensor contains complex values in complex64 data type. - - Args: - dim (int): Dimension of the frequency tensor. - end (int): End index for precomputing frequencies. - theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. - - Returns: - torch.Tensor: Precomputed frequency tensor with complex exponentials. - """ - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) - freqs = torch.outer(t, freqs).float() - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - """ - Reshape frequency tensor for broadcasting it with another tensor. - - This function reshapes the frequency tensor to have the same shape as the target tensor 'x' - for the purpose of broadcasting the frequency tensor during element-wise operations. - - The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), - and the first seqlen elements will be sliced, but dim must match x. - - Args: - freqs_cis (torch.Tensor): Frequency tensor to be reshaped. - x (torch.Tensor): Target tensor for broadcasting compatibility. - - Returns: - torch.Tensor: Reshaped frequency tensor. - """ - ndim = x.ndim - assert 0 <= 1 < ndim - seqlen = x.shape[1] - freqs_cis = freqs_cis[0:seqlen] - assert freqs_cis.shape == (seqlen, x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary embeddings to input tensors using the given frequency tensor. - - This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided - frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor - is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are - returned as real tensors. - - Args: - xq (torch.Tensor): Query tensor to apply rotary embeddings. - xk (torch.Tensor): Key tensor to apply rotary embeddings. - freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. - """ - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" - bs, slen, n_kv_heads, head_dim = x.shape - if n_rep == 1: - return x - return ( - torch.unsqueeze(x, dim=3) - .expand(bs, slen, n_kv_heads, n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * n_rep, head_dim) - ) - - -# Simplified FFN from Llama3 https://github.com/pytorch/torchtitan/blob/cca07028e440de6a13189d251c28337bd34256ef/torchtitan/models/llama/model.py#L128 -class Attention(nn.Module): - """ - Multi-head attention module. - - Attributes: - n_kv_heads (int): Number of key and value heads. - n_heads (int): Number of query heads. - n_rep (int): Number of repetitions for local heads. - head_dim (int): Dimension size of each attention head. - wq (Linear): Linear transformation for queries. - wk (Linear): Linear transformation for keys. - wv (Linear): Linear transformation for values. - wo (Linear): Linear transformation for output. - - """ - - def __init__( - self, - head_dim: int, - num_heads: int, - num_kv_heads: int, - max_seq_len: int = 1024, - rope_theta: int = 10000, - ): - super().__init__() - self.n_heads = num_heads - self.n_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.n_rep = self.n_heads // self.n_kv_heads - self.head_dim = head_dim // num_heads - - self.wq = nn.Linear(head_dim, num_heads * self.head_dim, bias=False) - self.wk = nn.Linear(head_dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(head_dim, self.n_kv_heads * self.head_dim, bias=False) - self.wo = nn.Linear(num_heads * self.head_dim, head_dim, bias=False) - self.max_seq_len = max_seq_len - self.rope_theta = rope_theta - # TODO persistent should be set to false, since this buffer can be recomputed. - # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, - # compile or pipeline-tracer will not correctly handle non-persistent buffers, - # so we need to fix that. (2) if we initialize pipeline-parallel models from - # a seed checkpoint rather than calling init_weights, we need freqs_cis to be - # initialized by the checkpoint, or we need to add a separate initializer for - # just the non-persistent buffers that is called after loading checkpoints. - self.register_buffer( - "freqs_cis", - self._precompute_freqs_cis(head_dim, num_heads, max_seq_len, rope_theta), - persistent=True, - ) - - def _precompute_freqs_cis( - self, - head_dim: int, - num_heads: int, - max_seq_len: int, - rope_theta: int, - ): - return precompute_freqs_cis( - head_dim // num_heads, - # Need to compute until at least the max token limit for generation - # TODO: explain in docs/composability.md why we removed the 2x - # relaxing in our CP enablement PR - max_seq_len, - rope_theta, - ) - - def init_weights(self, init_std: float): - for linear in (self.wq, self.wk, self.wv): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) - nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) - - def forward(self, x: torch.Tensor): - """ - Forward pass of the attention module. - - Args: - x (torch.Tensor): Input tensor. - freqs_cis (torch.Tensor): Precomputed frequency tensor. - - Returns: - torch.Tensor: Output tensor after attention. - - """ - bs, seqlen, _ = x.shape - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - - # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual - # local heads from sizes of xq, xk, and xv as TP may have sharded them - # after the above linear ops. - xq = xq.view(bs, seqlen, -1, self.head_dim) - xk = xk.view(bs, seqlen, -1, self.head_dim) - xv = xv.view(bs, seqlen, -1, self.head_dim) - - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis) - - # repeat k/v heads if n_kv_heads < n_heads - keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - - xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - - # we use casual mask for training - output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) - output = output.transpose( - 1, 2 - ).contiguous() # (bs, seqlen, n_local_heads, head_dim) - output = output.view(bs, seqlen, -1) - return self.wo(output) diff --git a/debugging/repros/repro.py b/debugging/repros/repro.py index 7976f44b10..248e9d8d0b 100644 --- a/debugging/repros/repro.py +++ b/debugging/repros/repro.py @@ -4,13 +4,14 @@ from argparse import ArgumentParser, Namespace from collections import defaultdict from datetime import datetime, timedelta +from functools import partial import torch import torch.distributed as dist import torch.nn.functional as F -from attention import Attention, precompute_freqs_cis from torch import nn +from torch.autograd.profiler import record_function from torch.distributed._composable.fsdp import fully_shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as ptd_checkpoint_wrapper, @@ -22,9 +23,6 @@ from torchao.float8.config import Float8LinearRecipeName, recipe_name_to_linear_config from torchao.float8.float8_linear_utils import convert_to_float8_training -# Keep a max of 100,000 alloc/free events in the recorded history -# leading up to the snapshot. -MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT: int = 100000 # logging logging.basicConfig( @@ -36,32 +34,6 @@ logger.setLevel(level=logging.INFO) -class LinearModel(nn.Module): - def __init__(self, num_layers: int = 1): - super(LinearModel, self).__init__() - self.layers = nn.Sequential(*[nn.Linear(4096, 4096) for _ in range(num_layers)]) - - def forward(self, x): - return self.layers(x) - - -# Simplified FFN from Llama3 https://github.com/pytorch/torchtitan/blob/cca07028e440de6a13189d251c28337bd34256ef/torchtitan/models/llama/model.py#L217 -class FFN(nn.Module): - def __init__(self, dim: int, hidden_dim: int): - super(FFN, self).__init__() - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - - def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - def init_weights(self, init_std: float): - nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) - for linear in (self.w2, self.w3): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) - - def main(args: Namespace): assert torch.cuda.is_available() try: @@ -87,7 +59,6 @@ def main(args: Namespace): raise ValueError( f"invalid model type: {args.model_type} (must be one of: linear,ffn,attn)" ) - x = torch.randn(1, 16, 4096, dtype=torch.bfloat16).to(device) # fp8 rowwise quant if args.float8: @@ -107,11 +78,29 @@ def main(args: Namespace): setup_distributed() apply_fsdp(model) - # one fwd + backward - out = model(x) - out.sum().backward() + x = torch.randn(1, 16, 4096, dtype=torch.bfloat16).to(device) - # only 1 process should snapshot memory + # if training is enabled, perform 5 training iterations with optimizer steps. + if args.train: + logger.info("Training for 5 steps") + optimizer = torch.optim.AdamW(model.parameters()) + label = torch.ones((1,), device=device, dtype=torch.bfloat16) + for _ in range(5): + out = model(x) + F.mse_loss(out.sum().unsqueeze(-1), label).backward() + optimizer.step() + optimizer.zero_grad(set_to_none=True) + else: + logger.info( + "Performing one forward+backward iteration with no optimizer step" + ) + # if training is not enabled, do one fwd+bwd pass without any optimizer steps. + out = model(x) + out.sum().backward() + + torch.cuda.synchronize() + + # snapshot memory. only 1 process should snapshot memory if not (args.fsdp and dist.get_rank() != 0): export_memory_snapshot(args.snapshot_file) @@ -124,6 +113,11 @@ def main(args: Namespace): clean_up_distributed() +################################ +# Compile/FSDP2/SAC/Float8 utils +################################ + + def apply_compile(model: nn.Module): model = torch.compile(model, fullgraph=True) logger.info("Compiled model") @@ -203,7 +197,15 @@ def selective_checkpointing_context_fn(): ) -# memory snapshotting functions from: +################## +# Memory profiling +################## + +# Keep a max of 100,000 alloc/free events in the recorded history +# leading up to the snapshot. +MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT: int = 100000 + + # https://pytorch.org/blog/understanding-gpu-memory-1/#appendix-a---resnet50-memory-snapshot-code-example def start_record_memory_history() -> None: if not torch.cuda.is_available(): @@ -238,7 +240,14 @@ def export_memory_snapshot(filepath: str) -> None: return +################### +# Distributed utils +################### + + def setup_distributed(): + assert "RANK" in os.environ, "env var RANK must be set for FSDP" + assert "WORLD_SIZE" in os.environ, "env var WORLD_SIZE must be set for FSDP" rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) dist.init_process_group("nccl", rank=rank, world_size=world_size) @@ -251,6 +260,239 @@ def clean_up_distributed(): logger.info("Destroyed process group") +################### +# Layer definitions +################### + + +class LinearModel(nn.Module): + def __init__(self, num_layers: int = 1): + super(LinearModel, self).__init__() + self.layers = nn.Sequential(*[nn.Linear(4096, 4096) for _ in range(num_layers)]) + + def forward(self, x): + return self.layers(x) + + +# Simplified FFN from Llama3 https://github.com/pytorch/torchtitan/blob/cca07028e440de6a13189d251c28337bd34256ef/torchtitan/models/llama/model.py#L217 +class FFN(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + super(FFN, self).__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +# MHA layer from Llama3 https://github.com/pytorch/torchtitan/blob/cca07028e440de6a13189d251c28337bd34256ef/torchtitan/models/llama/model.py#L128 +class Attention(nn.Module): + """ + Multi-head attention module. + + Attributes: + n_kv_heads (int): Number of key and value heads. + n_heads (int): Number of query heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + + """ + + def __init__( + self, + head_dim: int, + num_heads: int, + num_kv_heads: int, + max_seq_len: int = 1024, + rope_theta: int = 10000, + ): + super().__init__() + self.n_heads = num_heads + self.n_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = head_dim // num_heads + + self.wq = nn.Linear(head_dim, num_heads * self.head_dim, bias=False) + self.wk = nn.Linear(head_dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(head_dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(num_heads * self.head_dim, head_dim, bias=False) + self.max_seq_len = max_seq_len + self.rope_theta = rope_theta + # TODO persistent should be set to false, since this buffer can be recomputed. + # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, + # compile or pipeline-tracer will not correctly handle non-persistent buffers, + # so we need to fix that. (2) if we initialize pipeline-parallel models from + # a seed checkpoint rather than calling init_weights, we need freqs_cis to be + # initialized by the checkpoint, or we need to add a separate initializer for + # just the non-persistent buffers that is called after loading checkpoints. + self.register_buffer( + "freqs_cis", + self._precompute_freqs_cis(head_dim, num_heads, max_seq_len, rope_theta), + persistent=True, + ) + + def _precompute_freqs_cis( + self, + head_dim: int, + num_heads: int, + max_seq_len: int, + rope_theta: int, + ): + return precompute_freqs_cis( + head_dim // num_heads, + # Need to compute until at least the max token limit for generation + # TODO: explain in docs/composability.md why we removed the 2x + # relaxing in our CP enablement PR + max_seq_len, + rope_theta, + ) + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward(self, x: torch.Tensor): + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis) + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + # we use casual mask for training + output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + seqlen = x.shape[1] + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + if __name__ == "__main__": argparser = ArgumentParser() argparser.add_argument("--float8", action="store_true") @@ -262,5 +504,10 @@ def clean_up_distributed(): argparser.add_argument( "--snapshot-file", type=str, required=True, help="[linear,ffn,attn]" ) + argparser.add_argument( + "--train", + action="store_true", + help="If set, train for 5 steps w/ AdamW optimizer and MSE loss. Otherwise, only do one fwd+bwd with no optimizer step.", + ) args = argparser.parse_args() main(args) From 89a39e9116fda3d73562a70be9e0da2b9f8a60e9 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Sat, 1 Feb 2025 15:58:11 -0800 Subject: [PATCH 08/14] make memory snapshots optional --- debugging/repros/repro.py | 48 ++++++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/debugging/repros/repro.py b/debugging/repros/repro.py index 248e9d8d0b..edfd53e070 100644 --- a/debugging/repros/repro.py +++ b/debugging/repros/repro.py @@ -36,52 +36,62 @@ def main(args: Namespace): assert torch.cuda.is_available() + + fsdp_enabled = args.fsdp + memory_snapshotting_enabled = args.snapshot_file is not None + use_float8 = args.float8 + use_compile = args.compile + model_type = args.model_type + use_per_op_ac = args.per_op_ac + run_training_loop = args.train + try: device = torch.device("cuda") torch.cuda.reset_peak_memory_stats(device) # start memory profile - start_record_memory_history() + if memory_snapshotting_enabled: + start_record_memory_history() # allocate model and inputs - if args.model_type == "linear": + if model_type == "linear": model = LinearModel(args.num_layers).to(torch.bfloat16).to(device) - elif args.model_type == "ffn": + elif model_type == "ffn": dim = 4096 hidden_dim = 4 * dim model = FFN(dim, hidden_dim).to(torch.bfloat16).to(device) - elif args.model_type == "attn": + elif model_type == "attn": head_dim = 4096 heads = 4 kv_heads = 4 model = Attention(head_dim, heads, kv_heads).to(torch.bfloat16).to(device) else: raise ValueError( - f"invalid model type: {args.model_type} (must be one of: linear,ffn,attn)" + f"invalid model type: {model_type} (must be one of: linear,ffn,attn)" ) # fp8 rowwise quant - if args.float8: + if use_float8: apply_fp8_rowwise_quant(model) # selective per op AC - if args.per_op_ac: + if use_per_op_ac: model = apply_ac(model) # compile - if args.compile: + if use_compile: model = apply_compile(model) # FSDP2 (2 GPUs or more required to avoid _scaled_mm error: # "RuntimeError: Only bf16 high precsion output types are supported for row-wise scaling." - if args.fsdp: + if fsdp_enabled: setup_distributed() apply_fsdp(model) x = torch.randn(1, 16, 4096, dtype=torch.bfloat16).to(device) # if training is enabled, perform 5 training iterations with optimizer steps. - if args.train: + if run_training_loop: logger.info("Training for 5 steps") optimizer = torch.optim.AdamW(model.parameters()) label = torch.ones((1,), device=device, dtype=torch.bfloat16) @@ -101,10 +111,12 @@ def main(args: Namespace): torch.cuda.synchronize() # snapshot memory. only 1 process should snapshot memory - if not (args.fsdp and dist.get_rank() != 0): - export_memory_snapshot(args.snapshot_file) + if memory_snapshotting_enabled: + is_rank_0 = fsdp_enabled and dist.get_rank() == 0 + if not fsdp_enabled or (fsdp_enabled and is_rank_0): + export_memory_snapshot(args.snapshot_file) - stop_record_memory_history() + stop_record_memory_history() peak_memory = torch.cuda.max_memory_allocated(device) print(f"Peak GPU memory usage: {peak_memory / (1024 ** 2):.2f} MB") @@ -131,7 +143,7 @@ def apply_compile(model: nn.Module): def apply_ac(model: nn.Module): if hasattr(model, "layers"): for layer_id, layer in model.layers.named_children(): - layer = _apply_per_op_ac_to_model(transformer_block, ac_config) + layer = _apply_per_op_ac_to_model(layer) model.layers.register_module(layer_id, layer) logger.info( f"Applied selective per op activation checkpoitning to multi-layer model" @@ -500,9 +512,13 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: argparser.add_argument("--compile", action="store_true") argparser.add_argument("--per-op-ac", action="store_true") argparser.add_argument("--num-layers", type=int, default=1) - argparser.add_argument("--model-type", type=str, required=True) argparser.add_argument( - "--snapshot-file", type=str, required=True, help="[linear,ffn,attn]" + "--model-type", type=str, required=True, help="[linear,ffn,attn]" + ) + argparser.add_argument( + "--snapshot-file", + type=str, + help="where to write the memory snapshot pickle file", ) argparser.add_argument( "--train", From 031ca419794bfef43c99ea8605bec095a6278da4 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Sat, 1 Feb 2025 16:24:51 -0800 Subject: [PATCH 09/14] add transformer blocker --- debugging/repros/repro.py => repro.py | 103 +++++++++++++++++++++----- 1 file changed, 84 insertions(+), 19 deletions(-) rename debugging/repros/repro.py => repro.py (87%) diff --git a/debugging/repros/repro.py b/repro.py similarity index 87% rename from debugging/repros/repro.py rename to repro.py index edfd53e070..b8fe7911b0 100644 --- a/debugging/repros/repro.py +++ b/repro.py @@ -23,6 +23,8 @@ from torchao.float8.config import Float8LinearRecipeName, recipe_name_to_linear_config from torchao.float8.float8_linear_utils import convert_to_float8_training +from torchtitan.models.norms import build_norm + # logging logging.basicConfig( @@ -53,23 +55,36 @@ def main(args: Namespace): if memory_snapshotting_enabled: start_record_memory_history() + ffwd_dim = 4096 + ffwd_hidden = 4 * ffwd_dim + head_dim = 4096 + num_heads = 4 + num_kv_heads = 4 + # allocate model and inputs if model_type == "linear": - model = LinearModel(args.num_layers).to(torch.bfloat16).to(device) + model = LinearModel(args.num_layers) elif model_type == "ffn": - dim = 4096 - hidden_dim = 4 * dim - model = FFN(dim, hidden_dim).to(torch.bfloat16).to(device) + model = FeedForward(ffwd_dim, ffwd_hidden) elif model_type == "attn": - head_dim = 4096 - heads = 4 - kv_heads = 4 - model = Attention(head_dim, heads, kv_heads).to(torch.bfloat16).to(device) + model = Attention(head_dim, num_heads, num_kv_heads) + elif model_type == "transformer_block": + layer_id = 0 + model = TransformerBlock( + layer_id, + num_heads, + num_kv_heads, + head_dim, + ffwd_dim, + ffwd_hidden, + ) else: raise ValueError( f"invalid model type: {model_type} (must be one of: linear,ffn,attn)" ) + model = model.bfloat16().cuda() + # fp8 rowwise quant if use_float8: apply_fp8_rowwise_quant(model) @@ -286,10 +301,10 @@ def forward(self, x): return self.layers(x) -# Simplified FFN from Llama3 https://github.com/pytorch/torchtitan/blob/cca07028e440de6a13189d251c28337bd34256ef/torchtitan/models/llama/model.py#L217 -class FFN(nn.Module): +# Simplified FeedForward from Llama3 https://github.com/pytorch/torchtitan/blob/cca07028e440de6a13189d251c28337bd34256ef/torchtitan/models/llama/model.py#L217 +class FeedForward(nn.Module): def __init__(self, dim: int, hidden_dim: int): - super(FFN, self).__init__() + super(FeedForward, self).__init__() self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) @@ -340,13 +355,7 @@ def __init__( self.wo = nn.Linear(num_heads * self.head_dim, head_dim, bias=False) self.max_seq_len = max_seq_len self.rope_theta = rope_theta - # TODO persistent should be set to false, since this buffer can be recomputed. - # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, - # compile or pipeline-tracer will not correctly handle non-persistent buffers, - # so we need to fix that. (2) if we initialize pipeline-parallel models from - # a seed checkpoint rather than calling init_weights, we need freqs_cis to be - # initialized by the checkpoint, or we need to add a separate initializer for - # just the non-persistent buffers that is called after loading checkpoints. + self.register_buffer( "freqs_cis", self._precompute_freqs_cis(head_dim, num_heads, max_seq_len, rope_theta), @@ -505,6 +514,59 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: ) +# source: https://github.com/pytorch/torchtitan/blob/cca07028e440de6a13189d251c28337bd34256ef/torchtitan/models/llama/model.py#L261 +class TransformerBlock(nn.Module): + + def __init__( + self, + layer_id: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + ffwd_dim: int, + ffwd_hidden: int, + norm_type: str = "rmsnorm", + max_seq_len: int = 1024, + rope_theta: int = 10000, + ): + super().__init__() + self.n_heads = num_heads + self.attention = Attention(head_dim, num_heads, num_kv_heads) + self.feed_forward = FeedForward( + dim=ffwd_dim, + hidden_dim=ffwd_hidden, + ) + self.layer_id = layer_id + + self.attention_norm = build_norm( + norm_type, + dim=head_dim, + ) + self.ffn_norm = build_norm( + norm_type, + dim=ffwd_dim, + ) + + def forward( + self, + x: torch.Tensor, + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + + """ + h = x + self.attention(self.attention_norm(x)) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + if __name__ == "__main__": argparser = ArgumentParser() argparser.add_argument("--float8", action="store_true") @@ -513,7 +575,10 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: argparser.add_argument("--per-op-ac", action="store_true") argparser.add_argument("--num-layers", type=int, default=1) argparser.add_argument( - "--model-type", type=str, required=True, help="[linear,ffn,attn]" + "--model-type", + type=str, + required=True, + help="[linear,ffn,attn,transformer_block]", ) argparser.add_argument( "--snapshot-file", From 611f1c2d728c5a9b1293bd5b3c3eddf10889b553 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Sun, 2 Feb 2025 16:10:46 -0800 Subject: [PATCH 10/14] support more than one transformer block --- repro.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/repro.py b/repro.py index b8fe7911b0..78da3f7927 100644 --- a/repro.py +++ b/repro.py @@ -69,15 +69,28 @@ def main(args: Namespace): elif model_type == "attn": model = Attention(head_dim, num_heads, num_kv_heads) elif model_type == "transformer_block": - layer_id = 0 - model = TransformerBlock( - layer_id, - num_heads, - num_kv_heads, - head_dim, - ffwd_dim, - ffwd_hidden, - ) + + class Transformer(nn.Module): + def __init__(self): + super(Transformer, self).__init__() + self.layers = nn.Sequential( + *[ + TransformerBlock( + layer_id, + num_heads, + num_kv_heads, + head_dim, + ffwd_dim, + ffwd_hidden, + ) + for layer_id in range(args.num_layers) + ] + ) + + def forward(self, x: torch.Tensor): + return self.layers(x) + + model = Transformer() else: raise ValueError( f"invalid model type: {model_type} (must be one of: linear,ffn,attn)" @@ -161,7 +174,7 @@ def apply_ac(model: nn.Module): layer = _apply_per_op_ac_to_model(layer) model.layers.register_module(layer_id, layer) logger.info( - f"Applied selective per op activation checkpoitning to multi-layer model" + f"Applied selective per op activation checkpointing to multi-layer model" ) else: model = _apply_per_op_ac_to_model(model) From ca95fcc60dce849907e95f74d235f3960ea03e35 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 3 Feb 2025 21:51:58 -0800 Subject: [PATCH 11/14] add ffn fsdp kernels --- logs/ffn_bf16_fwd.py | 180 +++++++++ logs/ffn_fp8_fwd.py | 939 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1119 insertions(+) create mode 100644 logs/ffn_bf16_fwd.py create mode 100644 logs/ffn_fp8_fwd.py diff --git a/logs/ffn_bf16_fwd.py b/logs/ffn_bf16_fwd.py new file mode 100644 index 0000000000..20dda9de0f --- /dev/null +++ b/logs/ffn_bf16_fwd.py @@ -0,0 +1,180 @@ +# AOT ID: ['0_forward'] +import math +import os +import random +import tempfile +from ctypes import c_int, c_long, c_void_p +from math import inf, nan + +import torch +import triton +import triton.language as tl +from torch import device, empty_strided +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.codegen.memory_planning import _align as align +from torch._inductor.codegen.multi_kernel import MultiKernelCall +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.runtime.triton_heuristics import ( + cooperative_reduction_grid, + end_graph, + grid, + grid_combo_kernels, + split_scan_grid, + start_graph, +) +from torch._inductor.select_algorithm import extern_kernels +from torch._inductor.utils import maybe_profile + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /tmp/torchinductor_danvm/gj/cgjugimza7cnhbqbxv4oyj66lp24pgam7flmze4edtdeubwbldku.py +# Topologically Sorted Source Nodes: [silu, mul], Original ATen: [aten.silu, aten.mul] +# Source node to ATen node mapping: +# mul => mul_1 +# silu => convert_element_type_2, convert_element_type_3, mul, sigmoid +# Graph fragment: +# %convert_element_type_2 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_1, torch.float32), kwargs = {}) +# %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%convert_element_type_2,), kwargs = {}) +# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %sigmoid), kwargs = {}) +# %convert_element_type_3 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) +# %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_3, %view_3), kwargs = {}) +triton_poi_fused_mul_silu_0 = async_compile.triton( + "triton_poi_fused_mul_silu_0", + """ +import triton +import triton.language as tl +from triton.compiler.compiler import AttrsDescriptor + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 262144}, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, + inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_silu_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'DB4EC0BC06A1FCBFCDA04BA16907EC3B1E867E352F9777F2A8CBA8D490D26C32', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_silu_0(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 262144 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = tl.full([XBLOCK], True, tl.int1) + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) + tmp5 = tl.load(in_out_ptr0 + (x0), None).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.sigmoid(tmp1) + tmp3 = tmp1 * tmp2 + tmp4 = tmp3.to(tl.float32) + tmp6 = tmp4 * tmp5 + tl.store(in_out_ptr0 + (x0), tmp6, None) +""", + device_str="cuda", +) + + +async_compile.wait(globals()) +del async_compile + + +def call(args): + primals_1, primals_2, primals_3, primals_4 = args + args.clear() + assert_size_stride(primals_1, (1, 16, 4096), (65536, 4096, 1)) + assert_size_stride(primals_2, (16384, 4096), (4096, 1)) + assert_size_stride(primals_3, (16384, 4096), (4096, 1)) + assert_size_stride(primals_4, (4096, 16384), (16384, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf0 = empty_strided_cuda((16, 16384), (16384, 1), torch.bfloat16) + # EXTRA MEM: buf0=(16*16384) + + # Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm] + extern_kernels.mm( + reinterpret_tensor(primals_1, (16, 4096), (4096, 1), 0), + reinterpret_tensor(primals_2, (4096, 16384), (1, 4096), 0), + out=buf0, + ) + del primals_2 + buf1 = empty_strided_cuda((16, 16384), (16384, 1), torch.bfloat16) + # EXTRA MEM: buf0=(16*16384) + buf1=(16*16384) + + # Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm] + extern_kernels.mm( + reinterpret_tensor(primals_1, (16, 4096), (4096, 1), 0), + reinterpret_tensor(primals_3, (4096, 16384), (1, 4096), 0), + out=buf1, + ) + buf2 = reinterpret_tensor(buf1, (1, 16, 16384), (262144, 16384, 1), 0) + del buf1 # reuse + # EXTRA MEM: buf0=(16*16384) + buf2=(1*16*16384) + + # Topologically Sorted Source Nodes: [silu, mul], Original ATen: [aten.silu, aten.mul] + stream1 = get_raw_stream(1) + triton_poi_fused_mul_silu_0.run( + buf2, buf0, 262144, grid=grid(262144), stream=stream1 + ) + buf3 = empty_strided_cuda((16, 4096), (4096, 1), torch.bfloat16) + # EXTRA MEM: buf0=(16*16384) + buf2=(1*16*16384) + buf3=(16*4096) + + # Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm] + extern_kernels.mm( + reinterpret_tensor(buf2, (16, 16384), (16384, 1), 0), + reinterpret_tensor(primals_4, (16384, 4096), (1, 16384), 0), + out=buf3, + ) + del buf2 + # EXTRA MEM: buf0=(16*16384) + buf3=(16*4096) + + # PEAK EXTRA MEM was line 134: (16*16384) + (1*16*16384) + (16*4096) = 589824 * 2 bytes for bf16 = 655360 bytes + return ( + reinterpret_tensor(buf3, (1, 16, 4096), (65536, 4096, 1), 0), + primals_1, + primals_3, + primals_4, + buf0, + ) + + # RETURNS (save for backward?) only buf0 and buf3 => buf0=(16*16384) + buf3=(1*16*4096) * 2 bytes for bf16 = 393216 bytes + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + + primals_1 = rand_strided( + (1, 16, 4096), (65536, 4096, 1), device="cuda:1", dtype=torch.bfloat16 + ) + primals_2 = rand_strided( + (16384, 4096), (4096, 1), device="cuda:1", dtype=torch.bfloat16 + ) + primals_3 = rand_strided( + (16384, 4096), (4096, 1), device="cuda:1", dtype=torch.bfloat16 + ) + primals_4 = rand_strided( + (4096, 16384), (16384, 1), device="cuda:1", dtype=torch.bfloat16 + ) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + + compiled_module_main("None", benchmark_compiled_module) diff --git a/logs/ffn_fp8_fwd.py b/logs/ffn_fp8_fwd.py new file mode 100644 index 0000000000..6029338921 --- /dev/null +++ b/logs/ffn_fp8_fwd.py @@ -0,0 +1,939 @@ +# AOT ID: ['0_forward'] +import math +import os +import random +import tempfile +from ctypes import c_int, c_long, c_void_p +from math import inf, nan + +import torch +import triton +import triton.language as tl +from torch import device, empty_strided +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.codegen.memory_planning import _align as align +from torch._inductor.codegen.multi_kernel import MultiKernelCall +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.runtime.triton_heuristics import ( + cooperative_reduction_grid, + end_graph, + grid, + grid_combo_kernels, + split_scan_grid, + start_graph, +) +from torch._inductor.select_algorithm import extern_kernels +from torch._inductor.utils import maybe_profile + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /tmp/torchinductor_danvm/j3/cj3aaa4oa3cxacjerv7baomyivvg5jvg7usxhp2tlidhqirw7j2s.py +# Topologically Sorted Source Nodes: [output, output_1], Original ATen: [aten.abs, aten.amax, aten._to_copy, aten.clamp, aten.reciprocal, aten.mul, aten._scaled_mm] +# Source node to ATen node mapping: +# output => _scaled_mm, abs_1, amax, clamp_max_1, clamp_min_2, clamp_min_3, convert_element_type_4, convert_element_type_5, convert_element_type_6, convert_element_type_7, mul_2, mul_3, reciprocal_1, reciprocal_2, reciprocal_3 +# output_1 => _scaled_mm_1, clamp_max_3, clamp_min_6, clamp_min_7, convert_element_type_14, convert_element_type_15, convert_element_type_16, convert_element_type_17, mul_7, mul_8, reciprocal_5, reciprocal_7 +# Graph fragment: +# %abs_1 : [num_users=2] = call_function[target=torch.ops.aten.abs.default](args = (%primals_1,), kwargs = {}) +# %amax : [num_users=1] = call_function[target=torch.ops.aten.amax.default](args = (%abs_1, [-1], True), kwargs = {}) +# %convert_element_type_4 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%amax_1, torch.float64), kwargs = {}) +# %clamp_min_2 : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%convert_element_type_4, 1e-12), kwargs = {}) +# %reciprocal_1 : [num_users=1] = call_function[target=torch.ops.aten.reciprocal.default](args = (%clamp_min_2,), kwargs = {}) +# %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%reciprocal_1, 448.0), kwargs = {}) +# %convert_element_type_5 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_2, torch.float32), kwargs = {}) +# %convert_element_type_6 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute, torch.float32), kwargs = {}) +# %mul_3 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_6, %convert_element_type_5), kwargs = {}) +# %clamp_min_3 : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%mul_3, -448.0), kwargs = {}) +# %clamp_max_1 : [num_users=1] = call_function[target=torch.ops.aten.clamp_max.default](args = (%clamp_min_3, 448.0), kwargs = {}) +# %convert_element_type_7 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%clamp_max_1, torch.float8_e4m3fn), kwargs = {}) +# %reciprocal_2 : [num_users=3] = call_function[target=torch.ops.aten.reciprocal.default](args = (%view_1,), kwargs = {}) +# %reciprocal_3 : [num_users=1] = call_function[target=torch.ops.aten.reciprocal.default](args = (%convert_element_type_5,), kwargs = {}) +# %_scaled_mm : [num_users=2] = call_function[target=torch.ops.aten._scaled_mm.default](args = (%view, %convert_element_type_7, %reciprocal_2, %reciprocal_3, None, None, torch.bfloat16, True), kwargs = {}) +# %convert_element_type_14 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%amax_3, torch.float64), kwargs = {}) +# %clamp_min_6 : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%convert_element_type_14, 1e-12), kwargs = {}) +# %reciprocal_5 : [num_users=1] = call_function[target=torch.ops.aten.reciprocal.default](args = (%clamp_min_6,), kwargs = {}) +# %mul_7 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%reciprocal_5, 448.0), kwargs = {}) +# %convert_element_type_15 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_7, torch.float32), kwargs = {}) +# %convert_element_type_16 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_1, torch.float32), kwargs = {}) +# %mul_8 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_16, %convert_element_type_15), kwargs = {}) +# %clamp_min_7 : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%mul_8, -448.0), kwargs = {}) +# %clamp_max_3 : [num_users=1] = call_function[target=torch.ops.aten.clamp_max.default](args = (%clamp_min_7, 448.0), kwargs = {}) +# %convert_element_type_17 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%clamp_max_3, torch.float8_e4m3fn), kwargs = {}) +# %reciprocal_7 : [num_users=1] = call_function[target=torch.ops.aten.reciprocal.default](args = (%convert_element_type_15,), kwargs = {}) +# %_scaled_mm_1 : [num_users=1] = call_function[target=torch.ops.aten._scaled_mm.default](args = (%view, %convert_element_type_17, %reciprocal_2, %reciprocal_7, None, None, torch.bfloat16, True), kwargs = {}) +triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_0 = async_compile.triton( + "triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_0", + """ +import triton +import triton.language as tl +from triton.compiler.compiler import AttrsDescriptor + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'out_ptr2': '*fp32', 'out_ptr3': '*fp8e4nv', 'out_ptr4': '*fp8e4nv', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, + inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'DB4EC0BC06A1FCBFCDA04BA16907EC3B1E867E352F9777F2A8CBA8D490D26C32', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_0(in_ptr0, out_ptr0, out_ptr2, out_ptr3, out_ptr4, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16 + r0_numel = 4096 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tl_math.abs(tmp0) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = triton_helpers.maximum(_tmp3, tmp2) + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tl.store(out_ptr0 + (r0_1 + 4096*x0), tmp1, r0_mask & xmask) + tmp3 = triton_helpers.max2(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.float64) + tmp6 = tl.full([1, 1], 1e-12, tl.float64) + tmp7 = triton_helpers.maximum(tmp5, tmp6) + tmp8 = tl.full([1, 1], 1, tl.int32) + tmp9 = tmp8 / tmp7 + tmp10 = tl.full([1, 1], 448.0, tl.float64) + tmp11 = tmp9 * tmp10 + tmp12 = tmp11.to(tl.float32) + tmp13 = tmp8 / tmp12 + tl.store(out_ptr2 + (x0), tmp13, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp14 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp15 = tmp14.to(tl.float32) + tmp16 = tmp15 * tmp12 + tmp17 = -448.0 + tmp18 = triton_helpers.maximum(tmp16, tmp17) + tmp19 = 448.0 + tmp20 = triton_helpers.minimum(tmp18, tmp19) + tmp21 = tmp20.to(tl.float8e4nv) + tl.store(out_ptr3 + (r0_1 + 4096*x0), tmp21, r0_mask & xmask) + tl.store(out_ptr4 + (r0_1 + 4096*x0), tmp21, r0_mask & xmask) +""", + device_str="cuda", +) + + +# kernel path: /tmp/torchinductor_danvm/ii/ciixp3lqrhj65bjje7eugwloffecuwrhde43psb7fse27wrv3bex.py +# Topologically Sorted Source Nodes: [output], Original ATen: [aten.abs, aten.amax, aten._to_copy, aten.clamp, aten.reciprocal, aten.mul, aten._scaled_mm] +# Source node to ATen node mapping: +# output => _scaled_mm, abs_2, amax_1, clamp_max_1, clamp_min_2, clamp_min_3, convert_element_type_4, convert_element_type_5, convert_element_type_6, convert_element_type_7, mul_2, mul_3, reciprocal_1, reciprocal_3 +# Graph fragment: +# %abs_2 : [num_users=1] = call_function[target=torch.ops.aten.abs.default](args = (%permute,), kwargs = {}) +# %amax_1 : [num_users=1] = call_function[target=torch.ops.aten.amax.default](args = (%abs_2, [0], True), kwargs = {}) +# %convert_element_type_4 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%amax_1, torch.float64), kwargs = {}) +# %clamp_min_2 : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%convert_element_type_4, 1e-12), kwargs = {}) +# %reciprocal_1 : [num_users=1] = call_function[target=torch.ops.aten.reciprocal.default](args = (%clamp_min_2,), kwargs = {}) +# %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%reciprocal_1, 448.0), kwargs = {}) +# %convert_element_type_5 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_2, torch.float32), kwargs = {}) +# %convert_element_type_6 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute, torch.float32), kwargs = {}) +# %mul_3 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_6, %convert_element_type_5), kwargs = {}) +# %clamp_min_3 : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%mul_3, -448.0), kwargs = {}) +# %clamp_max_1 : [num_users=1] = call_function[target=torch.ops.aten.clamp_max.default](args = (%clamp_min_3, 448.0), kwargs = {}) +# %convert_element_type_7 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%clamp_max_1, torch.float8_e4m3fn), kwargs = {}) +# %reciprocal_3 : [num_users=1] = call_function[target=torch.ops.aten.reciprocal.default](args = (%convert_element_type_5,), kwargs = {}) +# %_scaled_mm : [num_users=2] = call_function[target=torch.ops.aten._scaled_mm.default](args = (%view, %convert_element_type_7, %reciprocal_2, %reciprocal_3, None, None, torch.bfloat16, True), kwargs = {}) +triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_1 = async_compile.triton( + "triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_1", + """ +import triton +import triton.language as tl +from triton.compiler.compiler import AttrsDescriptor + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr1': '*fp8e4nv', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, + inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'DB4EC0BC06A1FCBFCDA04BA16907EC3B1E867E352F9777F2A8CBA8D490D26C32', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_1(in_ptr0, out_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 4096 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tl_math.abs(tmp0) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = triton_helpers.maximum(_tmp3, tmp2) + _tmp3 = tl.where(r0_mask, tmp4, _tmp3) + tmp3 = triton_helpers.max2(_tmp3, 1)[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp3.to(tl.float64) + tmp8 = tl.full([1, 1], 1e-12, tl.float64) + tmp9 = triton_helpers.maximum(tmp7, tmp8) + tmp10 = tl.full([1, 1], 1, tl.int32) + tmp11 = tmp10 / tmp9 + tmp12 = tl.full([1, 1], 448.0, tl.float64) + tmp13 = tmp11 * tmp12 + tmp14 = tmp13.to(tl.float32) + tmp15 = tmp6 * tmp14 + tmp16 = -448.0 + tmp17 = triton_helpers.maximum(tmp15, tmp16) + tmp18 = 448.0 + tmp19 = triton_helpers.minimum(tmp17, tmp18) + tmp20 = tmp19.to(tl.float8e4nv) + tl.store(out_ptr1 + (r0_1 + 4096*x0), tmp20, r0_mask) + tmp21 = tmp3.to(tl.float64) + tmp22 = tl.full([1, 1], 1e-12, tl.float64) + tmp23 = triton_helpers.maximum(tmp21, tmp22) + tmp24 = tl.full([1, 1], 1, tl.int32) + tmp25 = tmp24 / tmp23 + tmp26 = tl.full([1, 1], 448.0, tl.float64) + tmp27 = tmp25 * tmp26 + tmp28 = tmp27.to(tl.float32) + tmp29 = tmp24 / tmp28 + tl.store(out_ptr2 + (x0), tmp29, None) +""", + device_str="cuda", +) + + +# kernel path: /tmp/torchinductor_danvm/iv/civvg45odupbusgmig3rlca5gn6ph5bbbcnun7xax2e344vikhe7.py +# Topologically Sorted Source Nodes: [output_1], Original ATen: [aten.abs, aten.amax, aten._to_copy, aten.clamp, aten.reciprocal, aten.mul, aten._scaled_mm] +# Source node to ATen node mapping: +# output_1 => _scaled_mm_1, abs_4, amax_3, clamp_max_3, clamp_min_6, clamp_min_7, convert_element_type_14, convert_element_type_15, convert_element_type_16, convert_element_type_17, mul_7, mul_8, reciprocal_5, reciprocal_7 +# Graph fragment: +# %abs_4 : [num_users=2] = call_function[target=torch.ops.aten.abs.default](args = (%permute_1,), kwargs = {}) +# %amax_3 : [num_users=1] = call_function[target=torch.ops.aten.amax.default](args = (%abs_4, [0], True), kwargs = {}) +# %convert_element_type_14 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%amax_3, torch.float64), kwargs = {}) +# %clamp_min_6 : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%convert_element_type_14, 1e-12), kwargs = {}) +# %reciprocal_5 : [num_users=1] = call_function[target=torch.ops.aten.reciprocal.default](args = (%clamp_min_6,), kwargs = {}) +# %mul_7 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%reciprocal_5, 448.0), kwargs = {}) +# %convert_element_type_15 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_7, torch.float32), kwargs = {}) +# %convert_element_type_16 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_1, torch.float32), kwargs = {}) +# %mul_8 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_16, %convert_element_type_15), kwargs = {}) +# %clamp_min_7 : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%mul_8, -448.0), kwargs = {}) +# %clamp_max_3 : [num_users=1] = call_function[target=torch.ops.aten.clamp_max.default](args = (%clamp_min_7, 448.0), kwargs = {}) +# %convert_element_type_17 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%clamp_max_3, torch.float8_e4m3fn), kwargs = {}) +# %reciprocal_7 : [num_users=1] = call_function[target=torch.ops.aten.reciprocal.default](args = (%convert_element_type_15,), kwargs = {}) +# %_scaled_mm_1 : [num_users=1] = call_function[target=torch.ops.aten._scaled_mm.default](args = (%view, %convert_element_type_17, %reciprocal_2, %reciprocal_7, None, None, torch.bfloat16, True), kwargs = {}) +triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_2 = async_compile.triton( + "triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_2", + """ +import triton +import triton.language as tl +from triton.compiler.compiler import AttrsDescriptor + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'out_ptr2': '*fp8e4nv', 'out_ptr3': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, + inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_2', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'DB4EC0BC06A1FCBFCDA04BA16907EC3B1E867E352F9777F2A8CBA8D490D26C32', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_2(in_ptr0, out_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 4096 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tl_math.abs(tmp0) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = triton_helpers.maximum(_tmp3, tmp2) + _tmp3 = tl.where(r0_mask, tmp4, _tmp3) + tl.store(out_ptr0 + (r0_1 + 4096*x0), tmp1, r0_mask) + tmp3 = triton_helpers.max2(_tmp3, 1)[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp3.to(tl.float64) + tmp8 = tl.full([1, 1], 1e-12, tl.float64) + tmp9 = triton_helpers.maximum(tmp7, tmp8) + tmp10 = tl.full([1, 1], 1, tl.int32) + tmp11 = tmp10 / tmp9 + tmp12 = tl.full([1, 1], 448.0, tl.float64) + tmp13 = tmp11 * tmp12 + tmp14 = tmp13.to(tl.float32) + tmp15 = tmp6 * tmp14 + tmp16 = -448.0 + tmp17 = triton_helpers.maximum(tmp15, tmp16) + tmp18 = 448.0 + tmp19 = triton_helpers.minimum(tmp17, tmp18) + tmp20 = tmp19.to(tl.float8e4nv) + tl.store(out_ptr2 + (r0_1 + 4096*x0), tmp20, r0_mask) + tmp21 = tmp3.to(tl.float64) + tmp22 = tl.full([1, 1], 1e-12, tl.float64) + tmp23 = triton_helpers.maximum(tmp21, tmp22) + tmp24 = tl.full([1, 1], 1, tl.int32) + tmp25 = tmp24 / tmp23 + tmp26 = tl.full([1, 1], 448.0, tl.float64) + tmp27 = tmp25 * tmp26 + tmp28 = tmp27.to(tl.float32) + tmp29 = tmp24 / tmp28 + tl.store(out_ptr3 + (x0), tmp29, None) +""", + device_str="cuda", +) + + +# kernel path: /tmp/torchinductor_danvm/26/c26iyqqjm3qargz5uwszhy2bsjvfs7pieaclnhfferte4nt6cje4.py +# Topologically Sorted Source Nodes: [silu, mul, output_2], Original ATen: [aten.silu, aten.mul, aten.abs, aten.amax] +# Source node to ATen node mapping: +# mul => mul_9 +# output_2 => abs_5, amax_4 +# silu => convert_element_type_8, convert_element_type_9, mul_4, sigmoid +# Graph fragment: +# %convert_element_type_8 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_2, torch.float32), kwargs = {}) +# %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%convert_element_type_8,), kwargs = {}) +# %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_8, %sigmoid), kwargs = {}) +# %convert_element_type_9 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_4, torch.bfloat16), kwargs = {}) +# %mul_9 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_9, %view_5), kwargs = {}) +# %abs_5 : [num_users=1] = call_function[target=torch.ops.aten.abs.default](args = (%mul_9,), kwargs = {}) +# %amax_4 : [num_users=1] = call_function[target=torch.ops.aten.amax.default](args = (%abs_5, [-1], True), kwargs = {}) +triton_red_fused_abs_amax_mul_silu_3 = async_compile.triton( + "triton_red_fused_abs_amax_mul_silu_3", + """ +import triton +import triton.language as tl +from triton.compiler.compiler import AttrsDescriptor + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, + inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_abs_amax_mul_silu_3', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'DB4EC0BC06A1FCBFCDA04BA16907EC3B1E867E352F9777F2A8CBA8D490D26C32', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_abs_amax_mul_silu_3(in_ptr0, in_ptr1, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 32 + r0_numel = 8192 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp9 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp5 = tl.load(in_ptr1 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.sigmoid(tmp1) + tmp3 = tmp1 * tmp2 + tmp4 = tmp3.to(tl.float32) + tmp6 = tmp4 * tmp5 + tmp7 = tl_math.abs(tmp6) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = triton_helpers.maximum(_tmp9, tmp8) + _tmp9 = tl.where(r0_mask & xmask, tmp10, _tmp9) + tmp9 = triton_helpers.max2(_tmp9, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp9, xmask) +""", + device_str="cuda", +) + + +# kernel path: /tmp/torchinductor_danvm/tg/ctgfrwoni6kr2t4tajwdeqfskectpmqoe27bprocl3iryjmtdroa.py +# Topologically Sorted Source Nodes: [silu, mul, output_2], Original ATen: [aten.silu, aten.mul, aten.abs, aten.amax, aten._to_copy, aten.clamp, aten.reciprocal, aten._scaled_mm] +# Source node to ATen node mapping: +# mul => mul_9 +# output_2 => _scaled_mm_2, abs_5, amax_4, clamp_max_5, clamp_min_10, clamp_min_11, convert_element_type_22, convert_element_type_23, convert_element_type_24, convert_element_type_25, mul_12, mul_13, reciprocal_10, reciprocal_11, reciprocal_9 +# silu => convert_element_type_8, convert_element_type_9, mul_4, sigmoid +# Graph fragment: +# %convert_element_type_8 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_2, torch.float32), kwargs = {}) +# %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%convert_element_type_8,), kwargs = {}) +# %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_8, %sigmoid), kwargs = {}) +# %convert_element_type_9 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_4, torch.bfloat16), kwargs = {}) +# %mul_9 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_9, %view_5), kwargs = {}) +# %abs_5 : [num_users=1] = call_function[target=torch.ops.aten.abs.default](args = (%mul_9,), kwargs = {}) +# %amax_4 : [num_users=1] = call_function[target=torch.ops.aten.amax.default](args = (%abs_5, [-1], True), kwargs = {}) +# %convert_element_type_22 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%amax_5, torch.float64), kwargs = {}) +# %clamp_min_10 : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%convert_element_type_22, 1e-12), kwargs = {}) +# %reciprocal_9 : [num_users=1] = call_function[target=torch.ops.aten.reciprocal.default](args = (%clamp_min_10,), kwargs = {}) +# %mul_12 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%reciprocal_9, 448.0), kwargs = {}) +# %convert_element_type_23 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_12, torch.float32), kwargs = {}) +# %convert_element_type_24 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_2, torch.float32), kwargs = {}) +# %mul_13 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_24, %convert_element_type_23), kwargs = {}) +# %clamp_min_11 : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%mul_13, -448.0), kwargs = {}) +# %clamp_max_5 : [num_users=1] = call_function[target=torch.ops.aten.clamp_max.default](args = (%clamp_min_11, 448.0), kwargs = {}) +# %convert_element_type_25 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%clamp_max_5, torch.float8_e4m3fn), kwargs = {}) +# %reciprocal_10 : [num_users=1] = call_function[target=torch.ops.aten.reciprocal.default](args = (%view_7,), kwargs = {}) +# %reciprocal_11 : [num_users=1] = call_function[target=torch.ops.aten.reciprocal.default](args = (%convert_element_type_23,), kwargs = {}) +# %_scaled_mm_2 : [num_users=1] = call_function[target=torch.ops.aten._scaled_mm.default](args = (%view_6, %convert_element_type_25, %reciprocal_10, %reciprocal_11, None, None, torch.bfloat16, True), kwargs = {}) +triton_per_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_silu_4 = async_compile.triton( + "triton_per_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_silu_4", + """ +import triton +import triton.language as tl +from triton.compiler.compiler import AttrsDescriptor + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 16, 'r0_': 2}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, + inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_silu_4', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'DB4EC0BC06A1FCBFCDA04BA16907EC3B1E867E352F9777F2A8CBA8D490D26C32', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_silu_4(in_ptr0, out_ptr0, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 16 + r0_numel = 2 + R0_BLOCK: tl.constexpr = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r0_1 + 2*x0), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None] + tmp5 = tmp4.to(tl.float64) + tmp6 = tl.full([1, 1], 1e-12, tl.float64) + tmp7 = triton_helpers.maximum(tmp5, tmp6) + tmp8 = tl.full([1, 1], 1, tl.int32) + tmp9 = tmp8 / tmp7 + tmp10 = tl.full([1, 1], 448.0, tl.float64) + tmp11 = tmp9 * tmp10 + tmp12 = tmp11.to(tl.float32) + tmp13 = tmp8 / tmp12 + tl.store(out_ptr1 + (x0), tmp13, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) +""", + device_str="cuda", +) + + +# kernel path: /tmp/torchinductor_danvm/j7/cj7bfkg5gexx2hvk4lidnvmdtyno47zcq6r5spgx32odg5wququs.py +# Topologically Sorted Source Nodes: [output_2], Original ATen: [aten._to_copy, aten.clamp, aten.reciprocal, aten.mul, aten._scaled_mm] +# Source node to ATen node mapping: +# output_2 => _scaled_mm_2, clamp_max_5, clamp_min_10, clamp_min_11, convert_element_type_22, convert_element_type_23, convert_element_type_24, convert_element_type_25, mul_12, mul_13, reciprocal_10, reciprocal_11, reciprocal_9 +# Graph fragment: +# %convert_element_type_22 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%amax_5, torch.float64), kwargs = {}) +# %clamp_min_10 : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%convert_element_type_22, 1e-12), kwargs = {}) +# %reciprocal_9 : [num_users=1] = call_function[target=torch.ops.aten.reciprocal.default](args = (%clamp_min_10,), kwargs = {}) +# %mul_12 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%reciprocal_9, 448.0), kwargs = {}) +# %convert_element_type_23 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_12, torch.float32), kwargs = {}) +# %convert_element_type_24 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_2, torch.float32), kwargs = {}) +# %mul_13 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_24, %convert_element_type_23), kwargs = {}) +# %clamp_min_11 : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%mul_13, -448.0), kwargs = {}) +# %clamp_max_5 : [num_users=1] = call_function[target=torch.ops.aten.clamp_max.default](args = (%clamp_min_11, 448.0), kwargs = {}) +# %convert_element_type_25 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%clamp_max_5, torch.float8_e4m3fn), kwargs = {}) +# %reciprocal_10 : [num_users=1] = call_function[target=torch.ops.aten.reciprocal.default](args = (%view_7,), kwargs = {}) +# %reciprocal_11 : [num_users=1] = call_function[target=torch.ops.aten.reciprocal.default](args = (%convert_element_type_23,), kwargs = {}) +# %_scaled_mm_2 : [num_users=1] = call_function[target=torch.ops.aten._scaled_mm.default](args = (%view_6, %convert_element_type_25, %reciprocal_10, %reciprocal_11, None, None, torch.bfloat16, True), kwargs = {}) +triton_poi_fused__scaled_mm__to_copy_clamp_mul_reciprocal_5 = async_compile.triton( + "triton_poi_fused__scaled_mm__to_copy_clamp_mul_reciprocal_5", + """ +import triton +import triton.language as tl +from triton.compiler.compiler import AttrsDescriptor + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 262144}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*fp8e4nv', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, + inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_mm__to_copy_clamp_mul_reciprocal_5', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'DB4EC0BC06A1FCBFCDA04BA16907EC3B1E867E352F9777F2A8CBA8D490D26C32', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=2 +) +@triton.jit +def triton_poi_fused__scaled_mm__to_copy_clamp_mul_reciprocal_5(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 262144 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = tl.full([XBLOCK], True, tl.int1) + x2 = xindex + x1 = xindex // 16384 + tmp0 = tl.load(in_ptr0 + (x2), None).to(tl.float32) + tmp5 = tl.load(in_ptr1 + (x2), None).to(tl.float32) + tmp8 = tl.load(in_ptr2 + (x1), None, eviction_policy='evict_last').to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.sigmoid(tmp1) + tmp3 = tmp1 * tmp2 + tmp4 = tmp3.to(tl.float32) + tmp6 = tmp4 * tmp5 + tmp7 = tmp6.to(tl.float32) + tmp9 = tmp8.to(tl.float64) + tmp10 = tl.full([1], 1e-12, tl.float64) + tmp11 = triton_helpers.maximum(tmp9, tmp10) + tmp12 = tl.full([1], 1, tl.int32) + tmp13 = tmp12 / tmp11 + tmp14 = tl.full([1], 448.0, tl.float64) + tmp15 = tmp13 * tmp14 + tmp16 = tmp15.to(tl.float32) + tmp17 = tmp7 * tmp16 + tmp18 = -448.0 + tmp19 = triton_helpers.maximum(tmp17, tmp18) + tmp20 = 448.0 + tmp21 = triton_helpers.minimum(tmp19, tmp20) + tmp22 = tmp21.to(tl.float8e4nv) + tl.store(out_ptr0 + (x2), tmp22, None) +""", + device_str="cuda", +) + + +# kernel path: /tmp/torchinductor_danvm/iu/ciu56dezu4d3prdcizpttizlavmyqmxxi5zt7zpfgpsyibjjb7su.py +# Topologically Sorted Source Nodes: [output_2], Original ATen: [aten.abs, aten.amax, aten._to_copy, aten.clamp, aten.reciprocal, aten.mul, aten._scaled_mm] +# Source node to ATen node mapping: +# output_2 => _scaled_mm_2, abs_6, amax_5, clamp_max_5, clamp_min_10, clamp_min_11, convert_element_type_22, convert_element_type_23, convert_element_type_24, convert_element_type_25, mul_12, mul_13, reciprocal_10, reciprocal_11, reciprocal_9 +# Graph fragment: +# %abs_6 : [num_users=1] = call_function[target=torch.ops.aten.abs.default](args = (%permute_2,), kwargs = {}) +# %amax_5 : [num_users=1] = call_function[target=torch.ops.aten.amax.default](args = (%abs_6, [0], True), kwargs = {}) +# %convert_element_type_22 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%amax_5, torch.float64), kwargs = {}) +# %clamp_min_10 : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%convert_element_type_22, 1e-12), kwargs = {}) +# %reciprocal_9 : [num_users=1] = call_function[target=torch.ops.aten.reciprocal.default](args = (%clamp_min_10,), kwargs = {}) +# %mul_12 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%reciprocal_9, 448.0), kwargs = {}) +# %convert_element_type_23 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_12, torch.float32), kwargs = {}) +# %convert_element_type_24 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_2, torch.float32), kwargs = {}) +# %mul_13 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_24, %convert_element_type_23), kwargs = {}) +# %clamp_min_11 : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%mul_13, -448.0), kwargs = {}) +# %clamp_max_5 : [num_users=1] = call_function[target=torch.ops.aten.clamp_max.default](args = (%clamp_min_11, 448.0), kwargs = {}) +# %convert_element_type_25 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%clamp_max_5, torch.float8_e4m3fn), kwargs = {}) +# %reciprocal_10 : [num_users=1] = call_function[target=torch.ops.aten.reciprocal.default](args = (%view_7,), kwargs = {}) +# %reciprocal_11 : [num_users=1] = call_function[target=torch.ops.aten.reciprocal.default](args = (%convert_element_type_23,), kwargs = {}) +# %_scaled_mm_2 : [num_users=1] = call_function[target=torch.ops.aten._scaled_mm.default](args = (%view_6, %convert_element_type_25, %reciprocal_10, %reciprocal_11, None, None, torch.bfloat16, True), kwargs = {}) +triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_6 = async_compile.triton( + "triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_6", + """ +import triton +import triton.language as tl +from triton.compiler.compiler import AttrsDescriptor + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr1': '*fp8e4nv', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, + inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_6', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'DB4EC0BC06A1FCBFCDA04BA16907EC3B1E867E352F9777F2A8CBA8D490D26C32', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_6(in_ptr0, out_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 16384*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tl_math.abs(tmp0) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = triton_helpers.maximum(_tmp3, tmp2) + _tmp3 = tl.where(r0_mask, tmp4, _tmp3) + tmp3 = triton_helpers.max2(_tmp3, 1)[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 16384*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp3.to(tl.float64) + tmp8 = tl.full([1, 1], 1e-12, tl.float64) + tmp9 = triton_helpers.maximum(tmp7, tmp8) + tmp10 = tl.full([1, 1], 1, tl.int32) + tmp11 = tmp10 / tmp9 + tmp12 = tl.full([1, 1], 448.0, tl.float64) + tmp13 = tmp11 * tmp12 + tmp14 = tmp13.to(tl.float32) + tmp15 = tmp6 * tmp14 + tmp16 = -448.0 + tmp17 = triton_helpers.maximum(tmp15, tmp16) + tmp18 = 448.0 + tmp19 = triton_helpers.minimum(tmp17, tmp18) + tmp20 = tmp19.to(tl.float8e4nv) + tl.store(out_ptr1 + (r0_1 + 16384*x0), tmp20, r0_mask) + tmp21 = tmp3.to(tl.float64) + tmp22 = tl.full([1, 1], 1e-12, tl.float64) + tmp23 = triton_helpers.maximum(tmp21, tmp22) + tmp24 = tl.full([1, 1], 1, tl.int32) + tmp25 = tmp24 / tmp23 + tmp26 = tl.full([1, 1], 448.0, tl.float64) + tmp27 = tmp25 * tmp26 + tmp28 = tmp27.to(tl.float32) + tmp29 = tmp24 / tmp28 + tl.store(out_ptr2 + (x0), tmp29, None) +""", + device_str="cuda", +) + + +async_compile.wait(globals()) +del async_compile + + +def call(args): + primals_1, primals_2, primals_3, primals_4 = args + args.clear() + assert_size_stride(primals_1, (1, 16, 4096), (65536, 4096, 1)) + assert_size_stride(primals_2, (16384, 4096), (4096, 1)) + assert_size_stride(primals_3, (16384, 4096), (4096, 1)) + assert_size_stride(primals_4, (4096, 16384), (16384, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((1, 16, 4096), (65536, 4096, 1), torch.bfloat16) + buf3 = empty_strided_cuda((16, 1), (1, 1), torch.float32) + buf4 = empty_strided_cuda((16, 4096), (4096, 1), torch.float8_e4m3fn) + buf10 = empty_strided_cuda((16, 4096), (4096, 1), torch.float8_e4m3fn) + import pdb + + pdb.set_trace() + + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf4:fp8=(16, 4096) + buf10:fp8=(16, 4096) + # total bytes: 262,208 + + # Topologically Sorted Source Nodes: [output, output_1], Original ATen: [aten.abs, aten.amax, aten._to_copy, aten.clamp, aten.reciprocal, aten.mul, aten._scaled_mm] + stream0 = get_raw_stream(0) + triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_0.run( + primals_1, buf0, buf3, buf4, buf10, 16, 4096, grid=grid(16), stream=stream0 + ) + buf5 = empty_strided_cuda((4096, 16384), (1, 4096), torch.float8_e4m3fn) + buf6 = empty_strided_cuda((1, 16384), (16384, 1), torch.float32) + + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf4:fp8=(16, 4096) + buf10:fp8=(16, 4096) + buf5:fp8=(4096, 16384) + buf6:fp32=(1, 16384) + # total bytes: 67,436,608 + + # Topologically Sorted Source Nodes: [output], Original ATen: [aten.abs, aten.amax, aten._to_copy, aten.clamp, aten.reciprocal, aten.mul, aten._scaled_mm] + stream0 = get_raw_stream(0) + triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_1.run( + primals_2, buf5, buf6, 16384, 4096, grid=grid(16384), stream=stream0 + ) + del primals_2 + buf7 = empty_strided_cuda((16, 16384), (16384, 1), torch.bfloat16) + + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf4:fp8=(16, 4096) + buf10:fp8=(16, 4096) + buf5:fp8=(4096, 16384) + buf6:fp32=(1, 16384) \ + # + buf7:bf16=(16, 16384) + # total bytes: 67,960,896 + + # Topologically Sorted Source Nodes: [output], Original ATen: [aten._to_copy, aten.clamp, aten.reciprocal, aten.mul, aten._scaled_mm] + extern_kernels._scaled_mm( + buf4, + buf5, + buf3, + buf6, + out_dtype=torch.bfloat16, + use_fast_accum=True, + out=buf7, + ) + del buf4 + + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf10:fp8=(16, 4096) + buf5:fp8=(4096, 16384) + buf6:fp32=(1, 16384) + buf7:bf16=(16, 16384) + # total bytes: 67,895,360 + + buf8 = empty_strided_cuda((4096, 16384), (1, 4096), torch.bfloat16) + + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf10:fp8=(16, 4096) + buf5:fp8=(4096, 16384) + buf6:fp32=(1, 16384) + buf7:bf16=(16, 16384) \ + # + buf8:bf16=(4096, 16384) + # total bytes: 202,113,088 + + buf11 = buf5 + + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf10:fp8=(16, 4096) + buf5:fp8=(4096, 16384) + buf11:fp8=(4096, 16384) + buf6:fp32=(1, 16384) + buf7:bf16=(16, 16384) \ + # + buf8:bf16=(4096, 16384) + # total bytes: 269,221,952 + + del buf5 # reuse + + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf10:fp8=(16, 4096) + buf11:fp8=(4096, 16384) + buf6:fp32=(1, 16384) + buf7:bf16=(16, 16384) \ + # + buf8:bf16=(4096, 16384) + # total bytes: 202,113,088 + + buf12 = buf6 + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf10:fp8=(16, 4096) + buf11:fp8=(4096, 16384) + buf6:fp32=(1, 16384) + buf12:fp32=(1, 16384)+ buf7:bf16=(16, 16384) \ + # + buf8:bf16=(4096, 16384) + # total bytes: 202,178,624 + + del buf6 # reuse + + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf10:fp8=(16, 4096) + buf11:fp8=(4096, 16384) + buf12:fp32=(1, 16384) + buf7:bf16=(16, 16384) \ + # + buf8:bf16=(4096, 16384) + # total bytes: 202,113,088 + + # Topologically Sorted Source Nodes: [output_1], Original ATen: [aten.abs, aten.amax, aten._to_copy, aten.clamp, aten.reciprocal, aten.mul, aten._scaled_mm] + stream0 = get_raw_stream(0) + triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_2.run( + primals_3, buf8, buf11, buf12, 16384, 4096, grid=grid(16384), stream=stream0 + ) + buf13 = empty_strided_cuda((16, 16384), (16384, 1), torch.bfloat16) + + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf10:fp8=(16, 4096) + buf11:fp8=(4096, 16384) + buf12:fp32=(1, 16384) + buf7:bf16=(16, 16384) \ + # + buf8:bf16=(4096, 16384) + buf13:bf16=(16, 16384) + # total bytes: 202,637,376 + + # Topologically Sorted Source Nodes: [output_1], Original ATen: [aten._to_copy, aten.clamp, aten.reciprocal, aten.mul, aten._scaled_mm] + extern_kernels._scaled_mm( + buf10, + buf11, + buf3, + buf12, + out_dtype=torch.bfloat16, + use_fast_accum=True, + out=buf13, + ) + del buf10 + del buf12 + + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf11:fp8=(4096, 16384) + buf7:bf16=(16, 16384) + buf8:bf16=(4096, 16384) + buf13:bf16=(16, 16384) + # total bytes: 202506304 + + buf14 = empty_strided_cuda((1, 16, 1, 2), (32, 2, 32, 1), torch.float32) + + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf11:fp8=(4096, 16384) + buf7:bf16=(16, 16384) + buf8:bf16=(4096, 16384) + buf13:bf16=(16, 16384) \ + # + buf14:fp32=(1, 16, 1, 2) + # total bytes: 202506432 + + # Topologically Sorted Source Nodes: [silu, mul, output_2], Original ATen: [aten.silu, aten.mul, aten.abs, aten.amax] + stream0 = get_raw_stream(0) + triton_red_fused_abs_amax_mul_silu_3.run( + buf7, buf13, buf14, 32, 8192, grid=grid(32), stream=stream0 + ) + buf15 = empty_strided_cuda((1, 16, 1), (16, 1, 16), torch.bfloat16) + buf19 = empty_strided_cuda((16, 1), (1, 1), torch.float32) + + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf11:fp8=(4096, 16384) + buf7:bf16=(16, 16384) + buf8:bf16=(4096, 16384) + buf13:bf16=(16, 16384) \ + # + buf14:fp32=(1, 16, 1, 2) + buf15:bf16=(1, 16, 1) + buf19:fp32=(16, 1) + # total bytes: 202506528 + + # Topologically Sorted Source Nodes: [silu, mul, output_2], Original ATen: [aten.silu, aten.mul, aten.abs, aten.amax, aten._to_copy, aten.clamp, aten.reciprocal, aten._scaled_mm] + stream0 = get_raw_stream(0) + triton_per_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_silu_4.run( + buf14, buf15, buf19, 16, 2, grid=grid(16), stream=stream0 + ) + del buf14 + + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf11:fp8=(4096, 16384) + buf7:bf16=(16, 16384) + buf8:bf16=(4096, 16384) + buf13:bf16=(16, 16384) \ + # + buf15:bf16=(1, 16, 1) + buf19:fp32=(16, 1) + # total bytes: 202506400 + + buf17 = empty_strided_cuda((16, 16384), (16384, 1), torch.float8_e4m3fn) + + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf11:fp8=(4096, 16384) + buf7:bf16=(16, 16384) + buf8:bf16=(4096, 16384) + buf13:bf16=(16, 16384) \ + # + buf15:bf16=(1, 16, 1) + buf19:fp32=(16, 1) + buf17:fp8=(16, 16384) + # total bytes: 202768544 + + # Topologically Sorted Source Nodes: [output_2], Original ATen: [aten._to_copy, aten.clamp, aten.reciprocal, aten.mul, aten._scaled_mm] + stream0 = get_raw_stream(0) + triton_poi_fused__scaled_mm__to_copy_clamp_mul_reciprocal_5.run( + buf7, buf13, buf15, buf17, 262144, grid=grid(262144), stream=stream0 + ) + del buf13 + del buf15 + + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf11:fp8=(4096, 16384) + buf7:bf16=(16, 16384) + buf8:bf16=(4096, 16384) \ + # + buf19:fp32=(16, 1) + buf17:fp8=(16, 16384) + # total bytes: 202244224 + + buf18 = reinterpret_tensor(buf11, (16384, 4096), (1, 16384), 0) + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf11:fp8=(4096, 16384) + buf7:bf16=(16, 16384) + buf8:bf16=(4096, 16384) \ + # + buf19:fp32=(16, 1) + buf17:fp8=(16, 16384) + buf18:fp8=(16384, 4096) + # total bytes: 269353088 + + del buf11 # reuse + + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf7:bf16=(16, 16384) + buf8:bf16=(4096, 16384) \ + # + buf19:fp32=(16, 1) + buf17:fp8=(16, 16384) + buf18:fp8=(16384, 4096) + # total bytes: 202244224 + + buf20 = empty_strided_cuda((1, 4096), (4096, 1), torch.float32) + + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf7:bf16=(16, 16384) + buf8:bf16=(4096, 16384) \ + # + buf19:fp32=(16, 1) + buf17:fp8=(16, 16384) + buf18:fp8=(16384, 4096) + buf20:fp32=(1, 4096) + # total bytes: 202260608 + + # Topologically Sorted Source Nodes: [output_2], Original ATen: [aten.abs, aten.amax, aten._to_copy, aten.clamp, aten.reciprocal, aten.mul, aten._scaled_mm] + stream0 = get_raw_stream(0) + triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_6.run( + primals_4, buf18, buf20, 4096, 16384, grid=grid(4096), stream=stream0 + ) + buf21 = empty_strided_cuda((16, 4096), (4096, 1), torch.bfloat16) + + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf7:bf16=(16, 16384) + buf8:bf16=(4096, 16384) \ + # + buf19:fp32=(16, 1) + buf17:fp8=(16, 16384) + buf18:fp8=(16384, 4096) + buf20:fp32=(1, 4096) + buf21:bf16=(16, 4096) + # total bytes: 202391680 + + # Topologically Sorted Source Nodes: [output_2], Original ATen: [aten._to_copy, aten.clamp, aten.reciprocal, aten.mul, aten._scaled_mm] + extern_kernels._scaled_mm( + buf17, + buf18, + buf19, + buf20, + out_dtype=torch.bfloat16, + use_fast_accum=True, + out=buf21, + ) + del buf17 + del buf18 + del buf19 + del buf20 + # EXTRA MEM: buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf7:bf16=(16, 16384) + buf8:bf16=(4096, 16384) \ + # + buf21:bf16=(16, 4096) + # total bytes: 135004224 + return ( + reinterpret_tensor( + buf21, (1, 16, 4096), (65536, 4096, 1), 0 + ), # buf21:bf16=(16, 4096) => small + primals_1, + primals_3, + primals_4, + buf0, # buf0:bf16=(1, 16, 4096) => small + buf3, # buf3:fp32=(16, 1) => small + buf7, # buf7:bf16=(16, 16384) => small + buf8, # buf8:bf16=(4096, 16384) => huge, what is this? used in triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_2 + ) + + # RETURNS (save for backward): buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf7:bf16=(16, 16384) + buf8:bf16=(4096, 16384) = 134,873,152 bytes + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + + primals_1 = rand_strided( + (1, 16, 4096), (65536, 4096, 1), device="cuda:0", dtype=torch.bfloat16 + ) + primals_2 = rand_strided( + (16384, 4096), (4096, 1), device="cuda:0", dtype=torch.bfloat16 + ) + primals_3 = rand_strided( + (16384, 4096), (4096, 1), device="cuda:0", dtype=torch.bfloat16 + ) + primals_4 = rand_strided( + (4096, 16384), (16384, 1), device="cuda:0", dtype=torch.bfloat16 + ) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4]) + return print_performance(fn, times=times, repeat=repeat) + + +def total_bytes(tensors: list[torch.Tensor]) -> int: + total = 0 + for tensor in tensors: + total += tensor.element_size() * tensor.numel() + return total + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + + compiled_module_main("None", benchmark_compiled_module) From 0e00c202a5f9e7888a5a502dea8ce6911004bd49 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 3 Feb 2025 22:18:01 -0800 Subject: [PATCH 12/14] continue annotating fp8 fwd kernel --- logs/ffn_fp8_fwd.py | 105 +++++++++++++++++++++++++++----------------- 1 file changed, 65 insertions(+), 40 deletions(-) diff --git a/logs/ffn_fp8_fwd.py b/logs/ffn_fp8_fwd.py index 6029338921..a77a9ffb74 100644 --- a/logs/ffn_fp8_fwd.py +++ b/logs/ffn_fp8_fwd.py @@ -115,7 +115,7 @@ def triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_0(in_ptr0 tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) tmp4 = triton_helpers.maximum(_tmp3, tmp2) _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) - tl.store(out_ptr0 + (r0_1 + 4096*x0), tmp1, r0_mask & xmask) + tl.store(out_ptr0 + (r0_1 + 4096*x0), tmp1, r0_mask & xmask) <------ amaxes tmp3 = triton_helpers.max2(_tmp3, 1)[:, None] tmp5 = tmp3.to(tl.float64) tmp6 = tl.full([1, 1], 1e-12, tl.float64) @@ -126,7 +126,7 @@ def triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_0(in_ptr0 tmp11 = tmp9 * tmp10 tmp12 = tmp11.to(tl.float32) tmp13 = tmp8 / tmp12 - tl.store(out_ptr2 + (x0), tmp13, xmask) + tl.store(out_ptr2 + (x0), tmp13, xmask) <------ scale for r0_offset in range(0, r0_numel, R0_BLOCK): r0_index = r0_offset + r0_base r0_mask = r0_index < r0_numel @@ -141,8 +141,8 @@ def triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_0(in_ptr0 tmp19 = 448.0 tmp20 = triton_helpers.minimum(tmp18, tmp19) tmp21 = tmp20.to(tl.float8e4nv) - tl.store(out_ptr3 + (r0_1 + 4096*x0), tmp21, r0_mask & xmask) - tl.store(out_ptr4 + (r0_1 + 4096*x0), tmp21, r0_mask & xmask) + tl.store(out_ptr3 + (r0_1 + 4096*x0), tmp21, r0_mask & xmask) <---- quantized values + tl.store(out_ptr4 + (r0_1 + 4096*x0), tmp21, r0_mask & xmask) <---- quantized values """, device_str="cuda", ) @@ -210,7 +210,7 @@ def triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_1(in_ptr0 tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) tmp4 = triton_helpers.maximum(_tmp3, tmp2) _tmp3 = tl.where(r0_mask, tmp4, _tmp3) - tmp3 = triton_helpers.max2(_tmp3, 1)[:, None] + tmp3 = triton_helpers.max2(_tmp3, 1)[:, None] # <--- max along dim 1 for r0_offset in range(0, r0_numel, R0_BLOCK): r0_index = r0_offset + r0_base r0_mask = r0_index < r0_numel @@ -223,27 +223,27 @@ def triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_1(in_ptr0 tmp8 = tl.full([1, 1], 1e-12, tl.float64) tmp9 = triton_helpers.maximum(tmp7, tmp8) tmp10 = tl.full([1, 1], 1, tl.int32) - tmp11 = tmp10 / tmp9 + tmp11 = tmp10 / tmp9 <------ scale tmp12 = tl.full([1, 1], 448.0, tl.float64) tmp13 = tmp11 * tmp12 - tmp14 = tmp13.to(tl.float32) + tmp14 = tmp13.to(tl.float32) <--- apply scale tmp15 = tmp6 * tmp14 tmp16 = -448.0 - tmp17 = triton_helpers.maximum(tmp15, tmp16) + tmp17 = triton_helpers.maximum(tmp15, tmp16) <-- clamp tmp18 = 448.0 - tmp19 = triton_helpers.minimum(tmp17, tmp18) - tmp20 = tmp19.to(tl.float8e4nv) - tl.store(out_ptr1 + (r0_1 + 4096*x0), tmp20, r0_mask) + tmp19 = triton_helpers.minimum(tmp17, tmp18) <-- clamp + tmp20 = tmp19.to(tl.float8e4nv) <--- convert dtype + tl.store(out_ptr1 + (r0_1 + 4096*x0), tmp20, r0_mask) <----- quantized values tmp21 = tmp3.to(tl.float64) - tmp22 = tl.full([1, 1], 1e-12, tl.float64) - tmp23 = triton_helpers.maximum(tmp21, tmp22) + tmp22 = tl.full([1, 1], 1e-12, tl.float64) <--- EPS + tmp23 = triton_helpers.maximum(tmp21, tmp22) <--- apply min EPS tmp24 = tl.full([1, 1], 1, tl.int32) - tmp25 = tmp24 / tmp23 - tmp26 = tl.full([1, 1], 448.0, tl.float64) - tmp27 = tmp25 * tmp26 - tmp28 = tmp27.to(tl.float32) - tmp29 = tmp24 / tmp28 - tl.store(out_ptr2 + (x0), tmp29, None) + tmp25 = tmp24 / tmp23 <---- reciprocal of clamped w/ min EPS + tmp26 = tl.full([1, 1], 448.0, tl.float64) <--- max for fp8 dtype (448) + tmp27 = tmp25 * tmp26 <--- dtype max / clamped w/ EPS + tmp28 = tmp27.to(tl.float32) + tmp29 = tmp24 / tmp28 <--- reciprocal of scale + tl.store(out_ptr2 + (x0), tmp29, None) <--- return reciprocal of scale """, device_str="cuda", ) @@ -311,7 +311,7 @@ def triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_2(in_ptr0 tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) tmp4 = triton_helpers.maximum(_tmp3, tmp2) _tmp3 = tl.where(r0_mask, tmp4, _tmp3) - tl.store(out_ptr0 + (r0_1 + 4096*x0), tmp1, r0_mask) + tl.store(out_ptr0 + (r0_1 + 4096*x0), tmp1, r0_mask) <---- full tensor abs values? tmp3 = triton_helpers.max2(_tmp3, 1)[:, None] for r0_offset in range(0, r0_numel, R0_BLOCK): r0_index = r0_offset + r0_base @@ -335,7 +335,7 @@ def triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_2(in_ptr0 tmp18 = 448.0 tmp19 = triton_helpers.minimum(tmp17, tmp18) tmp20 = tmp19.to(tl.float8e4nv) - tl.store(out_ptr2 + (r0_1 + 4096*x0), tmp20, r0_mask) + tl.store(out_ptr2 + (r0_1 + 4096*x0), tmp20, r0_mask) <--- quantized values tmp21 = tmp3.to(tl.float64) tmp22 = tl.full([1, 1], 1e-12, tl.float64) tmp23 = triton_helpers.maximum(tmp21, tmp22) @@ -345,7 +345,7 @@ def triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_2(in_ptr0 tmp27 = tmp25 * tmp26 tmp28 = tmp27.to(tl.float32) tmp29 = tmp24 / tmp28 - tl.store(out_ptr3 + (x0), tmp29, None) + tl.store(out_ptr3 + (x0), tmp29, None) <---- scale """, device_str="cuda", ) @@ -687,10 +687,10 @@ def triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_6(in_ptr0 def call(args): primals_1, primals_2, primals_3, primals_4 = args args.clear() - assert_size_stride(primals_1, (1, 16, 4096), (65536, 4096, 1)) - assert_size_stride(primals_2, (16384, 4096), (4096, 1)) - assert_size_stride(primals_3, (16384, 4096), (4096, 1)) - assert_size_stride(primals_4, (4096, 16384), (16384, 1)) + assert_size_stride(primals_1, (1, 16, 4096), (65536, 4096, 1)) # input + assert_size_stride(primals_2, (16384, 4096), (4096, 1)) # w1 + assert_size_stride(primals_3, (16384, 4096), (4096, 1)) # w3 + assert_size_stride(primals_4, (4096, 16384), (16384, 1)) # w2 with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) buf0 = empty_strided_cuda((1, 16, 4096), (65536, 4096, 1), torch.bfloat16) @@ -705,9 +705,18 @@ def call(args): # total bytes: 262,208 # Topologically Sorted Source Nodes: [output, output_1], Original ATen: [aten.abs, aten.amax, aten._to_copy, aten.clamp, aten.reciprocal, aten.mul, aten._scaled_mm] + # INPUTS -> get amax, scales, and quantized values stream0 = get_raw_stream(0) triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_0.run( - primals_1, buf0, buf3, buf4, buf10, 16, 4096, grid=grid(16), stream=stream0 + primals_1, + buf0, # amaxes + buf3, # scales + buf4, # quantized values + buf10, # copy of quantized values + 16, + 4096, + grid=grid(16), + stream=stream0, ) buf5 = empty_strided_cuda((4096, 16384), (1, 4096), torch.float8_e4m3fn) buf6 = empty_strided_cuda((1, 16384), (16384, 1), torch.float32) @@ -716,9 +725,16 @@ def call(args): # total bytes: 67,436,608 # Topologically Sorted Source Nodes: [output], Original ATen: [aten.abs, aten.amax, aten._to_copy, aten.clamp, aten.reciprocal, aten.mul, aten._scaled_mm] + # W1 -> get quantized values and scale (or reciprocal of scale?) stream0 = get_raw_stream(0) triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_1.run( - primals_2, buf5, buf6, 16384, 4096, grid=grid(16384), stream=stream0 + primals_2, + buf5, # output quantized values + buf6, # output scales + 16384, + 4096, + grid=grid(16384), + stream=stream0, ) del primals_2 buf7 = empty_strided_cuda((16, 16384), (16384, 1), torch.bfloat16) @@ -728,14 +744,15 @@ def call(args): # total bytes: 67,960,896 # Topologically Sorted Source Nodes: [output], Original ATen: [aten._to_copy, aten.clamp, aten.reciprocal, aten.mul, aten._scaled_mm] + # SCALED_MM WITH FP8 INPUTS AND W1 extern_kernels._scaled_mm( - buf4, - buf5, - buf3, - buf6, + buf4, # quantized inputs + buf5, # quantized W1 + buf3, # input scales + buf6, # w1 scales out_dtype=torch.bfloat16, use_fast_accum=True, - out=buf7, + out=buf7, # w1(x) ) del buf4 @@ -772,9 +789,17 @@ def call(args): # total bytes: 202,113,088 # Topologically Sorted Source Nodes: [output_1], Original ATen: [aten.abs, aten.amax, aten._to_copy, aten.clamp, aten.reciprocal, aten.mul, aten._scaled_mm] + # W3 -> get quantized values and scale stream0 = get_raw_stream(0) triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_2.run( - primals_3, buf8, buf11, buf12, 16384, 4096, grid=grid(16384), stream=stream0 + primals_3, # W3 + buf8, # full W3 abs values?! + buf11, # quantized values + buf12, # scales + 16384, + 4096, + grid=grid(16384), + stream=stream0, ) buf13 = empty_strided_cuda((16, 16384), (16384, 1), torch.bfloat16) @@ -784,13 +809,13 @@ def call(args): # Topologically Sorted Source Nodes: [output_1], Original ATen: [aten._to_copy, aten.clamp, aten.reciprocal, aten.mul, aten._scaled_mm] extern_kernels._scaled_mm( - buf10, - buf11, - buf3, - buf12, + buf10, # quantized inputs + buf11, # quantized W3 + buf3, # input scales + buf12, # W3 scales out_dtype=torch.bfloat16, use_fast_accum=True, - out=buf13, + out=buf13, # W3(x) ) del buf10 del buf12 @@ -900,7 +925,7 @@ def call(args): buf0, # buf0:bf16=(1, 16, 4096) => small buf3, # buf3:fp32=(16, 1) => small buf7, # buf7:bf16=(16, 16384) => small - buf8, # buf8:bf16=(4096, 16384) => huge, what is this? used in triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_2 + buf8, # buf8:bf16=(4096, 16384) => huge, this is the full abs(W3) ) # RETURNS (save for backward): buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf7:bf16=(16, 16384) + buf8:bf16=(4096, 16384) = 134,873,152 bytes From 679381dad1bf994e70dbc66688e8bcff23486a12 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 4 Feb 2025 09:27:51 -0800 Subject: [PATCH 13/14] more kernel annotations --- logs/ffn_bf16_fwd.py | 12 ++++++------ logs/ffn_fp8_fwd.py | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/logs/ffn_bf16_fwd.py b/logs/ffn_bf16_fwd.py index 20dda9de0f..236714b404 100644 --- a/logs/ffn_bf16_fwd.py +++ b/logs/ffn_bf16_fwd.py @@ -96,10 +96,10 @@ def triton_poi_fused_mul_silu_0(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.conste def call(args): primals_1, primals_2, primals_3, primals_4 = args args.clear() - assert_size_stride(primals_1, (1, 16, 4096), (65536, 4096, 1)) - assert_size_stride(primals_2, (16384, 4096), (4096, 1)) - assert_size_stride(primals_3, (16384, 4096), (4096, 1)) - assert_size_stride(primals_4, (4096, 16384), (16384, 1)) + assert_size_stride(primals_1, (1, 16, 4096), (65536, 4096, 1)) # input + assert_size_stride(primals_2, (16384, 4096), (4096, 1)) # w1 + assert_size_stride(primals_3, (16384, 4096), (4096, 1)) # w3 + assert_size_stride(primals_4, (4096, 16384), (16384, 1)) # w2 with torch.cuda._DeviceGuard(1): torch.cuda.set_device(1) buf0 = empty_strided_cuda((16, 16384), (16384, 1), torch.bfloat16) @@ -144,11 +144,11 @@ def call(args): # PEAK EXTRA MEM was line 134: (16*16384) + (1*16*16384) + (16*4096) = 589824 * 2 bytes for bf16 = 655360 bytes return ( - reinterpret_tensor(buf3, (1, 16, 4096), (65536, 4096, 1), 0), + reinterpret_tensor(buf3, (1, 16, 4096), (65536, 4096, 1), 0), # FFN output primals_1, primals_3, primals_4, - buf0, + buf0, # w1(x) ) # RETURNS (save for backward?) only buf0 and buf3 => buf0=(16*16384) + buf3=(1*16*4096) * 2 bytes for bf16 = 393216 bytes diff --git a/logs/ffn_fp8_fwd.py b/logs/ffn_fp8_fwd.py index a77a9ffb74..74d62b8535 100644 --- a/logs/ffn_fp8_fwd.py +++ b/logs/ffn_fp8_fwd.py @@ -115,7 +115,7 @@ def triton_red_fused__scaled_mm__to_copy_abs_amax_clamp_mul_reciprocal_0(in_ptr0 tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) tmp4 = triton_helpers.maximum(_tmp3, tmp2) _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) - tl.store(out_ptr0 + (r0_1 + 4096*x0), tmp1, r0_mask & xmask) <------ amaxes + tl.store(out_ptr0 + (r0_1 + 4096*x0), tmp1, r0_mask & xmask) <------ full input abs? tmp3 = triton_helpers.max2(_tmp3, 1)[:, None] tmp5 = tmp3.to(tl.float64) tmp6 = tl.full([1, 1], 1e-12, tl.float64) @@ -918,14 +918,14 @@ def call(args): return ( reinterpret_tensor( buf21, (1, 16, 4096), (65536, 4096, 1), 0 - ), # buf21:bf16=(16, 4096) => small + ), # buf21:bf16=(16, 4096) => small (FFN output) primals_1, primals_3, primals_4, - buf0, # buf0:bf16=(1, 16, 4096) => small - buf3, # buf3:fp32=(16, 1) => small - buf7, # buf7:bf16=(16, 16384) => small - buf8, # buf8:bf16=(4096, 16384) => huge, this is the full abs(W3) + buf0, # buf0:bf16=(1, 16, 4096) => abs(input) => small + buf3, # buf3:fp32=(16, 1) => rowwise scales for inputs + buf7, # buf7:bf16=(16, 16384) => W1(x) => small + buf8, # buf8:bf16=(4096, 16384) => abs(W3) => huge ) # RETURNS (save for backward): buf0:bf16=(1, 16, 4096) + buf3:fp32=(16, 1) + buf7:bf16=(16, 16384) + buf8:bf16=(4096, 16384) = 134,873,152 bytes From 946a733a1c4cbf019132327b6f28ba4679f7941e Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 4 Feb 2025 09:31:54 -0800 Subject: [PATCH 14/14] mem calculations --- logs/ffn_bf16_fwd.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/logs/ffn_bf16_fwd.py b/logs/ffn_bf16_fwd.py index 236714b404..ca1a871afc 100644 --- a/logs/ffn_bf16_fwd.py +++ b/logs/ffn_bf16_fwd.py @@ -142,16 +142,18 @@ def call(args): del buf2 # EXTRA MEM: buf0=(16*16384) + buf3=(16*4096) - # PEAK EXTRA MEM was line 134: (16*16384) + (1*16*16384) + (16*4096) = 589824 * 2 bytes for bf16 = 655360 bytes + # PEAK EXTRA MEM was line 134: (16*16384) + (1*16*16384) + (16*4096) = 589824 * 2 bytes for bf16 = 655,360 bytes return ( - reinterpret_tensor(buf3, (1, 16, 4096), (65536, 4096, 1), 0), # FFN output + reinterpret_tensor( + buf3, (1, 16, 4096), (65536, 4096, 1), 0 + ), # FFN output => (1,16,4096) in bf16 = 131,072 bytes primals_1, primals_3, primals_4, - buf0, # w1(x) + buf0, # w1(x) => (16,16384) in bf16 = 524,288 bytes ) - # RETURNS (save for backward?) only buf0 and buf3 => buf0=(16*16384) + buf3=(1*16*4096) * 2 bytes for bf16 = 393216 bytes + # RETURNS (save for backward?) only buf0 and buf3 => (buf0=(16*16384) + buf3=(1*16*4096)) * 2 bytes for bf16 = 655,360 bytes def benchmark_compiled_module(times=10, repeat=10):