In [57]:
import os
import torch
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T


class CUBDataset(Dataset):
    def __init__(self, dataset_dir, split='train', transforms=None) -> None:
        super().__init__()
        self.split = split
        self.transforms = transforms
        self.dataset_dir = dataset_dir
        file_path_df = pd.read_csv(os.path.join(dataset_dir, 'CUB_200_2011', 'images.txt'), sep=' ',
                                   header=None, names=['img_id', 'file_path'])
        img_class_df = pd.read_csv(os.path.join(dataset_dir, 'CUB_200_2011', 'image_class_labels.txt'),
                                   sep=' ', header=None, names=['img_id', 'class_id'])
        train_test_split_df = pd.read_csv(os.path.join(dataset_dir, 'CUB_200_2011', 'train_test_split.txt'),
                                          sep=' ', header=None, names=['img_id', 'is_train'])
        merged_df = file_path_df.merge(img_class_df, on='img_id').merge(train_test_split_df, on='img_id')
        
        # Make class_id 0-indexed
        merged_df['class_id'] = merged_df['class_id'] - 1
        self.class_id2name = {}
        for line in open('datasets/CUB_200_2011/classes.txt'):
            [class_id, class_name] = line.strip().split(' ')
            self.class_id2name[int(class_id) - 1] = class_name

        train_df = merged_df[merged_df['is_train'] == 1].drop(columns=['is_train']).reset_index(drop=True)
        test_df = merged_df[merged_df['is_train'] == 0].drop(columns=['is_train']).reset_index(drop=True)

        self.annotations = {'train': train_df, 'test': test_df}

    def __len__(self):
        return len(self.annotations[self.split])

    def __getitem__(self, idx):
        img_id, file_path, class_id = self.annotations[self.split].iloc[idx]
        image = Image.open(os.path.join(self.dataset_dir, 'CUB_200_2011', 'images', file_path))
        if self.transforms is not None:
            image = self.transforms(image)
        return torch.tensor(img_id), image, torch.tensor(class_id)

In [58]:
# Train
augs_train = T.Compose([
    T.Resize((256, 256), Image.BILINEAR),
    T.RandomCrop((224, 224)),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Test
augs_test = T.Compose([
    T.Resize((256, 256), Image.BILINEAR),
    T.CenterCrop((224, 224)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [63]:
dataset_train = CUBDataset('datasets', 'train', augs_train)
dataloader_train = DataLoader(dataset_train, 64, shuffle=True)

In [65]:
img_ids, batch_imgs, batch_labels = next(iter(dataloader_train))

In [66]:
batch_imgs.size()

torch.Size([64, 3, 224, 224])

In [31]:
file_path_df = pd.read_csv('datasets/CUB_200_2011/images.txt', sep=' ', header=None, names=['img_id', 'file_path'])
img_class_df = pd.read_csv('datasets/CUB_200_2011/image_class_labels.txt', sep=' ', header=None, names=['img_id', 'class_id'])
train_test_split_df = pd.read_csv('datasets/CUB_200_2011/train_test_split.txt', sep=' ', header=None, names=['img_id', 'is_train'])
train_test_split_df

Unnamed: 0,img_id,is_train
0,1,0
1,2,1
2,3,0
3,4,1
4,5,1
...,...,...
11783,11784,1
11784,11785,0
11785,11786,0
11786,11787,1


In [42]:
merged_df = file_path_df.merge(img_class_df, on='img_id').merge(train_test_split_df, on='img_id')
merged_df['class_id'] = merged_df['class_id'] - 1
train_df = merged_df[merged_df['is_train'] == 1].drop(columns=['is_train']).reset_index(drop=True)
test_df = merged_df[merged_df['is_train'] == 0].drop(columns=['is_train']).reset_index(drop=True)
test_df

Unnamed: 0,img_id,file_path,class_id
0,1,001.Black_footed_Albatross/Black_Footed_Albatr...,0
1,3,001.Black_footed_Albatross/Black_Footed_Albatr...,0
2,6,001.Black_footed_Albatross/Black_Footed_Albatr...,0
3,10,001.Black_footed_Albatross/Black_Footed_Albatr...,0
4,12,001.Black_footed_Albatross/Black_Footed_Albatr...,0
...,...,...,...
5789,11780,200.Common_Yellowthroat/Common_Yellowthroat_00...,199
5790,11783,200.Common_Yellowthroat/Common_Yellowthroat_00...,199
5791,11785,200.Common_Yellowthroat/Common_Yellowthroat_00...,199
5792,11786,200.Common_Yellowthroat/Common_Yellowthroat_00...,199


In [48]:
len(train_df)

5994

In [45]:
img_id, file_path, class_id = train_df.iloc[100]

In [46]:
img_id, file_path, class_id

(202, '004.Groove_billed_Ani/Groove_Billed_Ani_0051_1650.jpg', 3)