In [1]:
import random
from collections import defaultdict
import torch
import torchvision
import torchvision.transforms as transforms

In [2]:
# Cifar10データセットをdownlaod_pathにダウンロード
def load_cifar10(download_path):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    train = torchvision.datasets.CIFAR10(root=download_path, train=True, download=True, transform=transform)
    test = torchvision.datasets.CIFAR10(root=download_path, train=False, download=True, transform=transform)
    return train, test

class DataSelector:

    def __init__(self, train, test):
        self.train = list(train)
        self.test = list(test)

    def get_dataset(self):
        dataset = (self.train, self.test)
        return dataset

    def print_len(self):
        print("\n------------------------------------------------------")
        print("訓練データ数:{0:>9}".format(len(self.train)))
        print("テストデータ数:{0:>9}".format(len(self.test)))

    def print_len_by_label(self):
        texts = ["訓練データ", "テストデータ"]
        for i, t in enumerate([self.train, self.test]):
            dic = defaultdict(int)
            for data in t:
                dic[data[1]] += 1
            print("\n{0}数  ----------------------------------------------".format(texts[i]))
            dic = sorted(dic.items())
            count = 0
            for key, value in dic:
                print("ラベル:  {0}    データ数:  {1}".format(key, value))
                count += value
            print("合計データ数:  {0}".format(count))

    def randomly_select_data_by_label(self, data_num, train=True):
        selected = []
        dataset = self.train if train == True else self.test
        dic = defaultdict(int)
        for data in dataset:
            if dic[data[1]] < data_num:
                selected.append(data)
                dic[data[1]] += 1
        if train == True:
            self.train = selected
        else:
            self.test = selected

    def __select_data_by_label(self, label):
        selected = [[], []]
        dataset = (self.train, self.test)
        for i, t in enumerate(dataset):
            for data in t:
                if data[1] == label:
                    selected[i].append(data)
        return selected

    def select_data_by_labels(self, labels):
        selected = [[], []]
        for label in labels:
            result = self.__select_data_by_label(label)
            selected[0] += result[0]
            selected[1] += result[1]
        self.train, self.test = selected

    def update_labels(self):
        updated = [[], []]
        label_mapping = defaultdict(lambda: -1)
        for i, t in enumerate([self.train, self.test]):
            t = sorted(t, key=lambda x:x[1])
            new_label = 0
            for data in t:
                if label_mapping[data[1]] == -1:
                    label_mapping[data[1]] = new_label
                    new_label += 1
                updated[i].append((data[0], label_mapping[data[1]]))
        self.train, self.test = updated
        print("\nラベルを変更しました.  ----------------------------------------------")
        for key, value in label_mapping.items():
            print("ラベル  {0} -> {1}".format(key, value))

In [3]:
download_path = "./data/"
train, test = load_cifar10(download_path)
data_selector = DataSelector(train, test)
data_selector.select_data_by_labels([1, 2, 8])
data_selector.randomly_select_data_by_label(300, train=True)
data_selector.randomly_select_data_by_label(30, train=False)
data_selector.update_labels()
data_selector.print_len_by_label()
train, test = data_selector.get_dataset()

Files already downloaded and verified
Files already downloaded and verified

ラベルを変更しました.  ----------------------------------------------
ラベル  1 -> 0
ラベル  2 -> 1
ラベル  8 -> 2

訓練データ数  ----------------------------------------------
ラベル:  0    データ数:  300
ラベル:  1    データ数:  300
ラベル:  2    データ数:  300
合計データ数:  900

テストデータ数  ----------------------------------------------
ラベル:  0    データ数:  30
ラベル:  1    データ数:  30
ラベル:  2    データ数:  30
合計データ数:  90


### ------------------- フィーチャ抽出とsqliteへの保存 --------------------

In [4]:
import torch.nn as nn
import torch.nn.functional as F

class PreLeNet(nn.Module):
    def __init__(self, out=3):
        super(PreLeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1, padding=1) # (1) 32*32*3 -> 32*32*16
        self.conv2 = nn.Conv2d(16, 32, 3, 1, padding=1) # (3) 16*16*16 -> 16*16*32
        self.gap = nn.AvgPool2d(kernel_size=8)
        self.fc1 = nn.Linear(8*8*32, out)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2) # (2) 32*32*16 -> 16*16*16
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2) # (4) 16*16*32 -> 8*8*32
        feature = self.gap(x) # 1*1*32
        feature = feature.view(-1, 32) # 1*32
        x = x.view(-1, 8*8*32)
        x = F.relu(self.fc1(x))
        return x, feature

#### ------------------------- 通常保存------------------------------

In [29]:
import json
import sqlite3

