In [1]:
import torch 
from torch import optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

import matplotlib.pyplot as plt

from model.model import SentenceEncoder, SentenceDecoder, ImageEncoder, cnnTransforms
from dataset import VisDialDataset
from utils.token import Lang

from VQA.model import VQAModel
from VQA.utils import collate_fn, setData, trainSentence, predit

jsonFile = "/home/ball/dataset/mscoco/visdialog/visdial_1.0_val.json"
cocoDir = "/home/ball/dataset/mscoco/"
langFile = "dataset/lang.pkl"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda


In [2]:
lang = Lang.load(langFile)
dataset = VisDialDataset(dialFile = jsonFile,
                         cocoDir = cocoDir, 
#                          sentTransform = torch.LongTensor,
#                          imgTransform = cnnTransforms,
#                          convertSentence = lang.sentenceToVector
                        )

Load lang model: dataset/lang.pkl. Word size: 43974


Preparing image paths with image_ids: 133351it [00:00, 370994.24it/s]


In [3]:
model = torch.load("VQA/models/first/VQAmodel.29.pth").to(DEVICE).eval()

In [4]:
data = dataset[966]

In [10]:
beamPredit(model, DEVICE, lang, cnnTransforms(data["image"]), lang.sentenceToVector(data["questions"][0]), 3)

([['<SOS> yes <EOS>', 0.5751819610595703],
  ['<SOS> no <EOS>', 0.3270225524902344],
  ['<SOS> no idea <EOS>', 0.24922943115234375]],
 [<__main__.Beam at 0x7fb48539fda0>,
  <__main__.Beam at 0x7fb48539fcf8>,
  <__main__.Beam at 0x7fb48539fe48>])

In [5]:
class Beam():
    def __init__(self, seq, end, scores=None, state=None):
        self.seq = seq
        self.state = None
        self.scores = scores
        self.end = end
        
    def getInput(self):
        return self.seq[:, -1:], self.state
        
    def addState(self, next_seq, score, state):
        seq = torch.cat([self.seq, next_seq], 1)
        scores = torch.cat([self.scores, score], 1) if self.scores is not None else score
        
        return Beam(seq, 
                    self.end,
                    scores, 
                    state)
    def isEnd(self):
        return self.seq[0, -1] == self.end
    
    def score(self):
        if self.scores is None:
            return -1
        if self.isEnd():
            endsc = 10
        else:
            endsc = 0
        return self.scores.mean() + endsc
    
    def __lt__(self, other):
        return self.score() < other.score()

In [11]:
def beamPredit(model, device, lang, image, question, beamSize, MAX_LEN=20): 
    image_t = image.unsqueeze(0).to(device)
    question_t = torch.LongTensor(question).unsqueeze(0).to(device)
    feature = model.makeContext(image_t, question_t)
    beams = [Beam(torch.LongTensor([[lang["<SOS>"]]]).to(device), lang["<EOS>"])]
    for _ in range(MAX_LEN):
        newBeams = []
        for beam in beams:
            if beam.isEnd():
                newBeams.append(beam)
            else:
                pre_inputs, hidden = beam.getInput()
                next_outputs, hidden = model.decode(pre_inputs, feature, hidden)

                probs, next_outputs = next_outputs.topk(beamSize)
                for i in range(probs.size(2)):
                    newBeams.append(beam.addState(next_outputs[:,:,i].detach(), probs[:,:,i].detach(), hidden))
        newBeams.sort(reverse=True)
        beams = newBeams[:beamSize]
    ans = []
    for beam in beams:
        score = (beam.score() - 10) if beam.isEnd() else beam.score()
        ans.append([lang.vectorToSentence(beam.seq[0].cpu().numpy()), score.item()])
    return ans