### データセット読み込みとfeature_table_indexesの初期化

In [1]:
from collections import defaultdict
import pandas as pd
import random
import sqlite3
from typing import List, Tuple, TypeVar, Dict

Dataframe = TypeVar("pandas.core.frame.DataFrame")

# datasetをsqliteからDataFrame形式で読み込み
def load_dataset(dbpath="./ft.db", tablename="feature_table") -> Dataframe:
    conn=sqlite3.connect(dbpath)
    c = conn.cursor()
    dataset = pd.read_sql('SELECT * FROM ' + tablename, conn)
    return dataset

# feature_table_indexesの初期化 (queryはランダムに選択)
def init_feature_table_indexes(feature_table: Dataframe, queryN=1, seed=0) -> Dict[str, Dict]:
    random.seed(seed)
    labels = sorted(feature_table["label"].unique())
    ft_labelby = feature_table.groupby("label")
    ft_indexes = {}
    ft_indexes["queries"], ft_indexes["used_queries"], ft_indexes["selected_data"],  = {}, {}, {}
    
    for label in labels:
        ft_indexes["used_queries"][label] = []
        ft_indexes["selected_data"][label] = []
        
        ft = ft_labelby.get_group(label)
        indexes = ft["index"].values.tolist()
        queries = random.sample(indexes, queryN)
        ft_indexes["queries"][label] = queries
    
    return ft_indexes

### DataSelector

In [118]:
from collections import defaultdict
import copy
import json
import numpy as np
import pandas as pd
import torch
from typing import List, Tuple, TypeVar, Dict

Dataframe = TypeVar("pandas.core.frame.DataFrame")
Tensor = TypeVar("torch.Tensor")
NpInt64 = TypeVar("numpy.int64")

class DataSelector:
    def __init__(self, dataset_table: Dataframe, dataset_table_indexes: Dict[str, Dict[int, List]]):
        self.default_dt = dataset_table
        self.dt_indexes = copy.deepcopy(dataset_table_indexes)
        self.labels = sorted(dataset_table["label"].unique())
        # 学習済みのデータを削除したdataset_table
        self.dt = self.__init_dt(dataset_table, dataset_table_indexes)
        
    def __init_dt(self, dataset_table: Dataframe, dataset_table_indexes: Dict[str, Dict[int, List]]) -> Dataframe:
        drop_indexes = []
        for indexes in dataset_table_indexes["selected_data"].values():
            drop_indexes += indexes
        dt = dataset_table.drop(index=drop_indexes)
        return dt
    
    def __drop_data(self, indexes: List):
        self.dt = self.dt.drop(index=indexes)
        
    def __convert_to_torch(self, json_image):
        image = json.loads(json_image)
        image = np.array(image)
        image = torch.from_numpy(image.astype(np.float32)).clone()
        return image
        
    def get_dt_indexes(self) -> Dict[str, Dict[int, List]]:
        return self.dt_indexes
    
    def get_dataset(self, indexes_labelby: Dict[int, List]) -> List[Tuple[Tensor, NpInt64]]:
        dataset = []
        dt_labelby = self.default_dt.groupby("label")
        for label in self.labels:
            indexes = indexes_labelby[label]
            dt = dt_labelby.get_group(label)
            rows = dt[dt["index"].isin(indexes)]
            images = rows["image"].values
            labels = rows["label"].values
            for image, label in zip(images, labels):
                image = self.__convert_to_torch(image)
                dataset.append((image, label))
        return dataset
    
    def randomly_select_indexes(self, dataN: int, seed=0) -> Dict[int, List]:
        indexes_labelby = {}
        dt_labelby = self.dt.groupby("label")
        for label in self.labels:
            dt = dt_labelby.get_group(label)
            dt = dt.sample(n=dataN, random_state=seed)
            selected_indexes = list(dt["index"].values)
            indexes_labelby[label] = selected_indexes
            self.dt_indexes["selected_data"][label] += selected_indexes
            self.__drop_data(selected_indexes)
        return indexes_labelby

In [60]:
ft = load_dataset(dbpath="./assets/ft.db")
print("データ数:  {0}".format(len(ft)))
ft[:5]

データ数:  15000


