In [184]:
import os
import torchtext
import torch.nn as nn
import torch.utils.data as tud
import torch.nn.functional as F
from collections import Counter

In [208]:
#prepare data
BATCH_SIZE = 32
sequence_len = 50
vocab_size = 30000
train_file, dev_file, test_file = [os.path.join('./data', file) \
                                   for file in ['text8.train', 'text8.dev', 'text8.test']]
train_raw = open(train_file).readlines()[0]
dev_raw = open(dev_file).readlines()[0]
test_raw = open(test_file).readlines()[0]
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


def tokenize(text):
    return text.split(' ')

vocab = Counter(train_raw.split(' ')).most_common(vocab_size - 1)
idx_to_word = [item[0] for item in vocab]
idx_to_word.append('UNK')
word_to_idx = {word: i for i, word in enumerate(idx_to_word)}

class LanguageDataset(tud.Dataset):
    def __init__(self, text, sequence_len, idx_to_word, word_to_idx, vocab_size, device):
        super(LanguageDataset, self).__init__()
        self.device = device
        self.vocab_size = vocab_size
        self.idx_to_word = idx_to_word
        self.word_to_idx = word_to_idx
        self.word_encode = [self.word_to_idx.get(word, self.vocab_size - 1) for word in text]
        self.word_encode = torch.LongTensor(self.word_encode).to(device)
        self.sequence_len = sequence_len
    
    def __len__(self):
        return len(self.word_encode) - self.sequence_len
    
    def __getitem__(self, idx):
        x = self.word_encode[idx: min(idx + self.sequence_len, len(self.word_encode) - 1)]
        y = self.word_encode[idx + 1: min(idx + self.sequence_len + 1, len(self.word_encode))]
        return x, y

In [209]:
#data loader
train_data = LanguageDataset(train_raw, sequence_len, idx_to_word, word_to_idx, vocab_size, device)
dev_data = LanguageDataset(dev_raw, sequence_len, idx_to_word, word_to_idx, vocab_size, device)
test_data = LanguageDataset(test_raw, sequence_len, idx_to_word, word_to_idx, vocab_size, device)
train_iter = tud.DataLoader(train_data, batch_size = BATCH_SIZE, shuffle = True)
dev_iter = tud.DataLoader(dev_data, batch_size = BATCH_SIZE, shuffle = True)
test_iter = tud.DataLoader(test_data, batch_size = BATCH_SIZE, shuffle = True)

In [211]:
for i, (x, y) in enumerate(train_iter):
    print(x.shape)
    print(y.shape)
    print(' '.join([idx_to_word[idx] for idx in x[0]]))
    print(' '.join([idx_to_word[idx] for idx in y[0]]))
    break

torch.Size([32, 50])
torch.Size([32, 50])
a n d UNK p i c n i c s UNK a c r o s s UNK a r g e n t i n a UNK v e g e t a b l e s UNK a n d UNK s a l a d s
n d UNK p i c n i c s UNK a c r o s s UNK a r g e n t i n a UNK v e g e t a b l e s UNK a n d UNK s a l a d s UNK


In [241]:
#model
import torch
embed_size, hidden_size = 300, 1000
class LanguageModel(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super(LanguageModel, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        init_range = 0.5 / vocab_size
        self.embed.weight.data.uniform_(-init_range, init_range)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.hidden_size = hidden_size
    
    def forward(self, x, hidden):
        input_x = self.embed(x)     #batch_size * sequence_len * embedding_size
        output, hidden = self.lstm(input_x, hidden)  #output: batch_size * sequence_len * embedding_size
        output_vocab = self.linear(output)   #output_vocab:  batch_size * sequence_len * vocab_size
        return output_vocab, hidden
    
    def init_hidden(self, bsz, requires_grad=True):
        weight = next(self.parameters())
        return (weight.new_zeros((1, bsz, self.hidden_size), requires_grad=requires_grad),
                    weight.new_zeros((1, bsz, self.hidden_size), requires_grad=requires_grad))
        

In [242]:
#train
model = LanguageModel(vocab_size, embed_size, hidden_size).to(device)
learning_rate = 4e-4
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.5)
requires_grad = False
GRAD_CLIP = 1.0

def repackage_hidden(h):
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)
    
def evaluate(model, data_iter):
    model.eval()
    loss_all = 0.
    count = 0.
    with torch.no_grad():
        hidden = model.init_hidden(BATCH_SIZE, requires_grad=False)
        print(len(data_iter))
        for i, (x, y) in enumerate(data_iter):
            hidden = repackage_hidden(hidden)
            output, hidden = model(x, hidden)
            loss = loss_fn(output.view(-1, vocab_size), y.view(-1))
            loss_all += loss * x.shape[0]
            count += x.shape[0]
    mode.train()
    return loss_all / count
    
