## Imports

In [35]:
import torch
import flux_triton.modules.layer_norm as triton_layer_norm
import flux.modules.layers as torch_layers
import functools

## Setup Tensors

In [28]:
device = "cuda" if torch.cuda.is_available() else "cpu"
SHAPE = [1, 256]
x = torch.randn(SHAPE).to(device)


## Construct Layers

In [29]:
torch_ln = torch_layers.LayerNorm(SHAPE).to(device)
liger_ln = triton_layer_norm.LigerLayerNorm(SHAPE)

In [34]:
triton_layer_norm.LigerLayerNorm

flux_triton.modules.layer_norm.LigerLayerNorm

## Profile

In [40]:
# Change these values to where you want to save
LAYER_NAME = "LN"

TORCH_PROFILE_DIR = f"./tmp_log/{LAYER_NAME}"
LIGER_PROFILE_DIR = f"./tmp_log/{LAYER_NAME}"

In [41]:
def run_profile(
    x: torch.Tensor,
    torch_layer: torch.nn.Module,
    liger_layer: torch.nn.Module,
):
    profile_fn = functools.partial(torch.profiler.profile, 
        with_stack=True,
        profile_memory=True,
        with_flops=True,
        use_cuda=True,
        record_shapes=True
    )
    with profile_fn(on_trace_ready=torch.profiler.tensorboard_trace_handler(TORCH_PROFILE_DIR)):
        _ = torch_layer(x)
    with profile_fn(on_trace_ready=torch.profiler.tensorboard_trace_handler(LIGER_PROFILE_DIR)):
        _ = liger_layer(x)

In [42]:
run_profile(
    x = x,
    torch_layer = torch_ln,
    liger_layer = liger_ln
)

  with profile_fn(on_trace_ready=torch.profiler.tensorboard_trace_handler(TORCH_PROFILE_DIR)):
  with profile_fn(on_trace_ready=torch.profiler.tensorboard_trace_handler(LIGER_PROFILE_DIR)):
