In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
a = torch.ones(4, 4)
for i in range(4):
    a[i, :] =  torch.arange(4).view(1, 4) * 2 ** i
a

tensor([[ 0.,  1.,  2.,  3.],
        [ 0.,  2.,  4.,  6.],
        [ 0.,  4.,  8., 12.],
        [ 0.,  8., 16., 24.]])

In [3]:
b = F.softmax(a, dim=-1)
b.sum(-1)

tensor([1.0000, 1.0000, 1.0000, 1.0000])

In [4]:
class Net(nn.Module):
    def __init__(self, dims, vocab):
        super().__init__()
        self.embedding = nn.Embedding(vocab, dims)
        self.l1 = nn.Linear(dims, dims * 4)
        self.relu = nn.GELU()
        self.l2 = nn.Linear(dims * 4, dims)
        self.lm_head = nn.Linear(dims, vocab, bias=False)

        self.embedding.weight = self.lm_head.weight

    def forward(self, idx):
        x = self.embedding(idx)
        x = self.relu(self.l1(x))
        x = self.l2(x)
        x = self.lm_head(x)
        return x

In [5]:
net = Net(16, 50)
idx = torch.arange(5, dtype=torch.long)
y = torch.randint(0, 50, size=(5, ))
print(idx.shape, y.shape)
optim = torch.optim.AdamW(net.parameters())

torch.Size([5]) torch.Size([5])


In [6]:
loss = F.cross_entropy(net(y), y)
loss.backward()

In [7]:
net.embedding.weight.grad is net.lm_head.weight.grad

True

In [8]:
net.embedding.weight.grad is net.lm_head.weight.grad

True

In [9]:
torch.allclose(net.embedding.weight.grad, net.lm_head.weight.grad)

True