In [21]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [22]:
class customDataset(Dataset):
  def __init__(self, text, word2idx, seq_length):
    self.text=text
    self.word2idx=word2idx
    self.seq_length=seq_length

  def __len__(self):
    return len(self.text) - self.seq_length

  def __getitem__(self, index):
    sequence=[self.word2idx[word] for word in self.text[index:index+self.seq_length]]
    target=self.word2idx[self.text[index+self.seq_length]]
    return torch.tensor(sequence),torch.tensor(target)

In [68]:
text = "Climate change is a pressing global issue caused by various factors, including carbon emissions, volcanic eruptions, and solar radiation. Human activities like fossil fuel combustion, deforestation, and vehicular pollution play a significant role. These changes have led to more frequent and severe extreme weather various volcanic led Climate radiation."

In [69]:
word2idx={word: i for i,word in enumerate(set(text.split()))}

In [70]:
idx2word={i: word for word,i in word2idx.items()}

In [71]:
dataset= customDataset(text.split(), word2idx, seq_length=10)

In [72]:
dataloader=DataLoader(dataset, batch_size=32, shuffle=True)

In [73]:
class VanillaRNN(nn.Module):
  def __init__(self, vocab_size, embed_size, hidden_size):
    super(VanillaRNN, self).__init__()
    self.embed=nn.Embedding(vocab_size, embed_size)
    self.rnn=nn.RNN(embed_size, hidden_size, batch_first=True)
    self.fc=nn.Linear(hidden_size, vocab_size)

  def forward(self, x, h0):
    embed=self.embed(x)
    out,h=self.rnn(embed, h0)
    output=self.fc(out[:,-1,:])
    return output,h

In [74]:
model=VanillaRNN(len(word2idx), embed_size=128, hidden_size=256)

In [75]:
import torch.optim as optim

In [76]:
criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(),lr=0.001)

In [77]:
for epoch in range(10):
  for input, label in dataloader:
    optimizer.zero_grad()
    h0=torch.zeros(1, input.size(0),256)
    outputs,_ = model(input,h0)
    loss=criterion(outputs,label)
    loss.backward()
    optimizer.step()
  print(f'epoch {epoch} : loss : {loss.item()}')

input_seq=torch.tensor([word2idx[word] for word in text.split()[-10:]]).unsqueeze(0)
h0=torch.zeros(1,1,256)
output,_ = model(input_seq,h0)
predicted_word = idx2word[output.argmax().item()]
print(f"predicted next word : {predicted_word}")

epoch 0 : loss : 3.5717241764068604
epoch 1 : loss : 3.810302495956421
epoch 2 : loss : 3.6805479526519775
epoch 3 : loss : 3.698279619216919
epoch 4 : loss : 3.635925531387329
epoch 5 : loss : 3.461418867111206
epoch 6 : loss : 3.5363142490386963
epoch 7 : loss : 3.458329439163208
epoch 8 : loss : 3.701122999191284
epoch 9 : loss : 3.6665284633636475
predicted next word : radiation.
