In [132]:
import torch 


In [133]:
from torch.utils.data import Dataset

In [141]:
class MyDataset(Dataset):
		def __init__(self,data_tensor,target_tensor):
					self.data_tensor = data_tensor
					self.target_tensor = target_tensor
    
		def __len__(self):
    		 	return self.data_tensor.size(0)
    
		def __getitem__(self,index):
    			return self.data_tensor[index], self.target_tensor[index]
		

In [145]:
data_tensor = torch.randn(10,3)
target_tensor = torch.randint(2,(10,))

In [148]:
my_dataset = MyDataset(data_tensor,target_tensor)
print('dataset size :',len(my_dataset))
print('tensor_data[0]:',my_dataset[0])

dataset size : 10
tensor_data[0]: (tensor([ 0.0957,  0.9271, -0.0389]), tensor(1))


# DataLoader 类
在实际项目中，如果数据量很大，考虑到内存有限、I/O 速度等问题，在训练过程中不可能一次性的将所有数据全部加载到内存中，也不能只用一个进程去加载，所以就需要多进程、迭代加载，而 DataLoader 就是基于这些需要被设计出来的。DataLoader 是一个迭代器，最基本的使用方法就是传入一个 Dataset 对象，它会根据参数 batch_size 的值生成一个 batch 的数据，节省内存的同时，它还可以实现多进程、数据打乱等处理。

In [149]:
from torch.utils.data import DataLoader
tensor_dataloader = DataLoader(dataset=my_dataset,
                               batch_size=2,
                               shuffle=True,
                               num_workers=0
															 )

for data,target in tensor_dataloader:
    print(data,target)

tensor([[-0.9752, -1.4995,  0.1578],
        [ 0.0957,  0.9271, -0.0389]]) tensor([0, 1])
tensor([[ 2.3736,  0.3833,  2.3885],
        [-0.2964,  1.1723, -0.4042]]) tensor([1, 0])
tensor([[ 0.9604, -1.1306, -0.7801],
        [-0.9741,  0.1602, -1.8740]]) tensor([1, 1])
tensor([[ 1.5239,  1.5347, -0.1928],
        [ 1.1754,  2.2938,  0.4071]]) tensor([0, 1])
tensor([[ 1.2060, -0.5984,  0.1999],
        [-1.1016,  0.9386, -0.4546]]) tensor([1, 0])


# 利用 Torchvision 读取数据

torchvision.datasets这个包本身并不包含数据集的文件本身，它的工作方式是先从网络上把数据集下载到用户指定目录，然后再用它的加载器把数据集加载到内存中。最后，把这个加载后的数据集作为对象返回给用户。

In [2]:
# 以MNIST为例
import torchvision
mnist_dataset = torchvision.datasets.MNIST(root='./data',
                                       train=True,
                                       transform=None,
                                       target_transform=None,
                                       download=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100.0%


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100.0%


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100.0%


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100.0%

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw




