In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import load_dataset
import matplotlib.pyplot as plt
%matplotlib inline


torch.manual_seed(12046)

In [None]:
# 一些超参数
learning_rate = 1e-3
eval_iters = 10
batch_size=1000
sequence_len=64
# 如果有GPU，该脚本将使用GPU进行计算
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
raw_datasets = load_dataset("Nan-Do/code-search-net-python")
datasets = raw_datasets['train'].filter(lambda x: 'apache/spark' in x['repo'])

In [None]:
class CharTokenizer:

    def __init__(self, data, end_ind=0):
        # data: list[str]
        # 得到所有的字符
        chars = sorted(list(set(''.join(data))))
        self.char2ind = {s: i + 1 for i, s in enumerate(chars)}
        self.char2ind['<|e|>'] = end_ind
        self.ind2char = {v: k for k, v in self.char2ind.items()}
        self.end_ind = end_ind

    def encode(self, x):
        # x: str
        return [self.char2ind[i] for i in x]

    def decode(self, x):
        # x: int or list[x]
        if isinstance(x, int):
            return self.ind2char[x]
        return [self.ind2char[i] for i in x]

tokenizer = CharTokenizer(datasets['original_string'])

In [None]:
test_str = 'def f(x):'
re = tokenizer.encode(test_str)
print(re)
''.join(tokenizer.decode(range(len(tokenizer.char2ind))))

In [None]:
class RNN(nn.Module):
  def __init__(self, input_size, hidden_size):
    super().__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.i2h = nn.Linear(input_size + hidden_size, hidden_size)

  def forward(self, input, hidden=None):
    # input (B, T, C)
    # hidden (B,  H)
    # out (B, T, H)
    B, T, C = input.shape
    re = []
    if hidden is None:
      hidden = self.init_hidden(B, input.device)
    for i in range(T):
      combined = torch.concat((input[:, i, :], hidden), dim=-1) # (B, C+H)
      hidden = self.i2h(combined) # (B, H)
      re.append(hidden)
    return torch.stack(re, dim=1) # (B, T, H)

  def init_hidden(self, B, device):
    return torch.zeros((B, self.hidden_size), device=device)

In [None]:
r = RNN(3, 4)
x = torch.randn(5, 2, 3)
r(x).shape

In [None]:
class CharRNNBatch(nn.Module):
  def __init__(self, vs):
    super().__init__()
    emb_size = 256
    hidden_size = 128
    self.emb = nn.Embedding(vs, emb_size)
    self.rnn1 = RNN(emb_size, hidden_size)
    self.ln1 = nn.LayerNorm(hidden_size)
    self.rnn2 = RNN(hidden_size, hidden_size)
    self.ln2 = nn.LayerNorm(hidden_size)
    self.lm = nn.Linear(hidden_size, vs)
    self.dp = nn.Dropout(0.4)

  def forward(self, x):
    # x : (B, T)
    B = x.shape[0]
    embeddings = self.emb(x)  # (B, T, C)
    h = F.relu(self.ln1(self.rnn1(embeddings))) # (B, T， H)
    h = self.dp(h)
    h = F.relu(self.ln2(self.rnn2(h)))      # (B, T, hidden_size)
    h = self.dp(h)
    out = self.lm(h)    # (B, T, vs)
    return out

In [None]:
c_model = CharRNNBatch(len(tokenizer.char2ind)).to(device)
c_model

In [None]:
@torch.no_grad()
def generate(model, context, tokenizer, max_new_tokens=300):
  # context : (1, T)
  out = context.tolist()[0]
  model.eval()
  for _ in range(max_new_tokens):
    logits = model(context)   # (1, T, vs98)
    probs = F.softmax(logits[:, -1, :], dim=-1) # (1, vs98)
    # 随机生成文本
    ix = torch.multinomial(probs, num_samples=1) # (1, 1)
    # 更新背景
    context = torch.concat((context, ix), dim=-1)
    out.append(ix.item())
    if out[-1] == tokenizer.end_ind:
      break
  model.train()
  return out

In [None]:
context = torch.tensor(tokenizer.encode('def'), device=device).unsqueeze(0)
print(''.join(tokenizer.decode(generate(c_model, context, tokenizer))))

In [None]:
def process(data, tokenizer, sequence_len=sequence_len):
  text = data['original_string']
  # text is list[str]
  inputs, labels = [], []
  for t in text:
    enc = tokenizer.encode(t)
    enc += [tokenizer.end_ind]
    for i in range(len(enc) - sequence_len):
      inputs.append(enc[i : i+sequence_len])
      labels.append(enc[i+1 : i+1+sequence_len])
  return {'inputs' : inputs, 'labels' : labels}

In [None]:
# 将数据分为训练集和测试集
tokenized = datasets.train_test_split(test_size=0.1, seed=1024, shuffle=True)
f = lambda x : process(x, tokenizer)
tokenized = tokenized.map(f, batched=True, remove_columns=datasets.column_names)
tokenized.set_format(type='torch', device=device)


In [None]:
tokenized

In [None]:
train_loader = DataLoader(tokenized['train'], batch_size=batch_size, shuffle=True)
test_loader = DataLoader(tokenized['test'], batch_size=batch_size, shuffle=True)

In [None]:
next(iter(train_loader))

In [None]:
def estimate_loss(model):
  re = {}
  # 将模式切换至评估模式
  model.eval()
  re['train'] = _loss(model, train_loader)
  re['test'] = _loss(model, test_loader)
  # 将模型切换至训练模式
  model.train()
  return re


def _loss(model, data_loader):
  loss = []
  data_iter = iter(data_loader)

  # 随机使用多个批量数据来评估模型效果
  for k in range(eval_iters):
    data = next(data_iter)
    if data is None:
      data_iter = iter(data_loader)
      data = next(data_iter)
    inputs, labels = data['inputs'], data['labels'] # (B, T)
    logits = model(inputs) # (B, T, vs)
    loss.append(F.cross_entropy(logits.transpose(-2, -1), labels).item())
  return torch.tensor(loss).mean().item()

In [None]:
estimate_loss(c_model)

In [None]:
def train_model(model, optimizer, epochs=10):
  lossi = []
  for epoch in range(epochs):
    for i, data in enumerate(train_loader, 0):
      inputs, labels = data['inputs'], data['labels'] # (B, T)
      logits = model(inputs) # (B, T, vs)
      optimizer.zero_grad()
      loss = F.cross_entropy(logits.transpose(-2, -1), labels)
      lossi.append(loss.item())
      loss.backward()
      optimizer.step()
    # 评估模型，并输出结果
    stats = estimate_loss(model)
    train_loss = f'train loss {stats["train"]:.4f}'
    test_loss = f'test loss {stats["test"]:.4f}'
    print(f'epoch {epoch:>2}: {train_loss}, {test_loss}')
  return lossi


In [None]:
l = train_model(c_model, optim.Adam(c_model.parameters(), lr=learning_rate))

In [None]:
context = torch.tensor(tokenizer.encode('def'), device=device).unsqueeze(0)
print(''.join(tokenizer.decode(generate(c_model, context, tokenizer))))

In [None]:
plt.plot(torch.tensor(l).view(-1, 10).mean(dim=-1))