In [91]:
import json

def init_dataset_table_indexes():
    dt_indexes = {}
    dt_indexes["selected"] = []
    return dt_indexes
        
# dataset_table_indexes(dt_indexes) を辞書形式で読み込み
def load_dataset_table_indexes(path):
    with open(path, "r") as f:
        dt_indexes = json.load(f)
    return dt_indexes

In [98]:
import json
import numpy as np
import torch

class DataSelector:
    def __init__(self, dataset_table, dataset_table_indexes):
        self.dt = dataset_table
        self.dt_indexes = dataset_table_indexes
        self.labels = sorted(dataset_table["label"].unique())
        # 選択済みのデータを削除済みのdataset_table
        self.dropped_dt = self.__drop_selected_data(dataset_table, dataset_table_indexes)
        
    def __drop_selected_data(self, dataset_table, dataset_table_indexes):
        dt = dataset_table.drop(index=dataset_table_indexes["selected"])
        dt = dt.reset_index(drop=True)
        return dt
    
    def randomly_add(self, dataN, seed=None):
        dt_labelby = self.dropped_dt.groupby("label")
        for label in self.labels:
            df = dt_labelby.get_group(label)
            df = df.sample(n=dataN, random_state=seed)
            self.dt_indexes["selected"] += list(df["index"].values)
        return self.dt_indexes
    
    def out_dataset_table_indexes(self):
        return self.dt_indexes
    
    # dataset_table_indexes(dt_indexes) をjson形式で保存
    def save_dataset_table_indexes(self, savepath="./dt_indexes_v1.json"):
        dt_indexes = self.dt_indexes
        dt_indexes["selected"] = [int(index) for index in dt_indexes["selected"]]
        with open(savepath, "w") as f:
            json.dump(dt_indexes, f, indent=4)
    
    def out_selected_dataset(self):
        selected_dataset = []
        for index in self.dt_indexes["selected"]:
            irow = self.dt[self.dt["index"]==index]
            image = json.loads(irow["image"].iloc[0])
            image = np.array(image)
            image = torch.from_numpy(image.astype(np.float32)).clone()
            label = irow["label"].iloc[0]
            selected_dataset.append((image, label))
        return selected_dataset

In [49]:
# dataset_tableを読み込むためのstub
import pandas as pd
import sqlite3

def load(dbpath, tablename="dataset"):
    conn=sqlite3.connect(dbpath)
    c = conn.cursor()
    dataset = pd.read_sql('SELECT * FROM ' + tablename, conn)
    return dataset

In [50]:
dbpath = "./assets/data_v1.db"
dt = load(dbpath)

### 1反復目のテスト

In [94]:
dt_indexes = init_dataset_table_indexes() # dt_indexes 初期化
selector = DataSelector(dt, dt_indexes)
dt_indexes2 = selector.randomly_add(dataN=10, seed=1)
selected_dataset2 = selector.out_selected_dataset()

In [95]:
print("インデックス数:")
print(len(dt_indexes2["selected"]))
print("インデックス:")
print(dt_indexes2["selected"])

インデックス数:
30
インデックス:
[2764, 4767, 3814, 3499, 2735, 3922, 2701, 1179, 932, 792, 7764, 9767, 8814, 8499, 7735, 8922, 7701, 6179, 5932, 5792, 12764, 14767, 13814, 13499, 12735, 13922, 12701, 11179, 10932, 10792]


In [97]:
savepath = "./assets/dt_indexes_v1.json"
selector.save_dataset_table_indexes(savepath)

In [99]:
import torch
train = selected_dataset2
trainloader = torch.utils.data.DataLoader(train, batch_size=128, shuffle=True, num_workers=2)

### 2反復目のテスト

In [100]:
selector = DataSelector(dt, dt_indexes2)
dt_indexes3 = selector.randomly_add(dataN=10, seed=1)
selected_dataset3 = selector.out_selected_dataset()

In [101]:
print("インデックス数:")
print(len(dt_indexes3["selected"]))
print("インデックス:")
print(dt_indexes3["selected"])

インデックス数:
60
インデックス:
[2764, 4767, 3814, 3499, 2735, 3922, 2701, 1179, 932, 792, 7764, 9767, 8814, 8499, 7735, 8922, 7701, 6179, 5932, 5792, 12764, 14767, 13814, 13499, 12735, 13922, 12701, 11179, 10932, 10792, 2800, 2988, 4134, 4288, 4260, 4186, 2120, 1743, 1383, 4063, 7800, 7988, 9134, 9288, 9260, 9186, 7120, 6743, 6383, 9063, 12800, 12988, 14134, 14288, 14260, 14186, 12120, 11743, 11383, 14063]


In [102]:
savepath = "./assets/dt_indexes_v2.json"
selector.save_dataset_table_indexes(savepath)

In [103]:
import torch
train = selected_dataset2
trainloader = torch.utils.data.DataLoader(train, batch_size=128, shuffle=True, num_workers=2)

In [106]:
path = "./assets/dt_indexes_v2.json"
dt_indexes3 = load_dataset_table_indexes(path)
print(dt_indexes3["selected"])

[2764, 4767, 3814, 3499, 2735, 3922, 2701, 1179, 932, 792, 7764, 9767, 8814, 8499, 7735, 8922, 7701, 6179, 5932, 5792, 12764, 14767, 13814, 13499, 12735, 13922, 12701, 11179, 10932, 10792, 2800, 2988, 4134, 4288, 4260, 4186, 2120, 1743, 1383, 4063, 7800, 7988, 9134, 9288, 9260, 9186, 7120, 6743, 6383, 9063, 12800, 12988, 14134, 14288, 14260, 14186, 12120, 11743, 11383, 14063]
