<a href="https://colab.research.google.com/github/tommyEzreal/study_low_level/blob/main/pytorch/%EB%AA%A8%EB%91%90%EB%A5%BC%EC%9C%84%ED%95%9C%EB%94%A5%EB%9F%AC%EB%8B%9D2/RNN_seq2seq.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# difference between vanilla RNN and seq2seq 

# "Today's perfect weather makes me such sad"
# 문장을 끝까지 듣고 답변을 출력하는 encoder - decoder architecture


In [1]:
import random 
import torch
import torch.nn as nn

In [210]:
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

raw = ["I feel hungry.	나는 배가 고프다.",
       "Pytorch is very easy.	파이토치는 매우 쉽다.",
       "Pytorch is a framework for deep learning.   파이토치는 딥러닝을 위한 프레임워크이다."]

SOS_token = 0
EOS_token = 1

In [35]:
# vocab class
# vocab->index / index -> vocab

class Vocab:
  def __init__(self):
    self.vocab2index = {"<SOS>": SOS_token, "<EOS>": EOS_token}
    self.index2vocab = {SOS_token : "<SOS>", EOS_token:"<EOS>"}
    self.vocab_count = {}
    self.n_vocab = len(self.vocab2index)

  def add_vocab(self, sentence):
    for word in sentence.split(" "):
      if word not in self.vocab2index:
        self.vocab2index[word] = self.n_vocab
        self.vocab_count[word] = 1
        self.index2vocab[self.n_vocab] = word
        self.n_vocab +=1 
      else:
        self.vocab_count[word] +=1 

In [114]:
vocab = Vocab()
vocab.add_vocab("Pytorch is very easy.  파이토치는 너무 어렵습니다.")
print(vocab.vocab2index, vocab.index2vocab)
print(vocab.vocab_count, vocab.n_vocab)

{'<SOS>': 0, '<EOS>': 1, 'Pytorch': 2, 'is': 3, 'very': 4, 'easy.': 5, '': 6, '파이토치는': 7, '너무': 8, '어렵습니다.': 9} {0: '<SOS>', 1: '<EOS>', 2: 'Pytorch', 3: 'is', 4: 'very', 5: 'easy.', 6: '', 7: '파이토치는', 8: '너무', 9: '어렵습니다.'}
{'Pytorch': 1, 'is': 1, 'very': 1, 'easy.': 1, '': 1, '파이토치는': 1, '너무': 1, '어렵습니다.': 1} 10


In [98]:
# source&target maxlen 초과여부 측정 
def filter_pair(pair, source_max_length, target_max_length):
  return len(pair[0].split(" ")) < source_max_length and len(pair[1].split(" ")) < target_max_length

In [115]:
# read and preprocess the corpus data
def preprocess(corpus, source_max_length, target_max_length):
    print("reading corpus...")

    pairs = [] # [[pair],[pair],[pair]]
    for line in corpus:
        pairs.append([s for s in line.strip().lower().split("\t")])
    print("Read {} sentence pairs".format(len(pairs)))

    # filter (maxlength 넘지않는것만 pairs로 저장)
    pairs = [pair for pair in pairs if filter_pair(pair, source_max_length, target_max_length)]
    print("Trimmed to {} sentence pairs".format(len(pairs)))

    source_vocab = Vocab()
    target_vocab = Vocab()
    
    # source target 각각 vocab 추가 
    print("Counting words...")
    for pair in pairs:
        source_vocab.add_vocab(pair[0])
        target_vocab.add_vocab(pair[1])
    print("source vocab size =", source_vocab.n_vocab)
    print("target vocab size =", target_vocab.n_vocab)

    return pairs, source_vocab, target_vocab

In [116]:
# declare max length for sentence
SOURCE_MAX_LENGTH = 10
TARGET_MAX_LENGTH = 12

In [184]:
# encoder decoder 

