In [1]:
from configuration import Config
from modeling import CasualLM
import torch
from torch.utils.data import TensorDataset, DataLoader

### 1 定义模型

In [2]:
config = Config(
    vocab_size=10,
    num_hiddens=8,
    num_layers=6,
    num_heads=4,
    num_mlp_intermediate=16,
    max_context_length=1024,
    dropout=0.1
)

model = CasualLM(config)

### 2 测试推理

In [3]:
x = torch.arange(0,10).repeat(2, 2)
print(x)
y = model(x)
print(y.logits)

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
tensor([[[ 2.0295e-01,  7.7202e-01, -5.3938e-01,  1.1913e+00,  3.0952e-01,
          -5.8248e-01,  5.2985e-01, -1.6016e-01,  7.8455e-01,  7.9056e-01],
         [ 5.6074e-01,  9.1363e-01, -2.3251e-01,  1.5162e+00, -4.0787e-01,
          -5.9454e-01,  7.7063e-01,  1.4086e-02,  8.8369e-01,  6.6595e-01],
         [ 1.1240e+00,  1.0077e+00, -2.2039e-01,  1.0828e+00,  5.9379e-01,
          -8.8804e-02,  3.9066e-01, -5.5947e-01,  3.4589e-02, -1.1040e-01],
         [ 1.9495e-01,  7.1471e-01,  2.7624e-01,  9.9195e-01,  1.8915e-01,
          -3.5822e-02, -6.0732e-01, -5.2180e-01,  4.8105e-02, -2.7012e-01],
         [-9.7744e-01, -5.4651e-01,  2.0934e-01,  4.0143e-02,  4.9317e-01,
          -9.1393e-03, -2.1514e-01,  9.8829e-01,  7.3278e-01,  3.7638e-01],
         [ 3.5402e-01, -9.1471e-02, -2.1180e-01,  1.6680e+00, -1.6327e-01,
          -8.9716e-01,  7.553

### 3 打印参数量

In [4]:
print(sum([param.nelement() for param in model.parameters()]))

3594


### 4 测试训练

#### 4.1 定义超参数

In [5]:
num_epochs = 100
batch_size = 4
trainer = torch.optim.SGD(model.parameters(), lr=1e-3)

#### 4.2 构造训练数据

In [6]:
features = torch.randint(0, 10, (100, 15))
labels = torch.randint(0, 10, (100, 15))
labels[:, :10] = -100
dataset = TensorDataset(features, labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

#### 4.3 定义损失

In [7]:
lossfn = torch.nn.CrossEntropyLoss()
lossfn(model(features).logits.transpose(1, 2), labels)

tensor(2.5178, grad_fn=<NllLoss2DBackward0>)

#### 4.4 开始训练

In [8]:
for i in range(num_epochs):
    for feature, label in dataloader:
        loss = lossfn(model(feature).logits.transpose(1, 2), label)
        trainer.zero_grad()
        loss.backward()
        trainer.step()
    with torch.no_grad():
        loss = lossfn(model(features).logits.transpose(1, 2), labels)
        print(f"第{i}个epoch：loss大小为{loss}")

第0个epoch：loss大小为2.4987058639526367
第1个epoch：loss大小为2.4887895584106445
第2个epoch：loss大小为2.4834766387939453
第3个epoch：loss大小为2.4729599952697754
第4个epoch：loss大小为2.4457595348358154
第5个epoch：loss大小为2.4603073596954346
第6个epoch：loss大小为2.43985652923584
第7个epoch：loss大小为2.4296345710754395
第8个epoch：loss大小为2.4229400157928467
第9个epoch：loss大小为2.4184956550598145
第10个epoch：loss大小为2.4249837398529053
第11个epoch：loss大小为2.4010183811187744
第12个epoch：loss大小为2.419232130050659
第13个epoch：loss大小为2.404738187789917
第14个epoch：loss大小为2.3988239765167236
第15个epoch：loss大小为2.382740020751953
第16个epoch：loss大小为2.3913955688476562
第17个epoch：loss大小为2.393268346786499
第18个epoch：loss大小为2.3654017448425293
第19个epoch：loss大小为2.3915183544158936
第20个epoch：loss大小为2.3775522708892822
第21个epoch：loss大小为2.3735265731811523
第22个epoch：loss大小为2.3795180320739746
第23个epoch：loss大小为2.376084566116333
第24个epoch：loss大小为2.365304946899414
第25个epoch：loss大小为2.3563361167907715
第26个epoch：loss大小为2.3596041202545166
第27个epoch：loss大小为2.3677685260772705
第28个epoch：