In [None]:
import os
from pathlib import Path
from joblib import Parallel, delayed
from tqdm import tqdm
import shutil
import itertools
import numpy as np
import time
import math
import cv2
import torch
from torch.utils.data import DataLoader
from utils import FaceInferenceImageDataset, FaceDataset, PyLModel
from utils import ScaleCategory, find_latest_checkpoint_path, load_inference_dir

data_dir = Path.cwd() / "data_scale"
dataset_dir = data_dir / "dataset"
log_dir = data_dir / "log"
inference_dir = data_dir / "inference"
#inference_dir = data_dir / "raw"
output_dir = data_dir / "output"

In [None]:
shutil.rmtree(output_dir, ignore_errors=True)
time.sleep(1)
output_dir.mkdir(exist_ok=True)

In [None]:
device = "cuda"
#device = "cpu"

batch_size=512
#batch_size=16

In [None]:
dataset = FaceInferenceImageDataset(inference_dir)
dataset = FaceDataset(dataset, with_flipped=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8)

checkpoint_path = find_latest_checkpoint_path(log_dir / "lightning_logs")
assert(checkpoint_path is not None)

model = PyLModel.load_from_checkpoint(str(checkpoint_path))
print("Load:", checkpoint_path)

In [None]:
model = model.to(device)
model = model.eval()

In [None]:
tmp_input_dir = data_dir / "tmp_input"
tmp_output_dir = data_dir / "tmp_output"

shutil.rmtree(tmp_input_dir, ignore_errors=True)
shutil.rmtree(tmp_output_dir, ignore_errors=True)
time.sleep(1)
tmp_input_dir.mkdir(exist_ok=True)
tmp_output_dir.mkdir(exist_ok=True)
time.sleep(1)

In [None]:
parameters = load_inference_dir(inference_dir)
parameters = [(p, tmp_input_dir / p.name) for p in parameters]

print("copy files:", len(parameters))

for src_path, dst_path in tqdm(parameters):
    shutil.copyfile(src_path, dst_path)
time.sleep(1)

In [None]:
scale_factor = 0.1

def resize_process(param):
    src_path, dst_path = param

    image = cv2.imread(str(src_path), cv2.IMREAD_COLOR)
    image_h, image_w, _ = image.shape

    trim_w = math.ceil(image_w * scale_factor * 0.5)
    trim_h = math.ceil(image_h * scale_factor * 0.5)

    crop_image = image[trim_h:image_h-trim_h-1, trim_w:image_w-trim_w-1, :]
    crop_image = cv2.resize(crop_image, (image_w, image_h), interpolation = cv2.INTER_CUBIC)

    cv2.imwrite(str(dst_path), crop_image)

In [None]:
for scale in range(64):
    dataset = FaceInferenceImageDataset(tmp_input_dir)
    if (len(dataset) == 0):
        break
    
    dataset = FaceDataset(dataset, with_flipped=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8)
    
    estimated_list = []

    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch, batch_flip, _ = batch
            this_batch_size = len(batch)

            batch = torch.cat([batch, batch_flip], dim=0)
            batch = batch.to(device)

            pred = model(batch)

            pred = pred.view(2, this_batch_size, len(ScaleCategory))
            pred = pred.sum(dim=0)
            _, estimated = pred.max(dim=1)

            estimated = estimated.cpu().numpy()
            estimated_list.append(estimated)

    estimated_list = np.concatenate(estimated_list, axis=0)
    
    image_files = dataset.dataset.paths
    assert(len(estimated_list) == len(image_files))
    
    output_dir_selector = {
        ScaleCategory.valid: tmp_output_dir,
        ScaleCategory.invalid: output_dir,
    }
    
    parameters = zip(image_files, estimated_list)
    parameters = list(parameters)
    parameters = [(p, output_dir_selector[l] / p.name) for p, l in parameters]
    
    shutil.rmtree(tmp_output_dir, ignore_errors=True)
    time.sleep(1)
    tmp_output_dir.mkdir(exist_ok=True)

    for src_path, dst_path in tqdm(parameters):
        shutil.copyfile(src_path, dst_path)
    
    shutil.rmtree(tmp_input_dir, ignore_errors=True)
    time.sleep(1)
    tmp_input_dir.mkdir(exist_ok=True)
    
    parameters = list(tmp_output_dir.glob("*.png"))
    parameters = [(p, tmp_input_dir / f"{p.stem}_.png") for p in parameters]
    
    #[resize_process(params) for params in tqdm(parameters)]
    Parallel(n_jobs=-1, verbose=10)([delayed(resize_process)(params) for params in parameters])

In [None]:
output_files = load_inference_dir(output_dir)
output_files = [p.stem.split("_") for p in output_files]
assert(min([len(l) for l in output_files]) >= 4)
output_files = ["_".join(l[0:4]) for l in output_files]
input_files = load_inference_dir(inference_dir)
input_files = [p.stem for p in input_files]

parameters = [f"{p}.png" for p in tqdm(input_files) if not p in output_files]
parameters = [(inference_dir / p, output_dir / p) for p in parameters]

for src_path, dst_path in tqdm(parameters):
    shutil.copyfile(src_path, dst_path)