In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from datasets import load_dataset
import matplotlib.pyplot as plt
%matplotlib inline


torch.manual_seed(12046)

In [None]:
# 一些超参数
learning_rate = 1e-3
# 如果有GPU，该脚本将使用GPU进行计算
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
raw_datasets = load_dataset('Nan-Do/code-search-net-python')
datasets = raw_datasets['train'].filter(lambda x: 'apache/spark' in x['repo'])

In [None]:
class RNNCell(nn.Module):
  def __init__(self, input_size, hidden_size):
    super().__init__()
    self.inputs_size = input_size
    self.hidden_size = hidden_size
    self.i2h = nn.Linear(self.inputs_size + self.hidden_size, self.hidden_size)

  def forward(self, input, hidden=None):
    # input:  (1, I),在NLP领域，I等于文本嵌入的C
    # hidden: (1, H)
    if hidden is None:
      hidden = self.init_hidden(input.device)
    combined = torch.concat((input, hidden), dim=-1)  # (1, I + H)
    hidden = F.relu(self.i2h(combined))    # (1,    H)
    return hidden

  def init_hidden(self, device):
    return torch.zeros((1, self.hidden_size), device=device)

In [None]:
r_model = RNNCell(2, 3)
data = torch.randn(4, 1, 2) # 四个句子，每个句子的的形状是(1, 2)
hidden = None

# 因为RNN是按序计算，不能并行，所以要手动循环每个句子
for i in range(data.shape[0]):
  hidden = r_model(data[i], hidden)
  print(hidden)

In [None]:
class CharRNN(nn.Module):
  def __init__(self, vs):
    super().__init__()
    self.emb = nn.Embedding(vs, 30)
    self.rnn = RNNCell(30, 50)
    self.lm = nn.Linear(50, vs)

  def forward(self, x, hidden=None):
    # x : (1)
    # hidden : (1, 50)
    embeddings = self.emb(x) # (1, 30)
    hidden = self.rnn(embeddings, hidden)
    out = self.lm(hidden) # (1, vs)
    return out, hidden



In [None]:
class CharTokenizer:

  def __init__(self, data, end_ind=0):
    # data:list[str]
    # 得到所有的字符
    chars = sorted(list(set(''.join(data))))
    self.char2ind = {s : i+1 for i, s in enumerate(chars)}
    self.char2ind['<|e|>'] = end_ind
    self.ind2char = {v : k for k, v in self.char2ind.items()}
    self.end_ind = end_ind

  def encode(self, x):
    return [self.char2ind[c] for c in x]

  def decode(self, x):
    # x : int or list[int]
    if isinstance(x, int):
      return self.ind2char[x]
    return [self.ind2char[i] for i in x]

In [None]:
tokenizer = CharTokenizer(data=datasets['original_string'])
test_str = 'def f(x):'
encode_result = tokenizer.encode(test_str)
decode_result = tokenizer.decode(encode_result)
print(encode_result, ''.join(decode_result))

In [None]:
c_model = CharRNN(len(tokenizer.char2ind)).to(device)

In [None]:
c_model

In [None]:
inputs = torch.tensor(tokenizer.encode('d'), device=device)
out, hidden = c_model(inputs)
out.shape, hidden.shape

In [None]:
@torch.no_grad()
def generate(model, idx, tokenizer, max_new_tokens=300):
  # idx : (1)
  out = idx.tolist()
  hidden = None
  model.eval()
  for _ in range(max_new_tokens):
    logits, hidden = model(idx, hidden)
    probs = F.softmax(logits, dim=-1) #(1, 98)
    # 随机生成文本
    ix = torch.multinomial(probs, num_samples=1) # (1, 1)
    out.append(ix.item())
    idx = ix.squeeze(0)
    if out[-1] == tokenizer.end_ind:
      break
  model.train()
  return out

In [None]:
inputs = torch.tensor(tokenizer.encode('d'), device=device)
''.join(tokenizer.decode(generate(c_model, inputs, tokenizer)))

In [None]:
def process(text, tokenizer):
  # text : str
  enc = tokenizer.encode(text)
  inputs = enc
  labels = enc[1:] + [tokenizer.end_ind]
  return torch.tensor(inputs, device=device), torch.tensor(labels, device=device)

In [None]:
process(test_str, tokenizer)

In [None]:
lossi = []
epochs = 1
optimizer = optim.Adam(c_model.parameters(), lr=learning_rate)

for e in range(epochs):
  for data in datasets:
    inputs, labels = process(data['original_string'], tokenizer)
    hidden = None
    _loss = 0.0
    lens = len(inputs)
    for i in range(lens):
      logits, hidden = c_model(inputs[i].unsqueeze(0), hidden)
      _loss += F.cross_entropy(logits, labels[i].unsqueeze(0)) / lens
    lossi.append(_loss.item())
    optimizer.zero_grad()
    _loss.backward()
    optimizer.step()

In [None]:
inputs = torch.tensor(tokenizer.encode('d'), device=device)
print(''.join(tokenizer.decode(generate(c_model, inputs, tokenizer))))

In [None]:
plt.plot(lossi)

In [None]:
datasets[1]['original_string']