In [10]:
from configuration import Config
from modeling import CasualLM
import torch
from torch.utils.data import TensorDataset, DataLoader

### 1 定义模型

In [2]:
config = Config(
    vocab_size=10,
    num_hiddens=8,
    num_layers=6,
    num_heads=4,
    num_mlp_intermediate=16,
    dropout=0.1
)

model = CasualLM(config)

### 2 测试推理

In [3]:
x = torch.arange(0,10).repeat(2, 2)
print(x)
y = model(x)
print(y)

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
tensor([[[ 0.3573,  0.3219, -0.9565, -0.7045,  0.9719, -0.5493,  0.5671,
          -0.0624,  0.1499, -0.2596],
         [-0.1416, -1.0830, -0.6425, -0.0328,  0.3968, -0.4209,  0.6192,
          -0.4647, -0.3847, -0.7722],
         [ 0.8495, -1.0861, -0.3949, -0.0258,  0.6572, -0.0767,  0.3021,
           0.9272,  0.2468, -0.3968],
         [ 0.5241, -0.5333, -0.9193, -0.8196,  0.8877, -0.9154,  0.5109,
          -0.0172, -0.2684, -0.4070],
         [-0.7473,  0.1931,  0.0084, -0.4060, -0.5590, -0.5153, -0.3671,
          -1.8173, -1.1206,  0.2517],
         [ 0.3043, -0.7791, -1.0441, -0.3981,  0.5118, -1.1790,  1.1013,
          -0.6823, -0.0833, -0.7182],
         [ 0.4172,  0.1654, -0.3159,  0.2294,  0.8430, -0.1531,  1.2662,
          -0.1286,  0.1848, -1.0226],
         [ 0.2108, -0.3202, -0.2809,  0.0977,  0.3838, -0.8936,  1.3252,
         

### 3 打印参数量

In [4]:
print(sum([param.nelement() for param in model.parameters()]))

3594


### 4 测试训练

#### 4.1 定义超参数

In [35]:
num_epochs = 100
batch_size = 4
trainer = torch.optim.SGD(model.parameters(), lr=1e-3)

#### 4.2 构造训练数据

In [36]:
features = torch.randint(0, 10, (100, 15))
labels = torch.randint(0, 10, (100, 15))
labels[:, :10] = -100
dataset = TensorDataset(features, labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

#### 4.3 定义损失

In [37]:
lossfn = torch.nn.CrossEntropyLoss()
lossfn(model(features).transpose(1, 2), labels)

tensor(2.4492, grad_fn=<NllLoss2DBackward0>)

#### 4.4 开始训练

In [39]:
for i in range(num_epochs):
    for feature, label in dataloader:
        loss = lossfn(model(feature).transpose(1, 2), label)
        trainer.zero_grad()
        loss.backward()
        trainer.step()
    with torch.no_grad():
        loss = lossfn(model(features).transpose(1, 2), labels)
        print(f"第{i}个epoch：loss大小为{loss}")

第0个epoch：loss大小为2.3111133575439453
第1个epoch：loss大小为2.3021719455718994
第2个epoch：loss大小为2.3124024868011475
第3个epoch：loss大小为2.3148953914642334
第4个epoch：loss大小为2.307684898376465
第5个epoch：loss大小为2.3119096755981445
第6个epoch：loss大小为2.312882423400879
第7个epoch：loss大小为2.314678430557251
第8个epoch：loss大小为2.31310772895813
第9个epoch：loss大小为2.3101611137390137
第10个epoch：loss大小为2.3020505905151367
第11个epoch：loss大小为2.311980962753296
第12个epoch：loss大小为2.2991394996643066
第13个epoch：loss大小为2.3056933879852295
第14个epoch：loss大小为2.30820631980896
第15个epoch：loss大小为2.306088447570801
第16个epoch：loss大小为2.3056869506835938
第17个epoch：loss大小为2.3023319244384766
第18个epoch：loss大小为2.3036677837371826
第19个epoch：loss大小为2.3019161224365234
第20个epoch：loss大小为2.2932984828948975
第21个epoch：loss大小为2.3017568588256836
第22个epoch：loss大小为2.303422212600708
第23个epoch：loss大小为2.302946090698242
第24个epoch：loss大小为2.3010942935943604
第25个epoch：loss大小为2.302445411682129
第26个epoch：loss大小为2.3081696033477783
第27个epoch：loss大小为2.3089253902435303
第28个epoch：loss