In [1]:
import torch 
from torch import optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

from model.model import SentenceEncoder, SentenceDecoder, ImageEncoder, cnnTransforms
from dataset import VisDialDataset
from utils.token import Lang
jsonFile = "/home/ball/dataset/mscoco/visdialog/visdial_1.0_val.json"
cocoDir = "/home/ball/dataset/mscoco/"
langFile = "dataset/lang.pkl"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
lang = Lang.load(langFile)
dataset = VisDialDataset(dialFile = jsonFile,
                         cocoDir = cocoDir, 
                         sentTransform = torch.LongTensor,
                         imgTransform = cnnTransforms,
                         convertSentence = lang.sentenceToVector)

Load lang model: dataset/lang.pkl. Word size: 43974


Preparing image paths with image_ids: 133351it [00:00, 372527.98it/s]


In [3]:
def collate_fn(batch):
    images = []
    questions = []
    answers = []
    for row in batch:
        images.append(row["image"])
        questions.append(row["questions"][0])
        answers.append(row['answers'][0])
    images = torch.stack(images)
    return {
        "images": images,
        "questions": questions,
        "answers": answers
    }
loader = torch.utils.data.DataLoader(dataset, 
                                     batch_size=4, 
                                     shuffle=True, 
                                     num_workers=4, 
                                     collate_fn=collate_fn)
it = iter(loader)

In [4]:
data = it.next()

In [None]:
def setData(data, lang):
    questions = []
    in_ans = []
    out_ans = []
    
    for ans in data["answers"]:
        in_ans.append(torch.cat([torch.LongTensor([lang["<SOS>"]]), ans]).to(DEVICE))
        out_ans.append(torch.cat([ans, torch.LongTensor([lang["<EOS>"]])]).to(DEVICE))
        
    out_seq = pad_sequence(out_seq, batch_first=True)
    images = data["images"].to(DEVICE)
    questions = torch.stack(data["questions"]).to(DEVICE)
    return images, in_seq, out_seq

In [5]:
data

{'answers': [tensor([ 45, 277,   6, 244]),
  tensor([399]),
  tensor([7665,  407]),
  tensor([6457, 1316, 6081])],
 'images': tensor([[[[ 0.3309,  0.4337,  0.5536,  ...,  0.9132,  0.8618,  0.8789],
           [ 0.6049,  0.6734,  0.7077,  ...,  0.8789,  0.8618,  0.8618],
           [ 0.7077,  0.7077,  0.7077,  ...,  0.9132,  0.8961,  0.8789],
           ...,
           [ 0.8961,  0.9303,  0.9132,  ...,  0.7248,  0.7591,  0.7933],
           [ 0.8789,  0.9474,  0.9474,  ...,  0.8618,  0.8276,  0.8104],
           [ 0.8618,  0.8961,  0.9132,  ...,  0.9303,  0.8789,  0.8618]],
 
          [[ 0.4503,  0.5553,  0.7129,  ...,  0.8354,  0.8529,  0.8704],
           [ 0.7129,  0.8179,  0.8529,  ...,  0.8529,  0.8354,  0.8529],
           [ 0.8179,  0.8529,  0.8354,  ...,  0.8529,  0.8354,  0.8529],
           ...,
           [ 0.9755,  0.9580,  0.9405,  ...,  0.6604,  0.6779,  0.7129],
           [ 0.9755,  0.9755,  0.9580,  ...,  0.8004,  0.7479,  0.7304],
           [ 0.9755,  0.9755,  0.9755