# 手撕 Decoder-Only Loss 实现

In [2]:
import torch
import math

## 数据和标签

In [3]:
batch_size = 1  # batch为多少条数据
length = 4      # length 为 4

x = torch.randn(batch_size, 4, 512) #input :  batch_size, length, embd_dim
y = torch.randint(low=0, high=32000, size=(batch_size, 4), dtype=torch.long)

print(x.shape)
print(y.shape)
print(y)

## Attention

In [4]:
q = torch.randn(512, 512)
k = torch.randn(512, 512)
v = torch.randn(512, 512)
o = torch.randn(512, 512)

mask=torch.tril(torch.ones(1, 4, 4))
print(mask)

# scaled dot produc attention
Q,K,V = x@q, x@k, x@v
scores = Q@K.transpose(1,2) / math.sqrt(512.0)
scores = scores.masked_fill(mask == 0, float('-inf'))
weight = torch.nn.functional.softmax(scores, dim=2)
attn = weight@V
attn = attn@o
attn.shape

torch.Size([1, 4, 512])

## mlp

In [5]:
mlp_up = torch.randn(512, 1024)
mlp_down = torch.randn(1024, 512)
mlp = attn @ mlp_up @ mlp_down
mlp.shape

torch.Size([1, 4, 512])

## Output

In [6]:
lm_head = torch.randn(512, 32000)
logits = mlp@lm_head
logits.shape

torch.Size([1, 4, 32000])

## Loss

In [7]:
# probs
probs = torch.softmax(logits, dim=2)
print(probs.shape) # model ouput prob
print(y)    # model lables
print(y.shape)

# Loss
loss_fn = torch.nn.CrossEntropyLoss()
loss = loss_fn(logits.transpose(1, 2), y)
# 4 32000 ,  1, 4
# 1 32000,  label1
# 1 32000   label2
# loss function 
print(loss.mean())

# pred
pred = torch.argmax(logits, dim=2)
print(pred) # model pred

## inference NEXT TOKEN

In [8]:
print(logits.shape)
pred = torch.argmax(logits, dim=2)
print(pred)
# 会有四个token生成， 对应不同的预测任务

# 前面都不需要
# 我         -> 很
# 我很       -> 开
# 我很开     -> 车

# next token
# 我很开心   -> 呀

# 取最后一个token，即是next token预测
print(pred[0, -1])

# other

In [9]:
# about crossentropy loss input
# https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
# logits(Batchsize, Classnumber, length) as input
# label(Batchsize, length)               as target
# this code reference : https://stackoverflow.com/questions/73696319/what-is-cross-entropy-loss-really-doing-when-input-is-3d


import torch
import torch.nn as nn

loss = nn.CrossEntropyLoss()

batch_size = 4
seq_len = 8
vocab_size = 3

inputs = torch.randn((batch_size, vocab_size, seq_len), requires_grad=True)
target = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len))

print(inputs.shape)
print(target.shape)
loss1 = loss(inputs, target)
grad1 = torch.autograd.grad(loss1, inputs)[0]

inputs_transposed = inputs.permute(0, 2, 1).reshape(batch_size*seq_len, vocab_size)
target_transposed = target.view(batch_size*seq_len)

loss2 = loss(inputs_transposed, target_transposed)
grad2 = torch.autograd.grad(loss2, inputs)[0]

print(torch.allclose(loss1, loss2))
print(torch.allclose(grad1, grad2))