In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os
import sys
sys.path.insert(0, os.path.abspath('../../.'))
import time

import torch
import torch.nn as nn

import numpy as np
# from sklearn import svm
from thundersvm import *
from tqdm import tqdm

from src.model.SparseNet import SparseNet
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from src.model.FolderDataset import FolderPatchDataset
from src.utils.cmd_line import parse_args
from src.scripts.plotting import plot_rf

## 1. sparse coding inference

In [2]:
device_name = 'cuda:0'

root_path = '/home/xd/data/defect_detection/data/focusight1_round1_train_part1'
ok_subpath = 'OK_Images'
ok_path = os.path.join(root_path, ok_subpath)

model_path = './trained_models'
train_name = 'sparse-net'
train_id = 'p1-04'
checkpoint_epoch = 10
checkpoint_path = os.path.join(
    model_path,
    '{}_{}'.format(train_name, train_id),
    'ckpt_{}.pth'.format(checkpoint_epoch)
)
# a temp solution
checkpoint_path = '../../trained_models/ckpt-990.pth'

N = 2000
patch_size = 10
n_neuron = 400
r_learning_rate = 1e-2
reg = 5e-3

NU = 0.2
KERNEL = 'rbf'
GAMMA = 0.1

In [3]:
dataset = FolderPatchDataset(
    patch_size, patch_size,
    N=N,
    folder=ok_path,
    training=True
)
dataloader = DataLoader(
    dataset,
    shuffle=False,
    batch_size=dataset.N,
    num_workers=8
)

100%|██████████| 1000/1000 [00:00<00:00, 1146.39it/s]


In [4]:
device = torch.device(device_name)

model = SparseNet(
    n_neuron,
    patch_size,
    R_lr=r_learning_rate,
    lmda=reg,
    device=device
)

model = torch.load(checkpoint_path, map_location='cpu')
model.to(device)



SparseNet(
  (U): Linear(in_features=400, out_features=100, bias=False)
)

In [5]:
model.eval()

# preds = []
resps = []

for img_batch in tqdm(dataloader, total=len(dataloader), file=sys.stdout):
    img_batch = img_batch.reshape(img_batch.shape[0], -1).to(device)
    
    pred = model(img_batch)
    # preds.append(pred.cpu().detach())
    
    resps.append(model.R.cpu().detach())

100%|██████████| 1000/1000 [02:23<00:00,  6.98it/s]


In [6]:
np_resps = [resp.numpy() for resp in resps]
np_resps = np.concatenate(np_resps)

print(np_resps.shape)

(144000, 400)


## 2. train ocsvm

In [7]:
# clf = svm.OneClassSVM(nu=0.1, kernel="rbf", gamma=0.1)
clf = OneClassSVM(nu=NU, kernel=KERNEL, gamma=GAMMA, verbose=True)

start_ts = time.time()
clf.fit(np_resps)
end_ts = time.time()

print('fit duration: {}s'.format(end_ts-start_ts))

fit duration: 22.511326551437378s


In [8]:
clf.save_to_file('../../svdd.ckpt')