In [1]:
from __future__ import unicode_literals, print_function, division
from io import open
from os import system
from dataloader import *
from VAE import *
from scores import *

import unicodedata
import string
import re
import random
import time
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.ticker as ticker
import matplotlib.pyplot as plt
plt.switch_backend('agg')

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

# Prepare data

In [3]:
train_vocab = load_data('./data/train.txt')
test_vocab = load_data('./data/test.txt')

## Get different tense pairs for (unconditional) VAE training

In [4]:
def get_tense_paris(train_vocab, source_index, target_index):
    pairs = []

    for vocabs in train_vocab:
        pairs.append((vocabs[source_index],vocabs[target_index]))
        
    return pairs

### Simple Present -> Third Person

In [5]:
train_st_tp  = get_tense_paris(train_vocab, 0, 1)

### Simple Present -> Present Progressive

In [6]:
train_st_pp  = get_tense_paris(train_vocab, 0, 2)

### Simple Present -> Past

In [7]:
train_st_past  = get_tense_paris(train_vocab, 0, 3)

# Train VAE

In [8]:
vocab_size = 28 #The number of vocabulary
SOS_token = 0
EOS_token = vocab_size-1

## Setting hyperparameters

In [9]:
#----------Hyper Parameters----------#
hidden_size = 256
latent_size = 32
teacher_forcing_ratio = 0.37
empty_input_ratio = 0.1
KLD_weight = 0.0
lr = 0.05

In [10]:
def seqFromPair(pair):
    ord_a = ord('a')
    input_seq = [ord(c) - ord_a + 1 for c in pair[0]]
    target_seq = [ord(c) - ord_a + 1 for c in pair[1]]
    
    return input_seq, target_seq

In [11]:
def train(vae_model, input_seq, target_seq, use_teacher_forcing, optimizer, criterion):
    optimizer.zero_grad()
    
    # Check device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize hidden feature
    hidden = torch.zeros(1, 1, hidden_size, device=device)
        
    # Run model
    if use_teacher_forcing:
        result, mu, logvar = vae_model(input_seq, hidden, use_teacher_forcing, target_seq)
    else:
        result, mu, logvar = vae_model(input_seq, hidden, use_teacher_forcing, None)
            
            
    # Ground truth should have EOS in the end
    target_seq.append(EOS_token)
        
    # Calculate loss
    # First, we should strim the sequences by the length of smaller one
    min_len = min(len(target_seq),len(result))
        
    # hat_y need not to do one-hot encoding
    hat_y = result[:min_len]
    y = torch.tensor(target_seq[:min_len], device=device)
        
    loss = criterion(hat_y, y, mu, logvar)
        
    loss.backward()
    optimizer.step()
    
    return loss.item(), hat_y

In [33]:
def trainIter(vae_model, data_pairs, n_iters, print_every=1000, save_every=1000, learning_rate=0.01, teacher_forcing_ratio = 1.0,\
         optimizer = None, criterion = VAE_Loss, date = ''):
    loss_list = []
  
    
    # Check optimizer; default: SGD
    if optimizer is None:
        optimizer = optim.SGD(vae_model.parameters(), lr=learning_rate)
    
    vae_model.train()
    for i in range(n_iters): 
        # Randomly generate training pairs from data
        chosen_pair = random.choice(data_pairs)
        training_pair = seqFromPair(chosen_pair)                         
        
        use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
        # Seperate pair for input
        input_seq, target_seq = training_pair
        
        loss , hat_y = train(vae_model, input_seq, target_seq, use_teacher_forcing, optimizer, criterion)
    
        loss_list.append(loss)
        if (i+1) % print_every == 0:
            print('-----------------')
            print('Iter %d: loss = %.4f' % (i+1, loss))
            
            pred_seq = ''
            for output in hat_y:
                _, c = output.topk(1)
                pred_seq += chr(c+ord('a')-1)
            print('input_seq = ', chosen_pair[0])
            print('pred_seq = ', pred_seq)
            print('target_seq = ', chosen_pair[1])
            
        if (i+1) % save_every == 0:
            torch.save(vae_model,'./models/vae_'+str(i+1)+date)
    
    return loss_list

