In [488]:

from triton import heuristics, jit
from triton import language as tl
from triton import next_power_of_2

import triton


def num_warps(N):
    if N < 2048:
        return 4
    elif N < 8192:
        return 8
    return 16



def num_warps(N):
    if N < 2048:
        return 4
    elif N < 8192:
        return 8
    return 16

def num_warps(N):
    if N < 2048:
        return 4
    elif N < 8192:
        return 8
    return 16


@heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
@heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
@jit
def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK)
    idx = tl.load(IDX + row)
    ignore_index = -100
    # pointers to logit and probs
    LOGITS = LOGITS + row * N + cols
    WRIT_PROBS = PROBS + row * N + cols
    READ_PROBS = PROBS + row * N + idx
    # write-back negative log-probs
    logits = tl.load(LOGITS, mask=cols < N, other=-float('inf'))
    logits = logits.to(tl.float32)
    logits = logits - tl.max(logits, 0)
    probs_left = tl.log(tl.sum(tl.exp(logits), 0))
    probs = probs_left - logits

    probs_loss = probs_left - tl.sum(tl.where(cols == idx, logits, 0.0))
    probs_loss = tl.where(idx == ignore_index, 0.0, probs_loss)
    # tl.store(WRIT_PROBS, probs, mask=cols < N)

    # There is a bug in the compiler, which fails to insert a barrier here.
    # We add it explicitly for now. Will be fixed soon.
    # tl.debug_barrier()
    # write-back loss
    # probs_loss = tl.load(READ_PROBS)
    # probs_loss = tl.where(idx == ignore_index, 0.0, probs_loss)
    tl.store(LOSS + row, probs_loss)

    tl.debug_barrier()
    probs = -probs
    probs = tl.exp(probs.to(tl.float32))
    delta = cols == idx
    din = (probs - delta)
    din = tl.where(idx == ignore_index, 0.0, din)
    tl.store(WRIT_PROBS, din, mask=cols < N)

class _cross_entropy(torch.autograd.Function):

    @classmethod
    def forward(cls, ctx, hidden_states, indices, weights):
        logits = torch.matmul(hidden_states, weights.T)
        logits = logits.float()
        # make sure we can use triton
        assert (indices.dtype == torch.int64), "Indices are expected to be of type long."
        # make kernel
        device, dtype = logits.device, logits.dtype
        n_cols = logits.shape[-1]
        # run the kernel
        result = torch.empty_like(indices, dtype=dtype, device=device)
        neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device)
        grid = lambda opt: (logits.numel() // n_cols, )
        _forward[grid](logits, neg_logprobs, indices, result, n_cols)
        # save for backward
        neg_logprobs = neg_logprobs.to(torch.bfloat16)
        grad_input = neg_logprobs @ weights

        ignore_index = -100
        mask = (indices != ignore_index)

        if hasattr(weights, 'grad') and weights.grad != None:
            torch.addmm(
                    weights.grad,
                    neg_logprobs.T,
                    hidden_states,
                    out=weights.grad,
                )
        else:
            weights.grad = neg_logprobs.T @ hidden_states

        
        if hasattr(weights, 'mul') and weights.mul != None:
            weights.mul += torch.sum(indices != ignore_index) 
        else:
            weights.mul = torch.sum(indices != ignore_index) 
            
        weights.grad_mul = False
        
        neg_logprobs = None

        ctx.save_for_backward(grad_input, weights)
        return result[mask].mean()

    @classmethod
    def backward(cls, ctx, dneg_logprobs):
        """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
        so we initialize the gradient as neg_logprobs, so we can just exponentiate
        to get p[k], which is most of what we need...  neg_logprobs will be
        modified in place to become the gradient we want
        """
        # load saved tensors
        neg_logprobs, weights = ctx.saved_tensors
        
        dneg_logprobs = dneg_logprobs / weights.mul
        if weights.grad_mul is False:
            weights.grad *= dneg_logprobs
            weights.grad_mul = True
        neg_logprobs *= dneg_logprobs
        
        return neg_logprobs, None, weights.grad


class FusedCrossEntropyLMhead(nn.Module):
    def __init__(
        self,
        original_weight = None
    ):
        super().__init__()
        if original_weight is None:
            self.LM_head_weight = nn.Parameter(torch.empty(hidden_size, vocab_size))
        else:
            self.LM_head_weight = original_weight
        self.cross_entropy = _cross_entropy.apply

    def forward(self, hidden_states, labels):
        ignore_index = -100
        loss = self.cross_entropy(hidden_states, labels, self.LM_head_weight)
        # mask = (labels != ignore_index)
        return loss


In [489]:
M = 8
N = 8

In [490]:
from triton.ops import cross_entropy
import numpy as np

dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x = torch.randn(M, N, dtype=dtype, device=device, requires_grad=True)

random_sequence = np.random.randint(low=0, high=N, size=(M,))
idx = torch.tensor(random_sequence).cuda() 

tt_y = triton.ops.cross_entropy(x, idx)
th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx)

print(tt_y, th_y, torch.sum(tt_y - th_y))

