# Project - Text Generation by LSTM

In [20]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

## Load data and Embedding

In [51]:
# First, build the word dictionary
file ='./data/ptb.train.txt'

words = []
i = 0
with open(file, encoding ='UTF8') as f:
    for line in f:
        line = "BOS" + line + "EOS"
        for word in line.split():
            words.append(word)

words_unique = list(set(words))

In [62]:
sens = []

with open(file, encoding ='UTF8') as f:
    for line in f:
        line = "BOS" + line + "EOS"
        sen = []
        for word in line.split():
            index = words_unique.index(word)
            sen.append(index)
            
        sens.append(sen)

print(sens[:5])

[[7742, 849, 8715, 5052, 1749, 8421, 5972, 3319, 840, 2724, 4369, 3258, 2331, 8837, 3352, 1928, 3813, 377, 612, 756, 4055, 4756, 8902, 7227, 8366, 5893], [7742, 8371, 5332, 560, 7114, 2377, 3492, 7349, 5644, 4247, 3123, 1052, 2726, 9777, 3001, 560, 5893], [7742, 6627, 5332, 4733, 5324, 4682, 5332, 8780, 5644, 8285, 4502, 4367, 5893], [7742, 9784, 5332, 560, 7114, 2377, 4774, 7529, 5324, 4682, 6120, 6960, 7520, 7518, 3212, 4324, 1052, 2726, 9777, 4682, 7078, 1825, 3861, 397, 5893], [7742, 1052, 5088, 4682, 5844, 5884, 2254, 8042, 7497, 6594, 5544, 4505, 6849, 4133, 1052, 9164, 5467, 4682, 9387, 3922, 8275, 1052, 4367, 4682, 3685, 1657, 8042, 1941, 273, 703, 560, 7114, 7730, 6009, 5445, 5893]]


In [68]:
n = len(words_unique)

embedding = nn.Embedding(n, 10)
embedding_list = []
# Instead of embedding by batch, we do embedding one by one 
for sen in sens:
    sen_tensor = torch.LongTensor(sen)
    sen_embedding = embedding(sen_tensor)
    embedding_list.append(sen_embedding)
    
print(embedding_list[:5])

[tensor([[-0.6558,  1.6667, -0.8407, -0.8869,  1.4735, -0.8498, -1.6394, -1.1189,
          0.1925,  0.0803],
        [-1.5767,  1.6287, -1.5840,  0.3921, -0.6270, -0.2382,  1.0209, -0.3900,
          1.3832, -0.9788],
        [ 0.4930,  0.7241, -1.3401, -0.3061, -0.2785,  0.9104,  0.3124, -0.0653,
         -0.9939,  0.4393],
        [-1.1195,  0.9449, -2.0204,  1.9142,  0.2467,  0.7554, -1.0662,  0.7212,
          1.1440,  1.1643],
        [-1.3421, -0.6325, -1.0215,  0.2067, -1.0000, -0.4908, -0.2701,  0.2356,
         -0.3251,  0.4515],
        [-0.8523, -0.4859, -1.2666,  0.4692, -1.9282, -0.9697, -0.6895, -1.1087,
         -1.6011,  0.6621],
        [ 0.2322, -1.2437, -0.0425, -0.9900, -0.2425,  0.7900,  0.8657,  1.2571,
         -0.3929, -0.5492],
        [ 0.7899, -1.1164, -0.5443, -1.0058,  0.3683,  1.1855, -0.4950, -0.6990,
         -1.0344, -0.5472],
        [ 0.7071,  1.5260, -0.0638, -2.1414,  0.6345,  0.7466, -1.7030,  1.3502,
          0.2942,  1.4578],
        [ 0.8016, 