dev_loss_list = []
model_path = './best_mode.pth'
for epoch in range(2):
    model.train()
    hidden = model.init_hidden(BATCH_SIZE)
    for i, (x, y) in enumerate(train_iter):
        hidden = repackage_hidden(hidden)
        output, hidden = model(x, hidden)
        loss = loss_fn(output.view(-1, vocab_size), y.view(-1))
        loss.requires_grad_(True)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        optimizer.step()
        if i % 100 == 0:
            print("Epoch: {}, iter: {}, train loss: {}".format(epoch, i, loss))
            
#         if i % 1000 == 0:
#             dev_loss = evaluate(model, dev_iter)
#             print("Epoch: {}, iter: {}, dev loss: {}".format(epoch, i, dev_loss))
#             if len(dev_loss_list) == 0 or dev_loss < min(dev_loss_list):
#                 torch.save(model.state_dict(), model_path)
#             else:
#                 scheduler.step()
#             dev_loss_list.append(dev_loss)

Epoch: 0, iter: 0, train loss: 10.30607795715332
Epoch: 0, iter: 100, train loss: 10.30573844909668
Epoch: 0, iter: 200, train loss: 10.304821968078613
Epoch: 0, iter: 300, train loss: 10.305386543273926
Epoch: 0, iter: 400, train loss: 10.305276870727539
Epoch: 0, iter: 500, train loss: 10.305619239807129
Epoch: 0, iter: 1000, train loss: 10.305541038513184
Epoch: 0, iter: 1100, train loss: 10.304949760437012
Epoch: 0, iter: 1200, train loss: 10.305624008178711
Epoch: 0, iter: 1300, train loss: 10.30508041381836
Epoch: 0, iter: 1400, train loss: 10.304722785949707
Epoch: 0, iter: 1500, train loss: 10.305615425109863
Epoch: 0, iter: 1600, train loss: 10.305436134338379
Epoch: 0, iter: 1700, train loss: 10.305097579956055
Epoch: 0, iter: 1800, train loss: 10.305451393127441
Epoch: 0, iter: 1900, train loss: 10.304823875427246
Epoch: 0, iter: 2000, train loss: 10.305272102355957
Epoch: 0, iter: 2100, train loss: 10.305196762084961
Epoch: 0, iter: 2200, train loss: 10.305206298828125
Epoc

Epoch: 0, iter: 15900, train loss: 10.305420875549316
Epoch: 0, iter: 16000, train loss: 10.304740905761719
Epoch: 0, iter: 16100, train loss: 10.304905891418457
Epoch: 0, iter: 16200, train loss: 10.305919647216797
Epoch: 0, iter: 16300, train loss: 10.305276870727539
Epoch: 0, iter: 16400, train loss: 10.305904388427734
Epoch: 0, iter: 16500, train loss: 10.305615425109863
Epoch: 0, iter: 16600, train loss: 10.304744720458984
Epoch: 0, iter: 16700, train loss: 10.304800987243652
Epoch: 0, iter: 16800, train loss: 10.305512428283691
Epoch: 0, iter: 16900, train loss: 10.304861068725586
Epoch: 0, iter: 17000, train loss: 10.304634094238281
Epoch: 0, iter: 17100, train loss: 10.305066108703613
Epoch: 0, iter: 17200, train loss: 10.30462646484375
Epoch: 0, iter: 17300, train loss: 10.306146621704102
Epoch: 0, iter: 17400, train loss: 10.305891990661621
Epoch: 0, iter: 17500, train loss: 10.305103302001953
Epoch: 0, iter: 17600, train loss: 10.305541038513184
Epoch: 0, iter: 17700, train 

Epoch: 0, iter: 31200, train loss: 10.305219650268555
Epoch: 0, iter: 31300, train loss: 10.305341720581055
Epoch: 0, iter: 31400, train loss: 10.305096626281738
Epoch: 0, iter: 31500, train loss: 10.3058443069458
Epoch: 0, iter: 31600, train loss: 10.304805755615234
Epoch: 0, iter: 31700, train loss: 10.304418563842773
Epoch: 0, iter: 31800, train loss: 10.304895401000977
Epoch: 0, iter: 31900, train loss: 10.305886268615723
Epoch: 0, iter: 32000, train loss: 10.304708480834961
Epoch: 0, iter: 32100, train loss: 10.305288314819336
Epoch: 0, iter: 32200, train loss: 10.305402755737305
Epoch: 0, iter: 32300, train loss: 10.305286407470703
Epoch: 0, iter: 32400, train loss: 10.305922508239746
Epoch: 0, iter: 32500, train loss: 10.305919647216797
Epoch: 0, iter: 32600, train loss: 10.304279327392578
Epoch: 0, iter: 32700, train loss: 10.305517196655273
Epoch: 0, iter: 32800, train loss: 10.305620193481445
Epoch: 0, iter: 32900, train loss: 10.305193901062012
Epoch: 0, iter: 33000, train l

