In [1]:
import os
import sys
import pickle
from contextlib import nullcontext
import torch
import tiktoken 
import configparser

In [2]:
def read_config_file(file_path):
    # 创建 ConfigParser 对象
    config = configparser.ConfigParser()
    # 读取配置文件
    config.read(file_path)
    return config

config_file_path = 'hyper-parameters.ini'
config = read_config_file(config_file_path)

# Hyperparameters
batch_size = config.getint('Hyperparameters', 'batch_size')
context_length = config.getint('Hyperparameters', 'context_length')
max_iters = config.getint('Hyperparameters', 'max_iters')
learning_rate = config.getfloat('Hyperparameters', 'learning_rate')
eval_interval = config.getint('Hyperparameters', 'eval_interval')
eval_iters = config.getint('Hyperparameters', 'eval_iters')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_SEED = 1337
torch.manual_seed(TORCH_SEED)


<torch._C.Generator at 0x7f2cf8b4e690>

In [3]:
# 准备训练数据
with open('data/scifi.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [4]:
vocab = sorted(list(set(text)))
vocab_size = max_token_value = len(vocab)
char2idx = {c:i for i,c in enumerate(vocab)}
idx2char = {i:c for i,c in enumerate(vocab)}
encode = lambda x: [char2idx[c] for c in x]
decode = lambda idxs: ''.join([idx2char[i] for i in idxs])
tokenized_text = torch.tensor(encode(text), dtype=torch.long)

In [7]:
# Split train and validation set
train_size = int(0.8 * len(tokenized_text))
train_data = tokenized_text[:train_size]
val_data = tokenized_text[train_size:]

In [8]:
# 初始化模型
from model import Model
model = Model().to(device)

[Hyperparameters]
context_length = 32
d_model = 64
num_blocks = 12
num_heads = 8
dropout = 0.1
batch_size = 6
learning_rate = 0.001
max_iters = 5000
eval_interval = 50
eval_iters = 20


In [9]:
# Get input embedding batch
def get_batch(split: str):
    data = train_data if split == 'train' else val_data
    idxs = torch.randint(low=0, high=len(data) - context_length, size=(batch_size,))
    x = torch.stack([data[idx:idx + context_length] for idx in idxs]).to(device)
    y = torch.stack([data[idx + 1:idx + context_length + 1] for idx in idxs]).to(device)
    return x, y


In [10]:
x_batch, y_batch = get_batch('train')
print(x_batch.shape, y_batch.shape )
print(context_length)

torch.Size([6, 32]) torch.Size([6, 32])
32


In [13]:
# Calculate loss
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'valid']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            x_batch, y_batch = get_batch(split)
            logits, loss = model(x_batch, y_batch)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [14]:
# Use AdamW optimizer
optimizer = torch.optim.AdamW(params=model.parameters(), lr=learning_rate)
tracked_losses = list()
for step in range(max_iters):
    if step % eval_iters == 0 or step == max_iters - 1:
        losses = estimate_loss()
        tracked_losses.append(losses)
        print('Step:', step, 'Training Loss:', round(losses['train'].item(), 3), 'Validation Loss:',
              round(losses['valid'].item(), 3))

    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

Step: 0 Training Loss: 7.402 Validation Loss: 7.458
Step: 20 Training Loss: 6.904 Validation Loss: 7.02
Step: 40 Training Loss: 6.933 Validation Loss: 6.832
Step: 60 Training Loss: 6.536 Validation Loss: 6.719
Step: 80 Training Loss: 6.672 Validation Loss: 6.571
Step: 100 Training Loss: 6.576 Validation Loss: 6.512
Step: 120 Training Loss: 6.285 Validation Loss: 6.387
Step: 140 Training Loss: 6.321 Validation Loss: 6.431
Step: 160 Training Loss: 6.348 Validation Loss: 6.436
Step: 180 Training Loss: 6.349 Validation Loss: 6.253
Step: 200 Training Loss: 6.169 Validation Loss: 6.32
Step: 220 Training Loss: 6.196 Validation Loss: 6.29
Step: 240 Training Loss: 6.237 Validation Loss: 6.198
Step: 260 Training Loss: 6.215 Validation Loss: 6.264
Step: 280 Training Loss: 6.142 Validation Loss: 6.219
Step: 300 Training Loss: 6.042 Validation Loss: 6.13
Step: 320 Training Loss: 6.133 Validation Loss: 6.135
Step: 340 Training Loss: 6.083 Validation Loss: 6.114
Step: 360 Training Loss: 6.014 Validat

In [49]:
# Save the model state dictionary
torch.save(model.state_dict(), 'model/model-ckpt.pt')