In [1]:
from configuration import Config
from modeling import CasualLM
import torch
from torch.utils.data import TensorDataset, DataLoader

### 1 定义模型

In [2]:
config = Config(
    vocab_size=10,
    num_hiddens=8,
    num_layers=6,
    num_heads=4,
    num_mlp_intermediate=16,
    max_context_length=1024,
    dropout=0.1
)

model = CasualLM(config)

### 2 测试推理

In [3]:
x = torch.arange(0,10).repeat(2, 2)
print(x)
y = model(x)
print(y)

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
tensor([[[-2.6927e-01, -2.1824e-01,  3.5512e-01,  7.7438e-01, -2.3734e-01,
          -6.0176e-01,  4.5596e-01, -1.3173e+00, -8.4155e-02,  2.5243e-01],
         [ 1.9288e-01,  5.9345e-01,  1.6909e-02,  7.7641e-01, -4.9430e-01,
           1.2356e-01, -1.1156e-02, -7.9102e-01,  6.6643e-01, -3.9936e-02],
         [-5.6503e-01,  3.1372e-01, -7.3653e-02,  7.5184e-01, -6.1256e-01,
          -4.4261e-01,  6.0736e-01, -1.5883e-01,  4.5473e-01, -4.4825e-01],
         [ 4.1224e-02,  2.0152e-01,  5.8485e-01,  5.9259e-01, -6.1998e-01,
          -8.6237e-01,  1.2582e+00, -2.9968e-01, -3.5849e-01,  6.1480e-01],
         [-9.4048e-01, -6.3689e-01,  1.7298e-01,  3.1853e-01,  2.5150e-01,
          -5.7013e-01,  9.2062e-02, -1.5063e+00, -3.4878e-01,  3.9488e-01],
         [-4.5945e-01, -1.9124e-01,  5.7957e-01,  7.6237e-01, -1.7704e-01,
          -7.8193e-01,  7.985

### 3 打印参数量

In [4]:
print(sum([param.nelement() for param in model.parameters()]))

3594


### 4 测试训练

#### 4.1 定义超参数

In [5]:
num_epochs = 100
batch_size = 4
trainer = torch.optim.SGD(model.parameters(), lr=1e-3)

#### 4.2 构造训练数据

In [6]:
features = torch.randint(0, 10, (100, 15))
labels = torch.randint(0, 10, (100, 15))
labels[:, :10] = -100
dataset = TensorDataset(features, labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

#### 4.3 定义损失

In [7]:
lossfn = torch.nn.CrossEntropyLoss()
lossfn(model(features).transpose(1, 2), labels)

tensor(2.4384, grad_fn=<NllLoss2DBackward0>)

#### 4.4 开始训练

In [8]:
for i in range(num_epochs):
    for feature, label in dataloader:
        loss = lossfn(model(feature).transpose(1, 2), label)
        trainer.zero_grad()
        loss.backward()
        trainer.step()
    with torch.no_grad():
        loss = lossfn(model(features).transpose(1, 2), labels)
        print(f"第{i}个epoch：loss大小为{loss}")

第0个epoch：loss大小为2.4214675426483154
第1个epoch：loss大小为2.4215474128723145
第2个epoch：loss大小为2.4265780448913574
第3个epoch：loss大小为2.4157650470733643
第4个epoch：loss大小为2.4138569831848145
第5个epoch：loss大小为2.415598154067993
第6个epoch：loss大小为2.4056010246276855
第7个epoch：loss大小为2.4035487174987793
第8个epoch：loss大小为2.410526990890503
第9个epoch：loss大小为2.3908309936523438
第10个epoch：loss大小为2.3870975971221924
第11个epoch：loss大小为2.3977158069610596
第12个epoch：loss大小为2.390331506729126
第13个epoch：loss大小为2.378121852874756
第14个epoch：loss大小为2.3779795169830322
第15个epoch：loss大小为2.3768489360809326
第16个epoch：loss大小为2.3670997619628906
第17个epoch：loss大小为2.3767571449279785
第18个epoch：loss大小为2.3610267639160156
第19个epoch：loss大小为2.370131015777588
第20个epoch：loss大小为2.364672899246216
第21个epoch：loss大小为2.363600730895996
第22个epoch：loss大小为2.356736898422241
第23个epoch：loss大小为2.3586337566375732
第24个epoch：loss大小为2.360290765762329
第25个epoch：loss大小为2.3628933429718018
第26个epoch：loss大小为2.356078863143921
第27个epoch：loss大小为2.3739938735961914
第28个epoch：lo