In [21]:
my_vae = VAE(vocab_size, hidden_size, vocab_size, teacher_forcing_ratio).to(device)

In [22]:
optimizer = optim.SGD(my_vae.parameters(), lr=lr)

## Train with Simple Present -> Present Progressive

In [34]:
loss_list = trainIter(my_vae, train_st_pp, n_iters=1000000, print_every=10000, save_every=100000,\
                      learning_rate=lr,teacher_forcing_ratio=teacher_forcing_ratio, \
                      optimizer= optimizer, criterion = VAE_Loss,date = '_0812')

-----------------
Iter 10000: loss = 0.9617
input_seq =  hurl
pred_seq =  sarling
target_seq =  hurling
-----------------
Iter 20000: loss = 1.1375
input_seq =  save
pred_seq =  eeving
target_seq =  saving
-----------------
Iter 30000: loss = 0.8964
input_seq =  bestir
pred_seq =  aaltrrring
target_seq =  bestirring
-----------------
Iter 40000: loss = 1.1162
input_seq =  grope
pred_seq =  arawing
target_seq =  groping
-----------------
Iter 50000: loss = 1.0625
input_seq =  babble
pred_seq =  clybling
target_seq =  babbling
-----------------
Iter 60000: loss = 0.9414
input_seq =  gush
pred_seq =  colhing
target_seq =  gushing
-----------------
Iter 70000: loss = 0.6900
input_seq =  develop
pred_seq =  sevoloping
target_seq =  developing
-----------------
Iter 80000: loss = 0.7156
input_seq =  contend
pred_seq =  shnsrnding
target_seq =  contending
-----------------
Iter 90000: loss = 0.8301
input_seq =  enact
pred_seq =  sxabting
target_seq =  enacting
-----------------
Iter 100000: l

-----------------
Iter 760000: loss = 0.8684
input_seq =  trouble
pred_seq =  saaubling
target_seq =  troubling
-----------------
Iter 770000: loss = 1.0749
input_seq =  mop
pred_seq =  aauping
target_seq =  mopping
-----------------
Iter 780000: loss = 1.3057
input_seq =  fly
pred_seq =  alaing
target_seq =  flying
-----------------
Iter 790000: loss = 0.7247
input_seq =  identify
pred_seq =  pnentifying
target_seq =  identifying
-----------------
Iter 800000: loss = 1.2483
input_seq =  come
pred_seq =  sonpng
target_seq =  coming
-----------------
Iter 810000: loss = 0.7909
input_seq =  attribute
pred_seq =  sptaabuting
target_seq =  attributing
-----------------
Iter 820000: loss = 0.7379
input_seq =  originate
pred_seq =  pviginating
target_seq =  originating
-----------------
Iter 830000: loss = 0.7895
input_seq =  depend
pred_seq =  cellnding
target_seq =  depending
-----------------
Iter 840000: loss = 1.1720
input_seq =  ask
pred_seq =  sssing
target_seq =  asking
-------------

In [28]:
loss_list[-1]

1.221312165260315

# Evaluation

In [29]:
def val(vae_model, data_pairs, criterion = VAE_Loss):
    loss_list = []
    
    vae_model.eval()
    
    with torch.no_grad():
        for data_pair in data_pairs:
            # Seperate pair for input
            pair = seqFromPair(data_pair)
            input_seq, target_seq = pair

            # Check device
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

            # Initialize hidden feature
            hidden = torch.zeros(1, 1, hidden_size, device=device)

            result, mu, logvar = vae_model(input_seq, hidden)

            # Ground truth should have EOS in the end
            target_seq.append(EOS_token)

            # Calculate loss
            # First, we should strim the sequences by the length of smaller one
            min_len = min(len(target_seq),len(result))

            # hat_y need not to do one-hot encoding
            hat_y = result[:min_len]
            y = torch.tensor(target_seq[:min_len], device=device)

            loss = criterion(hat_y, y, mu, logvar)

            pred_seq = ''
            for output in hat_y:
                _, c = output.topk(1)
                pred_seq += chr(c+ord('a')-1)
            print('-----------------')
            print('loss = ', loss)
            print('input_seq = ', data_pair[0])
            print('pred_seq = ', pred_seq)
            print('target_seq = ', data_pair[1][:min_len])

