In [None]:
import os
from pathlib import Path
from joblib import Parallel, delayed
from tqdm import tqdm
import shutil
import itertools
import numpy as np
import time
import torch
from torch.utils.data import DataLoader
from utils import FaceInferenceImageDataset, FaceDataset, PyLModel
from utils import FilterCategory, find_latest_checkpoint_path

data_dir = Path.cwd() / "data_filter"
dataset_dir = data_dir / "dataset"
log_dir = data_dir / "log"
inference_dir = data_dir / "inference"
#inference_dir = data_dir / "raw"
output_dir = data_dir / "output"

In [None]:
shutil.rmtree(output_dir, ignore_errors=True)
time.sleep(1)
output_dir.mkdir(exist_ok=True)

for category in FilterCategory:
    (output_dir / category.name).mkdir(exist_ok=True)

In [None]:
device = "cuda"
#device = "cpu"

batch_size=512

In [None]:
dataset = FaceInferenceImageDataset(inference_dir)
dataset = FaceDataset(dataset, with_flipped=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8)

checkpoint_path = find_latest_checkpoint_path(log_dir / "lightning_logs")
assert(checkpoint_path is not None)

model = PyLModel.load_from_checkpoint(str(checkpoint_path))
print("Load:", checkpoint_path)

In [None]:
model = model.to(device)
model = model.eval()

In [None]:
estimated_list = []

with torch.no_grad():
    for batch in tqdm(dataloader):
        batch, batch_flip, _ = batch
        batch_size = len(batch)
        
        batch = torch.cat([batch, batch_flip], dim=0)
        batch = batch.to(device)
        
        pred = model(batch)
        
        pred = pred.view(2, batch_size, len(FilterCategory))
        pred = pred.sum(dim=0)
        _, estimated = pred.max(dim=1)
        
        estimated = estimated.cpu().numpy()
        estimated_list.append(estimated)

estimated_list = np.concatenate(estimated_list, axis=0)

In [None]:
image_files = dataset.dataset.paths
assert(len(estimated_list) == len(image_files))

parameters = zip(image_files, estimated_list)
parameters = [(p, FilterCategory(l).name) for p, l in parameters]
parameters = [(p, output_dir / f"{l}/{p.name}") for p, l in parameters]

In [None]:
import pickle
with open(data_dir / "tmp_estimated.pkl", "wb") as f:
    pickle.dump(parameters, f)

In [None]:
import pickle
with open(data_dir / "tmp_estimated.pkl", "rb") as f:
    parameters = pickle.load(f)

In [None]:
for src_path, dst_path in tqdm(parameters):
    shutil.copyfile(src_path, dst_path)