In [216]:
class Encoder(torch.nn.Module):
  def __init__(self, input_size, hidden_size):
    super(Encoder, self).__init__()
    self.hidden_size = hidden_size
    self.embedding = torch.nn.Embedding(input_size, hidden_size)
    self.gru = torch.nn.GRU(hidden_size, hidden_size)
  
  def forward(self,x,hidden):
    x = self.embedding(x).view(1,1,-1)
    x, hidden = self.gru(x,hidden)
    return x, hidden

In [214]:
from prompt_toolkit import output
class Decoder(torch.nn.Module):
  def __init__(self, hidden_size, output_size):
    super(Decoder, self).__init__()
    self.hidden_size = hidden_size
    self.embedding = torch.nn.Embedding(output_size, hidden_size)
    self.gru = torch.nn.GRU(hidden_size, hidden_size)
    self.out = nn.Linear(hidden_size, output_size)
    self.softmax = nn.LogSoftmax(dim=1)

  def forward(self, x, hidden):
    x = self.embedding(x).view(1,1,-1)
    x, hidden = self.gru(x, hidden)
    x = self.softmax(self.out(x[0]))
    return x, hidden 

In [187]:
# sentence to index tensor 
def tensorize(vocab, sentence):
  indices = [vocab.vocab2index[word] for word in sentence.split(" ")]
  indices.append(vocab.vocab2index["<EOS>"])
  return torch.Tensor(indices).long().to(device).view(-1,1)

In [198]:
preprocessed_pairs, source_vocab, target_vocab = preprocess(raw, 10,12)

print("paris:",preprocessed_pairs)
print(source_vocab, target_vocab)

reading corpus...
Read 3 sentence pairs
Trimmed to 2 sentence pairs
Counting words...
source vocab size = 9
target vocab size = 8
paris: [['i feel hungry.', '나는 배가 고프다.'], ['pytorch is very easy.', '파이토치는 매우 쉽다.']]
<__main__.Vocab object at 0x7f6c7c2686d0> <__main__.Vocab object at 0x7f6c7c268f70>


In [199]:
tensorize(source_vocab, preprocessed_pairs[0][0]), preprocessed_pairs[0][0]

(tensor([[2],
         [3],
         [4],
         [1]], device='cuda:0'), 'i feel hungry.')

In [204]:
x = 'dd'
[x for _ in range(9)]

['dd', 'dd', 'dd', 'dd', 'dd', 'dd', 'dd', 'dd', 'dd']

In [205]:
def train(pairs, source_vocab, target_vocab,
          encoder, decoder, epochs, print_every = 1000,
          learning_rate = 0.01):
  loss_total = 0

  # define optim 
  encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=learning_rate)
  decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=learning_rate)

  # loss 
  criterion = nn.NLLLoss()

  # batch manually ([rand_pair,..])
  training_batch = [random.choice(pairs) for _ in range(epochs)]
  
  # batch에서 pair꺼내서 vocab 2 index
  training_source = [tensorize(source_vocab, pair[0]) for pair in training_batch]
  training_target = [tensorize(target_vocab, pair[1]) for pair in training_batch]

  # training loop
  for epoch in range(1, epochs+1):
    source_tensor = training_source[epoch-1]
    target_tensor = training_target[epoch-1]

    encoder_hidden = torch.zeros([1,1, encoder.hidden_size]).to(device)

    encoder_optimizer.zero_grad
    decoder_optimizer.zero_grad

    source_len = source_tensor.size(0)
    target_len = target_tensor.size(0)

    loss=0
    
    for enc_input in range(source_len):
      _, encoder_hidden = encoder(source_tensor[enc_input], encoder_hidden)
    
    decoder_input = torch.Tensor([[SOS_token]]).long().to(device)
    # encoder output -> decoder input 
    decoder_hidden = encoder_hidden

    for di in range(target_len):
      decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
      loss += criterion(decoder_output, target_tensor[di])
      decoder_input = target_tensor[di]

    loss.backward()

    encoder_optimizer.step()
    decoder_optimizer.step()
    
    loss_iter = loss.item() / target_len
    loss_total += loss_iter

    if epoch % print_every ==0:
      loss_avg = loss_total / print_every
      loss_total = 0
      print("[{} - {}%] loss = {:05.4f}".format(epoch, epoch / epochs * 100, loss_avg))