In [30]:
val(my_vae, train_st_pp, criterion = VAE_Loss)

-----------------
loss =  tensor(8.3039, device='cuda:0')
input_seq =  abandon
pred_seq =  shapingingi
target_seq =  abandoning
-----------------
loss =  tensor(11.2504, device='cuda:0')
input_seq =  abet
pred_seq =  shapingin
target_seq =  abetting
-----------------
loss =  tensor(9.3905, device='cuda:0')
input_seq =  abdicate
pred_seq =  shapingingi
target_seq =  abdicating
-----------------
loss =  tensor(13.2642, device='cuda:0')
input_seq =  abduct
pred_seq =  shapinging
target_seq =  abducting
-----------------
loss =  tensor(12.0394, device='cuda:0')
input_seq =  abound
pred_seq =  shapinging
target_seq =  abounding
-----------------
loss =  tensor(14.8368, device='cuda:0')
input_seq =  absorb
pred_seq =  shapinging
target_seq =  absorbing
-----------------
loss =  tensor(11.5309, device='cuda:0')
input_seq =  accept
pred_seq =  shapinging
target_seq =  accepting
-----------------
loss =  tensor(13.1654, device='cuda:0')
input_seq =  accompany
pred_seq =  shapinginging
target_se

-----------------
loss =  tensor(9.0409, device='cuda:0')
input_seq =  attain
pred_seq =  shapinging
target_seq =  attaining
-----------------
loss =  tensor(8.3896, device='cuda:0')
input_seq =  attempt
pred_seq =  shapingingi
target_seq =  attempting
-----------------
loss =  tensor(12.2223, device='cuda:0')
input_seq =  attend
pred_seq =  shapinging
target_seq =  attending
-----------------
loss =  tensor(11.9339, device='cuda:0')
input_seq =  attest
pred_seq =  shapinging
target_seq =  attesting
-----------------
loss =  tensor(7.4788, device='cuda:0')
input_seq =  attract
pred_seq =  shapingingi
target_seq =  attracting
-----------------
loss =  tensor(11.6928, device='cuda:0')
input_seq =  attribute
pred_seq =  shapingingin
target_seq =  attributing
-----------------
loss =  tensor(6.8000, device='cuda:0')
input_seq =  augment
pred_seq =  shapingingi
target_seq =  augmenting
-----------------
loss =  tensor(12.3705, device='cuda:0')
input_seq =  authorize
pred_seq =  shapingingin

-----------------
loss =  tensor(10.4105, device='cuda:0')
input_seq =  chance
pred_seq =  shapingin
target_seq =  chancing
-----------------
loss =  tensor(10.2478, device='cuda:0')
input_seq =  change
pred_seq =  shapingin
target_seq =  changing
-----------------
loss =  tensor(9.6523, device='cuda:0')
input_seq =  chant
pred_seq =  shapingin
target_seq =  chanting
-----------------
loss =  tensor(12.4399, device='cuda:0')
input_seq =  characterize
pred_seq =  shapingingingin
target_seq =  characterizing
-----------------
loss =  tensor(10.0502, device='cuda:0')
input_seq =  charge
pred_seq =  shapingin
target_seq =  charging
-----------------
loss =  tensor(9.0069, device='cuda:0')
input_seq =  chat
pred_seq =  shapingin
target_seq =  chatting
-----------------
loss =  tensor(10.6897, device='cuda:0')
input_seq =  check
pred_seq =  shapingin
target_seq =  checking
-----------------
loss =  tensor(5.6045, device='cuda:0')
input_seq =  cherish
pred_seq =  shapingingi
target_seq =  che

