In [1]:
%matplotlib inline
import os
import pandas as pd
import torch
import torchvision
from d2l import torch as d2l

#@save
d2l.DATA_HUB['banana-detection'] = (
    d2l.DATA_URL + 'banana-detection.zip',
    '5de26c8fce5ccdea9f91267273464dc968d20d72')

In [2]:
#@save
def read_data_bananas(is_train=True):
    """读取香蕉检测数据集中的图像和标签"""
    data_dir = d2l.download_extract('banana-detection')
    csv_fname = os.path.join(data_dir, 'bananas_train' if is_train
                             else 'bananas_val', 'label.csv')
    csv_data = pd.read_csv(csv_fname)
    csv_data = csv_data.set_index('img_name')
    images, targets = [], []
    for img_name, target in csv_data.iterrows():
        images.append(torchvision.io.read_image(
            os.path.join(data_dir, 'bananas_train' if is_train else
                         'bananas_val', 'images', f'{img_name}')))
        # 这里的target包含（类别，左上角x，左上角y，右下角x，右下角y），
        # 其中所有图像都具有相同的香蕉类（索引为0）
        targets.append(list(target))
    return images, torch.tensor(targets).unsqueeze(1) / 256

In [6]:
#@save
class BananasDataset(torch.utils.data.Dataset):
    """一个用于加载香蕉检测数据集的自定义数据集"""
    def __init__(self, is_train):
        self.features, self.labels = read_data_bananas(is_train)
        print('read ' + str(len(self.features)) + (f' training examples' if
              is_train else f' validation examples'))

    def __getitem__(self, idx):
        return (self.features[idx].float(), self.labels[idx])

    def __len__(self):
        return len(self.features)

In [3]:
#@save
def load_data_bananas(batch_size):
    """加载香蕉检测数据集"""
    train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),
                                             batch_size, shuffle=True)
    val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),
                                           batch_size)
    return train_iter, val_iter

In [7]:
batch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)
batch = next(iter(train_iter))
batch[0].shape, batch[1].shape

正在从http://d2l-data.s3-accelerate.amazonaws.com/banana-detection.zip下载..\data\banana-detection.zip...
read 1000 training examples
read 100 validation examples


(torch.Size([32, 3, 256, 256]), torch.Size([32, 1, 5]))

In [8]:
import pandas as pd
lst=[{"clm0":1,"clm1":2,"clm2":3},
     {"clm0":4,"clm1":5,"clm2":6},
     {"clm0":7,"clm1":8,"clm2":9}]
df=pd.DataFrame(lst)
print(df) 

   clm0  clm1  clm2
0     1     2     3
1     4     5     6
2     7     8     9


In [10]:
a = torch.randn((3,4))
a,a.unsqueeze(1),a.shape,a.unsqueeze(1).shape

(tensor([[ 0.0188,  0.1810,  1.5003, -0.3060],
         [ 1.3744,  2.1947,  1.0230,  0.8207],
         [ 1.6068, -0.8535,  0.5080,  0.5537]]),
 tensor([[[ 0.0188,  0.1810,  1.5003, -0.3060]],
 
         [[ 1.3744,  2.1947,  1.0230,  0.8207]],
 
         [[ 1.6068, -0.8535,  0.5080,  0.5537]]]),
 torch.Size([3, 4]),
 torch.Size([3, 1, 4]))

In [11]:
torch.zeros(5)

tensor([0., 0., 0., 0., 0.])