In [23]:
# %load siamese_dataloader.py
from torch.utils.data import Dataset
from PIL import Image
import os
import torch
import numpy as np

def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')

def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)

def default_loader(path):
    return pil_loader(path)

class SiameseData(Dataset):
    def __init__(self, root, image_paths0, image_paths1, labels, transform=None, loader=default_loader):
        self.image_pairs =zip(image_paths0, image_paths1)
        self.labels = labels
        self.loader = loader
        self.root = root
        self.transform = transform

    def __getitem__(self, index):
        path0, path1 = self.image_pairs[index]
        img0 = self.loader(os.path.join(self.root, path0))
        img1 = self.loader(os.path.join(self.root, path1))
        if self.transform is not None:
            img0 = self.transform(img0)
            img1 = self.transform(img1)
        target = self.labels[index]
        return img0, img1, torch.LongTensor(target)

    def __len__(self):
        return len(self.labels)



In [3]:
import torch
from torch.utils.data import DataLoader

In [5]:
import torchvision.transforms as transforms


In [6]:
image_list0 = open('/tmp/image_list0', 'r').read().splitlines()

In [8]:
image_list1 = open('/tmp/image_list1', 'r').read().splitlines()

In [9]:
import numpy as np

In [10]:
targets =  np.random.randint(0, 2, 50)

In [11]:
trans = transforms.Compose([transforms.Scale(160), transforms.ToTensor()])

In [24]:
dataset = SiameseData('/home/wenfahu/faces/MTCNN_LFW/', image_list0, image_list1, targets, transform=trans)

In [25]:
train_loader = DataLoader(dataset, batch_size=2, shuffle=True)