-----------------
loss =  tensor(7.3246, device='cuda:0')
input_seq =  correspond
pred_seq =  shapingingingi
target_seq =  corresponding
-----------------
loss =  tensor(12.0878, device='cuda:0')
input_seq =  accost
pred_seq =  shapinging
target_seq =  accosting
-----------------
loss =  tensor(10.0817, device='cuda:0')
input_seq =  cough
pred_seq =  shapingin
target_seq =  coughing
-----------------
loss =  tensor(11.3687, device='cuda:0')
input_seq =  counsel
pred_seq =  shapingingin
target_seq =  counselling
-----------------
loss =  tensor(7.5905, device='cuda:0')
input_seq =  counter
pred_seq =  shapingingi
target_seq =  countering
-----------------
loss =  tensor(10.0998, device='cuda:0')
input_seq =  couple
pred_seq =  shapingin
target_seq =  coupling
-----------------
loss =  tensor(12.1623, device='cuda:0')
input_seq =  cover
pred_seq =  shapingin
target_seq =  covering
-----------------
loss =  tensor(10.7851, device='cuda:0')
input_seq =  crack
pred_seq =  shapingin
target_s

loss =  tensor(13.9184, device='cuda:0')
input_seq =  discount
pred_seq =  shapingingin
target_seq =  discounting
-----------------
loss =  tensor(13.8162, device='cuda:0')
input_seq =  discover
pred_seq =  shapingingin
target_seq =  discovering
-----------------
loss =  tensor(8.2346, device='cuda:0')
input_seq =  discuss
pred_seq =  shapingingi
target_seq =  discussing
-----------------
loss =  tensor(12.7829, device='cuda:0')
input_seq =  disfigure
pred_seq =  shapingingin
target_seq =  disfiguring
-----------------
loss =  tensor(8.8947, device='cuda:0')
input_seq =  disguise
pred_seq =  shapingingi
target_seq =  disguising
-----------------
loss =  tensor(11.7641, device='cuda:0')
input_seq =  dislike
pred_seq =  shapinging
target_seq =  disliking
-----------------
loss =  tensor(12.9727, device='cuda:0')
input_seq =  dismember
pred_seq =  shapinginging
target_seq =  dismembering
-----------------
loss =  tensor(7.1707, device='cuda:0')
input_seq =  dismiss
pred_seq =  shapingingi

-----------------
loss =  tensor(3.7964, device='cuda:0')
input_seq =  evoke
pred_seq =  shapingi
target_seq =  evoking
-----------------
loss =  tensor(12.6680, device='cuda:0')
input_seq =  devolve
pred_seq =  shapinging
target_seq =  devolving
-----------------
loss =  tensor(13.2432, device='cuda:0')
input_seq =  exceed
pred_seq =  shapinging
target_seq =  exceeding
-----------------
loss =  tensor(6.0577, device='cuda:0')
input_seq =  exchange
pred_seq =  shapingingi
target_seq =  exchanging
-----------------
loss =  tensor(12.3334, device='cuda:0')
input_seq =  excite
pred_seq =  shapingin
target_seq =  exciting
-----------------
loss =  tensor(9.7970, device='cuda:0')
input_seq =  exclaim
pred_seq =  shapingingi
target_seq =  exclaiming
-----------------
loss =  tensor(12.5255, device='cuda:0')
input_seq =  excuse
pred_seq =  shapingin
target_seq =  excusing
-----------------
loss =  tensor(8.9611, device='cuda:0')
input_seq =  exercise
pred_seq =  shapingingi
target_seq =  exer

