In [3]:
import sys
sys.path.append("..")

In [11]:
import torch 
from torch import optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

import matplotlib.pyplot as plt

from dataset import VisDialDataset
from utils.token import Lang

jsonFile = "/home/ball/dataset/mscoco/visdialog/visdial_1.0_val.json"
cocoDir = "/home/ball/dataset/mscoco/"
langFile = "../dataset/lang.pkl"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda


In [6]:
def trainSentence(sents, SOS, EOS, device):
    train_sents = {}
    train_sents["in"] = []
    train_sents["out"] = []    
    
    for sent in sents:
        train_sents["in"].append(torch.cat([sent.new([SOS]), sent]).to(device))
        train_sents["out"].append(torch.cat([sent, sent.new([EOS])]).to(device))
    return train_sents

In [77]:
import torch
import torch.nn.functional as F
from model.model import SentenceEncoder, SentenceDecoder

class AutoModel(torch.nn.Module):
    def __init__(self, encoder_setting, decoder_setting, padding_idx=0):
        super(AutoModel, self).__init__()

        self.encodeModel = SentenceEncoder(**encoder_setting)
        decoder_setting["feature_size"] = encoder_setting["output_size"]
        self.decoderModel = SentenceDecoder(**decoder_setting)
        
    def forward(self, encode_seqs, input_seqs, hidden=None):
        context = self.makeContext(encode_seqs)
        output, hidden = self.decoderModel(input_seqs, context, hidden)
        return output, hidden
    
    def makeContext(self, encode_seqs):
        en_out, _ = self.encodeModel(encode_seqs)
        return [en_out]
    
    def decode(self, input_seq, context, hidden=None):
        output, hidden = self.decoderModel(input_seq, context, hidden)
        output = F.softmax(output, dim=2)
        return output, hidden

In [13]:
lang = Lang.load(langFile)
dataset = VisDialDataset(dialFile = jsonFile,
                         cocoDir = cocoDir, 
                         sentTransform = torch.LongTensor,
                         convertSentence = lang.sentenceToVector)

Load lang model: ../dataset/lang.pkl. Word size: 43974


Preparing image paths with image_ids: 133351it [00:00, 381268.91it/s]


In [47]:
loader = torch.utils.data.DataLoader(dataset.getAllSentences(),
                                     batch_size=5, 
                                     shuffle=True, 
                                     num_workers=4,
                                     collate_fn=collate_fn)

In [48]:
def collate_fn(batch):
    return batch

In [59]:
def setData(data, lang, device):
    seqs_t = []
    for seq in data:
        seqs_t.append(torch.LongTensor(seq).to(device))
        
    decode_t = trainSentence(seqs_t, lang["<SOS>"], lang["<EOS>"], device)
    decode_t["in"] = pad_sequence(decode_t["in"], batch_first=True)
    decode_t["out"] = pad_sequence(decode_t["out"], batch_first=True)
    encode_t = pad_sequence(seqs_t, batch_first=True)
    return encode_t, decode_t

In [49]:
it = iter(loader)

In [51]:
seqs = it.next()
en_seq, de_seq = setData(seqs, lang, DEVICE)

In [78]:
encoder_setting = {
    "word_size": len(lang),
    "output_size": 512
}
decoder_setting = {
    "word_size": len(lang),
}

model = AutoModel(encoder_setting, decoder_setting).to(DEVICE)

In [73]:
cxt = model.makeContext(en_seq)

In [79]:
model(en_seq, de_seq["in"])

(tensor([[[-0.0000,  0.1270, -0.1571,  ...,  0.2824, -0.0007,  0.0038],
          [-0.0710,  0.2599, -0.0645,  ...,  0.0468, -0.0819, -0.2804],
          [-0.0000,  0.0000, -0.0685,  ..., -0.0306,  0.0762, -0.1800],
          ...,
          [ 0.0468,  0.0208,  0.0000,  ..., -0.0444, -0.0156,  0.0273],
          [ 0.0475,  0.0299,  0.0391,  ..., -0.0478, -0.0111,  0.0285],
          [ 0.0431,  0.0318,  0.0367,  ..., -0.0474, -0.0098,  0.0354]],
 
         [[-0.1626,  0.0980, -0.1199,  ...,  0.3976,  0.1012,  0.1444],
          [ 0.1954,  0.0000, -0.3049,  ...,  0.2141, -0.0295, -0.2370],
          [ 0.0656,  0.0119,  0.0740,  ..., -0.0846, -0.2322, -0.4034],
          ...,
          [-0.1533, -0.1284,  0.1295,  ...,  0.5990,  0.1152, -0.3269],
          [ 0.0848, -0.0416,  0.0778,  ...,  0.6078,  0.1296, -0.0667],
          [-0.0000, -0.0000,  0.2237,  ...,  0.4336,  0.0138, -0.0989]],
 
         [[-0.1438,  0.0815, -0.1442,  ...,  0.0000,  0.1145,  0.1583],
          [-0.0379,  0.1623,