In [1]:
import torch
from collections import Counter
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np

In [2]:
BATCH_SIZE = 3000
MAX_EPOCHS = 2000
SEQUENCE_LENGTH = 5

In [3]:
text = """
Nepal officially the Federal Democratic Republic of Nepal, is a landlocked country in South Asia. It is mainly situated in the Himalayas, but also includes parts of the Indo-Gangetic Plain. It borders the Tibet Autonomous Region of China to the north, and India to the south, east, and west, while it is narrowly separated from Bangladesh by the Siliguri Corridor, and from Bhutan by the Indian state of Sikkim. Nepal has a diverse geography, including fertile plains, subalpine forested hills, and eight of the world's ten tallest mountains, including Mount Everest, the highest point on Earth. Kathmandu is the nation's capital and the largest city. Nepal is a multi-ethnic, multi-lingual, multi-religious and multi-cultural state, with Nepali as the official language.

The name "Nepal" is first recorded in texts from the Vedic period of the Indian subcontinent, the era in ancient Nepal when Hinduism was founded, the predominant religion of the country. In the middle of the first millennium BC, Gautama Buddha, the founder of Buddhism, was born in Lumbini in southern Nepal. Parts of northern Nepal were intertwined with the culture of Tibet. The centrally located Kathmandu Valley is intertwined with the culture of Indo-Aryans, and was the seat of the prosperous Newar confederacy known as Nepal Mandala. The Himalayan branch of the ancient Silk Road was dominated by the valley's traders. The cosmopolitan region developed distinct traditional art and architecture. By the 18th century, the Gorkha Kingdom achieved the unification of Nepal. The Shah dynasty established the Kingdom of Nepal and later formed an alliance with the British Empire, under its Rana dynasty of premiers. The country was never colonised but served as a buffer state between Imperial China and British India. Parliamentary democracy was introduced in 1951 but was twice suspended by Nepalese monarchs, in 1960 and 2005. The Nepalese Civil War in the 1990s and early 2000s resulted in the establishment of a secular republic in 2008, ending the world's last Hindu monarchy.

The Constitution of Nepal, adopted in 2015, affirms the country as a secular federal parliamentary republic divided into seven provinces. Nepal was admitted to the United Nations in 1955, and friendship treaties were signed with India in 1950 and China in 1960. Nepal hosts the permanent secretariat of the South Asian Association for Regional Cooperation (SAARC), of which it is a founding member. Nepal is also a member of the Non-Aligned Movement and the Bay of Bengal Initiative. The Nepalese Armed Forces are the fifth-largest in South Asia; and are notable for their Gurkha history, particularly during the world wars, and has been a significant contributor to United Nations peacekeeping operations.
"""

In [4]:
class Dataset(torch.utils.data.Dataset):
  def __init__(self, data):

    self.data = data

    self.words = self.load_words()
    self.uniq_words = self.get_uniq_words()

    self.index_to_word = {index : word for index, word in enumerate(self.uniq_words)}
    self.word_to_index = {word : index for index, word in enumerate(self.uniq_words)}

    self.words_indexes = [self.word_to_index[w] for w in self.words]

  def load_words(self):
    text = self.data.lower()
    return text.lower().split()

  def get_uniq_words(self):
    word_counts = Counter(self.words)
    return sorted(word_counts, key=word_counts.get, reverse=True)

  def __len__(self):
    return len(self.words_indexes) - SEQUENCE_LENGTH

  def __getitem__(self, index):
    return (
        torch.tensor(self.words_indexes[index : index + SEQUENCE_LENGTH]),
        torch.tensor(self.words_indexes[index + 1 : index + SEQUENCE_LENGTH + 1])
    )

In [5]:
class Model(nn.Module):
  def __init__(self, dataset):
    super(Model, self).__init__()
    self.lstm_size = 128
    self.embedding_dim = 128
    self.num_layers = 3

    n_vocab = len(dataset.uniq_words)
    self.embedding = nn.Embedding(
        num_embeddings = n_vocab,
        embedding_dim = self.embedding_dim
    )

    self.lstm = nn.LSTM(
        input_size = self.lstm_size,
        hidden_size = self.lstm_size,
        num_layers = self.num_layers,
        dropout = 0.2
    )

    self.fc = nn.Linear(self.lstm_size, n_vocab)

  def forward(self, x, prev_state):
    embed = self.embedding(x)
    output, state = self.lstm(embed, prev_state)
    logits = self.fc(output)

    return logits, state

  def init_state(self, sequence_length):
    return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
            torch.zeros(self.num_layers, sequence_length, self.lstm_size))

