In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [2]:
sentence = "Repeat is the best medicine for memory".split()
print(sentence)

['Repeat', 'is', 'the', 'best', 'medicine', 'for', 'memory']


In [3]:
vocab=sorted(list(set(sentence)))
print(vocab)

['Repeat', 'best', 'for', 'is', 'medicine', 'memory', 'the']


In [4]:
vocab_size=len(vocab)
vocab_size

7

In [10]:
word_to_index={tkn:i for i, tkn in enumerate(vocab,1)}
print(word_to_index)
word_to_index['<unk>']=0
print(word_to_index)

{'Repeat': 1, 'best': 2, 'for': 3, 'is': 4, 'medicine': 5, 'memory': 6, 'the': 7}
{'Repeat': 1, 'best': 2, 'for': 3, 'is': 4, 'medicine': 5, 'memory': 6, 'the': 7, '<unk>': 0}


In [12]:
print(word_to_index['for'])
encode=[word_to_index[t] for t in sentence]
print(encode)

3
[1, 4, 7, 2, 5, 3, 6]


In [16]:
def build_data(sentence, word_to_index):
    encode=[word_to_index[t] for t in sentence]
    input_seq=encode[:-1]
    label_seq=encode[1:]
    input_seq=torch.LongTensor(input_seq).unsqueeze(0)
    label_seq=torch.LongTensor([label_seq])
    return input_seq, label_seq

In [17]:
X,Y=build_data(sentence, word_to_index)
print(X)
print(Y)

tensor([[1, 4, 7, 2, 5, 3]])
tensor([[4, 7, 2, 5, 3, 6]])


In [18]:
class Net(nn.Module):
    def __init__(self, vocab_size, input_size, hidden_size, batch_first=True):
        super(Net, self).__init__()

        self.embedding_layer=nn.Embedding(num_embeddings=vocab_size,
                                          embedding_dim=input_size)
        self.rnn=nn.RNN(input_size, hidden_size, batch_first=batch_first)
        self.fc=nn.Linear(hidden_size, vocab_size)

    def forward(self, x):
        output=self.embedding_layer(x)
        #임베딩층 : 크기변화(배치크기, 시퀀스 길이)=>(배치크기, 시퀀스길이, 임베딩 차원)

        output, hidden=self.rnn(output)
        #RNN층 : 크기변화(배치크기, 시퀀스 길이, 임베딩 차원)=>
        # output:(배치크기, 시퀀스 길이, 은닉층 크기)
        # hidden:(1, 배치크기, 은닉층크기)

        output=self.fc(output)
        #크기변화: (배치크기, 시퀀스 길이, 단어장 크기)=>(배치크기*시퀀스길이, 단어장 크기)
        
        return output.view(-1, output.size(2))