### データセット読み込みとfeature_table_indexesの初期化

In [35]:
from collections import defaultdict
import pandas as pd
import random
import sqlite3
from typing import List, Tuple, TypeVar, Dict

Dataframe = TypeVar("pandas.core.frame.DataFrame")

# datasetをsqliteからDataFrame形式で読み込み
def load_dataset(dbpath="./ft.db", tablename="feature_table") -> Dataframe:
    conn=sqlite3.connect(dbpath)
    c = conn.cursor()
    dataset = pd.read_sql('SELECT * FROM ' + tablename, conn)
    return dataset

# feature_table_indexesの初期化 (queryはランダムに選択)
def init_feature_table_indexes(feature_table: Dataframe, queryN=1, seed=0) -> Dict[str, Dict]:
    random.seed(seed)
    labels = sorted(feature_table["label"].unique())
    ft_labelby = feature_table.groupby("label")
    ft_indexes = {}
    
    ft_indexes["queries"], ft_indexes["used_queries"], ft_indexes["train_data"],  = {}, {}, {}
    for label in labels:
        ft_indexes["used_queries"][label] = []
        ft_indexes["train_data"][label] = []
        
        ft = ft_labelby.get_group(label)
        indexes = ft["index"].values.tolist()
        queries = random.sample(indexes, queryN)
        ft_indexes["queries"][label] = queries
    
    return ft_indexes

In [33]:
ft = load_dataset(dbpath="./assets/ft.db")
ft

Unnamed: 0,index,feature,image,label
0,0,"[0.30711856484413147, 0.19312363862991333, 0.0...","[[[-0.3176470398902893, -0.29411762952804565, ...",1
1,1,"[0.4214461147785187, 1.198604702949524, 0.9510...","[[[-0.9764705896377563, -0.9686274528503418, -...",1
2,2,"[0.2851516008377075, 0.20933431386947632, 0.07...","[[[0.13725495338439941, 0.13725495338439941, 0...",2
3,3,"[0.6752024292945862, 0.7612708806991577, 0.712...","[[[-0.24705880880355835, -0.27843135595321655,...",2
4,4,"[0.3068203926086426, 0.6951863169670105, 0.444...","[[[0.30980396270751953, 0.30980396270751953, 0...",2
5,5,"[0.709537923336029, 0.9612928032875061, 0.8762...","[[[-0.24705880880355835, -0.3647058606147766, ...",2
6,6,"[0.5953243374824524, 0.7417647242546082, 0.796...","[[[-0.7411764860153198, -0.7647058963775635, -...",1
7,7,"[1.2410345077514648, 0.5029025673866272, 0.416...","[[[0.14509809017181396, 0.12941181659698486, 0...",1
8,8,"[0.23761172592639923, 0.045143790543079376, 0....","[[[-0.38823527097702026, -0.3647058606147766, ...",2
9,9,"[0.3618304133415222, 1.2137598991394043, 1.467...","[[[-0.7176470756530762, -0.615686297416687, -0...",0


In [36]:
ft_indexes = init_feature_table_indexes(feature_table=ft, queryN=5)
ft_indexes

{'queries': {0: [9518, 10404, 943, 6402, 12667],
  1: [11889, 9860, 7408, 11641, 8721],
  2: [14315, 5320, 12298, 3385, 6860]},
 'train_data': {0: [], 1: [], 2: []},
 'used_queries': {0: [], 1: [], 2: []}}

In [None]:
class DataSelector:
    def __init__(self, dataset_table: Dataframe, dataset_table_indexes: Dict[str, List[int]]):
        self.dt = dataset_table
        self.dt_indexes = dataset_table_indexes
        self.labels = sorted(dataset_table["label"].unique())
        # 選択済みのデータを削除済みのdataset_table
        self.dropped_dt = self.__drop_selected_data(dataset_table, dataset_table_indexes)

    def __drop_selected_data(self, dataset_table: Dataframe, dataset_table_indexes: Dict[str, List[int]]) -> Dataframe:
        dt = dataset_table.drop(index=dataset_table_indexes["selected"])
        dt = dt.reset_index(drop=True)
        return dt

    def randomly_add(self, dataN: int, seed=None) -> Tuple[Dict[str, List[int]], List[int]]:
        selected_indexes = []
        dt_labelby = self.dropped_dt.groupby("label")
        for label in self.labels:
            df = dt_labelby.get_group(label)
            df = df.sample(n=dataN, random_state=seed)
            selected_indexes += list(df["index"].values)
            self.dt_indexes["selected"] += list(df["index"].values)
        return self.dt_indexes, selected_indexes

    def out_dataset_table_indexes(self) -> Dict[str, List[int]]:
        return self.dt_indexes

    def out_selected_dataset(self) -> List[Tuple]:
        selected_dataset = []
        for index in self.dt_indexes["selected"]:
            irow = self.dt[self.dt["index"]==index]
            image = json.loads(irow["image"].iloc[0])
            image = np.array(image)
            image = torch.from_numpy(image.astype(np.float32)).clone()
            label = irow["label"].iloc[0]
            selected_dataset.append((image, label))
        return selected_dataset

    def out_dataset(self, dt_indexes:List[int]) -> List[Tuple]:
        dataset = []
        for index in dt_indexes:
            irow = self.dt[self.dt["index"]==index]
            image = json.loads(irow["image"].iloc[0])
            image = np.array(image)
            image = torch.from_numpy(image.astype(np.float32)).clone()
            label = irow["label"].iloc[0]
            dataset.append((image, label))
        return dataset