Epoch: 0, iter: 46500, train loss: 10.304723739624023
Epoch: 0, iter: 46600, train loss: 10.305729866027832
Epoch: 0, iter: 46700, train loss: 10.30596923828125
Epoch: 0, iter: 46800, train loss: 10.305035591125488
Epoch: 0, iter: 46900, train loss: 10.304269790649414
Epoch: 0, iter: 47000, train loss: 10.306767463684082
Epoch: 0, iter: 47100, train loss: 10.305059432983398
Epoch: 0, iter: 47200, train loss: 10.304559707641602
Epoch: 0, iter: 47300, train loss: 10.305115699768066
Epoch: 0, iter: 47400, train loss: 10.304896354675293
Epoch: 0, iter: 47500, train loss: 10.305288314819336
Epoch: 0, iter: 47600, train loss: 10.305190086364746
Epoch: 0, iter: 47700, train loss: 10.304950714111328
Epoch: 0, iter: 47800, train loss: 10.305194854736328
Epoch: 0, iter: 47900, train loss: 10.305739402770996
Epoch: 0, iter: 48000, train loss: 10.30505084991455
Epoch: 0, iter: 48100, train loss: 10.304848670959473
Epoch: 0, iter: 48200, train loss: 10.304980278015137
Epoch: 0, iter: 48300, train l

Epoch: 0, iter: 61800, train loss: 10.305171966552734
Epoch: 0, iter: 61900, train loss: 10.306279182434082
Epoch: 0, iter: 62000, train loss: 10.305185317993164
Epoch: 0, iter: 62100, train loss: 10.305039405822754
Epoch: 0, iter: 62200, train loss: 10.304617881774902
Epoch: 0, iter: 62300, train loss: 10.305520057678223
Epoch: 0, iter: 62400, train loss: 10.305477142333984
Epoch: 0, iter: 62500, train loss: 10.305675506591797
Epoch: 0, iter: 62600, train loss: 10.305644989013672
Epoch: 0, iter: 62700, train loss: 10.305680274963379
Epoch: 0, iter: 62800, train loss: 10.30467414855957
Epoch: 0, iter: 62900, train loss: 10.305951118469238
Epoch: 0, iter: 63000, train loss: 10.305655479431152
Epoch: 0, iter: 63100, train loss: 10.305852890014648
Epoch: 0, iter: 63200, train loss: 10.305624961853027
Epoch: 0, iter: 63300, train loss: 10.305525779724121
Epoch: 0, iter: 63400, train loss: 10.304691314697266
Epoch: 0, iter: 63500, train loss: 10.304559707641602
Epoch: 0, iter: 63600, train 

Epoch: 0, iter: 77100, train loss: 10.305015563964844
Epoch: 0, iter: 77200, train loss: 10.305170059204102
Epoch: 0, iter: 77300, train loss: 10.304656028747559
Epoch: 0, iter: 77400, train loss: 10.3049955368042
Epoch: 0, iter: 77500, train loss: 10.306097030639648
Epoch: 0, iter: 77600, train loss: 10.305622100830078
Epoch: 0, iter: 77700, train loss: 10.305209159851074
Epoch: 0, iter: 77800, train loss: 10.304821968078613
Epoch: 0, iter: 77900, train loss: 10.305119514465332
Epoch: 0, iter: 78000, train loss: 10.304692268371582
Epoch: 0, iter: 78100, train loss: 10.306661605834961
Epoch: 0, iter: 78200, train loss: 10.305310249328613
Epoch: 0, iter: 78300, train loss: 10.304648399353027
Epoch: 0, iter: 78400, train loss: 10.305821418762207
Epoch: 0, iter: 78500, train loss: 10.305100440979004
Epoch: 0, iter: 78600, train loss: 10.30553913116455
Epoch: 0, iter: 78700, train loss: 10.305062294006348
Epoch: 0, iter: 78800, train loss: 10.304965019226074
Epoch: 0, iter: 78900, train lo

