### LSTM Model
长短期记忆网络

In [2]:
!pip install datasets



In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
import matplotlib.pyplot as plt
%matplotlib inline


torch.manual_seed(12046)

<torch._C.Generator at 0x1f25353ae90>

In [2]:
# 定义超参
learning_rate = 1e-3
eval_iters = 10
batch_size=1000
seq_len=64
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
raw_datasets = load_dataset('code_search_net', 'python')
datasets = raw_datasets['train'].filter(lambda x: 'apache/spark' in x['repository_name'])

In [4]:
class CharTokenizer:
    def __init__(self, data, end_ind=0):
        # data : list[str]
        chars = sorted(list(set("".join(data))))
        self.char2ind = {char: i + 1 for i, char in enumerate(chars)}
        # self.char2ind["<|b|>"] = begin_ind
        self.char2ind["<|e|>"] = end_ind
        self.ind2char = {i: char for char, i in self.char2ind.items()}
        # self.begin_ind = begin_ind
        self.end_ind = end_ind

    def encode(self, text):
        # text : str
        return [self.char2ind[i] for i in text]

    def decode(self, inds):
        # inds : list[int] or int
        if isinstance(inds, int):
            inds = [inds]
        return [self.ind2char[i] for i in inds]

In [5]:
tokenizer = CharTokenizer(datasets["whole_func_string"])

In [None]:
@torch.no_grad()
def generate(model, context, tokenizer, max_tokens=300):
    # context: (1, T)
    out = context.tolist()[0]  # (T, )
    model.eval()
    for _ in range(max_tokens):
        logits = model(context)  # (1, T, 98)
        probs = F.softmax(logits[:, -1, :], dim=-1)  # (1, T, 98) -> (1, 98)
        next_token = torch.multinomial(probs, 1)  # (1, 1)
        out.append(next_token.item())
        context = torch.concat([context, next_token], dim=-1)  # (1, T + 1)
        if out[-1] == tokenizer.end_ind:
            break
    model.train()
    return out

In [7]:
# 准备数据
def process(data, tokenizer, seq_len=seq_len):
    text = data["whole_func_string"]
    inputs, labels = [], []
    for t in text:
        # t: str
        enc = tokenizer.encode(t)
        enc += [tokenizer.end_ind]
        for i in range(len(enc) - seq_len):
            inputs.append(enc[i : i + seq_len])
            labels.append(enc[i + 1 : i + seq_len + 1])
    return {"inputs": inputs, "labels": labels}

In [8]:
# 将数据分为训练集和测试集
tokenized = datasets.train_test_split(test_size=0.1, seed=1024, shuffle=True)

f = lambda x: process(x, tokenizer)
tokenized = tokenized.map(f, batched=True, remove_columns=datasets.column_names)
tokenized.set_format(type='torch', device=device)

tokenized['train']['inputs'].shape, tokenized['train']['labels'].shape, tokenized['test']['inputs'].shape, tokenized['test']['labels'].shape

(torch.Size([605913, 64]),
 torch.Size([605913, 64]),
 torch.Size([62633, 64]),
 torch.Size([62633, 64]))

In [9]:
train_loader = DataLoader(tokenized['train'], batch_size=batch_size, shuffle=True)
test_loader = DataLoader(tokenized['test'], batch_size=batch_size, shuffle=True)

In [11]:
criterion = F.cross_entropy

# 计算损失
def estimate_loss(model):
    model.eval()
    re = {}
    with torch.no_grad():
        re['train'] = _loss(model, train_loader)
        re['test'] = _loss(model, test_loader)
    model.train()
    return re

@torch.no_grad()
def _loss(model, dataloader):
    total_loss = []
    data_iter = iter(dataloader)
    for k in range(eval_iters): # 手动控制批次数量
        data = next(data_iter, None)
        if data is None:
            data_iter = iter(dataloader)
            data = next(data_iter, None)
        inputs, labels = data['inputs'], data['labels']
        logits = model(inputs)                               # (B, T, vs)
        loss = criterion(logits.transpose(-2, -1), labels)
        total_loss.append(loss.item())
    return torch.tensor(total_loss).mean().item()

