In [104]:
%matplotlib inline
import math
import torch
from torch import nn
from torch.nn import functional as F
import d2l.torch as d2l
import collections
import re

In [105]:
def read_file():
    with open('./The Echo of a Dying Star.txt','r') as f:
        lines = f.readlines()
    return [re.sub('[^A-Za-z]+',' ',line).strip().lower() for line in lines]

    
def tokenize(lines,token='word'):
    if token == 'word':
        return [line.split() for line in lines]
    elif token=='char':
        return [list(line) for line in lines]
    else:
        print('错误，未知词元类型:'+token)

#统计词元频率
def count_conpus(tokens):
    if len(tokens) == 0 or isinstance(tokens[0],list):
        #把词元列表展平成使用词元填充的一个列表
        tokens = [token for line in tokens for token in line]
    return collections.Counter(tokens)

In [106]:
# 整合所有的功能
def load_corpus_file(max_tokens=-1):
    """返回时光机器文本数据集中的词元索引和词汇表"""
    lines = read_file()
    tokens = tokenize(lines,'char')
    vocab = Vocab(tokens)
    #把所有文本行展平到一个列表
    corpus = [vocab[token] for line in tokens for token in line]
    if max_tokens > 0:
        corpus = corpus[:max_tokens]
    return corpus,vocab

In [107]:
#词汇表类
class Vocab:
    def __init__(self,tokens=None,min_freq=0,reserved_token=None):
        if tokens is None:
            tokens = []
        if reserved_token is None:
            reserved_token = []
        #按照出现的频率进行排序
        counter = count_conpus(tokens)
        self.token_freqs = sorted(counter.items(),key=lambda x: x[1],reverse=True)
        #未知的词元索引为0
        self.unk,uniq_tokens = 0,['<unk>'] + reserved_token
        uniq_tokens += [token for token,freq in self.token_freqs if freq >= min_freq and tokens not in uniq_tokens]
        self.idx_to_token,self.token_to_idx = [],dict()
        for token in uniq_tokens:
            self.idx_to_token.append(token)
            self.token_to_idx[token]=len(self.idx_to_token) - 1
    def __len__(self):
        return len(self.idx_to_token)
    def __getitem__(self,tokens):
        if not isinstance(tokens,(list,tuple)):
            return self.token_to_idx.get(tokens,self.unk)
        return [self.__getitem__(token) for token in tokens]
    def to_tokens(self,indices):
        if not isinstance(indices,(list,tuple)):
            return self.idx_to_token[indices]
        return [self.idx_to_token[index] for index in indices]

In [108]:
import random
import torch

#随机采样
def seq_data_iter_random(corpus,batch_size,num_steps):
    #考虑标签，所以-1
    num_subseqs = (len(corpus) -1) // num_steps
    #序列的起始索引
    initial_indices = list(range(random.randint(0,5),num_subseqs * num_steps,num_steps))
    print(initial_indices)
    #为了随机的效果，打乱initial_indices
    random.shuffle(initial_indices)
    def data(pos):
        return corpus[pos:pos+num_steps]
    num_batches = num_subseqs // batch_size
    for i in range(0,batch_size * num_batches,batch_size):
        initial_indices_per_batch = initial_indices[i:i+batch_size]
        #取数据
        x = [data(j) for j in initial_indices_per_batch]
        y = [data(j+1) for j in initial_indices_per_batch]
        yield torch.tensor(x),torch.tensor(y)

#顺序采样
def seq_data_iter_sequential(corpus,batch_size,num_steps):
    #有效tokens长度
    index = random.randint(0,num_steps)
    num_tokens = ((len(corpus) - index -1) // batch_size) * batch_size
    xs = torch.tensor(corpus[index:index + num_tokens])
    ys = torch.tensor(corpus[index + 1:index + num_tokens + 1])
    #print(xs,ys)
    xs,ys = xs.reshape(batch_size,-1),ys.reshape(batch_size,-1)
    
    num_batches = xs.shape[1] // num_steps
    for i in range(0,num_steps * num_batches,num_steps):
        x = xs[:,i:i+num_steps]
        y = ys[:,i:i+num_steps]
        yield x,y

In [109]:
#把两个采样函数包装到类中，方便后续使用
class SeqDataLoader:
    def __init__(self,batch_size,num_steps,use_random_iter,max_tokens):
        if use_random_iter:
            self.data_iter_fn = seq_data_iter_random
        else:
            self.data_iter_fn = seq_data_iter_sequential
        self.corpus,self.vocab = load_corpus_file(max_tokens)
        self.batch_size,self.num_steps = batch_size,num_steps
    def __iter__(self):
        return self.data_iter_fn(self.corpus,self.batch_size,self.num_steps)

In [110]:
def load_data_file(batch_size,num_steps,use_random_iter=False,max_tokens=10000):
    data_iter = SeqDataLoader(batch_size,num_steps,use_random_iter,max_tokens)
    return data_iter,data_iter.vocab

In [111]:
### 加载file数据
batch_size,num_steps = 32,35
train_iter,vocab = load_data_file(batch_size=batch_size,num_steps=num_steps)

In [140]:
#包装成类
class RNNModel(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1):
        super(RNNModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        # 词嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_size)
        # RNN层 (可以选择 LSTM 或 GRU)
        self.rnn = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        # 输出层
        self.fc = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, x, state=None):
        # x shape: (batch_size, seq_len)
        batch_size = x.size(0)
        
        # 嵌入层
        embedded = self.embedding(x)  # (batch_size, seq_len, embed_size)
        
        # RNN前向传播
        if state is None:
            # 初始化隐藏状态
            h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device)
            c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device)
            state = (h0, c0)
        
        output, state = self.rnn(embedded, state)  # output: (batch_size, seq_len, hidden_size)
        display(output.shape,state.shape)
        
        # 全连接层
        output = self.fc(output)  # (batch_size, seq_len, vocab_size)
        
        # 重塑输出以便计算损失
        output = output.reshape(-1, output.shape[2])  # (batch_size * seq_len, vocab_size)
        
        return output, state
    
    def begin_state(self, batch_size, device):
        # 初始化隐藏状态
        return (
            torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device),
            torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        )