-----------------
loss =  tensor(11.0849, device='cuda:0')
input_seq =  gather
pred_seq =  shapinging
target_seq =  gathering
-----------------
loss =  tensor(12.1614, device='cuda:0')
input_seq =  gaze
pred_seq =  shaping
target_seq =  gazing
-----------------
loss =  tensor(13.8708, device='cuda:0')
input_seq =  degenerate
pred_seq =  shapinginging
target_seq =  degenerating
-----------------
loss =  tensor(11.5659, device='cuda:0')
input_seq =  gesture
pred_seq =  shapinging
target_seq =  gesturing
-----------------
loss =  tensor(13.0815, device='cuda:0')
input_seq =  giggle
pred_seq =  shapingin
target_seq =  giggling
-----------------
loss =  tensor(10.6527, device='cuda:0')
input_seq =  glance
pred_seq =  shapingin
target_seq =  glancing
-----------------
loss =  tensor(3.2746, device='cuda:0')
input_seq =  glare
pred_seq =  shapingi
target_seq =  glaring
-----------------
loss =  tensor(12.0940, device='cuda:0')
input_seq =  gleam
pred_seq =  shapingin
target_seq =  gleaming
--

-----------------
loss =  tensor(10.6568, device='cuda:0')
input_seq =  disinherit
pred_seq =  shapingingingi
target_seq =  disinheriting
-----------------
loss =  tensor(8.7397, device='cuda:0')
input_seq =  inhibit
pred_seq =  shapingingi
target_seq =  inhibiting
-----------------
loss =  tensor(5.1635, device='cuda:0')
input_seq =  initiate
pred_seq =  shapingingi
target_seq =  initiating
-----------------
loss =  tensor(11.8184, device='cuda:0')
input_seq =  injure
pred_seq =  shapingin
target_seq =  injuring
-----------------
loss =  tensor(10.4761, device='cuda:0')
input_seq =  inquire
pred_seq =  shapinging
target_seq =  inquiring
-----------------
loss =  tensor(13.1964, device='cuda:0')
input_seq =  insert
pred_seq =  shapinging
target_seq =  inserting
-----------------
loss =  tensor(12.5295, device='cuda:0')
input_seq =  insist
pred_seq =  shapinging
target_seq =  insisting
-----------------
loss =  tensor(10.1874, device='cuda:0')
input_seq =  inspire
pred_seq =  shapinging

loss =  tensor(8.0976, device='cuda:0')
input_seq =  mislead
pred_seq =  shapingingi
target_seq =  misleading
-----------------
loss =  tensor(13.6849, device='cuda:0')
input_seq =  mistrust
pred_seq =  shapingingin
target_seq =  mistrusting
-----------------
loss =  tensor(11.4750, device='cuda:0')
input_seq =  bemoan
pred_seq =  shapinging
target_seq =  bemoaning
-----------------
loss =  tensor(4.4610, device='cuda:0')
input_seq =  mop
pred_seq =  shapingi
target_seq =  mopping
-----------------
loss =  tensor(10.1967, device='cuda:0')
input_seq =  mourn
pred_seq =  shapingin
target_seq =  mourning
-----------------
loss =  tensor(10.7688, device='cuda:0')
input_seq =  move
pred_seq =  shaping
target_seq =  moving
-----------------
loss =  tensor(11.5402, device='cuda:0')
input_seq =  multiply
pred_seq =  shapingingin
target_seq =  multiplying
-----------------
loss =  tensor(12.6558, device='cuda:0')
input_seq =  mumble
pred_seq =  shapingin
target_seq =  mumbling
-----------------

-----------------
loss =  tensor(8.8365, device='cuda:0')
input_seq =  prefer
pred_seq =  shapingingi
target_seq =  preferring
-----------------
loss =  tensor(11.3752, device='cuda:0')
input_seq =  prepare
pred_seq =  shapinging
target_seq =  preparing
-----------------
loss =  tensor(14.1634, device='cuda:0')
input_seq =  prescribe
pred_seq =  shapingingin
target_seq =  prescribing
-----------------
loss =  tensor(7.7015, device='cuda:0')
input_seq =  preserve
pred_seq =  shapingingi
target_seq =  preserving
-----------------
loss =  tensor(12.9737, device='cuda:0')
input_seq =  compress
pred_seq =  shapingingin
target_seq =  compressing
-----------------
loss =  tensor(12.8612, device='cuda:0')
input_seq =  presume
pred_seq =  shapinging
target_seq =  presuming
-----------------
loss =  tensor(5.9471, device='cuda:0')
input_seq =  pretend
pred_seq =  shapingingi
target_seq =  pretending
-----------------
loss =  tensor(8.4808, device='cuda:0')
input_seq =  prevail
pred_seq =  shapin

