In [None]:
import os
from pathlib import Path
from joblib import Parallel, delayed
from tqdm import tqdm
import shutil
import itertools
import numpy as np
import time
import torch
from torch.utils.data import DataLoader
from utils import FaceImageDataset, FaceDataset, PyLModel
from utils import SelectCategory, find_latest_checkpoint_path

data_dir = Path.cwd() / "data_select"
dataset_dir = data_dir / "dataset"
log_dir = data_dir / "log"
check_output_dir = data_dir / "check_output"
update_dir = data_dir / "update"

In [None]:
shutil.rmtree(check_output_dir, ignore_errors=True)
shutil.rmtree(update_dir, ignore_errors=True)
time.sleep(1)
check_output_dir.mkdir(exist_ok=True)
update_dir.mkdir(exist_ok=True)

for category in SelectCategory:
    (check_output_dir / category.name).mkdir(exist_ok=True)
    (update_dir / category.name).mkdir(exist_ok=True)

In [None]:
device = "cuda"
#device = "cpu"

batch_size=512

In [None]:
dataset = FaceImageDataset(dataset_dir, SelectCategory)
dataset = FaceDataset(dataset, with_flipped=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8)
#dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

checkpoint_path = find_latest_checkpoint_path(log_dir / "lightning_logs")
assert(checkpoint_path is not None)

model = PyLModel.load_from_checkpoint(str(checkpoint_path))
print("Load:", checkpoint_path)

In [None]:
model = model.to(device)
model = model.eval()

In [None]:
prob_list = []

with torch.no_grad():
    for batch in tqdm(dataloader):
        batch, batch_flip, label = batch
        batch_size = len(batch)
        
        batch = torch.cat([batch, batch_flip], dim=0)
        batch = batch.to(device)
        label = label.to(device)
        
        pred = model(batch)
        
        pred = pred.view(2, batch_size, len(SelectCategory))
        label = label.view(1, batch_size, 1).repeat(2, 1, 1)
        pred = pred.gather(dim=2, index=label)
        pred = pred.reshape(2, batch_size)
        pred = pred.exp()
        prob, _ = pred.min(dim=0)

        prob = prob.cpu().numpy()
        prob_list.append(prob)

prob_list = np.concatenate(prob_list, axis=0)

In [None]:
check_threshold = 0.9
check_list = (prob_list < check_threshold)

image_files = dataset.dataset.data
assert(len(check_list) == len(image_files))

parameters = zip(image_files, check_list)
parameters = [(p, SelectCategory(l).name) for (p, l), flag in parameters if flag]
parameters = [(p, check_output_dir / f"{l}/{p.name}") for p, l in parameters]

assert(len(parameters) == sum(check_list))

In [None]:
for src_path, dst_path in tqdm(parameters):
    shutil.copyfile(src_path, dst_path)