In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os 
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [5]:
import torch
import triton
import triton.language as tl

In [6]:
DEVICE = torch.device('cuda')


In [26]:
@triton.jit
def fused_relu_kernel(x_ptr,bias_ptr,output_ptr , M:int , N :int ,BLOCK_SIZE:tl.constexpr):
    start_id  = tl.program_id(0)
    x_ptr += start_id* N 
    output_ptr+= start_id* N 

    for offset in range(0,N,BLOCK_SIZE):
        offsets = offset  + tl.arange(0,BLOCK_SIZE)
        mask = offsets< N 
        x_data =  tl.load(x_ptr+offsets,mask=mask,other=0.0)
        bias_data = tl.load(bias_ptr+offsets,mask=mask,other=0.0)
        x_data += bias_data
        out = tl.maximum(x_data, 0.0)
        tl.store(output_ptr+offsets,out,mask=mask)
        


def triton_relu(x:torch.Tensor,bias:torch.Tensor):
    M,N = x.shape
    y = torch.zeros_like(x,device=x.device)
    TOTAL_SM_MEM = 65536
    BLOCK_SIZE = TOTAL_SM_MEM // x.element_size()
    BLOCK_SIZE = min(256,triton.next_power_of_2(BLOCK_SIZE))
    num_warps = 8
    fused_relu_kernel[(M,)](
        x,bias,y,M,N,BLOCK_SIZE=BLOCK_SIZE,num_warps=num_warps
    )
    return y


@triton.jit 
def fused_relu_backward(x_ptr, bias_ptr, grad_output_ptr, grad_x_ptr, M, N, BLOCK_SIZE: tl.constexpr):
    start_id = tl.program_id(0)
    x_ptr += start_id * N 
    grad_output_ptr += start_id * N 
    grad_x_ptr += start_id * N 
    
    for offset in range(0, N, BLOCK_SIZE):
        offsets = offset + tl.arange(0, BLOCK_SIZE)
        mask = offsets < N 
        x_data = tl.load(x_ptr + offsets, mask=mask, other=0.0)
        grad_out_data = tl.load(grad_output_ptr + offsets, mask=mask, other=0.0)
        bias_data = tl.load(bias_ptr + offsets, mask=mask, other=0.0)
        x_bias = x_data + bias_data
        mask_relu = x_bias > 0 
        grad_x_val = grad_out_data * mask_relu
        tl.store(grad_x_ptr + offsets, grad_x_val, mask=mask)
        
        
    
    


def triton_relu_backward(x,bias,grad_output):
    """
    shapes we get in is x (M,N)
    bias is (N,)
    and grad_output again is (M,N)
    """
    M,N = x.shape 
    grad_x = torch.empty_like(x,device=x.device)
    BLOCK_SIZE  = min(256,triton.next_power_of_2(N))
    num_warps = 8 
    fused_relu_backward[(M,)](
        x,bias,grad_output,grad_x,M,N,BLOCK_SIZE=BLOCK_SIZE,num_warps=num_warps
    )
    return grad_x
    
    



class FusedReluFunction(torch.autograd.Function):
    @staticmethod 
    def forward(ctx,x,bias):
        y = triton_relu(x,bias)
        ctx.save_for_backward(x,bias)
        return y 
    @staticmethod
    def backward(ctx,grad_output):
        saved_tensors = ctx.saved_tensors
        x,bias = saved_tensors
        grad_x = triton_relu_backward(x,bias,grad_output)
        grad_bias  = torch.sum(grad_x,axis=0)
        return grad_x,grad_bias

def fused_relu(x, bias):
    return FusedReluFunction.apply(x, bias)

In [27]:
def test_relu(M,N,atol=1e-5):
    x = torch.randn(M,N,dtype=torch.float32,device='cuda')
    bias = torch.randn(N,dtype=torch.float32,device='cuda')
    torch_out = torch.nn.functional.relu(x+bias)
    tri_out = triton_relu(x,bias)
    triton.testing.assert_close(torch_out,tri_out,atol=1e-5,rtol=1e-5)

def test_backward():
    M,N = 4,100
    x = torch.randn(M,N,dtype=torch.float32,device='cuda',requires_grad=True)
    bias = torch.randn(N,dtype=torch.float32,device='cuda',requires_grad=True)
    x_ref = x.clone().detach().requires_grad_(True)
    bias_ref = bias.clone().detach().requires_grad_(True)
    y_triton = fused_relu(x,bias)
    y_torch = torch.relu(x_ref+bias_ref)

    grad_output = torch.randn_like(y_triton)

    y_triton.backward(grad_output)
    y_torch.backward(grad_output)
    

    triton.testing.assert_close(x.grad,x_ref.grad,atol=1e-5,rtol=1e-5)
    triton.testing.assert_close(bias.grad,bias_ref.grad,atol=1e-5,rtol=1e-5)


In [28]:
test_backward()