-----------------
loss =  tensor(5.1476, device='cuda:0')
input_seq =  rely
pred_seq =  shapingi
target_seq =  relying
-----------------
loss =  tensor(10.1878, device='cuda:0')
input_seq =  remain
pred_seq =  shapinging
target_seq =  remaining
-----------------
loss =  tensor(13.4707, device='cuda:0')
input_seq =  remark
pred_seq =  shapinging
target_seq =  remarking
-----------------
loss =  tensor(14.9886, device='cuda:0')
input_seq =  remember
pred_seq =  shapingingin
target_seq =  remembering
-----------------
loss =  tensor(12.7275, device='cuda:0')
input_seq =  remind
pred_seq =  shapinging
target_seq =  reminding
-----------------
loss =  tensor(9.8812, device='cuda:0')
input_seq =  remonstrate
pred_seq =  shapingingingi
target_seq =  remonstrating
-----------------
loss =  tensor(12.8188, device='cuda:0')
input_seq =  remove
pred_seq =  shapingin
target_seq =  removing
-----------------
loss =  tensor(11.7360, device='cuda:0')
input_seq =  render
pred_seq =  shapinging
target_

-----------------
loss =  tensor(8.0065, device='cuda:0')
input_seq =  conserve
pred_seq =  shapingingi
target_seq =  conserving
-----------------
loss =  tensor(8.9792, device='cuda:0')
input_seq =  resettle
pred_seq =  shapingingi
target_seq =  resettling
-----------------
loss =  tensor(2.3513, device='cuda:0')
input_seq =  shape
pred_seq =  shapingi
target_seq =  shaping
-----------------
loss =  tensor(6.6719, device='cuda:0')
input_seq =  shatter
pred_seq =  shapingingi
target_seq =  shattering
-----------------
loss =  tensor(2.4472, device='cuda:0')
input_seq =  shave
pred_seq =  shapingi
target_seq =  shaving
-----------------
loss =  tensor(9.2344, device='cuda:0')
input_seq =  shiver
pred_seq =  shapinging
target_seq =  shivering
-----------------
loss =  tensor(9.8454, device='cuda:0')
input_seq =  shock
pred_seq =  shapingin
target_seq =  shocking
-----------------
loss =  tensor(9.7068, device='cuda:0')
input_seq =  shout
pred_seq =  shapingin
target_seq =  shouting
-----

-----------------
loss =  tensor(8.7264, device='cuda:0')
input_seq =  straighten
pred_seq =  shapingingingi
target_seq =  straightening
-----------------
loss =  tensor(11.9136, device='cuda:0')
input_seq =  streak
pred_seq =  shapinging
target_seq =  streaking
-----------------
loss =  tensor(12.9873, device='cuda:0')
input_seq =  distress
pred_seq =  shapingingin
target_seq =  distressing
-----------------
loss =  tensor(7.4490, device='cuda:0')
input_seq =  stretch
pred_seq =  shapingingi
target_seq =  stretching
-----------------
loss =  tensor(11.3550, device='cuda:0')
input_seq =  stroll
pred_seq =  shapinging
target_seq =  strolling
-----------------
loss =  tensor(7.5968, device='cuda:0')
input_seq =  struggle
pred_seq =  shapingingi
target_seq =  struggling
-----------------
loss =  tensor(10.0838, device='cuda:0')
input_seq =  strut
pred_seq =  shapinging
target_seq =  strutting
-----------------
loss =  tensor(10.3387, device='cuda:0')
input_seq =  study
pred_seq =  shaping

