In [15]:
import torch
import torch.utils.data as Data
import numpy as np
import torchvision
import torchvision.transforms as T
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

In [16]:
class MyDatasets(Data.Dataset):
    """My datasets """
    def __init__(self,filepath,transforms = None):
        '''
        params filepath: input filelist 
                1.jpg 0
                2.jpg 1
                3.jpg 1
        params transforms: imgs transforms 
                T.Compose([
                    T.RandomCrop(224),
                    T.RandomHorizontalFlip(),
                    T.Resize(256),
                    T.ToTensor(),
                    T.Normalize((0,5,0.5,0.5),(0,5,0.5,0.5))
                ])
        '''
        with open(filepath,'r') as f:
            lines = f.readlines()
        img_list = [line.strip().split()[0] for line in lines]
        img_label = [line.strip().split()[1] for line in lines]
        self.imgs_ = img_list
        self.label_ = img_label
        self.transforms_ = transforms
        assert len(img_list) == len(img_label)
        self.num_ = len(img_list)
    
    def __getitem__(self,item):
        """get img and label accroding item"""
        img = self.imgs_[item]
        label = self.label_[item]
        im = Image.open(img).convert('RGB')
        if self.transforms_ is not None:
            im = self.transforms_(im)   # 在这里做transform，转为tensor等等
        return im, label
        
    def __len__(self):
        """return imageset num"""
        return self.num_
        

In [17]:
import torch.utils.data as Data
! tree datasets/

datasets/
├── test
│   ├── 0
│   │   ├── 0.jpg
│   │   └── 1.jpg
│   └── 1
│       ├── 0.jpg
│       └── 1.jpg
├── test.list
├── train
│   ├── 0
│   │   ├── 0.jpg
│   │   └── 1.jpg
│   └── 1
│       ├── 0.jpg
│       └── 1.jpg
├── train.list
├── valid
│   ├── 0
│   │   ├── 0.jpg
│   │   └── 1.jpg
│   └── 1
│       ├── 0.jpg
│       └── 1.jpg
└── valid.list

9 directories, 15 files


In [18]:
train_normalized =  T.Compose([
                    T.RandomCrop(224),
                    T.RandomHorizontalFlip(),
                    T.Resize(256),
                    T.ToTensor(),
                    T.Normalize((0,5,0.5,0.5),(0,5,0.5,0.5))
                ])
test_normalized =  T.Compose([
                    T.CenterCrop(224),
                    T.Resize(256),
                    T.ToTensor(),
                    T.Normalize((0,5,0.5,0.5),(0,5,0.5,0.5))
                ])
train_dataset = MyDatasets('./datasets/train.list',train_normalized)
valid_dataset = MyDatasets('./datasets/valid.list',train_normalized)
test_dataset = MyDatasets('./datasets/test.list',test_normalized)

In [19]:
train_loader = Data.DataLoader(train_dataset,batch_size=4096,shuffle=True,num_workers=2)
valid_loader = Data.DataLoader(valid_dataset,batch_size=1,shuffle=False,num_workers=1)
test_loader = Data.DataLoader(test_dataset,batch_size=1,shuffle=False,num_workers=1)

In [20]:
#其他关于样本扩增的操作，可以放在init中进行。或者在准备样本的时候进行扩增。

In [22]:
### cal mean and std
train_txt_path = './datasets/train.list'
def cal_meanStd(train_txt_path):
    CNum = 2000     # 挑选多少图片进行计算
    img_h, img_w = 32, 32
    imgs = np.zeros([img_w, img_h, 3, 1])
    means, stdevs = [], []

    with open(train_txt_path, 'r') as f:
        lines = f.readlines()
        random.shuffle(lines)   # shuffle , 随机挑选图片
        for i in range(CNum):
            img_path = lines[i].rstrip().split()[0]
            img = cv2.imread(img_path)
            img = cv2.resize(img, (img_h, img_w))
            img = img[:, :, :, np.newaxis]
            imgs = np.concatenate((imgs, img), axis=3)
            print(i)
    imgs = imgs.astype(np.float32)/255.
    for i in range(3):
        pixels = imgs[:,:,i,:].ravel()  # 拉成一行
        means.append(np.mean(pixels))
        stdevs.append(np.std(pixels))
    means.reverse() # BGR --> RGB
    stdevs.reverse()
    print("normMean = {}".format(means))
    print("normStd = {}".format(stdevs))
    print('transforms.Normalize(normMean = {}, normStd = {})'.format(means, stdevs))