In [12]:
def train(model, optimizer, epochs=10):
    lossi = []
    for epoch in range(epochs):
        for i, data in enumerate(train_loader):
            inputs, labels = data['inputs'], data['labels']  # (B, T)
            optimizer.zero_grad()
            logits = model(inputs)
            loss = F.cross_entropy(logits.transpose(-2, -1), labels)  # (B, T, vs) -> (B, vs, T)
            loss.backward()
            optimizer.step()
            lossi.append(loss.item())
        stats = estimate_loss(model)
        train_loss = f"{stats['train']:.4f}"
        test_loss = f"{stats['test']:.4f}"
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss} | Test Loss: {test_loss}")
    return lossi

In [13]:
class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        combined_size = input_size + hidden_size
        self.forget_gate = nn.Linear(combined_size, hidden_size)
        self.in_gate = nn.Linear(combined_size, hidden_size)
        self.new_cell_state = nn.Linear(combined_size, hidden_size)
        self.out_gate = nn.Linear(combined_size, hidden_size)

    def forward(self, input, state=None):
        # input: (B, I)
        # state: ((B, H), (B, H))
        B = input.shape[0]
        if state is None:
            state = self.init_state(B, input.device)
        hs, cs = state
        combined = torch.concat((input, hs), dim=-1) # (B, I + H)
        # 更新细胞状态
        forgetgate = F.sigmoid(self.forget_gate(combined))
        ingate = F.sigmoid(self.in_gate(combined))
        ncs = F.tanh(self.new_cell_state(combined))
        cs = (cs * forgetgate) + (ncs * ingate)
        # 更新隐藏状态
        outgate = F.sigmoid(self.out_gate(combined))
        hs = F.tanh(cs) * outgate

        return hs, cs
    
    def init_state(self, B, device):
        hs = torch.zeros((B, self.hidden_size), device=device)
        cs = torch.zeros((B, self.hidden_size), device=device)
        return hs, cs

In [14]:
l_cell = LSTMCell(3, 4)
x = torch.randn(5, 3)
a, b = l_cell(x)
a.shape, b.shape

(torch.Size([5, 4]), torch.Size([5, 4]))

In [15]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.cell = LSTMCell(input_size, hidden_size)

    def forward(self, input, state=None):
        # input: (B, T, C)
        # state: ((B, H), (B, H))
        # output: (B, T, H)
        B, T, C = input.shape
        re = [] # 输出列表
        for i in range(T):
            state = self.cell(input[:, i, :], state)
            re.append(state[0]) # T 个 (B, H) 把 hidden 取出来
        return torch.stack(re, dim=1) # (B, T, H)

In [16]:
def test_lstm():
    '''
    测试LSTM实现的准确性
    '''
    # 随机生成模型结构
    B, T, input_size, hidden_size, num_layers = torch.randint(1, 20, (5,)).tolist()
    ref_model = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True)
    # 随机生成输入
    inputs = torch.randn(B, T, input_size)
    hs, cs = torch.randn((2 * num_layers, B, hidden_size)).chunk(2, 0)
    re = inputs
    # 取出模型参数
    for layer_index in range(num_layers):
        l = ref_model.all_weights[layer_index]
        if layer_index == 0:
            model = LSTM(input_size, hidden_size)
        else:
            model = LSTM(hidden_size, hidden_size)
        i, f, c, o = torch.cat((l[0], l[1]), dim=1).chunk(4, 0)
        ib, fb, cb, ob = (l[2] + l[3]).chunk(4, 0)
        # 设置模型参数
        model.cell.in_gate.weight = nn.Parameter(i)
        model.cell.in_gate.bias = nn.Parameter(ib)
        model.cell.forget_gate.weight = nn.Parameter(f)
        model.cell.forget_gate.bias = nn.Parameter(fb)
        model.cell.new_cell_state.weight = nn.Parameter(c)
        model.cell.new_cell_state.bias = nn.Parameter(cb)
        model.cell.out_gate.weight = nn.Parameter(o)
        model.cell.out_gate.bias = nn.Parameter(ob)
        # 计算隐藏状态
        re = model(re, (hs[layer_index], cs[layer_index]))
    ref_re, _ = ref_model(inputs, (hs, cs))
    # 验证计算结果（最后一层的隐藏状态是否一致）
    out = torch.all(torch.abs(re - ref_re) < 1e-4)
    return out, (B, T, input_size, hidden_size, num_layers)

