In [2]:
import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import os
import sys
sys.path.append('../')
import pickle
import numpy as np
import nltk
from PIL import Image
from build_vocab import Vocabulary
from pycocotools.coco import COCO
from collections import Counter

In [3]:
def collate_fn(data):
    """Creates mini-batch tensors from the list of tuples (image, caption).
    
    We should build custom collate_fn rather than using default collate_fn, 
    because merging caption (including padding) is not supported in default.

    Args:
        data: list of tuple (image, caption). 
            - image: torch tensor of shape (3, 256, 256).
            - caption: torch tensor of shape (?); variable length.

    Returns:
        images: torch tensor of shape (batch_size, 3, 256, 256).
        targets: torch tensor of shape (batch_size, padded_length).
        lengths: list; valid length for each padded caption.
    """
    # Sort a data list by caption length (descending order).
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    # Merge images (from tuple of 3D tensor to 4D tensor).
    images = torch.stack(images, 0)

    # Merge captions (from tuple of 1D tensor to 2D tensor).
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]        
    return images, targets, lengths

In [5]:
class XRDataSet(data.Dataset):
    def __init__(self, config_file, root_data, vocab, transform):
        self.config_file = config_file
        self.root_data = root_data
        self.vocab = vocab
        self.transform = transform
        self.images_list = []
        self.findings_list = []
        with open(config_file) as f:
            for line in f.readlines():
                if line is None or len(line) == 0:
                    continue
                ss = line.split('\t')
                if len(ss) < 2:
                    continue
                for img_name in ss[1:]:
                    img_file = os.path.join(root_data, '{}.npy'.format(img_name))
                    if os.path.exists(img_file):
                        self.images_list.append(img_file)
                        self.findings_list.append(ss[0])
                        break
        
    def __getitem__(self, index):
        img_id = self.images_list[index]
#         image = Image.open(img_id).convert('RGB')
#         if self.transform is not None:
#             image = self.transform(image)
        image = np.load(img_id)
        image = torch.from_numpy(image)
        caption = self.findings_list[index]
        caption = str(caption)
        caption = caption.lower()
        tokens = caption.split()
        caption = []
        caption.append(vocab('<start>'))
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab('<end>'))
        target = torch.Tensor(caption)
        return image, target
    
    def __len__(self):
        return len(self.images_list)

In [6]:
def get_loader(root, config_file, vocab, transform, batch_size, shuffle, num_workers):
    """Returns torch.utils.data.DataLoader for custom coco dataset."""
    # COCO caption dataset
    
    ds_xr = XRDataSet(config_file, root, vocab, transform)
    
    # Data loader for COCO dataset
    # This will return (images, captions, lengths) for each iteration.
    # images: a tensor of shape (batch_size, 3, 224, 224).
    # captions: a tensor of shape (batch_size, padded_length).
    # lengths: a list indicating valid length for each caption. length is (batch_size).
    data_loader = torch.utils.data.DataLoader(dataset=ds_xr, 
                                              batch_size=batch_size,
                                              shuffle=shuffle,
                                              num_workers=num_workers,
                                              collate_fn=collate_fn)
    return data_loader

In [7]:
transform = transforms.Compose([ 
        transforms.RandomCrop(256),
        transforms.RandomHorizontalFlip(), 
        transforms.ToTensor(), 
        transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))])

In [8]:
with open('./xray_data/../../coco/vocab/vocab.pkl', 'rb') as f:
    vocab = pickle.load(f)

In [9]:
all_words = []
with open('xray_data/findings1.txt') as f:
    for line in f.readlines():
        line = line.strip()
        if line is None or len(line) == 0:
            continue
        finding = line.split('\t')[0]
        all_words += finding.split()
set(all_words)
print(len(set(all_words)))

2182


In [10]:
end_sig_words = [i[-1] for i in all_words]
bg_sig_words = [i[0] for i in all_words]

In [11]:
# begin ['.',',','[','(']
# end [';', '?', ':', '.', ',', ')', '/']
print(set(bg_sig_words))
print(set(end_sig_words))

{'g', '.', 'T', 'i', 'B', 't', ',', 'd', 'k', 'E', 'f', 'Q', ';', '2', '5', 'a', 'G', 'L', '9', '0', 'e', '4', 'X', 'c', ')', 'S', 'C', 'M', 'A', 'j', '(', 'r', 'p', 'P', '7', '/', ':', 'N', 'n', 'q', '8', 'x', 'O', 'K', 'b', 'D', '[', 'I', 'z', 'U', 'y', '?', 's', 'v', 'm', 'u', 'h', '<', '1', 'H', 'o', 'F', 'w', 'R', 'W', '6', '3', 'l', 'V'}
{'g', '.', 'T', 'i', 'B', 't', ',', 'd', 'k', 'E', 'f', ';', '2', ']', '5', 'a', 'G', '9', '0', 'e', ')', 'X', 'c', '4', 'C', 'S', 'A', '(', 'r', 'p', 'P', '7', '/', ':', 'N', 'n', 'J', 'x', '8', 'K', 'b', 'D', '[', 'y', 'I', '?', 's', 'm', 'h', '1', 'o', 'H', 'F', 'w', '6', '3', 'l', 'V'}


In [13]:
data_loader = get_loader('xray_data/features', 'xray_data/findings1.txt', vocab, 
                             transform, 256,
                             shuffle=True, num_workers=4) 

In [15]:
for enum, (images, captions, lengths) in enumerate(data_loader):
    print(images.shape)
    print(captions.shape)
    print(lengths)
    break

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/zhangwd/.conda/envs/py36/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/zhangwd/.conda/envs/py36/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "<ipython-input-3-b2e18b264bf7>", line 22, in collate_fn
    images = torch.stack(images, 0)
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 16 and 19 in dimension 2 at /opt/conda/conda-bld/pytorch_1573049304260/work/aten/src/TH/generic/THTensor.cpp:689


In [19]:
len(data_loader)

1