In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import os

In [2]:

class CustomImageDataset(Dataset):
    def __init__(self, csv_file, imd_dir, is_train=True, transform = None, label_map = None):
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.label_map = label_map
        self.imd_dir = imd_dir
        self.is_train = is_train

        if label_map is None and self.is_train:
            self._create_label_map()

    def _create_label_map(self):
        unique_labels = sorted(self.data.iloc[:,1].unique())
        # self.label_map = dict(zip(unique_labels, range(len(unique_labels))))
        self.label_map = {label:idx for idx,label in enumerate(unique_labels)}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_rel_path = self.data.iloc[idx,0]
        img_path = os.path.join(self.imd_dir, img_rel_path)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        if self.is_train:
            label_text = self.data.iloc[idx,1]
            label_idx = self.label_map[label_text]
            return image, torch.tensor(label_idx, dtype=torch.long)
        else:
            return image


#### 数据预处理

In [3]:
train_transform = transforms.Compose([transforms.ToTensor()])
test_transform = transforms.Compose([transforms.ToTensor()])

In [4]:
train_csv = '../data/classify-leaves/train.csv'
test_csv = '../data/classify-leaves/test.csv'
img_root_dir = '../data/classify-leaves/images'

train_dataset = CustomImageDataset(train_csv, img_root_dir, is_train=True, transform=train_transform)
label_map = train_dataset.label_map
print(f"Using label mapping: {label_map}")

test_dataset = CustomImageDataset(test_csv, img_root_dir, is_train=False, transform=test_transform)

train_iter = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=10, pin_memory=True)
test_iter = DataLoader(test_dataset, batch_size=32, shuffle=True, num_workers=10, pin_memory=True)

# 验证数据加载
print("\nTesting data loaders:")
print(f"\nTesting data loaders:{len(train_iter.dataset)}")
# for images, labels in train_iter:
#     print(f"Train batch - Images shape: {images.shape}, Labels: {labels}")
#     break
#
# for images, labels in test_iter:
#     print(f"Test batch - Images shape: {images.shape}, Labels: {labels}")
#     break

Using label mapping: {'abies_concolor': 0, 'abies_nordmanniana': 1, 'acer_campestre': 2, 'acer_ginnala': 3, 'acer_griseum': 4, 'acer_negundo': 5, 'acer_palmatum': 6, 'acer_pensylvanicum': 7, 'acer_platanoides': 8, 'acer_pseudoplatanus': 9, 'acer_rubrum': 10, 'acer_saccharinum': 11, 'acer_saccharum': 12, 'aesculus_flava': 13, 'aesculus_glabra': 14, 'aesculus_hippocastamon': 15, 'aesculus_pavi': 16, 'ailanthus_altissima': 17, 'albizia_julibrissin': 18, 'amelanchier_arborea': 19, 'amelanchier_canadensis': 20, 'amelanchier_laevis': 21, 'asimina_triloba': 22, 'betula_alleghaniensis': 23, 'betula_jacqemontii': 24, 'betula_lenta': 25, 'betula_nigra': 26, 'betula_populifolia': 27, 'broussonettia_papyrifera': 28, 'carpinus_betulus': 29, 'carpinus_caroliniana': 30, 'carya_cordiformis': 31, 'carya_glabra': 32, 'carya_ovata': 33, 'carya_tomentosa': 34, 'castanea_dentata': 35, 'catalpa_bignonioides': 36, 'catalpa_speciosa': 37, 'cedrus_atlantica': 38, 'cedrus_deodara': 39, 'cedrus_libani': 40, 'cel