# pytorch数据加载

In [52]:
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

import warnings
warnings.filterwarnings("ignore")
plt.ion()

In [12]:
landmarks_frame = pd.read_csv("/home/derek/Documents/mm/resources/2020/dataset/faces/face_landmarks.csv")
root = "/home/derek/Documents/mm/resources/2020/dataset/faces/"

def show_landmarks(image, landmarks):
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker=".", c="r")
    plt.pause(0.001)

n = 65
name = landmarks_frame.iloc[n, 0]
landmarks = np.asarray(landmarks_frame.iloc[n, 1:], dtype=float).reshape(-1, 2)
# show_landmarks(io.imread(os.path.join(root, name)), landmarks)



## Loading dataset with `Dataset`

In [48]:
from collections import namedtuple
class FaceLandmarksDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.landmarks_frame = pd.read_csv(os.path.join(root_dir, csv_file))
        self.root_dir = root_dir
        self.transform = transform

        # self.FaceLandmarkTuple = namedtuple("FaceLandmarkTuple", ["image", "landmarks"])
    
    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        # ATTENTION, here idx may be a tensor
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # may fail when idx is an array of index
        img_name = os.path.join(self.root_dir, 
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = np.asarray(self.landmarks_frame.iloc[idx, 1:], dtype=float).reshape(-1, 2)

        # can sample be a named tuple here?
        sample = {"image":image, "landmarks":landmarks}
        # sample = self.FaceLandmarkTuple(image, landmarks)

        if self.transform:
            sample = self.transform(sample)
        return sample

In [24]:
face_dataset = FaceLandmarksDataset(csv_file="face_landmarks.csv", root_dir=root)

## Pre-processing dataset using `Transform`

In [31]:
class Rescale(object):
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size
    def __call__(self, sample):
        image, landmarks = sample["image"],sample["landmarks"]
        h, w = image.shape[0:2] # 可能还有通道数
        if isinstance(self.output_size, int):
            if h>w:
                new_h, new_w = self.output_size*h/w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size*w/h
        else:
            new_h, new_w = self.output_size
        new_h, new_w = int(new_h), int(new_w)
        img = transform.resize(image, (new_h, new_w))
        landmarks = landmarks * [new_w/w, new_h/h]
        return {"image":img, 
                "landmarks":landmarks}
# # help(transform.resize)
# sample = face_dataset[0]
# fn = Rescale((300,200))
# # sample = fn(sample)
# plt.subplot(131)
# show_landmarks(sample["image"], sample["landmarks"])
# plt.subplot(132)
# sample = fn(sample)
# show_landmarks(sample["image"], sample["landmarks"])
# plt.subplot(133)
# fn = Rescale((300,100))
# sample = fn(sample)
# show_landmarks(sample["image"], sample["landmarks"])

In [34]:
class RandomCrop(object):
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image = image[top: top + new_h,
                      left: left + new_w]

        landmarks = landmarks - [left, top]

        return {'image': image, 'landmarks': landmarks}

class ToTensor(object):
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}

In [49]:
# 下面定义一个transform的pipeline，组合上述操作
transformed_dataset = FaceLandmarksDataset(
    csv_file = "face_landmarks.csv", 
    root_dir=root, 
    transform=transforms.Compose([
        Rescale(256), 
        RandomCrop(224), 
        ToTensor()
    ])
)

## Iterating dataset with `DataLoader`
Iterating `Dataset` by `enumerate(dataset)` fails to 
+ Batching the data
+ Shuffling the data
+ Load the data in parallel using `multiproxessing` workers
So we use `DataLoader` to iterate the dataset instead, in which the above features are well-integrated.

In [51]:
dataloader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=0)


## Use `collate_fn` to deal with different data types of dataset
+ if `Dataset.__getitem__()` returns a `dict`, then the default collate_fn merges entries on their key, and return a `dict`.
+ if `Dataset.__getitem__()` returns a `namedtuple` or `list`, then the default collate_fn concate the elements of the entry based on their idx, and returns a list of tensor. The length of the list equals the length of the entry. 
+ Both the default collate_fns mentioned above will try to convert the data to tensor at last. If the data cannot be converted to tensor (such as inconsistent lengths or dtypes), an error will be raised. 
+ To handle this, you can use a customized collate_fn to merge the entries in your dataset.  