-----------------
loss =  tensor(8.4998, device='cuda:0')
input_seq =  unscrew
pred_seq =  shapingingi
target_seq =  unscrewing
-----------------
loss =  tensor(3.7558, device='cuda:0')
input_seq =  purge
pred_seq =  shapingi
target_seq =  purging
-----------------
loss =  tensor(5.4924, device='cuda:0')
input_seq =  abuse
pred_seq =  shapingi
target_seq =  abusing
-----------------
loss =  tensor(9.7193, device='cuda:0')
input_seq =  butter
pred_seq =  shapinging
target_seq =  buttering
-----------------
loss =  tensor(13.3048, device='cuda:0')
input_seq =  vanish
pred_seq =  shapinging
target_seq =  vanishing
-----------------
loss =  tensor(4.5177, device='cuda:0')
input_seq =  veer
pred_seq =  shapingi
target_seq =  veering
-----------------
loss =  tensor(11.0655, device='cuda:0')
input_seq =  violate
pred_seq =  shapinging
target_seq =  violating
-----------------
loss =  tensor(8.0759, device='cuda:0')
input_seq =  revisit
pred_seq =  shapingingi
target_seq =  revisiting
-------

-----------------
loss =  tensor(9.4867, device='cuda:0')
input_seq =  overcome
pred_seq =  shapingingi
target_seq =  overcoming
-----------------
loss =  tensor(12.5255, device='cuda:0')
input_seq =  overhear
pred_seq =  shapingingin
target_seq =  overhearing
-----------------
loss =  tensor(7.3680, device='cuda:0')
input_seq =  overtake
pred_seq =  shapingingi
target_seq =  overtaking
-----------------
loss =  tensor(7.2430, device='cuda:0')
input_seq =  overpay
pred_seq =  shapingingi
target_seq =  overpaying
-----------------
loss =  tensor(6.5705, device='cuda:0')
input_seq =  bespeak
pred_seq =  shapingingi
target_seq =  bespeaking
-----------------
loss =  tensor(5.9373, device='cuda:0')
input_seq =  outrun
pred_seq =  shapingingi
target_seq =  outrunning
-----------------
loss =  tensor(3.4784, device='cuda:0')
input_seq =  seek
pred_seq =  shapingi
target_seq =  seeking
-----------------
loss =  tensor(7.4434, device='cuda:0')
input_seq =  outsell
pred_seq =  shapingingi
targe

-----------------
loss =  tensor(10.9731, device='cuda:0')
input_seq =  sleep
pred_seq =  shapingin
target_seq =  sleeping
-----------------
loss =  tensor(11.2693, device='cuda:0')
input_seq =  speak
pred_seq =  shapingin
target_seq =  speaking
-----------------
loss =  tensor(11.2277, device='cuda:0')
input_seq =  spend
pred_seq =  shapingin
target_seq =  spending
-----------------
loss =  tensor(10.2314, device='cuda:0')
input_seq =  stand
pred_seq =  shapingin
target_seq =  standing
-----------------
loss =  tensor(11.4621, device='cuda:0')
input_seq =  take
pred_seq =  shaping
target_seq =  taking
-----------------
loss =  tensor(5.0877, device='cuda:0')
input_seq =  tell
pred_seq =  shapingi
target_seq =  telling
-----------------
loss =  tensor(10.1992, device='cuda:0')
input_seq =  think
pred_seq =  shapingin
target_seq =  thinking
-----------------
loss =  tensor(11.5922, device='cuda:0')
input_seq =  throw
pred_seq =  shapingin
target_seq =  throwing
-----------------
loss = 

In [67]:
a = nn.Embedding(28,256+8)
condition_embedding = nn.Embedding(4, 8)
input_t = torch.tensor([[1]])
input_cond = torch.tensor([[1]])
output = a(input_t).view(1, 1, -1)
c = condition_embedding(input_cond).view(1, 1, -1)

In [68]:
#output

In [69]:
hidden_size = 256
hidden = torch.zeros(1, 1, hidden_size)

In [72]:
torch.cat((hidden, c),2)

tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.00

In [62]:
c.size()

torch.Size([1, 1, 8])