In [86]:
import os
import torch
import argparse
import numpy as np
import pandas as pd
from torch import nn
from torchvision import transforms as T

In [88]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms as T


class CUBDataset(Dataset):
    def __init__(self, dataset_dir: str, split: str='train', transforms=None) -> None:
        super().__init__()
        if dataset_dir.endswith('CUB_200_2011'):
            self.dataset_dir = dataset_dir
        else:
            self.dataset_dir = os.path.join(dataset_dir, 'CUB_200_2011')

        with open(os.path.join(self.dataset_dir, 'classes.txt')) as fp:
            lines = fp.read().split('\n')
            self.id2label = [l.split(' ')[1] for l in lines if len(l) > 0]  # ['001.Black_footed_Albatross', '002.Laysan_Albatross', ...]
        
        filename_df = pd.read_csv(os.path.join(self.dataset_dir, 'images.txt'),
                                  delimiter=' ',index_col=0, names=['filename'])
        split_df = pd.read_csv(os.path.join(self.dataset_dir, 'train_test_split.txt'),
                             delimiter=' ', index_col=0, names=['is_train'])
        label_df = pd.read_csv(os.path.join(self.dataset_dir, 'image_class_labels.txt'),
                             delimiter=' ', index_col=0, names=['label'])

        joined = filename_df.join(split_df).join(label_df)
        joined['label'] = joined['label'] - 1
        if split == 'train':
            split_idxs = joined['is_train'] == 1
        else:
            split_idxs = joined['is_train'] == 0
        self.ann = joined[split_idxs].drop(columns=['is_train']).reset_index(drop=True)

        if transforms:
            self.transforms = transforms
        else:
            self.transforms = T.ToTensor()

    @property
    def num_lables(self):
        return len(self.id2labels)

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, idx):
        fn, label = self.ann.iloc[idx]
        img = Image.open(os.path.join(self.dataset_dir, 'images', fn))
        img = self.transforms(img)
        return img, torch.tensor(label)

In [89]:
cub = CUBDataset('data', 'train')

In [90]:
cub[0]

