In [146]:
import torch
import math
import random
import os
import numpy as np
import PIL
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader

In [2]:
from torchvision.datasets import ImageFolder

In [118]:
root = './data/CUB_200_2011' # 'Stanford_Online_Products'
#root = 'Stanford_Online_Products'

In [122]:
dataset = ImageFolder(root + '/images')
len(dataset)

11788

In [104]:
cls_to_paths = defaultdict(list)

In [105]:
for img, label in dataset.imgs:
    cls_to_paths[label].append(img)

In [106]:
sum(len(cls_to_paths[i]) for i in range(len(cls_to_paths))) / len(cls_to_paths)

58.94

In [108]:
for label, paths in cls_to_paths.items():
    print(label, len(paths))
    test = paths
    break

0 60


In [110]:
labeled_fraction = 0.05
num_labeled = math.ceil(labeled_fraction * len(test)) # 3

random.shuffle(test)
labeled, unlabeled = test[:num_labeled], test[num_labeled:]
len(labeled), len(unlabeled), labeled[0]

(3,
 57,
 './data/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0008_796083.jpg')

In [152]:
class BaseDataset(Dataset):
    def __init__(self, root, labels, labeled_fraction):        
        self.labels = labels
        self.transform = None
        
        cls_to_paths = defaultdict(list)
        imgfolder = ImageFolder(root=os.path.join(root, 'images'))
        
        for img, label in imgfolder.imgs:
            if label in labels:
                cls_to_paths[label].append(img)
        
        self.img_paths, self.targets = [], []
        # self.img_paths_unlabeled, self.targets_unlabeled = [], []

        for label, paths in cls_to_paths.items():
            num_labeled = math.ceil(labeled_fraction * len(paths))
            random.shuffle(paths)
            
            labeled, unlabeled = paths[:num_labeled], paths[num_labeled:]
            self.img_paths += labeled
            self.targets += [label] * num_labeled
            # self.img_paths_unlabeled += unlabeled
            # self.targets_unlabeled += [label] * (len(paths) - num_labeled)
    
    
    def num_classes(self):
        n = len(np.unique(self.targets))
        assert n == len(self.labels)
        return n

    
    def __len__(self):
        return len(self.targets)

    
    def __getitem__(self, index):
        img = self.pil_loader(self.img_paths[index])
        target = self.targets[index]
        
        if self.transform:
            img = self.transform(img)
       
        return img, target
    
    
    def pil_loader(self, path):
        with open(path, 'rb') as f:
            img = PIL.Image.open(f)
            return img.convert('RGB')



In [156]:
dataset = BaseDataset(root, range(100), 0.05)

In [157]:
len(dataset)

300

In [160]:
counts = defaultdict(int)
for x, y in dataset:
    counts[y] += 1
#counts

In [159]:
dataset.num_classes()

100