In [3]:
from collections import defaultdict
import json
import time
import pandas as pd
import sqlite3
from typing import TypeVar

Dataframe = TypeVar("pandas.core.frame.DataFrame")

class FeatureExtractor:
    
    def __init__(self, model, dataloader):
        self.model = model
        self.dataloader = dataloader
        self.feature_table: Dataframe = None
    
    def make_feature_table(self):
        start = time.time()
        table = defaultdict(list)
        for (inputs, labels) in self.dataloader:
            outputs, features = self.model(inputs)
            for label, image, feature in zip(labels, inputs, features):
                table["label"].append(int(label))
                table["image"].append(json.dumps(image.cpu().numpy().tolist()))
                table["feature"].append(json.dumps(feature.data.cpu().numpy().tolist()))
        for k, v in table.items():
            table[k] = pd.Series(v)
        self.feature_table = pd.DataFrame(table)
        elapsed_time = time.time() - start
        print("elapsed time of feature extraction:  {0}m {1}s".format(elapsed_time//60, elapsed_time%60))
        
    def save_feature_table(self, savepath="./ft.db", tablename="feature_table"):
        if self.feature_table is None:
             print("feature table does not exist.")
        else:
            conn = sqlite3.connect(savepath)
            c = conn.cursor()
            self.feature_table.to_sql(tablename, conn, if_exists='replace')
            conn.close()

In [4]:
# trainloader, testloader を生成するstub
import torchvision
import torchvision.transforms as transforms

def update_labels(train, test):
    updated = [[], []]
    mapping_dict = defaultdict(lambda: -1)
    for i, t in enumerate([train, test]):
        t = sorted(t, key=lambda x:x[1])
        new_label = 0
        for data in t:
            if mapping_dict[data[1]] == -1:
                mapping_dict[data[1]] = new_label
                new_label += 1
            updated[i].append((data[0], mapping_dict[data[1]]))
    train, test = updated
    return train, test

def dataset_stub():
    path = "../../../prototype/proposal/data/"
    classes = [1,2,8]
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    train = torchvision.datasets.CIFAR10(root=path, train=True, download=True, transform=transform)
    test = torchvision.datasets.CIFAR10(root=path, train=False, download=True, transform=transform)
    train = [d for d in train if d[1] in classes]
    test = [d for d in test if d[1] in classes]
    train, test = update_labels(train, test)
    return train, test

def loader_stub():
    batch_size = 128
    train, test = dataset_stub()
    trainloader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2)
    testloader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False, num_workers=2)
    return trainloader, testloader

In [5]:
# netのstub
import torch.nn as nn
import torch.nn.functional as F

class PreLeNet(nn.Module):
    def __init__(self, out=3):
        super(PreLeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1, padding=1) # (1) 32*32*3 -> 32*32*16
        self.conv2 = nn.Conv2d(16, 32, 3, 1, padding=1) # (3) 16*16*16 -> 16*16*32
        self.gap = nn.AvgPool2d(kernel_size=8)
        self.fc1 = nn.Linear(8*8*32, out)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2) # (2) 32*32*16 -> 16*16*16
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2) # (4) 16*16*32 -> 8*8*32
        feature = self.gap(x) # 1*1*32
        feature = feature.view(-1, 32) # 1*32
        x = x.view(-1, 8*8*32)
        x = F.relu(self.fc1(x))
        return x, feature

In [8]:
# feature_tableを読み込むためのstub
def load_feature_table(dbpath, tablename="feature_table"):
    conn=sqlite3.connect(dbpath)
    c = conn.cursor()
    ft = pd.read_sql('SELECT * FROM ' + tablename, conn)
    return ft

In [6]:
# model と trainloader を生成
import torch

model_path = "./assets/v4.pth"
model = PreLeNet(3)
model.load_state_dict(torch.load(model_path))
trainloader, _ = loader_stub()

Files already downloaded and verified
Files already downloaded and verified


In [7]:
feature_extractor = FeatureExtractor(model, trainloader)
feature_extractor.make_feature_table()
savepath = "./assets/ft.db"
feature_extractor.save_feature_table(savepath=savepath)

elapsed time of feature extraction:  0.0m 49.205907344818115s


In [10]:
dbpath = "./assets/ft.db"
feature_table = load_feature_table(dbpath)
feature_table

Unnamed: 0,index,feature,image,label
0,0,"[0.44359123706817627, 0.5560004711151123, 0.17...","[[[-0.10588234663009644, -0.1294117569923401, ...",1
1,1,"[0.8936551213264465, 0.8017301559448242, 0.027...","[[[0.6392157077789307, 0.843137264251709, 0.81...",1
2,2,"[0.49301645159721375, 0.4722248911857605, 0.14...","[[[0.5764706134796143, 0.5764706134796143, 0.5...",2
3,3,"[0.5521446466445923, 1.1740447282791138, 0.025...","[[[-0.6313725709915161, -0.6549019813537598, -...",1
4,4,"[0.7862429618835449, 0.8334169387817383, 0.042...","[[[-0.24705880880355835, -0.26274508237838745,...",1
5,5,"[0.42934268712997437, 0.5402464866638184, 0.20...","[[[0.8823529481887817, 0.8039215803146362, 0.8...",1
6,6,"[0.6640405058860779, 0.8305144906044006, 0.096...","[[[-0.03529411554336548, -0.003921568393707275...",2
7,7,"[0.48203176259994507, 0.5245118141174316, 0.03...","[[[0.2705882787704468, 0.24705886840820312, 0....",1
8,8,"[0.3360908627510071, 0.940671443939209, 0.0016...","[[[-0.3490195870399475, -0.35686272382736206, ...",1
9,9,"[0.7106035351753235, 1.010310411453247, 0.0402...","[[[0.2705882787704468, 0.050980448722839355, -...",0
