In [86]:
import sys
sys.path
sys.path.insert(1, 'C:\\Users\\scj14\\BYU\\research\\image_encoder\\github\\experiment1')

from process_data import save_maps, load_maps, index2sentence, subsequent_mask
from encoderdecoder import EncoderDecoder, save_model, load_model
from dataset import TranslationDataset, AutoencoderDataset, padding_collate_fn, Batch
from image import tensor2image

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
import time


PAD = 0
SOS = 1
EOS = 2

In [27]:
base_dir = "../text_corpora/prepared/"
data_dir = "eng/"
filename = "toy_data.txt"
map_dir = '../outputs/maps/'
model_dir = '../outputs/models/'
train_name = 'toy_data_txt-2023-02-13-11-11'

In [45]:
file_path = base_dir + data_dir + filename
dataset = AutoencoderDataset(file_path, min_freq_vocab=5)
word2index, index2word = load_maps(train_name, map_dir=map_dir)
dataset.init_using_existing_maps(None, word2index, index2word)
vocab_size = len(dataset.word2index)
dataloader = DataLoader(
    dataset, 
    batch_size=32, 
    pin_memory=True, 
    collate_fn=padding_collate_fn, 
    shuffle=False
)
model = load_model(train_name, model_dir=model_dir)

Converting lines to indices...


In [218]:
def greedy_decode(model, src, src_mask, max_len=50):
    src = src.unsqueeze(0)
    src_mask = src_mask.unsqueeze(0)
    x = model.encode(src, src_mask)
    memory, image = model.extract_features(x)
    ys = torch.ones(1, 1).fill_(SOS).type_as(src.data)
    for i in range(max_len-1):
        ys = Variable(ys)
        ys_mask = Variable(subsequent_mask(ys.size(1)).type_as(src.data))
        out = model.decode(ys, memory, ys_mask)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.data[0]
        ys = torch.cat([ys, 
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
        if next_word.item() == EOS:
            break
    return ys[0], image

def beam_search(model, src, src_mask, k=5, max_len=50):
    src = src.unsqueeze(0)
    src_mask = src_mask.unsqueeze(0)
    x = model.encode(src, src_mask)
    memory, image = model.extract_features(x)
    ys = torch.ones(1, 1).fill_(SOS).type_as(src.data)
    top_sequences = [(ys, 0, False)]
    for i in range(50-1):
        new_sequences = []
        for seq, score, eos_bool in top_sequences:
            if eos_bool:
                new_sequences.append((seq, score, eos_bool))
                continue
            seq = Variable(seq)
            mask = Variable(subsequent_mask(seq.size(1)).type_as(src.data))
            out = model.decode(seq, memory, mask)
            dist = model.generator(out[:, -1])
            probs, words = torch.topk(dist, k)
            for prob, word in zip(probs[0], words[0]):
                new_seq = torch.cat([seq, torch.ones(1, 1).type_as(src.data).fill_(word)], dim=1)
                new_score = score + prob
                new_eos_bool = True if word.item() == EOS else False
                new_sequences.append((new_seq, new_score, new_eos_bool))
        top_sequences = sorted(new_sequences, key=lambda val: val[1], reverse=True)
        top_sequences = top_sequences[:k]
        if sum(b for _, _, b in top_sequences) == k:
            break
    return top_sequences, image

def calc_src_prob(model, src, src_mask):
    src = src.unsqueeze(0)
    src_mask = src_mask.unsqueeze(0)
    x = model.encode(src, src_mask)
    memory, image = model.extract_features(x)
    prob = 0
    print(src)
    for i in range(len(src[0])-1):
        ys = src[0, :i+1].unsqueeze(0)
        ys_mask = Variable(subsequent_mask(ys.size(1)).type_as(src.data))
        out = model.decode(ys, memory, ys_mask)
        dist = model.generator(out[:, -1])
        prob += dist[0, src[0,i+1]].item()
        print(prob)
    out = model.decode(src, memory, src_mask)
    dist = model.generator(out[:,-1])
    prob += dist[0, EOS].item()
    return prob

In [236]:
batch = Batch(dataset[4:10])

In [237]:
calc_src_prob(model, batch.src[0], batch.src_pad_mask[0])

tensor([[ 1, 15, 28, 17, 29, 30, 31, 32, 19, 33, 34, 26, 19, 35,  8]],
       device='cuda:0')
-0.09190216660499573
-0.766737312078476
-0.878957487642765
-2.959669329226017
-3.71846554428339
-9.085223890841007
-11.248544670641422
-11.287339769303799
-18.501816354691982
-20.546979032456875
-20.869659088551998
-21.58021166175604
-21.765391387045383
-21.848864771425724


-21.926146119832993

In [255]:
model.eval()
greedy, image = greedy_decode(model, batch.src[0], batch.src_pad_mask[0])
beam, _ = beam_search(model, batch.src[0], batch.src_pad_mask[0], k=20)

In [256]:
src = batch.src[0]
src_len = (src != 0).sum().item()
inp = index2sentence(src[1:src_len].tolist(), index2word)
greedy_out = index2sentence(greedy[1:-1].tolist(), index2word)

In [257]:
inp

'tom could not help but notice all the beautiful women on the beach .'

In [258]:
greedy_out

'tom could not help but cry on the other side .'

In [259]:
for seq, _, _ in beam:
    print(index2sentence(seq[0][1:-1].tolist(), index2word))

tom could not eat the food he was in .
tom could not eat the food he was in my eyes .
tom could not eat the food in the sky .
tom and mary could not feel at the back of the world .
tom could not help but cry on the way all .
tom could not eat the food he was in the bathroom .
tom and mary could not feel at the back of the town .
tom could not eat the food he was in their .
tom could not eat the food he was in the woods .
tom and mary were not even in the back of the food .
tom could not eat the food he was in the .
tom and mary could not find the same food in town .
tom and mary were not even in the back of the meeting .
tom and mary could not feel at the back of the city .
tom and mary were not even in the back of the election .
tom and mary could not find all the same food in her .
tom and mary were not even in the back of the speech .
tom could not help but cry on the way all the other .
tom and mary could not find all the same food in his bedroom .
tom and mary could not find all t

In [254]:
for _, p, _ in beam:
    print(p.item())

-8.869412422180176
-9.547994613647461
-9.809532165527344
-9.952197074890137
-11.857104301452637
