In [72]:
import torch
import torch.nn as nn 
import os 
import numpy as np

dataset_path = '../dataset_224/activitynet/m_data.pt'
batch_size = 1

SOS_token = 0
EOS_token = 1

In [None]:
# get dataset vocab size
dataset = torch.load(dataset_path)
data_set = set()
for data_kind in dataset:
    for _, (_, value) in enumerate(dataset[data_kind].items()):
        captions = list(value.keys())
        for caption in captions:
            tokens = caption.split()
            for token in tokens:
                data_set.add(token)

dict_size = len(data_set) + 2 # binary array size for each unique word in dataset, EOS, SOS token.

# create dictionary for tokens in dataset
sos_array = np.eye(dict_size, dtype=np.int)[SOS_token] # word_to_index dictionary value
eos_array = np.eye(dict_size, dtype=np.int)[EOS_token]  # word_to_index dictionary value

sos_array_str = np.array2string(sos_array)  # index_to_word dictionary key.
eos_array_str = np.array2string(eos_array)  # index_to_word dictionary key. 

word_to_index = {"SOS":sos_array, "EOS":eos_array}
index_to_word = {sos_array_str:"SOS", eos_array_str:"EOS"}
for i, token in enumerate(data_set):
    one_hot_vector = np.eye(dict_size, dtype=np.int)[i]
    word_to_index[token] = one_hot_vector
    
    one_hot_str = np.array2string(one_hot_vector)
    index_to_word[one_hot_str] = token

In [66]:
class encoder_RNN(nn.Module):
    
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.BLstm = nn.RNN(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=2, bidirectional=True)

    def forward(self, input, hidden):
        output, hidden = self.BLstm(input, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)
         

In [67]:
class DecoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.BLstm = nn.RNN(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=2, bidirectional=True)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        output, hidden = self.BLstm(input, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden


In [69]:
# Training the model
teacher_forcing_ratio = 0.5
MAX_LENGTH = 10

def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):
    encoder_hidden = encoder.initHidden()

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)

    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)

    loss = 0

    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(
            input_tensor[ei], encoder_hidden)
        encoder_outputs[ei] = encoder_output[0, 0]

    decoder_input = torch.tensor([[SOS_token]], device=device)

    decoder_hidden = encoder_hidden

    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    if use_teacher_forcing:
        # Teacher forcing: Feed the target as the next input
        for di in range(target_length):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden)
            loss += criterion(decoder_output, target_tensor[di])
            decoder_input = target_tensor[di]  # Teacher forcing

    else:
        # Without teacher forcing: use its own predictions as the next input
        for di in range(target_length):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()  # detach from history as input

            loss += criterion(decoder_output, target_tensor[di])
            if decoder_input.item() == EOS_token:
                break

    loss.backward()

    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_length  # average loss per each word.

In [None]:
import time
import math


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [None]:
def trainIters(encoder, decoder, n_iters, print_every=1000, plot_every=100, learning_rate=0.01):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
    training_pairs = [tensorsFromPair(random.choice(pairs))
                      for i in range(n_iters)]
    criterion = nn.NLLLoss()

    for iter in range(1, n_iters + 1):
        training_pair = training_pairs[iter - 1]
        input_tensor = training_pair[0]
        target_tensor = training_pair[1]

        loss = train(input_tensor, target_tensor, encoder,
                     decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        if iter % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
                                         iter, iter / n_iters * 100, print_loss_avg))

        if iter % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

    showPlot(plot_losses)
