## Problem statement

Your task is to speed up the `forward pass` of the `swiglu` function:
```python
def swiglu(a, b):
    return torch.nn.functional.silu(a) * b
````

Your implementation must be written in `Triton`, be called via the interface
`torch.ops.llm_scaling_week.swiglu_fwd(a, b)` and return only a single
`torch.Tensor` – the result of the operation.

The implementation will be checked for correctness and performance.
To pass the correctness test, the result of your function must match the output of
the eager `swiglu` implementation under `torch.allclose`.
To pass the performance test, your function must take **≤ 75%** of the time of the
eager `swiglu` implementation.

**It is guaranteed that the inputs will be `contiguous` tensors. The types and shapes
of the input tensors `a` and `b` are the same.**
Note that the function must work with both `fp32` and `bf16` tensors.
The function must work efficiently regardless of the tensor shape.

The reference solution passes all tests both on H100 and in Google Colab.

## Note

You can view the test logs by downloading the output of test 1 on the contest website.
Do not rename the `solution.py` file. Your solution must be in this file.



In [1]:
!pip install -q triton

import torch
print("CUDA available:", torch.cuda.is_available())

CUDA available: True


In [2]:
# First try

# import torch
# import triton
# import triton.language as tl
# from torch.library import Library

# _lib = Library("llm_scaling_week", "DEF")
# _lib.define("swiglu_fwd(Tensor a, Tensor b) -> Tensor") 

# @triton.jit
# def swiglu_kernel(a_ptr, b_ptr, out_ptr, n_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr):
#     pid = tl.program_id(axis=0)                       
#     offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 
#     mask = offs < n_elements                           
    
#     a_vals = tl.load(a_ptr + offs, mask=mask, other=0.0)
#     b_vals = tl.load(b_ptr + offs, mask=mask, other=0.0)
    
#     a_vals_f32 = a_vals.to(tl.float32)
#     b_vals_f32 = b_vals.to(tl.float32)
#     silu_a = a_vals_f32 * tl.sigmoid(a_vals_f32)
#     out_f32 = silu_a * b_vals_f32
    
#     out_vals = out_f32.to(a_vals.dtype)  
#     tl.store(out_ptr + offs, out_vals, mask=mask)     
    

# def _swiglu_fwd_impl(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
#     if not a.is_cuda or not b.is_cuda:
#         raise RuntimeError("swiglu_fwd: both inputs must be CUDA tensors")
#     if a.dtype not in (torch.float32, torch.bfloat16) or b.dtype not in (torch.float32, torch.bfloat16):
#         raise RuntimeError("swiglu_fwd: inputs must be float32 or bfloat16 tensors")
#     if a.dtype != b.dtype:
#         raise RuntimeError("swiglu_fwd: inputs must have the same dtype")
#     if a.shape != b.shape:
#         raise RuntimeError("swiglu_fwd: inputs must have the same shape")
#     if not a.is_contiguous():
#         a = a.contiguous()
#     if not b.is_contiguous():
#         b = b.contiguous()
#     out = torch.empty_like(a)
    
#     n_elements = a.numel()
#     grid = (triton.cdiv(n_elements, 256),) 
#     swiglu_kernel[grid](a, b, out, n_elements, BLOCK_SIZE=256, num_warps=4)
#     return out

# _impl_lib = Library("llm_scaling_week", "IMPL")
# _impl_lib.impl("swiglu_fwd", _swiglu_fwd_impl, "CUDA")


import torch
import triton
import triton.language as tl
from torch.library import Library

try:
    _def_lib = Library("llm_scaling_week", "DEF")
    _def_lib.define("swiglu_fwd(Tensor a, Tensor b) -> Tensor")
except RuntimeError:
    pass


@triton.autotune(
    configs=[
        triton.Config({"BLOCK_SIZE": 256}, num_warps=4),
        triton.Config({"BLOCK_SIZE": 512}, num_warps=4),
        triton.Config({"BLOCK_SIZE": 1024}, num_warps=8),
        triton.Config({"BLOCK_SIZE": 2048}, num_warps=8),
    ],
    key=["n_elements"],
)
@triton.jit
def swiglu_kernel(a_ptr, b_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offs < n_elements

    a_vals = tl.load(a_ptr + offs, mask=mask, other=0.0)
    b_vals = tl.load(b_ptr + offs, mask=mask, other=0.0)

    a_f32 = a_vals.to(tl.float32)
    b_f32 = b_vals.to(tl.float32)

    silu = a_f32 * tl.sigmoid(a_f32)
    out_f32 = silu * b_f32

    out_vals = out_f32.to(a_vals.dtype)
    tl.store(out_ptr + offs, out_vals, mask=mask)


def _swiglu_fwd_cuda(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    if not (a.is_cuda and b.is_cuda):
        raise RuntimeError("swiglu_fwd: CUDA only")
    if a.dtype != b.dtype:
        raise RuntimeError("swiglu_fwd: dtypes must match")
    if a.shape != b.shape:
        raise RuntimeError("swiglu_fwd: shapes must match")
    if a.dtype not in (torch.float32, torch.bfloat16):
        raise RuntimeError("swiglu_fwd: supports only float32 and bfloat16")

    if not a.is_contiguous():
        a = a.contiguous()
    if not b.is_contiguous():
        b = b.contiguous()

    out = torch.empty_like(a)
    n_elements = a.numel()

    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
    swiglu_kernel[grid](a, b, out, n_elements)

    return out


def _swiglu_fwd_cpu(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    return torch.nn.functional.silu(a) * b


_impl_lib = Library("llm_scaling_week", "IMPL")
_impl_lib.impl("swiglu_fwd", _swiglu_fwd_cuda, "CUDA")
_impl_lib.impl("swiglu_fwd", _swiglu_fwd_cpu, "CPU")

In [3]:
device = "cuda"

for dtype in (torch.float32,): 
    a = torch.randn(4096, 4096, device=device, dtype=dtype)
    b = torch.randn_like(a)

    ref = torch.nn.functional.silu(a) * b
    out = torch.ops.llm_scaling_week.swiglu_fwd(a, b)

    print(dtype, "allclose =", torch.allclose(ref, out, atol=1e-3, rtol=1e-3))


torch.float32 allclose = True


In [4]:
import time
def bench(fn, iters=100):
    torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(iters):
        fn()
    torch.cuda.synchronize()
    return (t0, time.time() - t0)

a = torch.randn(8192, 4096, device=device, dtype=torch.float32)
b = torch.randn_like(a)

for _ in range(10):
    _ = torch.nn.functional.silu(a) * b
    _ = torch.ops.llm_scaling_week.swiglu_fwd(a, b)

_, t_eager = bench(lambda: torch.nn.functional.silu(a) * b)
_, t_triton = bench(lambda: torch.ops.llm_scaling_week.swiglu_fwd(a, b))

print("eager :", t_eager)
print("ops   :", t_triton)
print("ratio :", t_triton / t_eager)

eager : 0.2787649631500244
ops   : 0.15988707542419434
ratio : 0.573555132673352


In [5]:
# Output

'''
============================= test session starts ==============================
platform linux -- Python 3.11.14, pytest-9.0.0, pluggy-1.6.0 -- /opt/conda/bin/python3.11
cachedir: .pytest_cache
hypothesis profile 'default'
rootdir: /workspace
plugins: hypothesis-6.141.0
collecting ... collected 16 items

tests.py::test_swiglu_fwd_quality_check[shape0-dtype0] PASSED            [  6%]
tests.py::test_swiglu_fwd_quality_check[shape0-dtype1] PASSED            [ 12%]
tests.py::test_swiglu_fwd_quality_check[shape1-dtype0] PASSED            [ 18%]
tests.py::test_swiglu_fwd_quality_check[shape1-dtype1] PASSED            [ 25%]
tests.py::test_swiglu_fwd_quality_check[shape2-dtype0] PASSED            [ 31%]
tests.py::test_swiglu_fwd_quality_check[shape2-dtype1] PASSED            [ 37%]
tests.py::test_swiglu_fwd_quality_check[shape3-dtype0] PASSED            [ 43%]
tests.py::test_swiglu_fwd_quality_check[shape3-dtype1] PASSED            [ 50%]
tests.py::test_swiglu_fwd_speed_vs_eager_check[shape0-dtype0] PASSED     [ 56%]
tests.py::test_swiglu_fwd_speed_vs_eager_check[shape0-dtype1] PASSED     [ 62%]
tests.py::test_swiglu_fwd_speed_vs_eager_check[shape1-dtype0] PASSED     [ 68%]
tests.py::test_swiglu_fwd_speed_vs_eager_check[shape1-dtype1] PASSED     [ 75%]
tests.py::test_swiglu_fwd_speed_vs_eager_check[shape2-dtype0] PASSED     [ 81%]
tests.py::test_swiglu_fwd_speed_vs_eager_check[shape2-dtype1] PASSED     [ 87%]
tests.py::test_swiglu_fwd_speed_vs_eager_check[shape3-dtype0] PASSED     [ 93%]
tests.py::test_swiglu_fwd_speed_vs_eager_check[shape3-dtype1] PASSED     [100%]

============================== 16 passed in 8.57s ==============================
'''