In [44]:
batch_size = 128
dataloader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=False, num_workers=2)
model_path = "./v3.pth"
net = PreLeNet(3)
net.load_state_dict(torch.load(model_path))
dbname = "features.db"
conn = sqlite3.connect(dbname)
c = conn.cursor()
c.execute("create table featuretable (id integer PRIMARY KEY, label integer, image text, feature text)")
sql = "insert into featuretable (id, label, image, feature) values (?, ?, ?, ?)"

id = 1
for i, (inputs, labels) in enumerate(dataloader):
    outputs, features = net(inputs)
    for label, image, feature in zip(labels, inputs, features):
        label = int(label)
        image = json.dumps(image.cpu().numpy().tolist())
        feature = json.dumps(feature.data.cpu().numpy().tolist())
        data = (id, label, image, feature)
        c.execute(sql, data)
        id += 1

conn.commit()
conn.close()

In [53]:
dbname = "features.db"
conn = sqlite3.connect(dbname)
c = conn.cursor()
sql = "select * from featuretable"
c.execute(sql)
data = c.fetchone()
id = data[0]
label = data[1]
image = json.loads(data[2])
feature = json.loads(data[3])

#### ------------------------- DataFrame保存 ------------------------------

In [84]:
batch_size = 128
dataloader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=False, num_workers=2)
model_path = "./v3.pth"
net = PreLeNet(3)
net.load_state_dict(torch.load(model_path))
dbname = "features_dataframe.db"
conn = sqlite3.connect(dbname)
c = conn.cursor()

d = defaultdict(list)
for i, (inputs, labels) in enumerate(dataloader):
    outputs, features = net(inputs)
    for label, image, feature in zip(labels, inputs, features):
        d["label"].append(int(label))
        d["image"].append(json.dumps(image.cpu().numpy().tolist()))
        d["feature"].append(json.dumps(feature.data.cpu().numpy().tolist()))

for k, v in d.items():
    d[k] = pd.Series(v)
    
df = pd.DataFrame(d)
df.to_sql('feature_table', conn, if_exists='replace')
conn.close()

In [89]:
dbname = "features_dataframe.db"
conn=sqlite3.connect(dbname)
c = conn.cursor()
df = pd.read_sql('SELECT * FROM feature_table', conn)
conn.close()

### ----------------- sqlite 保存形式の確認 ----------------------

In [74]:
import sqlite3
import numpy as np
import io
import json

In [80]:
dbname = "test.db"
conn = sqlite3.connect(dbname)
c = conn.cursor()
c.execute("create table test (id integer, array text)")

sql = 'insert into test (id, array) values (?,?)'
x = np.arange(12).reshape(2,6)
data = (1, json.dumps(x.tolist()))
c.execute(sql, data)

conn.commit()
conn.close()

In [85]:
dbname = "test.db"
conn = sqlite3.connect(dbname)
c = conn.cursor()
sql = "select * from test"
c.execute(sql)
data = c.fetchall()
result_list = json.loads(data[0][1])
print(type(result_list))
print(result_list)

<class 'list'>
[[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]]


### ------------------ sqlite3 test ----------------------

In [25]:
dbname = 'test.db'

conn=sqlite3.connect(dbname)
c = conn.cursor()

# executeメソッドでSQL文を実行する
create_table = 'create table sample (id integer,name text)'
c.execute(create_table)

# SQL文に値をセットする場合は，Pythonのformatメソッドなどは使わずに，
# セットしたい場所に?を記述し，executeメソッドの第2引数に?に当てはめる値をタプルで渡す．
sql = 'insert into sample (id, name) values (?,?)'
text = "test test test"
query = (1, text)
c.execute(sql, query)

conn.commit()

conn.close()

In [26]:
dbname = 'test.db'

conn=sqlite3.connect(dbname)
c = conn.cursor()

sql = 'select * from sample'
c.execute(sql)
result=c.fetchall()
print(result)

[(1, 'test test test')]


### ------------- DataFrame をsqliteに保存 --------------

In [55]:
import pandas as pd

In [61]:
dbname = "pdtest.db"
conn=sqlite3.connect(dbname)
c = conn.cursor()

d = {}
d[0] = "aaa"
d[1] = "bbb"
d[2] = "ccc"

df = pd.DataFrame(list(d.items()), columns=["id", "text"])

df.to_sql('sample', conn, if_exists='replace')
conn.close()

In [68]:
dbname = "pdtest.db"
conn=sqlite3.connect(dbname)
c = conn.cursor()
df = pd.read_sql('SELECT * FROM sample', conn)
print(df)
conn.close()

   index  id text
0      0   0  aaa
1      1   1  bbb
2      2   2  ccc