Epoch: 0, iter: 92400, train loss: 10.305139541625977
Epoch: 0, iter: 92500, train loss: 10.305498123168945
Epoch: 0, iter: 92600, train loss: 10.306621551513672
Epoch: 0, iter: 92700, train loss: 10.30533218383789
Epoch: 0, iter: 92800, train loss: 10.305970191955566
Epoch: 0, iter: 92900, train loss: 10.305500030517578
Epoch: 0, iter: 93000, train loss: 10.305643081665039
Epoch: 0, iter: 93100, train loss: 10.305586814880371
Epoch: 0, iter: 93200, train loss: 10.305874824523926
Epoch: 0, iter: 93300, train loss: 10.30545425415039
Epoch: 0, iter: 93400, train loss: 10.304940223693848
Epoch: 0, iter: 93500, train loss: 10.305245399475098
Epoch: 0, iter: 93600, train loss: 10.305953025817871
Epoch: 0, iter: 93700, train loss: 10.305839538574219
Epoch: 0, iter: 93800, train loss: 10.305904388427734
Epoch: 0, iter: 93900, train loss: 10.305938720703125
Epoch: 0, iter: 94000, train loss: 10.305733680725098
Epoch: 0, iter: 94100, train loss: 10.305489540100098
Epoch: 0, iter: 94200, train l

Epoch: 0, iter: 107500, train loss: 10.305224418640137
Epoch: 0, iter: 107600, train loss: 10.305394172668457
Epoch: 0, iter: 107700, train loss: 10.305127143859863
Epoch: 0, iter: 107800, train loss: 10.306052207946777
Epoch: 0, iter: 107900, train loss: 10.30586051940918
Epoch: 0, iter: 108000, train loss: 10.305109024047852
Epoch: 0, iter: 108100, train loss: 10.305293083190918
Epoch: 0, iter: 108200, train loss: 10.305773735046387
Epoch: 0, iter: 108300, train loss: 10.30512809753418
Epoch: 0, iter: 108400, train loss: 10.304883003234863
Epoch: 0, iter: 108500, train loss: 10.305641174316406
Epoch: 0, iter: 108600, train loss: 10.305787086486816
Epoch: 0, iter: 108700, train loss: 10.30542278289795
Epoch: 0, iter: 108800, train loss: 10.304143905639648
Epoch: 0, iter: 108900, train loss: 10.305015563964844
Epoch: 0, iter: 109000, train loss: 10.305386543273926
Epoch: 0, iter: 109100, train loss: 10.305811882019043
Epoch: 0, iter: 109200, train loss: 10.306024551391602
Epoch: 0, ite

Epoch: 0, iter: 122500, train loss: 10.304662704467773
Epoch: 0, iter: 122600, train loss: 10.305166244506836
Epoch: 0, iter: 122700, train loss: 10.304632186889648
Epoch: 0, iter: 122800, train loss: 10.305739402770996
Epoch: 0, iter: 122900, train loss: 10.305034637451172
Epoch: 0, iter: 123000, train loss: 10.305333137512207
Epoch: 0, iter: 123100, train loss: 10.305584907531738
Epoch: 0, iter: 123200, train loss: 10.305283546447754
Epoch: 0, iter: 123300, train loss: 10.30451774597168
Epoch: 0, iter: 123400, train loss: 10.306100845336914
Epoch: 0, iter: 123500, train loss: 10.304823875427246
Epoch: 0, iter: 123600, train loss: 10.305303573608398
Epoch: 0, iter: 123700, train loss: 10.305248260498047
Epoch: 0, iter: 123800, train loss: 10.30434799194336
Epoch: 0, iter: 123900, train loss: 10.305508613586426
Epoch: 0, iter: 124000, train loss: 10.305286407470703
Epoch: 0, iter: 124100, train loss: 10.304095268249512
Epoch: 0, iter: 124200, train loss: 10.304803848266602
Epoch: 0, it

KeyboardInterrupt: 

In [None]:
test_model = LanguageModel(vocab_size, embed_size, hidden_size)
test_model.load_state_dict(torch.load(model_path))
words_list = []
input_x = torch.randint(vocab_size, (1, 1), dtype=torch.long)
hidden = test_model.init_hidden(1)
for i in range(100):
    output, hidden = model(input_x, hidden)
    y = torch.argmax(output.view(-1))
    input_x.fill_(y)
    word = idx_to_word[y]
    words_list.append(word)
print(' '.join(word_list))

In [92]:
import torch
rnn = nn.LSTM(10, 20, 2) #embedding_size, hidden_size, num_layer
input = torch.randn(5, 3, 10)   #sequence_len, batch_size, embedding_size
h0 = torch.randn(2, 3, 20)   #num_layer, batch_size, hidden_size
c0 = torch.randn(2, 3, 20)   #num_layer, batch_size, hidden_size
output, (hn, cn) = rnn(input, (h0, c0))
#output: sequence_len, batch_size, embedding_size
weights = next(rnn.parameters())