In [1]:
import torch
import numpy as np

用LSTM去学习ABCDE……ZABC…… 这个字母顺序，给出一个字母，可以说出下一个字母

In [2]:

text='ABCDEFGHIJKLMNOPQRSTUVWXYZ'
# if there are some repeat values, we need to use set.
tset = set(text)
# char to index
c2i = {c:i for i,c in enumerate(text)}
VOCAB_N = len(tset)
print(VOCAB_N, c2i)

26 {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5, 'G': 6, 'H': 7, 'I': 8, 'J': 9, 'K': 10, 'L': 11, 'M': 12, 'N': 13, 'O': 14, 'P': 15, 'Q': 16, 'R': 17, 'S': 18, 'T': 19, 'U': 20, 'V': 21, 'W': 22, 'X': 23, 'Y': 24, 'Z': 25}


In [3]:
device = torch.device("cpu")

def FloatTensor(x):
    return torch.tensor(x, device=device, dtype=torch.float)

def LongTensor(x):
    return torch.tensor(x, device=device, dtype=torch.long)

class Network(torch.nn.Module):
    def __init__(self, vocab_num = VOCAB_N, input=1, output=1, hidden=32):
        torch.nn.Module.__init__(self)
        self.eb = torch.nn.Embedding(vocab_num, hidden)
        self.l1 = torch.nn.LSTM(hidden, hidden, 1, batch_first=True)
        self.l2 = torch.nn.Linear(hidden, output)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.eb(x)
        x, h = self.l1(x)
        x = self.l2(x)
        x = self.relu(x) # x >= 0
        return x


m = Network()
loss_f = torch.nn.MSELoss()
opt = torch.optim.Adam(m.parameters(), 0.001)

In [4]:
tensor_text = LongTensor([c2i[c] for c in text])
tensor_textn = FloatTensor([c2i[c] for c in text])
m.train()

losses = []
for i in range(8000+1):
    ix = i % (VOCAB_N)
    # character
    thist = tensor_text[ix]
    # next character，if ends, next is A
    inx = ix + 1 if ix < (VOCAB_N-1) else 0
    # predict value
    pn = m(LongTensor([[thist]]))
    # true value
    nextt = tensor_textn[inx]
    # learn
    loss = loss_f(nextt, pn)
    opt.zero_grad()
    loss.backward()
    opt.step()

    # plot average loss:
    losses.append(loss.item())
    if i % 50 == 0:
        eva_loss = np.average(losses)
        print('epoch {}, loss:{}'.format(i, eva_loss))
        losses = []
    


epoch 0, loss:1.0
epoch 50, loss:220.72453033447266
epoch 100, loss:194.67828525543212
epoch 150, loss:192.6551647886634
epoch 200, loss:184.24432769969107
epoch 250, loss:170.74763990892097
epoch 300, loss:150.59333532601596
epoch 350, loss:128.01101777556556
epoch 400, loss:108.24093252718448
epoch 450, loss:88.46506323248148
epoch 500, loss:70.82479397444054
epoch 550, loss:56.856736540347335
epoch 600, loss:46.03324370814487
epoch 650, loss:37.86420858413796
epoch 700, loss:31.479401640412753
epoch 750, loss:22.025321021117747
epoch 800, loss:20.78806176149752
epoch 850, loss:18.947297332054003
epoch 900, loss:17.487916020080448
epoch 950, loss:15.285903479396366
epoch 1000, loss:13.323001085188007
epoch 1050, loss:11.446214171423271
epoch 1100, loss:9.842094578262477
epoch 1150, loss:8.670699112225556
epoch 1200, loss:7.580323329973035
epoch 1250, loss:6.64620992921009
epoch 1300, loss:5.876570167769805
epoch 1350, loss:5.219863151721402
epoch 1400, loss:3.323397108706704
epoch 14

In [11]:
# try it.
m.eval()
chars = ['A']
for i in range(51):
    char_idx = c2i[chars[-1]]
    pn = m(LongTensor([[char_idx]]))
    ti = round(pn.item())
    nt = text[ti] if ti < VOCAB_N else 'A'
    chars.append(nt)

print(''.join(chars))

ABCDEFGHIJKLMNOPQRSTUVWXYZABCDEFGHIJKLMNOPQRSTUVWXYZ