tensor([3.3438, 2.4062, 3.0312, 4.1250, 2.7656, 3.1094, 3.0625, 4.6562],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<_cross_entropyBackward>) tensor([3.3438, 2.4062, 3.0312, 4.1250, 2.7656, 3.1094, 3.0625, 4.6562],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>) tensor(0., device='cuda:0', dtype=torch.bfloat16, grad_fn=<SumBackward0>)


In [491]:
from triton.ops import cross_entropy

dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x = torch.randn(M, N, dtype=dtype, device=device, requires_grad=True)
random_sequence = np.random.randint(low=0, high=N, size=(M,))
idx = torch.tensor(random_sequence).cuda() 
lm_head = nn.Linear(N, N, bias=False).cuda()
lm_head.bfloat16()


tt_y = triton.ops.cross_entropy(lm_head(x), idx)
th_y = torch.nn.CrossEntropyLoss(reduction="none")(lm_head(x), idx)
print(tt_y, th_y, torch.sum(tt_y - th_y))

tensor([2.5781, 2.7500, 1.8828, 2.4844, 2.7656, 1.4688, 2.9844, 2.0312],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<_cross_entropyBackward>) tensor([2.5781, 2.7500, 1.8828, 2.4844, 2.7656, 1.4688, 2.9844, 2.0312],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>) tensor(0., device='cuda:0', dtype=torch.bfloat16, grad_fn=<SumBackward0>)


In [492]:
from triton.ops import cross_entropy

dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x = torch.randn(M, N, dtype=dtype, device=device, requires_grad=True)
x_test = x.detach().clone()
x_test.requires_grad=True
random_sequence = np.random.randint(low=0, high=N, size=(M,))
idx = torch.tensor(random_sequence).cuda() 
idx[-2:] = -100
print(idx)
lm_head = nn.Linear(N, N, bias=False).cuda()
lm_head.bfloat16()

fuse = FusedCrossEntropyLMhead(lm_head.weight.detach().clone())

th_y = torch.nn.CrossEntropyLoss(reduction="mean")(lm_head(x[:-1, :]).float(), idx[1:])
tt_y = fuse(x_test[:-1, :], idx[1:])

th_y.backward()
tt_y.backward()

print(lm_head.weight.grad, fuse.LM_head_weight.grad)
print(x.grad , x_test.grad)
print(torch.sum(lm_head.weight.grad - fuse.LM_head_weight.grad))
print(torch.sum(x.grad - x_test.grad))

print(tt_y, th_y, torch.sum(tt_y - th_y))

tensor([   4,    7,    4,    7,    0,    0, -100, -100], device='cuda:0')


TypeError: unsupported operand type(s) for +=: 'builtin_function_or_method' and 'Tensor'

In [527]:

from triton import heuristics, jit
from triton import language as tl
from triton import next_power_of_2

import triton


def num_warps(N):
    if N < 2048:
        return 4
    elif N < 8192:
        return 8
    return 16



def num_warps(N):
    if N < 2048:
        return 4
    elif N < 8192:
        return 8
    return 16

def num_warps(N):
    if N < 2048:
        return 4
    elif N < 8192:
        return 8
    return 16


@heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
@heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
@jit
def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK)
    idx = tl.load(IDX + row)
    ignore_index = -100
    # pointers to logit and probs
    LOGITS = LOGITS + row * N + cols
    WRIT_PROBS = PROBS + row * N + cols
    READ_PROBS = PROBS + row * N + idx
    # write-back negative log-probs
    logits = tl.load(LOGITS, mask=cols < N, other=-float('inf'))
    logits = logits.to(tl.float32)
    logits = logits - tl.max(logits, 0)
    probs_left = tl.log(tl.sum(tl.exp(logits), 0))
    probs = probs_left - logits

    probs_loss = probs_left - tl.sum(tl.where(cols == idx, logits, 0.0))
    probs_loss = tl.where(idx == ignore_index, 0.0, probs_loss)
    # tl.store(WRIT_PROBS, probs, mask=cols < N)

    # There is a bug in the compiler, which fails to insert a barrier here.
    # We add it explicitly for now. Will be fixed soon.
    # tl.debug_barrier()
    # write-back loss
    # probs_loss = tl.load(READ_PROBS)
    # probs_loss = tl.where(idx == ignore_index, 0.0, probs_loss)
    tl.store(LOSS + row, probs_loss)

    tl.debug_barrier()
    probs = -probs
    probs = tl.exp(probs.to(tl.float32))
    delta = cols == idx
    din = (probs - delta)
    din = tl.where(idx == ignore_index, 0.0, din)
    tl.store(WRIT_PROBS, din, mask=cols < N)

