<a href="https://colab.research.google.com/github/sourcecode369/transformers-tutorials/blob/master/lstm/Bi_LSTM_Predict_Next_Word_in_Long_Sentence.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
sentence = (
            "Artificial Intelligence involves using computers to do things that traditionally require human intelligence "
            "This means creating algorithms to classify analyze and draw predictions from data "
            "It also involves acting on data learning from new data and improving over time "
            "Just like a tiny human child growing up into a smarter human adult "
            "And like humans AI is not perfect "
)

In [None]:
word_dict = {w:i for i, w in enumerate(list(set(sentence.split())))}
number_dict = {i:w for i, w in enumerate(list(set(sentence.split())))}
n_class = len(word_dict)
n_hidden = 10
max_len = len(sentence.split())

In [None]:
class BiLSTM(nn.Module):
    def __init__(self):
        super(BiLSTM, self).__init__()

        self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden, bidirectional=True)
        self.W = nn.Linear(n_hidden * 2, n_class, bias=False)
        self.b = nn.Parameter(torch.ones([n_class]))

    def forward(self, X):
        input = X.transpose(0, 1)  # input : [n_step, batch_size, n_class]

        hidden_state = torch.zeros(1*2, len(X), n_hidden)   # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]
        cell_state = torch.zeros(1*2, len(X), n_hidden)     # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]

        outputs, (_, _) = self.lstm(input, (hidden_state, cell_state))
        outputs = outputs[-1]  # [batch_size, n_hidden]
        model = self.W(outputs) + self.b  # model : [batch_size, n_class]
        return model

In [None]:
def make_batch():
    input_batch = []
    target_batch = []

    words = sentence.split()
    for i, word in enumerate(words[:-1]):
        input = [word_dict[n] for n in words[:(i + 1)]]
        input = input + [0] * (max_len - len(input))
        target = word_dict[words[i + 1]]
        input_batch.append(np.eye(n_class)[input])
        target_batch.append(target)

    return input_batch, target_batch          

In [None]:
input_batch, target_batch = make_batch()
input_batch = torch.FloatTensor(input_batch)
target_batch = torch.LongTensor(target_batch)

In [None]:
model = BiLSTM()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
for epoch in range(5000):
    optimizer.zero_grad()
    outputs = model(input_batch)
    loss = criterion(outputs, target_batch)
    if (epoch + 1) % 500 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
    loss.backward()
    optimizer.step()

Epoch: 0500 cost = 3.096175
Epoch: 1000 cost = 2.469000
Epoch: 1500 cost = 2.287508
Epoch: 2000 cost = 1.783772
Epoch: 2500 cost = 1.940998
Epoch: 3000 cost = 1.487094
Epoch: 3500 cost = 1.375398
Epoch: 4000 cost = 1.300185
Epoch: 4500 cost = 1.247313
Epoch: 5000 cost = 1.199682


In [None]:
predict = model(input_batch).data.max(1, keepdim=True)[1]
print(sentence)
print([number_dict[n.item()] for n in predict.squeeze()])

Artificial Intelligence involves using computers to do things that traditionally require human intelligence This means creating algorithms to classify analyze and draw predictions from data It also involves acting on data learning from new data and improving over time Just like a tiny human child growing up into a smarter human adult And like humans AI is not perfect 
['to', 'to', 'to', 'new', 'to', 'draw', 'that', 'that', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'data', 'that', 'from', 'from', 'data', 'from', 'also', 'involves', 'also', 'data', 'data', 'from', 'from', 'new', 'data', 'data', 'improving', 'over', 'time', 'time', 'like', 'a', 'tiny', 'human', 'child', 'growing', 'up', 'into', 'a', 'smarter', 'human', 'adult', 'And', 'like', 'humans', 'AI', 'is', 'not', 'perfect']
