In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
torch.__version__, torchvision.__version__

('2.5.1+cu124', '0.20.1+cu124')

##### nlp

In [None]:
''' text processing '''
with open('guttenberg.txt', 'r') as fd:
  text = fd.read()
words = text.split(' ')
text = text.lower()
text = text.replace('\n', ' ')
text = text.replace('-', ' ')
for char in ",.:;?!$_@&#*%":
  text = text.replace(f'{char}', f' {char} ')
text = text.replace('"', ' " ')
text = text.split()
from collections import Counter
wordscount = Counter(text)
words = sorted(wordscount, key=wordscount.get, reverse=True)
word_to_int = {value:key for key, value in enumerate(words)}
int_to_word = {key:value for key, value in enumerate(words)}
word_index = [word_to_int[word] for word in text]
text[:5], words[:10], len(words)

(['the', 'project', 'gutenberg', 'ebook', 'of'],
 [',', '.', 'the', 'and', 'to', 'of', 'he', '”', 'a', 'in'],
 15122)

In [None]:
''' work with the batch of training data '''
seqlength = 100
substr = []
for n in range(0, len(word_index)-seqlength-1):
  t = word_index[n: n+seqlength]
  t1 = word_index[n+1: n+seqlength+1]
  substr.append((torch.tensor(t), (torch.tensor(t1))))

torch.manual_seed(40)
batchsize=32
loader = DataLoader(substr, batch_size=batchsize, shuffle=True)
x, y = next(iter(loader))
x.shape, y.shape, x[:1], y[:1]

In [None]:
''' define and train the model '''
class LSTM(nn.Module):
  def __init__(self, input_size=128, n_embed=128, n_layers=3, drop_prob=0.2):
    super().__init__()
    self.input_size = input_size
    self.drop_prob = drop_prob
    self.n_layers = n_layers
    self.n_embed = n_embed
    vocab_size = len(word_to_int)
    self.embedding = nn.Embedding(vocab_size, n_embed)
    self.lstm = nn.LSTM(input_size=self.input_size, hidden_size=self.n_embed,
          num_layers=self.n_layers, dropout=self.drop_prob, batch_first=True)
    self.fc = nn.Linear(input_size, vocab_size)
  def forward(self, x, hc):
    embed = self.embedding(x)
    x, hc = self.lstm(embed, hc)
    x = self.fc(x)
    return x, hc
  def init_hidden(self, n_seqs):
    weight = next(self.parameters()).data
    return (weight.new(self.n_layers, n_seqs, self.n_embed).zero_(),
            weight.new(self.n_layers, n_seqs, self.n_embed).zero_())

model = LSTM().to(device)
model

LSTM(
  (embedding): Embedding(15122, 128)
  (lstm): LSTM(128, 128, num_layers=3, batch_first=True, dropout=0.2)
  (fc): Linear(in_features=128, out_features=15122, bias=True)
)

In [None]:
lr = 0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_func = nn.CrossEntropyLoss()
model.train()
for epoch in range(1):
  totalloss = 0
  sh, sc = model.init_hidden(batchsize)
  for i, (x,y) in enumerate(loader):
    if x.shape[0] == batchsize:
      inputs, targets = x.to(device), y.to(device)
      optimizer.zero_grad()
      output, (sh, sc) = model(inputs, (sh,sc))
      loss = loss_func(output.transpose(1,2), targets)
      sh, sc = sh.detach(), sc.detach()
      loss.backward()
      nn.utils.clip_grad_norm(model.parameters(), 5)
      optimizer.step()
      totalloss += loss.item()
    if (i+1)%1000 ==0:
      print(f'at epoch {epoch} iteration {i+1} average loss = {totalloss/(i+1)}')

In [None]:
''' generate text and top-K sampling '''
def sample(model, prompt, length=100):
  model.eval()
  text = prompt.lower().split(' ')
  hc = model.init_hidden(1)
  length = length - len(text)
  for i in range(0, length):
    if len(text) <=seqlength:
      x = torch.tensor([[word_to_int[word] for word in text]])
    else:
      x = torch.tensor([[word_to_int[word] for word in text[-seqlength:]]])
    inputs = x.to(device)
    output, hc = model(inputs, hc)
    logits = output[0][-1]
    p = nn.functional.softmax(logits, dim=0).detach().cpu().numpy()
    idx = np.random.choice(len(logits), p=p)
  text = ' '.join(text)
  return text
sample(model, prompt='project guttenberg')

In [None]:
''' top-K sampling '''
def generate(model, prompt, top_k=None, length=100, temperature=1):
  model.eval()
  text = prompt.lower().split(' ')
  hc = model.init_hidden(1)
  length = length - len(text)
  for i in range(0, length):
    if len(text) <=seqlength:
      x = torch.tensor([[word_to_int[word] for word in text]])
    else:
      x = torch.tensor([[word_to_int[word] for word in text[-seqlength:]]])
    inputs = x.to(device)
    output, hc = model(inputs, hc)
    logits = output[0][-1]
    logits = logits/temperature
    p = nn.functional.softmax(logits, dim=0).detach().cpu()
    if top_k is None:
      idx = np.random.choice(len(logits), p=p.numpy())
    else:
      ps, tops = p.topk(top_k)
      ps = ps/ps.sum()
      idx = np.random.choice(tops, p=ps.numpy())
    text.append(int_to_word[idx])
    text = ' '.join(text)
  return text