(tensor([[[0.0000, 0.0000, 0.0157,  ..., 0.0157, 0.0039, 0.0078],
          [0.0039, 0.0039, 0.0196,  ..., 0.0000, 0.0000, 0.0000],
          [0.0196, 0.0000, 0.2941,  ..., 0.2941, 0.0039, 0.0078],
          ...,
          [0.0157, 0.0000, 0.2980,  ..., 0.2902, 0.0000, 0.0000],
          [0.0000, 0.0078, 0.0039,  ..., 0.0000, 0.0118, 0.0000],
          [0.0000, 0.0039, 0.0078,  ..., 0.0078, 0.0039, 0.0000]],
 
         [[0.0000, 0.0000, 0.0157,  ..., 0.0157, 0.0039, 0.0078],
          [0.0039, 0.0039, 0.0196,  ..., 0.0000, 0.0000, 0.0000],
          [0.0196, 0.0000, 0.2941,  ..., 0.2941, 0.0039, 0.0078],
          ...,
          [0.0157, 0.0000, 0.2980,  ..., 0.2902, 0.0000, 0.0000],
          [0.0000, 0.0039, 0.0000,  ..., 0.0000, 0.0118, 0.0000],
          [0.0000, 0.0039, 0.0078,  ..., 0.0078, 0.0039, 0.0000]],
 
         [[0.0000, 0.0000, 0.0157,  ..., 0.0157, 0.0039, 0.0078],
          [0.0039, 0.0039, 0.0196,  ..., 0.0000, 0.0000, 0.0000],
          [0.0118, 0.0000, 0.2863,  ...,

In [82]:
cub.num_lables

200

In [62]:
images = pd.read_csv('data/CUB_200_2011/images.txt', delimiter=' ', index_col=0, names=['filename'])
train_test_split = pd.read_csv('data/CUB_200_2011/train_test_split.txt', delimiter=' ', index_col=0, names=['is_train'])
labels = pd.read_csv('data/CUB_200_2011/image_class_labels.txt', delimiter=' ', index_col=0, names=['label'])

In [68]:
joined = images.join(train_test_split).join(labels)
joined['label'] = joined['label'] - 1
train_split = joined[joined['is_train'] == 1].drop(columns=['is_train']).reset_index(drop=True)
test_split = joined[joined['is_train'] == 0].drop(columns=['is_train']).reset_index(drop=True)
train_split

Unnamed: 0,filename,label
0,001.Black_footed_Albatross/Black_Footed_Albatr...,0
1,001.Black_footed_Albatross/Black_Footed_Albatr...,0
2,001.Black_footed_Albatross/Black_Footed_Albatr...,0
3,001.Black_footed_Albatross/Black_Footed_Albatr...,0
4,001.Black_footed_Albatross/Black_Footed_Albatr...,0
...,...,...
5989,200.Common_Yellowthroat/Common_Yellowthroat_00...,199
5990,200.Common_Yellowthroat/Common_Yellowthroat_00...,199
5991,200.Common_Yellowthroat/Common_Yellowthroat_00...,199
5992,200.Common_Yellowthroat/Common_Yellowthroat_00...,199


In [69]:
fn, label = test_split.iloc[0]
fn, label

('001.Black_footed_Albatross/Black_Footed_Albatross_0046_18.jpg', 0)

In [48]:
labels = joined['filename'].str.split('/').str[0].unique().tolist()
labels

['001.Black_footed_Albatross',
 '002.Laysan_Albatross',
 '003.Sooty_Albatross',
 '004.Groove_billed_Ani',
 '005.Crested_Auklet',
 '006.Least_Auklet',
 '007.Parakeet_Auklet',
 '008.Rhinoceros_Auklet',
 '009.Brewer_Blackbird',
 '010.Red_winged_Blackbird',
 '011.Rusty_Blackbird',
 '012.Yellow_headed_Blackbird',
 '013.Bobolink',
 '014.Indigo_Bunting',
 '015.Lazuli_Bunting',
 '016.Painted_Bunting',
 '017.Cardinal',
 '018.Spotted_Catbird',
 '019.Gray_Catbird',
 '020.Yellow_breasted_Chat',
 '021.Eastern_Towhee',
 '022.Chuck_will_Widow',
 '023.Brandt_Cormorant',
 '024.Red_faced_Cormorant',
 '025.Pelagic_Cormorant',
 '026.Bronzed_Cowbird',
 '027.Shiny_Cowbird',
 '028.Brown_Creeper',
 '029.American_Crow',
 '030.Fish_Crow',
 '031.Black_billed_Cuckoo',
 '032.Mangrove_Cuckoo',
 '033.Yellow_billed_Cuckoo',
 '034.Gray_crowned_Rosy_Finch',
 '035.Purple_Finch',
 '036.Northern_Flicker',
 '037.Acadian_Flycatcher',
 '038.Great_Crested_Flycatcher',
 '039.Least_Flycatcher',
 '040.Olive_sided_Flycatcher',
 '

In [36]:
pd.DataFrame(joined[['filename', 'labels']])

Unnamed: 0,filename,labels
1,001.Black_footed_Albatross/Black_Footed_Albatr...,0
2,001.Black_footed_Albatross/Black_Footed_Albatr...,0
3,001.Black_footed_Albatross/Black_Footed_Albatr...,0
4,001.Black_footed_Albatross/Black_Footed_Albatr...,0
5,001.Black_footed_Albatross/Black_Footed_Albatr...,0
...,...,...
11784,200.Common_Yellowthroat/Common_Yellowthroat_00...,199
11785,200.Common_Yellowthroat/Common_Yellowthroat_00...,199
11786,200.Common_Yellowthroat/Common_Yellowthroat_00...,199
11787,200.Common_Yellowthroat/Common_Yellowthroat_00...,199


In [26]:
labels = joined['filename'].unique()
labels

array(['001.Black_footed_Albatross/Black_Footed_Albatross_0046_18.jpg',
       '001.Black_footed_Albatross/Black_Footed_Albatross_0009_34.jpg',
       '001.Black_footed_Albatross/Black_Footed_Albatross_0002_55.jpg',
       ..., '200.Common_Yellowthroat/Common_Yellowthroat_0008_190703.jpg',
       '200.Common_Yellowthroat/Common_Yellowthroat_0049_190708.jpg',
       '200.Common_Yellowthroat/Common_Yellowthroat_0055_190967.jpg'],
      dtype=object)

In [24]:
test_split = joined[joined['is_train'] == 0]
test_split

Unnamed: 0,filename,is_train,labels
1,001.Black_footed_Albatross/Black_Footed_Albatr...,0,1
3,001.Black_footed_Albatross/Black_Footed_Albatr...,0,1
6,001.Black_footed_Albatross/Black_Footed_Albatr...,0,1
10,001.Black_footed_Albatross/Black_Footed_Albatr...,0,1
12,001.Black_footed_Albatross/Black_Footed_Albatr...,0,1
...,...,...,...
11780,200.Common_Yellowthroat/Common_Yellowthroat_00...,0,200
11783,200.Common_Yellowthroat/Common_Yellowthroat_00...,0,200
11785,200.Common_Yellowthroat/Common_Yellowthroat_00...,0,200
11786,200.Common_Yellowthroat/Common_Yellowthroat_00...,0,200


In [20]:
images

Unnamed: 0,filename
1,001.Black_footed_Albatross/Black_Footed_Albatr...
2,001.Black_footed_Albatross/Black_Footed_Albatr...
3,001.Black_footed_Albatross/Black_Footed_Albatr...
4,001.Black_footed_Albatross/Black_Footed_Albatr...
5,001.Black_footed_Albatross/Black_Footed_Albatr...
...,...
11784,200.Common_Yellowthroat/Common_Yellowthroat_00...
11785,200.Common_Yellowthroat/Common_Yellowthroat_00...
11786,200.Common_Yellowthroat/Common_Yellowthroat_00...
11787,200.Common_Yellowthroat/Common_Yellowthroat_00...


In [19]:
labels

Unnamed: 0,labels
1,1
2,1
3,1
4,1
5,1
...,...
11784,200
11785,200
11786,200
11787,200


In [21]:
train_test_split

Unnamed: 0,is_train
1,0
2,1
3,0
4,1
5,1
...,...
11784,1
11785,0
11786,0
11787,1
