In [1]:
import torch
import torch.nn as nn
from torch import optim


import pickle

from model.BigModel import SubToSeq
from utils.tokenMaker import Lang
from utils.tool import padding, flatMutileLength, Timer, Average
from dataset.readVideo import DramaDataset
useCuda = True
device = torch.device("cuda" if useCuda else "cpu")

In [2]:
import torch.utils.data
DataDir = "/home/ball/Videos/Broke"
datasets = DramaDataset(basedir=DataDir,
                        maxFrame=0,
                        timeOffset=0.2,
                        useBmp=True
                        )
loader = torch.utils.data.DataLoader(datasets, batch_size=5, shuffle=True, num_workers=2)

Total Drama: 134


In [3]:
ModalFile = "SubToSub/models/BK_CH_FIX_200/"
modal = torch.load(ModalFile+"SubSubModel.10.pth")
with open(ModalFile+"Lang.pkl", 'rb') as f:
    lang = pickle.load(f)
    print("Load lang model: {}. Word size: {}".format(ModalFile, len(lang)))

Load lang model: SubToSub/models/BK_CH_FIX_200/. Word size: 3703


In [4]:
def transData(in_sents, target_sents, lang):
    in_seqs = []
    in_targets = []
    out_targets = []
    
    vectorTransforms = [lambda x: torch.LongTensor(x).to(device)]
    
    for sent in in_sents:
        in_seqs.append(lang.sentenceToVector(sent, sos=False, eos=False))
    in_seqs = padding(in_seqs, lang["PAD"], vectorTransforms)
    
    for sent in target_sents:
        in_targets.append(lang.sentenceToVector(sent, sos=True, eos=False))
        out_targets.append(lang.sentenceToVector(sent, sos=False, eos=True))
    in_targets = padding(in_targets, lang["PAD"], vectorTransforms)
    out_targets = padding(out_targets, lang["PAD"], vectorTransforms)
    return in_seqs, in_targets, out_targets

In [5]:
def predit(model, lang, in_sents, max_length=50):
    ans = []
    in_seq = torch.LongTensor(lang.sentenceToVector(in_sents, sos=False, eos=False)).unsqueeze(0).to(device)
    inputs = torch.LongTensor([[lang["SOS"]]]).to(device)
    hidden = None
    
    cxt = model.makeContext(in_seq)
    for i in range(max_length):
        outputs, hidden = model.decode(inputs, cxt, hidden)
        prob, outputs = outputs.topk(1)

        if(outputs.item() == lang["EOS"]):
            break
        ans.append(outputs.item())
        inputs = outputs.squeeze(1).detach()
    return lang.vectorToSentence(ans)

In [6]:
import matplotlib.pyplot as plt
def showVar(data):
    x = range(0, data.size(-1))
    plt.bar(x, data.data.cpu())
    plt.show()

In [7]:
it = iter(loader)

In [8]:
pres, nexs, imgs = it.next()

In [9]:
modal.eval()
for pre, nex in zip(pres, nexs):
    pred = predit(modal, lang, pre)
    print("Pre  : {}\nNext : {}\nModal: {}\n\n".format(pre, nex, pred))

Pre  : 阿憨 你這樣真的太多管閑事了
Next : 真是太沒禮貌了 因為多管閑事是我的工作
Modal: 比如好消息的情況下冰門一


Pre  : 我的最愛
Next : 腰果除外
Modal: 腰果除外


Pre  : 而降低了膽固醇水平的人啊
Next : 還有那些因爲我 纔不吃Klamitra的人
Modal: 還有那些因爲我 纔不吃Klamitra的人


Pre  : 嘿 我又沒喝酒 干嘛要付那酒錢
Next : 啊 這可是我抓酸黃瓜的手
Modal: 我用的 我要用那筆錢給我買了一包酒店的酒


Pre  : 我們家族和Shecter家族私交甚好
Next : 其實 他家少爺David還在這實習呢
Modal: 其實 他家少爺David還在這實習呢


