## 数据处理技巧  

In [1]:
import numpy as np
import torch
from torchvision import transforms
from PIL import Image

### 图像数据的置换操作（permute）

In [11]:
## 随机生成一张图像 【HWC形式】
img_n = np.random.randint(0, 255, (32,32,3))
img_perm_bynp = np.transpose(img_n, (2,0,1))
print(img_n.shape)
print(img_perm_bynp.shape)

(32, 32, 3)
(3, 32, 32)


In [14]:
## 采用torch的方法
img_t = torch.from_numpy(img_n)
print(img_t.shape)
img_perm_bytorch = img_t.permute(2,0,1)
print(img_perm_bytorch.shape)

torch.Size([32, 32, 3])
torch.Size([3, 32, 32])


### 图像数据的重组操作（reshape）

In [15]:
## 随机生成一张图像
raw_data = np.random.randint(0, 255, (3072))
## numpy采用reshape函数
img_n = np.reshape(raw_data, (32,32,3))
print(img_n.shape)
## pytorch采用view函数
raw_data = torch.from_numpy(raw_data)
img_t = raw_data.view(32,32,3)
print(img_t.shape)

(32, 32, 3)
torch.Size([32, 32, 3])


### 图像数据的统一预处理（torchvision.transforms）

In [3]:
transfer = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

### 图像数据类型转换

In [None]:
img = np.random.randint(0,255,(224,224,3))
## ndarray --> PILImage
img = Image.fromarray(np.uint8(img))
## PILImage --> ndarray
img = np.asarray(img)
## torch --> ndarray
img = img.numpy()
## ndarray --> torch
img = img.from_numpy(img)
## torch --> PILImage
img = transforms.ToPILImage(img)
## PILImage --> torch
img = transforms.ToTensor(img)
## opencv --> numpy
img = np.asarray(img)
## numpy --> opencv
img = cv2.fromnumpy(img)

### 统一批数据加载接口

In [None]:
from torch.utils.data import Dataset, DataLoader
from glob import glob
import os.path as osp

class MyData(Dataset):
    def __ini__(self, data_dir, transfer=None):
        self.data_paths = glob(osp.join(data_dir, '*.jpg'))
        self.transfer = transfer
    
    def __len__(self):
        return len(self.data_paths)
    
    def __getitem__(self, index):
        data = Image.open(self.data_paths[index]).convert('RGB')
        label = int(data_paths)
        if self.transfer:
            data = transfer(data)
        return (data, label)

def get_loader(batch_size=32, shuffle=True):
    data_dir = "xxxx"
    data = MyData(data_dir, transfer = transfer)
    return DataLoader(data, batch_size=batch_size, shuffle=shuffle)