In [113]:
import time

In [134]:
def train_model(model, train_loader, vocab_size, num_epochs, lr, device):
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # 将模型移到设备
    model.to(device)
    
    # 记录训练历史
    history = {'loss': [], 'perplexity': [], 'time': []}
    
    for epoch in range(num_epochs):
        model.train()  # 设置训练模式
        total_loss = 0
        total_tokens = 0
        epoch_start_time = time.time()
        
        for batch_idx, (X, Y) in enumerate(train_loader):
            # 移动到设备
            X, Y = X.to(device), Y.to(device)
            
            # 前向传播
            outputs, _ = model(X)
            
            # 计算损失
            loss = criterion(outputs, Y.view(-1))
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            
            # 梯度裁剪（防止梯度爆炸）
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            # 更新参数
            optimizer.step()
            
            # 记录统计信息
            total_loss += loss.item() * Y.numel()
            total_tokens += Y.numel()
        
        # 计算epoch指标
        epoch_time = time.time() - epoch_start_time
        avg_loss = total_loss / total_tokens
        #display(avg_loss)
        perplexity = math.exp(avg_loss)
        tokens_per_sec = total_tokens / epoch_time
        
        # 记录历史
        history['loss'].append(avg_loss)
        history['perplexity'].append(perplexity)
        history['time'].append(epoch_time)
        
        # 打印进度
        if (epoch + 1) % 1 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}] | '
                  f'Loss: {avg_loss:.4f} | '
                  f'Perplexity: {perplexity:.2f} | '
                  f'Speed: {tokens_per_sec:.1f} tokens/sec')
    
    return history

# 辅助函数：分离隐藏状态
def detach_state(state):
    if isinstance(state, tuple):
        return (state[0].detach(), state[1].detach())
    else:
        return state.detach()

In [115]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'使用设备: {device}')

# 准备数据
train_loader, vocab_size = train_iter,len(vocab)
# 模型参数
embed_size = 128
hidden_size = 256
num_layers = 2

# 创建模型
model = RNNModel(vocab_size, embed_size, hidden_size, num_layers)
print(f'模型参数量: {sum(p.numel() for p in model.parameters()):,}')

使用设备: cuda
模型参数量: 932,380


In [141]:
# 训练参数
num_epochs = 1
learning_rate = 0.2
batch_size = 32

# 开始训练
print('开始训练...')
history = train_model(model, train_loader, vocab_size, num_epochs, learning_rate, device)

print('训练完成!')
print(f'最终困惑度: {history["perplexity"][-1]:.2f}')

开始训练...
Epoch [1/1] | Loss: 2.8540 | Perplexity: 17.36 | Speed: 24419.2 tokens/sec
训练完成!
最终困惑度: 17.36


In [None]:
# 预测
def predict(prefix,num_preds,net,vocab,device):
    state = net.begin_state(batch_size=1,device=device)
    outputs = [vocab[prefix[0]]]

    get_input = lambda : torch.tensor([outputs[-1]],device=device).reshape((1,1))
    #预热
    for u in prefix[1:]:
        #print(get_input())
        _,state = net(get_input(),state)
        outputs.append(vocab[u])
    #真正预测
    for _ in range(num_preds):
        y,state = net(get_input(),state)
        #print('真正预测:',y)
        outputs.append(int(y.argmax(dim=1).reshape(1)))
    return ''.join([vocab.idx_to_token[i] for i in outputs])

In [131]:
pred = lambda prefix:predict(prefix,50,model,vocab,device)
pred("It wasn't a sound")

'<unk>t wasn<unk>t a sound the the the the the the the the the the the the t'