In [221]:
source_max_len = 10
target_max_len = 12

processed_pairs, source_vocab, target_vocab = preprocess(raw, source_max_len, target_max_len)
print("raw:",raw)
print(processed_pairs)

reading corpus...
Read 3 sentence pairs
Trimmed to 2 sentence pairs
Counting words...
source vocab size = 9
target vocab size = 8
raw: ['I feel hungry.\t나는 배가 고프다.', 'Pytorch is very easy.\t파이토치는 매우 쉽다.', 'Pytorch is a framework for deep learning.   파이토치는 딥러닝을 위한 프레임워크이다.']
[['i feel hungry.', '나는 배가 고프다.'], ['pytorch is very easy.', '파이토치는 매우 쉽다.']]


In [222]:
enc_hidden_size = 16
dec_hidden_size = enc_hidden_size

# Encoder(input_size, hidden_size)
enc = Encoder(source_vocab.n_vocab, enc_hidden_size).to(device)

# Decoder(hidden_size, output_size)
dec = Decoder(dec_hidden_size, target_vocab.n_vocab).to(device)

In [223]:
# train
# def train(pairs, source_vocab, target_vocab,
#          encoder, decoder, epochs, print_every = 1000,
#          learning_rate = 0.01)

train(pairs = processed_pairs,
      source_vocab = source_vocab,
      target_vocab = target_vocab,
      encoder = enc,
      decoder = dec,
      epochs = 100,
      print_every = 10,
      learning_rate = 0.01)

[10 - 10.0%] loss = 1.8248
[20 - 20.0%] loss = 1.0362
[30 - 30.0%] loss = 0.4296
[40 - 40.0%] loss = 0.1664
[50 - 50.0%] loss = 0.0480
[60 - 60.0%] loss = 0.0118
[70 - 70.0%] loss = 0.0089
[80 - 80.0%] loss = 0.0040
[90 - 90.0%] loss = 0.0025
[100 - 100.0%] loss = 0.0016


In [230]:
def evaluate(pairs, source_vocab, target_vocab,
             encoder, decoder, target_max_len):
  for pair in pairs:
    print(">", pair[0])
    print("TRANS", pair[1])

    source_tensor = tensorize(source_vocab, pair[0])
    source_len = source_tensor.size()[0]
    encoder_hidden = torch.zeros([1,1, encoder.hidden_size]).to(device)

    for i in range(source_len):
      _, encoder_hidden = encoder(source_tensor[i], encoder_hidden)

      decoder_input = torch.Tensor([[SOS_token]]).long().to(device)
      decoder_hidden = encoder_hidden
      
      decoded_words = []
      for di in range(target_max_len):
        decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
        _, top_index = decoder_output.data.topk(1)
        
        if top_index.item() == EOS_token:
          decoded_words.append("<EOS>") # EOS token 보이면 end 
          break
        else:
          decoded_words.append(target_vocab.index2vocab[top_index.item()])
        
        decoder_input = top_index.squeeze().detach()

      #prediction
      predict_words = decoded_words
      predict_sentence = "".join(predict_words)
      print("<", predict_sentence)
      print()


In [231]:
evaluate(pairs = processed_pairs,
         source_vocab = source_vocab,
         target_vocab = target_vocab,
         encoder = enc,
         decoder = dec, 
         target_max_len = target_max_len)

> i feel hungry.
TRANS 나는 배가 고프다.
< 나는배가고프다.<EOS>

< 나는배가고프다.<EOS>

< 나는배가고프다.<EOS>

< 나는배가고프다.<EOS>

> pytorch is very easy.
TRANS 파이토치는 매우 쉽다.
< 나는배가고프다.<EOS>

< 나는배가고프다.<EOS>

< 파이토치는매우쉽다.<EOS>

< 파이토치는매우쉽다.<EOS>

< 파이토치는매우쉽다.<EOS>

