In [46]:
import torchvision
import torchvision.transforms as transforms
from collections import defaultdict
import pandas as pd
import json
import sqlite3

def download_cifar10(savepath):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    train = torchvision.datasets.CIFAR10(root=savepath, train=True, download=True, transform=transform)
    test = torchvision.datasets.CIFAR10(root=savepath, train=False, download=True, transform=transform)
    return train, test

def make_data_table(dataset):
    table = defaultdict(list)
    for image, label in dataset:
        table["image"].append(json.dumps(image.cpu().numpy().tolist()))
        table["label"].append(int(label))
    data_table = pd.DataFrame(table)
    return data_table

# DataFrame形式のdata_tableをsqliteに保存
def save_data_table(data_table, savepath, tablename="data_table"):
    conn = sqlite3.connect(savepath)
    c = conn.cursor()
    data_table.to_sql(tablename, conn, if_exists='replace')
    conn.close()
    
# data_table をDataFrame形式で読み込み
def load_data_table(dbpath, tablename="data_table"):
    conn=sqlite3.connect(dbpath)
    c = conn.cursor()
    data_table = pd.read_sql('SELECT * FROM ' + tablename, conn)
    return data_table

In [6]:
from collections import defaultdict

class DataPreprocessor:
    
    def __init__(self, train, test):
        self.train = train
        self.test = test
    
    def out_dataset(self):
        return self.train, self.test
    
    def show_length(self):
        print("\n------------------------------------------------------------")
        print("Number of train data:{0:>9}".format(len(self.train)))
        print("Number of test data:{0:>9}".format(len(self.test)))
        
    def show_labels(self):
        texts = ["Number of train data", "Number of test data"]
        for i, dataset in enumerate([self.train, self.test]):
            label_count = defaultdict(int)
            for data in dataset:
                label_count[data[1]] += 1
            print("\n{0}  ----------------------------------------------".format(texts[i]))
            label_count = sorted(label_count.items())
            sum = 0
            for label, count in label_count:
                print("label:  {0}    count:  {1}".format(label, count))
                sum += count
            print("total:  {0}".format(sum))
            
    # labelsに含まれるラベルのデータを選択
    def select_by_label(self, labels=[]):
        self.train = [data for data in self.train if data[1] in labels]
        self.test = [data for data in self.test if data[1] in labels]
    
    # 正解ラベルを0からの連番に更新
    def update_labels(self):
        updated = [[], []]
        label_mapping = defaultdict(lambda: -1)
        for i, dataset in enumerate([self.train, self.test]):
            dataset = sorted(dataset, key=lambda x:x[1])
            new_label = 0
            for data in dataset:
                if label_mapping[data[1]] == -1:
                    label_mapping[data[1]] = new_label
                    new_label += 1
                updated[i].append((data[0], label_mapping[data[1]]))
        self.train, self.test = updated
        print("\nChanged the label.  ----------------------------------------------")
        for old, new in label_mapping.items():
            print("label:  {0} -> {1}".format(old, new))

In [42]:
savepath = "../../../prototype/conventional/data"
train, test = download_cifar10(savepath)

Files already downloaded and verified
Files already downloaded and verified


In [43]:
preprocessor = DataPreprocessor(train, test)
preprocessor.select_by_label(labels=[1,2,8])
preprocessor.update_labels()
preprocessor.show_labels()
train, test = preprocessor.out_dataset()


Changed the label.  ----------------------------------------------
label:  1 -> 0
label:  2 -> 1
label:  8 -> 2

Number of train data  ----------------------------------------------
label:  0    count:  5000
label:  1    count:  5000
label:  2    count:  5000
total:  15000

Number of test data  ----------------------------------------------
label:  0    count:  1000
label:  1    count:  1000
label:  2    count:  1000
total:  3000


In [47]:
data_table = make_data_table(train)
savepath = "./assets/data_v1.db"
save_data_table(data_table, savepath)

In [48]:
dbpath = "./assets/data_v1.db"
data_table2 = load_data_table(dbpath)
data_table2

Unnamed: 0,index,image,label
0,0,"[[[0.3333333730697632, 0.3176470994949341, 0.3...",0
1,1,"[[[0.24705886840820312, 0.17647063732147217, 0...",0
2,2,"[[[-0.6078431606292725, -0.6000000238418579, -...",0
3,3,"[[[0.09019613265991211, 0.12941181659698486, 0...",0
4,4,"[[[-0.5764706134796143, -0.5372549295425415, -...",0
5,5,"[[[-0.26274508237838745, -0.43529409170150757,...",0
6,6,"[[[0.5764706134796143, 0.49803924560546875, 0....",0
7,7,"[[[-0.10588234663009644, -0.08235293626785278,...",0
8,8,"[[[-0.4901960492134094, -0.4745097756385803, -...",0
9,9,"[[[0.16078436374664307, 0.035294175148010254, ...",0
