In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
import os

## Torch自带数据集加载

#### FashionMNIST
以此为例，其他类似

In [97]:
def offical_exist_data(data_name='FashionMNIST', batch_size=256, path='../../../Datasets/FashionMNIST'):
    if data_name == 'FashionMNIST':
        train = torchvision.datasets.FashionMNIST(root=path,train=True,download=True,
                                              transform=transforms.ToTensor())
        test = torchvision.datasets.FashionMNIST(root=path,train=False,download=True,
                                              transform=transforms.ToTensor())
    # elif ...
    # ...
    
    train_iter = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
    test_iter = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=True)
    return train_iter, test_iter

In [98]:
train_iter, test_iter = offical_exist_data()
for x,y in train_iter:
    break
x.shape

torch.Size([256, 1, 28, 28])

## 自定义数据集

In [35]:
class Datasets(torch.utils.data.Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels
    def __getitem__(self, index):
        feature = torch.Tensor(self.features[index])
        label = torch.LongTensor(self.labels[index])
        return feature, label
    def __len__(self):
        return len(self.features)

#### 加载内存中的数据集

In [46]:
feature = np.random.random((10,10))
label = np.random.randint(0,10,(10,1))

In [47]:
data = Datasets(feature,label)

In [48]:
print(data[0])

(tensor([0.6762, 0.2933, 0.2260, 0.3678, 0.7459, 0.5180, 0.0798, 0.8472, 0.9283,
        0.2913]), tensor([2]))


#### 加载本地图像数据集

In [163]:
def label_map(id_labels=[], cls_labels=[],data_path='../Datasets/data1'):
    '''
    this function can map label's id to class name as well as class name to label's id 
    '''
    cls_list = os.listdir(data_path)
    id_name=list(range(len(cls_list)))
    cls_name=cls_list
    if cls_labels == []:
        return [cls_name[int(i)] for i in id_labels]
    elif id_labels == []:
        return list(map(cls_name.index,cls_labels))

In [164]:
def load_local_data(data_path = '../Datasets/data1', img_size=(28,28), img_type='L'):
    '''
    data folder tree:
        data1:
            class(0):
                image1
                image2
                ...
            class(1):
                image1
                image2
                ...
            ...
            class(n):
                ...
                
    img_size: HxW int
    img_type: 
        use to convert image by PIL.Image.convert() function
        options:
            'L': gray
            'RGB': 3-channel image
    '''
    # class name list
    cls_list = os.listdir(data_path)
    # all files' names
    all_cls_files = []
    for i in range(len(cls_list)):
        all_cls_files.append(os.listdir(os.path.join(data_path, cls_list[i])))
    # all files' paths
    all_feature_path = []
    for cls_i in range(len(all_cls_files)):
        for feature_i in range(len(all_cls_files[cls_i])):
            all_feature_path.append(os.path.join(data_path, cls_list[cls_i], all_cls_files[cls_i][feature_i]))
    #### handle features ####
    # use PIL load image data [0-255] RGB HxWxC
    features_PIL=[]
    for path in all_feature_path:
        img = Image.open(path)
        # convert to img_type
        img = img.convert(img_type)
        # resize image
        img = img.resize(img_size,Image.ANTIALIAS)#  Image.ANTIALIAS 最高质量
        features_PIL.append(img)
    # transform to Tensor [0,1]
    ToTensor = transforms.ToTensor()
    features = list()
    for feature in features_PIL:
        features.append(ToTensor(feature))
    features = torch.stack(features,0)
    #### handle labels ####
    # load labels
    cls_label = []
    for i in range(len(cls_list)):
        cls_name = []
        cls_name.append(cls_list[i])
        cls_label += (cls_name*len(all_cls_files[i]))
    # if not LongTensor will get some bugs...
    id_label = torch.LongTensor(label_map(cls_labels=cls_label))
    
    return features, id_label, cls_label

In [165]:
features, id_label, cls_label = load_local_data(data_path = '../Datasets/data1', img_size=(28,28))
data = Datasets(features,id_label)

In [166]:
features.shape

torch.Size([11, 1, 28, 28])

#### 打乱数据并分割训练集验证集

In [74]:
def train_test_split(feature,label,split_scale=0.7):
    index = np.arange(features.shape[0])
    np.random.shuffle(index)
    train_index = index[:int(len(index)*split_scale)]
    test_index = index[int(len(index)*split_scale):]
    train_feature = feature[train_index]
    test_feature = feature[test_index]
    train_label = label[train_index]
    test_label = label[test_index]
    return train_feature,test_feature,train_label,test_label

In [182]:
train_feature,test_feature,train_label,test_label = train_test_split(features,id_label,1)
train_data = Datasets(train_feature, train_label)
test_data = Datasets(test_feature, test_label)
batch_size = 256
train_iter = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
#test_iter = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)

In [185]:
test_feature

tensor([], size=(0, 1, 28, 28))

In [180]:
features[:11]

tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]]],


        [[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]]],


        [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..