In [10]:
from collections import defaultdict
import json
import time
import pandas as pd
import sqlite3

class FeatureExtractor:
    
    def __init__(self, model, dataloader):
        self.model = model
        self.dataloader = dataloader
        self.feature_table = None
    
    def make_feature_table(self):
        start = time.time()
        table = defaultdict(list)
        for (inputs, labels) in self.dataloader:
            outputs, features = self.model(inputs)
            for label, image, feature in zip(labels, inputs, features):
                table["label"].append(int(label))
                table["image"].append(json.dumps(image.cpu().numpy().tolist()))
                table["feature"].append(json.dumps(feature.data.cpu().numpy().tolist()))
        for k, v in table.items():
            table[k] = pd.Series(v)
        self.feature_table = pd.DataFrame(table)
        elapsed_time = time.time() - start
        print("elapsed time of feature extraction:  {0}m {1}s".format(elapsed_time//60, elapsed_time%60))
        
    def save_feature_table(self, savepath, tablename="feature_table"):
        if self.feature_table is None:
             print("feature table does not exist.")
        else:
            conn = sqlite3.connect(savepath)
            c = conn.cursor()
            self.feature_table.to_sql(tablename, conn, if_exists='replace')
            conn.close()

In [11]:
# trainloader, testloader を生成するstub
import torchvision
import torchvision.transforms as transforms

def update_labels(train, test):
    updated = [[], []]
    mapping_dict = defaultdict(lambda: -1)
    for i, t in enumerate([train, test]):
        t = sorted(t, key=lambda x:x[1])
        new_label = 0
        for data in t:
            if mapping_dict[data[1]] == -1:
                mapping_dict[data[1]] = new_label
                new_label += 1
            updated[i].append((data[0], mapping_dict[data[1]]))
    train, test = updated
    return train, test

def dataset_stub():
    path = "../../prototype/data/"
    classes = [1,2,8]
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    train = torchvision.datasets.CIFAR10(root=path, train=True, download=True, transform=transform)
    test = torchvision.datasets.CIFAR10(root=path, train=False, download=True, transform=transform)
    train = [d for d in train if d[1] in classes]
    test = [d for d in test if d[1] in classes]
    train, test = update_labels(train, test)
    return train, test

def loader_stub():
    batch_size = 128
    train, test = dataset_stub()
    trainloader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2)
    testloader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False, num_workers=2)
    return trainloader, testloader

In [12]:
# netのstub
import torch.nn as nn
import torch.nn.functional as F

class PreLeNet(nn.Module):
    def __init__(self, out=3):
        super(PreLeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1, padding=1) # (1) 32*32*3 -> 32*32*16
        self.conv2 = nn.Conv2d(16, 32, 3, 1, padding=1) # (3) 16*16*16 -> 16*16*32
        self.gap = nn.AvgPool2d(kernel_size=8)
        self.fc1 = nn.Linear(8*8*32, out)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2) # (2) 32*32*16 -> 16*16*16
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2) # (4) 16*16*32 -> 8*8*32
        feature = self.gap(x) # 1*1*32
        feature = feature.view(-1, 32) # 1*32
        x = x.view(-1, 8*8*32)
        x = F.relu(self.fc1(x))
        return x, feature

In [13]:
# model と trainloader を生成
import torch

model_path = "./assets/v4.pth"
model = PreLeNet(3)
model.load_state_dict(torch.load(model_path))
trainloader, _ = loader_stub()

Files already downloaded and verified
Files already downloaded and verified


In [16]:
feature_extractor = FeatureExtractor(model, trainloader)
feature_extractor.make_feature_table()
savepath = "./assets/features_v1.db"
feature_extractor.save_feature_table(savepath=savepath)

elapsed time of feature extraction:  0.0m 46.41586923599243s


In [17]:
feature_table = feature_extractor.feature_table
feature_table

Unnamed: 0,feature,image,label
0,"[0.8066024780273438, 0.7404053807258606, 0.143...","[[[-0.10588234663009644, -0.1294117569923401, ...",2
1,"[0.4560019373893738, 0.6529815793037415, 0.077...","[[[0.3490196466445923, 0.30980396270751953, 0....",2
2,"[0.3897155225276947, 0.4236888587474823, 0.134...","[[[0.3176470994949341, 0.3333333730697632, 0.3...",1
3,"[0.7064931392669678, 0.8661195039749146, 0.117...","[[[0.37254905700683594, 0.3176470994949341, 0....",2
4,"[1.0469402074813843, 0.7238845229148865, 0.215...","[[[0.8745098114013672, 0.7882353067398071, 0.7...",0
5,"[0.7571626901626587, 1.229896068572998, 0.0122...","[[[-0.7098039388656616, -0.6941176652908325, -...",0
6,"[0.6395211815834045, 0.6540312767028809, 0.214...","[[[0.615686297416687, 0.6313725709915161, 0.62...",1
7,"[0.8305415511131287, 1.0492537021636963, 0.107...","[[[-0.9764705896377563, -0.9686274528503418, -...",0
8,"[0.8620947599411011, 1.2904132604599, 0.063130...","[[[-0.5215686559677124, -0.5137255191802979, -...",2
9,"[0.7602303624153137, 1.0959500074386597, 0.046...","[[[0.41960787773132324, 0.2705882787704468, -0...",1