class _cross_entropy(torch.autograd.Function):

    @classmethod
    def forward(cls, ctx, hidden_states, indices, weights):
        logits = torch.matmul(hidden_states, weights.T)
        logits = logits.float()
        # make sure we can use triton
        assert (indices.dtype == torch.int64), "Indices are expected to be of type long."
        # make kernel
        device, dtype = logits.device, logits.dtype
        n_cols = logits.shape[-1]
        # run the kernel
        result = torch.empty_like(indices, dtype=dtype, device=device)
        neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device)
        grid = lambda opt: (logits.numel() // n_cols, )
        _forward[grid](logits, neg_logprobs, indices, result, n_cols)
        # save for backward
        neg_logprobs = neg_logprobs.to(torch.bfloat16)
        grad_input = neg_logprobs @ weights

        ignore_index = -100
        mask = (indices != ignore_index)

        if hasattr(weights, 'grad') and weights.grad != None:
            torch.addmm(
                    weights.grad,
                    neg_logprobs.T,
                    hidden_states,
                    out=weights.grad,
                )
        else:
            weights.grad = neg_logprobs.T @ hidden_states
        weights.grad_mul = False
        
        neg_logprobs = None

        ctx.save_for_backward(grad_input, weights)
        return result[mask].sum()

    @classmethod
    def backward(cls, ctx, dneg_logprobs):
        """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
        so we initialize the gradient as neg_logprobs, so we can just exponentiate
        to get p[k], which is most of what we need...  neg_logprobs will be
        modified in place to become the gradient we want
        """
        # load saved tensors
        neg_logprobs, weights = ctx.saved_tensors
        # dneg_logprobs = dneg_logprobs / weights.mul
        if weights.grad_mul is False:
            weights.grad *= dneg_logprobs
            weights.grad_mul = True
        neg_logprobs *= dneg_logprobs
        
        return neg_logprobs, None, weights.grad


class FusedCrossEntropyLMhead(nn.Module):
    def __init__(
        self,
        original_weight = None
    ):
        super().__init__()
        if original_weight is None:
            self.LM_head_weight = nn.Parameter(torch.empty(hidden_size, vocab_size))
        else:
            self.LM_head_weight = original_weight
        self.cross_entropy = _cross_entropy.apply

    def forward(self, hidden_states, labels):
        ignore_index = -100
        loss = self.cross_entropy(hidden_states, labels, self.LM_head_weight)
        return loss


In [531]:
from triton.ops import cross_entropy

M = 1024
N = 1024

dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x = torch.randn(M, N, dtype=dtype, device=device, requires_grad=True)
x_test = x.detach().clone()
x_test.requires_grad=True
random_sequence = np.random.randint(low=0, high=N, size=(M,))
idx = torch.tensor(random_sequence).cuda() 
idx[-5:] = -100
print(idx)
lm_head = nn.Linear(N, N, bias=False).cuda()
lm_head.bfloat16()

b_lm = lm_head.weight.detach().clone()

th_y = torch.nn.CrossEntropyLoss(reduction="mean")(lm_head(x[:-1, :]).float(), idx[1:])
# tt_y = fuse(x_test[:-1, :], idx[1:])

pretraining_tp = 8
tmp = M // pretraining_tp
hidden_states = x_test[:-1, :]
labels = idx[1:]

loss = None
b_lm.mul = None
for i in range(pretraining_tp):

    Fused = FusedCrossEntropyLMhead(b_lm)

    shift_hidden_states = hidden_states[i * tmp : (i+1)*tmp, :].contiguous()
    shift_labels = labels[i * tmp : (i+1)*tmp ].contiguous()

    loss_i = Fused(shift_hidden_states, shift_labels)

    print(loss_i)

    if loss is None:
        loss = loss_i
    else:
        loss = loss + loss_i
loss = loss / torch.sum(labels != -100)
th_y.backward()
loss.backward()

print(lm_head.weight.grad, b_lm.grad)
print(x.grad , x_test.grad)
print(torch.sum(lm_head.weight.grad - b_lm.grad))
print(torch.sum(x.grad - x_test.grad))

print(loss, th_y, torch.sum(loss - th_y))

tensor([ 263,  804,  414,  ..., -100, -100, -100], device='cuda:0')
tensor(920.7166, device='cuda:0', grad_fn=<_cross_entropyBackward>)
tensor(895.4290, device='cuda:0', grad_fn=<_cross_entropyBackward>)
tensor(903.4427, device='cuda:0', grad_fn=<_cross_entropyBackward>)
tensor(911.4877, device='cuda:0', grad_fn=<_cross_entropyBackward>)
tensor(906.2062, device='cuda:0', grad_fn=<_cross_entropyBackward>)
tensor(910.1755, device='cuda:0', grad_fn=<_cross_entropyBackward>)
tensor(918.3792, device='cuda:0', grad_fn=<_cross_entropyBackward>)
tensor(867.6902, device='cuda:0', grad_fn=<_cross_entropyBackward>)
tensor([[ 2.3346e-03,  3.3188e-04,  3.4180e-03,  ..., -1.8921e-03,
          1.0376e-03,  1.6708e-03],
        [-1.1086e-05, -3.9041e-06,  5.3346e-06,  ...,  1.7062e-06,
         -2.1219e-05, -5.0545e-05],
        [-6.0320e-05,  2.0027e-05,  2.8759e-06,  ..., -3.9339e-05,
          1.8001e-05, -3.5763e-05],
        ...,
        [-4.0054e-05,  4.2677e-05,  6.5565e-06,  ..., -7.8201e-05,