test_lstm()

(tensor(True), (5, 17, 1, 7, 15))

In [17]:
class CharLSTM(nn.Module):
    def __init__(self, vs):
        super().__init__()
        self.emb_size = 256
        self.hidden_size = 128
        self.embedding = nn.Embedding(vs, self.emb_size)
        self.dp = nn.Dropout(0.4)
        self.lstm1 = LSTM(self.emb_size, self.hidden_size)
        self.ln1 = nn.LayerNorm(self.hidden_size)
        self.lstm2 = LSTM(self.hidden_size, self.hidden_size)
        self.ln2 = nn.LayerNorm(self.hidden_size)
        self.lstm3 = LSTM(self.hidden_size, self.hidden_size)
        self.ln3 = nn.LayerNorm(self.hidden_size)
        self.lm = nn.Linear(self.hidden_size, vs)
    
    def forward(self, x):
        # x: (B, T)
        emb = self.embedding(x) # (B, T, C)
        h = self.ln1(self.dp(self.lstm1(emb))) # (B, T, H)
        h = self.ln2(self.dp(self.lstm2(h)))
        h = self.ln3(self.dp(self.lstm3(h)))
        out = self.lm(h) # (B, T, vs)
        return out

In [18]:
l_model = CharLSTM(len(tokenizer.ind2char)).to(device)
l_model

CharLSTM(
  (embedding): Embedding(98, 256)
  (dp): Dropout(p=0.4, inplace=False)
  (lstm1): LSTM(
    (cell): LSTMCell(
      (forget_gate): Linear(in_features=384, out_features=128, bias=True)
      (in_gate): Linear(in_features=384, out_features=128, bias=True)
      (new_cell_state): Linear(in_features=384, out_features=128, bias=True)
      (out_gate): Linear(in_features=384, out_features=128, bias=True)
    )
  )
  (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (lstm2): LSTM(
    (cell): LSTMCell(
      (forget_gate): Linear(in_features=256, out_features=128, bias=True)
      (in_gate): Linear(in_features=256, out_features=128, bias=True)
      (new_cell_state): Linear(in_features=256, out_features=128, bias=True)
      (out_gate): Linear(in_features=256, out_features=128, bias=True)
    )
  )
  (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (lstm3): LSTM(
    (cell): LSTMCell(
      (forget_gate): Linear(in_features=256, out_features=128, bias=Tru

In [19]:
context = torch.tensor(tokenizer.encode('def'), device=device).unsqueeze(0)
print(''.join(tokenizer.decode(generate(l_model, context, tokenizer))))

def*$K(h]Rf("YS{BZE G|uw=3:1'_$=z9NN[{KQ=CK|AM:iKca"|+Q3j<sA!!WS$I8tx!q*T3"?u?5Za)'W*5\rm&B"T{Y
cdtM\SDAx1zk<|e|>


In [20]:
estimate_loss(l_model)

{'train': 4.749468803405762, 'test': 4.756665229797363}

In [21]:
l = train(l_model, optim.Adam(l_model.parameters(), lr=learning_rate))

Epoch 1/10 | Train Loss: 1.2832 | Test Loss: 1.4375
Epoch 2/10 | Train Loss: 1.1272 | Test Loss: 1.3084
Epoch 3/10 | Train Loss: 1.0798 | Test Loss: 1.2798
Epoch 4/10 | Train Loss: 1.0035 | Test Loss: 1.2374
Epoch 5/10 | Train Loss: 0.9708 | Test Loss: 1.2185
Epoch 6/10 | Train Loss: 0.9526 | Test Loss: 1.2099
Epoch 7/10 | Train Loss: 0.9429 | Test Loss: 1.1949
Epoch 8/10 | Train Loss: 0.9169 | Test Loss: 1.1891
Epoch 9/10 | Train Loss: 0.9071 | Test Loss: 1.1910
Epoch 10/10 | Train Loss: 0.9028 | Test Loss: 1.2064
