In [1]:
import json
from tqdm import tqdm

In [2]:
with open('w2id+.json', 'r') as f:
    w2id = json.load(f)
with open('id2w+.json', 'r') as f:
    id2w = json.load(f)

In [3]:
data_list = []
with open('data_splited+.jl', 'r') as f:
    for l in f:
        data_list.append(json.loads(l))

In [4]:
import collections
c = collections.Counter()
for d in data_list:
    c[len(d[0])] += 1
c

Counter({7: 605095, 5: 659730, 9: 922, 8: 633, 6: 9593})

In [5]:
dlx = [[] for _ in range(5)]
for d in data_list:
    dlx[len(d[0]) - 5].append(d)

In [6]:
import torch
class MyDataSet(torch.utils.data.Dataset):
    def __init__(self, examples):
        self.examples = examples
    def __len__(self):
        return len(self.examples)
    def __getitem__(self, index):
        example = self.examples[index]
        s1 = example[0]
        s2 = example[1]
        return s1, s2, index

In [7]:
def str2id(s):
    ids = []
    for ch in s:
        if ch in w2id:
            ids.append(w2id[ch])
        else:
            ids.append(0)
    return ids

def the_collate_fn(batch):
    s1x = []
    s2x = []
    for b in batch:
        s1 = str2id(b[0])
        s2 = str2id(b[1])
        s1x.append(s1)
        s2x.append(s2)
    indexs = [b[2] for b in batch]
    s1 = torch.LongTensor(s1x)
    s2 = torch.LongTensor(s2x)
    return s1, s2, indexs

In [8]:
batch_size = 256
data_workers = 4

dldx = []
for d in dlx:
    ds = MyDataSet(d)
    dld = torch.utils.data.DataLoader(
        ds,
        batch_size=batch_size,
        shuffle = True,
        num_workers=data_workers,
        collate_fn=the_collate_fn,
    )
    dldx.append(dld)

In [9]:
len(dldx)

5

In [10]:
import torch.nn as nn
import torch.nn.functional as F
class LSTMModel(nn.Module):
    def __init__(self, device, word_size, embedding_dim=256, hidden_dim=256):
        super(LSTMModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.device = device
        self.embedding = nn.Embedding(word_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=3, bidirectional=True, batch_first=True)
        self.out = nn.Linear(hidden_dim*2, word_size)
        
    def forward(self, s1, s2=None):
        batch_size, length = s1.shape[:2]
        b = self.embedding(s1)
        l = self.lstm(b)[0]
        r = self.out(l)
        r = F.log_softmax(r, dim=1)
        if s2 is not None:
            criterion = nn.NLLLoss()
            loss = criterion(r.view(batch_size*length, -1), s2.view(batch_size*length))
            return loss
        return r

In [11]:
device = torch.device('cuda')
model = LSTMModel(device, len(w2id))
model.to(device)

LSTMModel(
  (embedding): Embedding(7042, 256)
  (lstm): LSTM(256, 256, num_layers=3, batch_first=True, bidirectional=True)
  (out): Linear(in_features=512, out_features=7042, bias=True)
)

In [12]:
import torch.optim as optim
learning_rate = 0.05
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
for e in range(10):
    print(e)
    loss_sum = 0
    c = 0
    xxx = [x.__iter__() for x in dldx]
    j = 0
    for i in tqdm(range((len(data_list)//batch_size) + 5)):
        if len(xxx) == 0:
            print('Done')
            break
        j = j % len(xxx)
        try:
            batch = xxx[j].__next__()
        except StopIteration:
            xxx.pop(j)
            continue
        j += 1
        s1, s2, index = batch
        s1 = s1.to(device)
        s2 = s2.to(device)
        loss = model(s1, s2)
        loss_sum += loss.item()
        c += 1
        loss.backward()
        optimizer.step()
    print(loss_sum / c)

0


100%|██████████| 4989/4989 [02:23<00:00, 34.69it/s]

1.4867789164708634
1



100%|██████████| 4989/4989 [02:24<00:00, 34.49it/s]

1.5007754878749102
2



100%|██████████| 4989/4989 [02:25<00:00, 34.36it/s]

1.6694258602117464
3



100%|██████████| 4989/4989 [02:25<00:00, 34.32it/s]

1.713665906644035
4



100%|██████████| 4989/4989 [02:25<00:00, 34.23it/s]

1.7680301611019353
5



100%|██████████| 4989/4989 [02:25<00:00, 34.27it/s]

1.8241329058003402
6



100%|██████████| 4989/4989 [02:25<00:00, 34.30it/s]

1.8315474072813152
7



100%|██████████| 4989/4989 [02:26<00:00, 34.16it/s]

1.8654739562582707
8



100%|██████████| 4989/4989 [02:25<00:00, 34.36it/s]

1.901343455764212
9



100%|██████████| 4989/4989 [02:25<00:00, 34.23it/s]

1.9380225642872908





In [13]:
def t2s(t):
    l = t.cpu().tolist()
    print(l)
    r = [id2w[x] for x in l[0]]
    return ''.join(r)

def get_next(s):
    ids = torch.LongTensor(str2id(s))
    print(ids)
    ids = ids.unsqueeze(0).to(device)
    r = model(ids)
    r = r.argmax(dim=2)
    return t2s(r)
    
        

In [14]:
get_next('我从高山走来')

tensor([ 92, 551, 257,   6,  60,  65])
[[6706, 2319, 1777, 1291, 6653, 3326]]


'脗宙羸互竛湄'

In [97]:
get_next('好好学习')

tensor([ 962,  962, 1432, 1983])
[[6706, 3502, 742, 6135]]


'脗鹧而暹'

In [98]:
get_next('白日依山尽')

tensor([   1,  106, 1510,    6,   84])
[[6706, 854, 463, 742, 6135]]


'脗遣兮而暹'