In [192]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

In [193]:
# RNN model trained to predict the next word in a sentence
class RNNModel(nn.Module):
    def __init__(self, target_sentences):
        super().__init__()
        self.target_sentences = [f"<START> {sentence.strip()} <END>" for sentence in target_sentences]
        self.vocab = list(set(" ".join(self.target_sentences).split(" ")))

        self.hidden_size = 128
        self.embedding_size = 64

        self.last_layer = nn.Linear(self.hidden_size, len(self.vocab))
        self.embedding = nn.Embedding(len(self.vocab), self.embedding_size)
        self.internal_layer = nn.Linear(self.embedding_size + self.hidden_size, self.hidden_size)

        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
    
    def tokenize(self, word):
        return torch.LongTensor([self.vocab.index(word)])

    def logits(self, sentence):

        # Initialize hidden state for sequence
        hidden = torch.zeros(1, self.hidden_size)

        # Process each token in the sequence
        split = sentence.split(" ")
        # Check if the word is in the vocab
        if not all([word in self.vocab for word in split]):
            raise ValueError("Word not in vocab")
        
        # Iterate through each word in the sequence
        for elem in split:
            token = self.tokenize(elem)
            embedding = self.embedding(token)
            input_data = torch.cat((embedding, hidden), dim=1)
            hidden = self.internal_layer(input_data)
            
        # The last hidden state should go through the last layer
        output = self.last_layer(hidden)

        return output
    
    def get_prob_dist_over_next_word(self, sentence: str):
        dist = self.forward(sentence)[0]
        return {self.vocab[i]: f"{dist[i].item():.2f}" for i in range(len(dist))}

    def forward(self, sentence):
        return F.softmax(self.logits(sentence), dim=1)
    
    def predict(self, sentence):
        # sentence = f"<START> {sentence}"
        result = self.vocab[torch.argmax(self.forward(sentence))]
        # if result == "<END>":
        #     return ""
        return result
    
    def complete_sentence(self, sentence, max_len=100):
        result = sentence.strip()
        new_token = ""
        while not new_token == "<END>":
            new_token = self.predict(result)
            if len(result.split(" ")) >= max_len:
                break
            result += " " + new_token
        return result

    def train(self, num_epochs=10):
        for epoch in tqdm(range(num_epochs), unit="epoch"):
            for sentence in self.target_sentences:
                split_sentence = sentence.split(" ")
            
                for i in range(len(split_sentence) - 1):
                    self.optimizer.zero_grad()
                    input_data = split_sentence[:i+1]
                    target = split_sentence[i+1]
                    output = self.forward(" ".join(input_data))
                    loss = self.loss_fn(output, self.tokenize(target))
                    loss.backward()
                    self.optimizer.step()
        
                        

Før vi trener modellen sjekker vi at den **ikke** klarer å skrive setninger som gir mening

In [198]:
target_sentences = ["dette er et kodekurs i maskinlæring", "kodekurset er gøy fordi man lærer mye"]
rnn = RNNModel(target_sentences)
rnn.complete_sentence("<START> ")

'<START> fordi maskinlæring fordi mye et er fordi maskinlæring fordi mye et er fordi maskinlæring fordi mye et er fordi maskinlæring fordi mye et er fordi maskinlæring fordi mye et er fordi maskinlæring fordi mye et er fordi maskinlæring fordi mye et er fordi maskinlæring fordi mye et er fordi maskinlæring fordi mye et er fordi maskinlæring fordi mye et er fordi maskinlæring fordi mye et er fordi maskinlæring fordi mye et er fordi maskinlæring fordi mye et er fordi maskinlæring fordi mye et er fordi maskinlæring fordi mye et er fordi maskinlæring fordi mye et er fordi maskinlæring fordi'

In [200]:
rnn.train(num_epochs=10)
    

100%|██████████| 10/10 [00:00<00:00, 22.13epoch/s]


In [203]:
rnn.complete_sentence("<START> dette")

'<START> dette er et kodekurs i maskinlæring <END>'

In [207]:
rnn.get_prob_dist_over_next_word("<START> kodekurset er")

{'dette': '0.00',
 'lærer': '0.00',
 'kodekurs': '0.00',
 '<END>': '0.00',
 'gøy': '1.00',
 'er': '0.00',
 'kodekurset': '0.00',
 'mye': '0.00',
 'fordi': '0.00',
 'maskinlæring': '0.00',
 '<START>': '0.00',
 'et': '0.00',
 'i': '0.00',
 'man': '0.00'}