In [3]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [4]:
import triton.language as tl
import torch
import triton

In [5]:
def cross_entropy_ref(logits,targets):
    """
    standard cross entropy function takes in logits of shape seq_len,vocab_size , targets of seq_len
    """
    softmax_probs  = torch.nn.functional.softmax(logits,dim=-1)
    selected_tokens_probs  = torch.gather(softmax_probs,1,targets.unsqueeze(1))
    loss = -torch.log(selected_tokens_probs + 1e-9)  # small epsilon to avoid log(0)
    return loss.sum()

@triton.jit
def indexed_matmul_kernel(E_ptr,C_ptr,target_ptr,output_ptr,embed_dim,seq_len,vocab_size,stride_e_seq_len,stride_e_embed_dim,
                         stride_c_embed_dim,stride_c_vocab_size,BLOCK_SIZE_N:tl.constexpr,BLOCK_SIZE_D:tl.constexpr):
    pid_n = tl.program_id(0)
    start_id  = BLOCK_SIZE_N * pid_n
    token_range_offsets = start_id + tl.arange(0,BLOCK_SIZE_N)
    mask_n = token_range_offsets<seq_len
    target_indices = tl.load(target_ptr+token_range_offsets,mask=mask_n,other=0.0)
    acc = tl.zeros((BLOCK_SIZE_N,),dtype=tl.float32)
    for embed_id in range(tl.cdiv(embed_dim,BLOCK_SIZE_D)):
        start_embed = embed_id * BLOCK_SIZE_D
        embed_offsets = start_embed + tl.arange(0,BLOCK_SIZE_D)
        embed_mask =  embed_offsets < embed_dim
        E_ptr_block_offsets  = token_range_offsets[:,None] * stride_e_seq_len + embed_offsets[None,:]* stride_e_embed_dim
        E_chunk = tl.load(E_ptr+E_ptr_block_offsets,mask = mask_n[:,None] & embed_mask[None,:] , other=0.0)

        C_ptr_block_offsets  = embed_offsets[:,None] * stride_c_embed_dim + target_indices[None,:] * stride_c_vocab_size

        C_chunk = tl.load(C_ptr+C_ptr_block_offsets, mask = embed_mask[:,None] & mask_n[None,:],other=0.0)

        acc += tl.sum(E_chunk * tl.trans(C_chunk), axis=1)
    tl.store(output_ptr+token_range_offsets,acc,mask=mask_n)

        
        


def cut_cross_entropy(E,C,targets):
    """
    E - (seq_len,embed_dim) 
    C  - (embed_dim, vocab_size)
    targets  - (seq_len,)
    
    """
    seq_len,embed_dim = E.shape
    embed_dim,vocab_size = C.shape 
    target_logits = torch.empty(seq_len,device=E.device,dtype=E.dtype)
    lse = torch.empty(seq_len,device=E.device,dtype=E.dtype)

    loss = (lse - target_logits).sum()

    BLOCK_SIZE_N = 128 
    BLOCK_SIZE_D = 128
    grid = (triton.cdiv(seq_len,BLOCK_SIZE_N),)
    indexed_matmul_kernel[grid](
        E,C,targets,target_logits,embed_dim,seq_len,vocab_size,E.stride(0),E.stride(1),C.stride(0),C.stride(1),BLOCK_SIZE_N=BLOCK_SIZE_N,
        BLOCK_SIZE_D = BLOCK_SIZE_D
    )

    return loss,target_logits
    



def test_cross_entropy_forward():
    """
    Test cut cross entropy implementation 
    
    """
    seq_len = 128
    vocab_size  = 1024 
    embed_dim = 256 
    E  = torch.randn(seq_len,embed_dim,device='cuda',dtype=torch.float32)
    C = torch.randn(embed_dim,vocab_size,device='cuda',dtype=torch.float32)
    targets = torch.randint(0,vocab_size,(seq_len,),device='cuda')
    logits  = E @ C
    ref_loss = cross_entropy_ref(logits,targets)
    cce_loss,target_logits  = cut_cross_entropy(E,C,targets)
    return logits,target_logits,targets

In [6]:
logits,logits_triton,targets = test_cross_entropy_forward()

In [7]:
target_logits_ref = torch.gather(logits, 1, targets.unsqueeze(1)).squeeze(1)