In [6]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [7]:
def train(dataset, model):

  model.train()

  dataloader = DataLoader(dataset, batch_size = BATCH_SIZE)

  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(), lr = 0.001)

  for epoch in range(MAX_EPOCHS):
    state_h, state_c = model.init_state(SEQUENCE_LENGTH)

    for batch, (x,y) in enumerate(dataloader):

      optimizer.zero_grad()
      x = x.to(DEVICE)
      y = y.to(DEVICE)

      y_pred, (state_h, state_c) = model(x, (state_h.to(DEVICE), state_c.to(DEVICE)))
      loss = criterion(y_pred.transpose(1,2),y)

      state_h = state_h.detach()
      state_c = state_c.detach()

      loss.backward()
      optimizer.step()

      print(f"EPOCH : {epoch}, 'BATCH : {batch}, Loss : {loss.item()}")

dataset = Dataset(text)
model = Model(dataset).to(DEVICE)

train(dataset, model)

EPOCH : 0, 'BATCH : 0, Loss : 5.501698970794678
EPOCH : 1, 'BATCH : 0, Loss : 5.491706371307373
EPOCH : 2, 'BATCH : 0, Loss : 5.4815778732299805
EPOCH : 3, 'BATCH : 0, Loss : 5.470121383666992
EPOCH : 4, 'BATCH : 0, Loss : 5.456357955932617
EPOCH : 5, 'BATCH : 0, Loss : 5.438433647155762
EPOCH : 6, 'BATCH : 0, Loss : 5.412837982177734
EPOCH : 7, 'BATCH : 0, Loss : 5.373276233673096
EPOCH : 8, 'BATCH : 0, Loss : 5.310094356536865
EPOCH : 9, 'BATCH : 0, Loss : 5.213853359222412
EPOCH : 10, 'BATCH : 0, Loss : 5.09924840927124
EPOCH : 11, 'BATCH : 0, Loss : 5.005402565002441
EPOCH : 12, 'BATCH : 0, Loss : 4.945343017578125
EPOCH : 13, 'BATCH : 0, Loss : 4.906871795654297
EPOCH : 14, 'BATCH : 0, Loss : 4.8858466148376465
EPOCH : 15, 'BATCH : 0, Loss : 4.876251697540283
EPOCH : 16, 'BATCH : 0, Loss : 4.871213912963867
EPOCH : 17, 'BATCH : 0, Loss : 4.86708927154541
EPOCH : 18, 'BATCH : 0, Loss : 4.8601531982421875
EPOCH : 19, 'BATCH : 0, Loss : 4.851619720458984
EPOCH : 20, 'BATCH : 0, Loss 

In [9]:
def predict(dataset, model, text, next_words=100):
  words = text.lower().split(' ')
  model.eval()
  model.to(DEVICE)

  state_h, state_c = model.init_state(len(words))

  for i in range(0, next_words):

    x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).to(DEVICE)
    y_pred, (state_h, state_c) = model(x, (state_h.to(DEVICE), state_c.to(DEVICE)))
    y_pred.to(DEVICE)
    last_word_logits = y_pred[0][-1]
    p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
    word_index = np.random.choice(len(last_word_logits), p=p)
    words.append(dataset.index_to_word[word_index])

  return " ".join(words)

predict(dataset, model, "Nepal", 100)

"nepal officially the federal democratic republic of nepal, is a landlocked country in south asia. it is mainly situated in the himalayas, but also includes parts of the indo-gangetic plain. it borders the tibet autonomous region of china to the north, and india to the south, east, and west, while it is narrowly separated from bangladesh by the siliguri corridor, and from bhutan by the indian state of sikkim. nepal has a diverse geography, including fertile plains, subalpine forested hills, and eight of the world's ten tallest mountains, including mount everest, the highest point on earth. kathmandu is the nation's capital"