In [5]:
import torchvision
import torchvision.transforms as transforms

def download_cifar10(savepath):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    train = torchvision.datasets.CIFAR10(root=savepath, train=True, download=True, transform=transform)
    test = torchvision.datasets.CIFAR10(root=savepath, train=False, download=True, transform=transform)
    return train, test

In [51]:
from collections import defaultdict

class DataPreprocessor:
    
    def __init__(self, train, test):
        self.train = train
        self.test = test
    
    def out_dataset(self):
        return self.train, self.test
    
    def show_length(self):
        print("\n------------------------------------------------------------")
        print("Number of train data:{0:>9}".format(len(self.train)))
        print("Number of test data:{0:>9}".format(len(self.test)))
        
    def show_labels(self):
        texts = ["Number of train data", "Number of test data"]
        for i, dataset in enumerate([self.train, self.test]):
            label_count = defaultdict(int)
            for data in dataset:
                label_count[data[1]] += 1
            print("\n{0}  ----------------------------------------------".format(texts[i]))
            label_count = sorted(label_count.items())
            sum = 0
            for label, count in label_count:
                print("label:  {0}    count:  {1}".format(label, count))
                sum += count
            print("total:  {0}".format(sum))
            
    # labelsに含まれるラベルのデータを選択
    def select_by_label(self, labels=[]):
        self.train = [data for data in self.train if data[1] in labels]
        self.test = [data for data in self.test if data[1] in labels]
    
    # 正解ラベルを0からの連番に更新
    def update_labels(self):
        updated = [[], []]
        label_mapping = defaultdict(lambda: -1)
        for i, dataset in enumerate([self.train, self.test]):
            dataset = sorted(dataset, key=lambda x:x[1])
            new_label = 0
            for data in dataset:
                if label_mapping[data[1]] == -1:
                    label_mapping[data[1]] = new_label
                    new_label += 1
                updated[i].append((data[0], label_mapping[data[1]]))
        self.train, self.test = updated
        print("\nChanged the label.  ----------------------------------------------")
        for old, new in label_mapping.items():
            print("label:  {0} -> {1}".format(old, new))

"""
    # dataN(1クラスのデータ数)分のデータをランダムに選択
    def select_by_dataN(self, dataN, train=True):
        selected = []
        label_count = defaultdict(int)
        dataset = self.train if train else self.test
        for data in dataset:
            if label_count[data[1]] < dataN:
                selected.append(data)
                label_count[data[1]] += 1
        if train : self.train = selected
        if not train: self.test = selected
"""

In [48]:
savepath = "../../prototype/data"
train, test = download_cifar10(savepath)

Files already downloaded and verified
Files already downloaded and verified


In [49]:
data_preprocessor = DataPreprocessor(train, test)
data_preprocessor.select_by_label(labels=[1,2,8])
#data_preprocessor.select_by_dataN(dataN=50, train=True)
#data_preprocessor.select_by_dataN(dataN=30, train=False)
data_preprocessor.update_labels()
data_preprocessor.show_labels()
train, test = data_preprocessor.out_dataset()


Changed the label.  ----------------------------------------------
label:  1 -> 0
label:  2 -> 1
label:  8 -> 2

Number of train data  ----------------------------------------------
label:  0    count:  50
label:  1    count:  50
label:  2    count:  50
total:  150

Number of test data  ----------------------------------------------
label:  0    count:  30
label:  1    count:  30
label:  2    count:  30
total:  90