Unnamed: 0,index,feature,image,label
0,0,"[0.30711856484413147, 0.19312363862991333, 0.0...","[[[-0.3176470398902893, -0.29411762952804565, ...",1
1,1,"[0.4214461147785187, 1.198604702949524, 0.9510...","[[[-0.9764705896377563, -0.9686274528503418, -...",1
2,2,"[0.2851516008377075, 0.20933431386947632, 0.07...","[[[0.13725495338439941, 0.13725495338439941, 0...",2
3,3,"[0.6752024292945862, 0.7612708806991577, 0.712...","[[[-0.24705880880355835, -0.27843135595321655,...",2
4,4,"[0.3068203926086426, 0.6951863169670105, 0.444...","[[[0.30980396270751953, 0.30980396270751953, 0...",2


In [125]:
ft_indexes1 = init_feature_table_indexes(feature_table=ft, queryN=5, seed=2)
ft_indexes1

{'queries': {0: [1347, 2238, 2070, 8954, 4165],
  1: [7522, 6188, 14867, 5218, 14887],
  2: [916, 14267, 3882, 10627, 9697]},
 'selected_data': {0: [], 1: [], 2: []},
 'used_queries': {0: [], 1: [], 2: []}}

#### 1反復目

In [126]:
selector = DataSelector(ft, ft_indexes1)
indexes_labelby1 = selector.randomly_select_indexes(dataN=200, seed=2)
indexes_labelby2 = selector.randomly_select_indexes(dataN=200, seed=2)
ft_indexes2 = selector.get_dt_indexes()

In [131]:
for i in range(3):
    print("ラベル{0}  の重複なしデータ数:  {1}".format(i, len(set(ft_indexes2["selected_data"][i]))))

ラベル0  の重複なしデータ数:  400
ラベル1  の重複なしデータ数:  400
ラベル2  の重複なしデータ数:  400


#### 2反復目

In [128]:
selector = DataSelector(ft, ft_indexes2)
indexes_labelby3 = selector.randomly_select_indexes(dataN=200, seed=0)
indexes_labelby4 = selector.randomly_select_indexes(dataN=200, seed=0)
ft_indexes3 = selector.get_dt_indexes()

In [132]:
for i in range(3):
    print("ラベル{0}  の重複なしデータ数:  {1}".format(i, len(set(ft_indexes3["selected_data"][i]))))

ラベル0  の重複なしデータ数:  800
ラベル1  の重複なしデータ数:  800
ラベル2  の重複なしデータ数:  800


In [133]:
dataset = selector.get_dataset(ft_indexes3["selected_data"])
print("データ数:  {0}".format(len(dataset)))
print(dataset[0])

データ数:  2400
(tensor([[[-0.6471, -0.6000, -0.6078,  ..., -0.6392, -0.6000, -0.7412],
         [-0.7020, -0.6078, -0.5137,  ..., -0.7255, -0.6235, -0.6627],
         [-0.6863, -0.5451, -0.5608,  ..., -0.7098, -0.5216, -0.5216],
         ...,
         [ 0.5373,  0.4980,  0.4510,  ...,  0.8431,  0.9373,  0.7569],
         [ 0.6157,  0.6784,  0.6863,  ...,  0.8745,  0.6863,  0.4275],
         [ 0.4431,  0.4510,  0.5216,  ...,  0.6235,  0.4039,  0.3647]],

        [[-0.6314, -0.5529, -0.6078,  ..., -0.6235, -0.6157, -0.6627],
         [-0.6314, -0.5373, -0.5294,  ..., -0.7020, -0.6471, -0.6549],
         [-0.6235, -0.4745, -0.5686,  ..., -0.6863, -0.5373, -0.5373],
         ...,
         [-0.8667, -0.8118, -0.7804,  ...,  0.3804,  0.2314, -0.5451],
         [-0.9529, -0.9216, -0.7569,  ...,  0.8431,  0.1373, -0.6549],
         [-0.7255, -0.7882, -0.5451,  ...,  0.3961, -0.0588, -0.1529]],

        [[-0.6627, -0.6314, -0.7020,  ..., -0.6000, -0.6471, -0.6706],
         [-0.7255, -0.6235, -0.5