In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [35]:
# RNN model trained to predict the next word in a sentence
class RNNModel(nn.Module):
    def __init__(self, target_sentences):
        super().__init__()
        self.target_sentences = [f"<START> {sentence} <END>" for sentence in target_sentences]
        self.vocab = list(set(" ".join(self.target_sentences).split(" ")))

        self.hidden_size = 128
        self.embedding_size = 64

        self.last_layer = nn.Linear(self.hidden_size, len(self.vocab))
        self.embedding = nn.Embedding(len(self.vocab), self.embedding_size)
        self.internal_layer = nn.Linear(self.embedding_size + self.hidden_size, self.hidden_size)

        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
    
    def tokenize(self, word):
        return torch.LongTensor([self.vocab.index(word)])

    def logits(self, sentence):

        # Initialize hidden state for sequence
        hidden = torch.zeros(1, self.hidden_size)

        # Process each token in the sequence
        split = sentence.split(" ")
        # Check if the word is in the vocab
        if not all([word in self.vocab for word in split]):
            raise ValueError("Word not in vocab")
        
        # Iterate through each word in the sequence
        for elem in split:
            token = self.tokenize(elem)
            embedding = self.embedding(token)
            input_data = torch.cat((embedding, hidden), dim=1)
            hidden = self.internal_layer(input_data)
            
        # The last hidden state should go through the last layer
        output = self.last_layer(hidden)

        return output
    

    def forward(self, sentence):
        return F.softmax(self.logits(sentence), dim=1)
    
    def predict(self, sentence):
        # sentence = f"<START> {sentence}"
        result = self.vocab[torch.argmax(self.forward(sentence))]
        # if result == "<END>":
        #     return ""
        return result
    
    def complete_sentence(self, sentence):
        sentence = f"<START> {sentence}"
        new_token = self.predict(sentence)
        while new_token != "<END>":
            sentence += " " + new_token
            new_token = self.predict(sentence)
        return sentence

    def run_epoch(self, do_test=False):
        for sentence in self.target_sentences:
            split_sentence = sentence.split(" ")
        
            for i in range(len(split_sentence) - 1):
                self.optimizer.zero_grad()
                input_data = split_sentence[:i+1]
                target = split_sentence[i+1]
                output = self.forward(" ".join(input_data))
                loss = self.loss_fn(output, self.tokenize(target))
                loss.backward()
                self.optimizer.step()
        # if do_test:
        #     curr_token = "<START>"
        #     for i in range(100):
        #         curr_token = self.predict(curr_token)
        #         print(curr_token, end=" ")
        #         if curr_token == "<END>":
        #             break


In [43]:
target_sentences = ["dette er et kodekurs i maskinlæring", "kodekurset er gøy"]
rnn = RNNModel(target_sentences)

Før vi trener modellen sjekker vi at den ikke klarer å skrive setninger som gir mening

In [44]:
rnn.predict("dette er et kodekurs i")

'kodekurset'

In [45]:
from tqdm import tqdm

num_epochs = 100

for epoch in tqdm(range(num_epochs)):
    rnn.run_epoch(do_test=epoch % 10 == 0)
    

100%|██████████| 100/100 [00:04<00:00, 24.57it/s]


In [47]:
rnn.complete_sentence("dette")

'<START> dette er et kodekurs i maskinlæring'