In [2]:
import torch
import torch.nn as nn
from torch import optim

import pickle

from model.BigModel import SubToSeq
from utils.tokenMaker import Lang
from utils.tool import padding, flatMutileLength, Timer, Average
from dataset.readVideo import DramaDataset
useCuda = True
device = torch.device("cuda" if useCuda else "cpu")

In [23]:
import torch.utils.data
DataDir = "/home/ball/Videos/BrokeEN"
datasets = DramaDataset(basedir=DataDir,
                        maxFrame=0,
                        timeOffset=0.2,
                        useBmp=True
                        )
loader = torch.utils.data.DataLoader(datasets, batch_size=5, shuffle=True, num_workers=2)

Total Drama: 113


In [24]:
ModalFile = "SubToSub/model/bken-large/"
modal = torch.load(ModalFile+"SubSubModel.10.pth")
with open(ModalFile+"Lang.pkl", 'rb') as f:
    lang = pickle.load(f)
    print("Load lang model: {}. Word size: {}".format(ModalFile, len(lang)))

Load lang model: SubToSub/model/bken-large/. Word size: 15479


In [25]:
def transData(in_sents, target_sents, lang):
    in_seqs = []
    in_targets = []
    out_targets = []
    
    vectorTransforms = [lambda x: torch.LongTensor(x).to(device)]
    
    for sent in in_sents:
        in_seqs.append(lang.sentenceToVector(sent, sos=False, eos=False))
    in_seqs = padding(in_seqs, lang["PAD"], vectorTransforms)
    
    for sent in target_sents:
        in_targets.append(lang.sentenceToVector(sent, sos=True, eos=False))
        out_targets.append(lang.sentenceToVector(sent, sos=False, eos=True))
    in_targets = padding(in_targets, lang["PAD"], vectorTransforms)
    out_targets = padding(out_targets, lang["PAD"], vectorTransforms)
    return in_seqs, in_targets, out_targets

In [26]:
def predit(model, lang, in_sents, max_length=50):
    ans = []
    in_seq = torch.LongTensor(lang.sentenceToVector(in_sents, sos=False, eos=False)).unsqueeze(0).to(device)
    inputs = torch.LongTensor([[lang["SOS"]]]).to(device)
    hidden = None
    
    cxt = model.makeContext(in_seq)
    for i in range(max_length):
        outputs, hidden = model.decode(inputs, cxt, hidden)
        prob, outputs = outputs.topk(1)

        if(outputs.item() == lang["EOS"]):
            break
        ans.append(outputs.item())
        inputs = outputs.squeeze(1).detach()
    return lang.vectorToSentence(ans)

In [27]:
import matplotlib.pyplot as plt
def showVar(data):
    x = range(0, data.size(-1))
    plt.bar(x, data.data.cpu())
    plt.show()

In [28]:
it = iter(loader)

In [34]:
pres, nexs, imgs = it.next()

In [35]:
modal.eval()
for pre, nex in zip(pres, nexs):
    pred = predit(modal, lang, pre)
    print("Pre  : {}\nNext : {}\nModal: {}\n\n".format(pre, nex, pred))

Pre  : That could ve hit us
Next : Now that s unsafe sex
Modal: Now that s unsafe sex


Pre  : No No it s on the house
Next : It s your birthday
Modal: It s your birthday


Pre  : You re pathetic
Next : And that s coming from someone who is homeless
Modal: And that s coming from someone who is homeless


Pre  : And not the kind where I can loot
Next : I needed those fake papers to renew my fake green card
Modal: I needed those fake papers to renew my fake green card


Pre  : You know that right
Next : Well I know I m entitled to my truths and how my truths make me feel
Modal: Well I know I m entitled to my truths and how my truths make me feel


