In [43]:
import torch

import torch.nn as nn

import torch.nn.functional as F 

import numpy as np

1. Dataset Preparation

In [44]:
# Tiny corpus

corpus = [
    "deep learning is powerful",
    "learning is fun and powerful",
    "deep learning is fun"
]


In [45]:
# Tokenize

tokens = " ".join(corpus).split()

vocab = sorted(set(tokens))

word_to_idx = {word: i for i, word in enumerate(vocab)}

idx_to_word = {i: word for word, i in word_to_idx.items()}

In [46]:
# Convert to sequences

seq_len = 3

data = []

In [47]:
for sentence in corpus:

    words = sentence.split()

    for i in range(len(words) - seq_len):

        input_seq = words[i: i + seq_len]

        target_word = words[i+seq_len]

        input_idx = [word_to_idx[word] for word in input_seq]

        target_idx = word_to_idx[target_word]

        data.append((input_idx, target_idx))


input_seqs = [torch.tensor(x) for x , _ in data]

target_words = [torch.tensor(y) for _, y in data]

2. Define LSTM Model

In [48]:
class WordLSTM(nn.Module):

    def __init__(self, vocab_size,embedding_dim, hidden_dim):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)

        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):

        x = self.embedding(x) # [batch, seq, embed]

        output, hidden = self.lstm(x, hidden) # [batch, seq, hidden]

        output = self.fc(output[:,-1,:]) # take last time step

        return output, hidden

3. Training Loop

In [49]:
emebedding_dim = 10

hidden_dim = 32

vocab_size = len(vocab)

learning_rate = 0.01

num_epochs = 300

In [50]:
model = WordLSTM(vocab_size, emebedding_dim, hidden_dim)

loss_fn = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


for epoch in range(1, num_epochs + 1):

    total_loss = 0

    for input_tensor,target_tensor in zip(input_seqs,target_words):

        input_tensor = input_tensor.unsqueeze(0)

        target_tensor = target_tensor.unsqueeze(0)


        optimizer.zero_grad()

        output, _ = model(input_tensor)

        loss = loss_fn(output, target_tensor)

        loss.backward()

        optimizer.step()

        total_loss += loss.item()

        if epoch % 50 == 0:
                
                print(f"Epoch {epoch}/{num_epochs}, Loss: {total_loss:.4f}")

Epoch 50/300, Loss: 0.7087
Epoch 50/300, Loss: 0.7117
Epoch 50/300, Loss: 0.7176
Epoch 50/300, Loss: 1.4295
Epoch 100/300, Loss: 0.7064
Epoch 100/300, Loss: 0.7075
Epoch 100/300, Loss: 0.7098
Epoch 100/300, Loss: 1.4175
Epoch 150/300, Loss: 0.7055
Epoch 150/300, Loss: 0.7061
Epoch 150/300, Loss: 0.7074
Epoch 150/300, Loss: 1.4136
Epoch 200/300, Loss: 0.7049
Epoch 200/300, Loss: 0.7053
Epoch 200/300, Loss: 0.7062
Epoch 200/300, Loss: 1.4113
Epoch 250/300, Loss: 0.7043
Epoch 250/300, Loss: 0.7046
Epoch 250/300, Loss: 0.7052
Epoch 250/300, Loss: 1.4095
Epoch 300/300, Loss: 0.7033
Epoch 300/300, Loss: 0.7035
Epoch 300/300, Loss: 0.7040
Epoch 300/300, Loss: 1.4077


4. Text Generation (Sampling with Temperature)

In [51]:
def sample_from_probs(probs, temperature=1.0):

    probs = torch.softmax(probs / temperature, dim=-1).detach().cpu().numpy()

    return np.random.choice(len(probs), p=probs)

def generate_text(model, seed_words,length=6, temperature=1.0):

    model.eval()

    words = seed_words[:]

    hidden = None


    for _ in range(length):

        input_seq = [word_to_idx[w] for w in words[-3:]]

        input_tensor = torch.tensor(input_seq).unsqueeze(0)

        with torch.no_grad():

            output, hidden = model(input_tensor, hidden)

            next_idx = sample_from_probs(output[0], temperature)

            next_word = idx_to_word[next_idx]

            words.append(next_word)

        return " ".join(words)

In [52]:
seed = ['deep', 'learning', 'is']

print("\nGenerated Text (temp=1.0):")

print(generate_text(model, seed_words=seed, length=6, temperature=1.0))


print("\nGenerated Text (temp=0.5):")

print(generate_text(model, seed_words=seed, length=6, temperature=0.5))

print("\nGenerated Text (temp=1.5):")

print(generate_text(model, seed_words=seed, length=6, temperature=1.5))


Generated Text (temp=1.0):
deep learning is fun

Generated Text (temp=0.5):
deep learning is fun

Generated Text (temp=1.5):
deep learning is powerful
