In [1]:
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
import json
import random
from PIL import Image
from sklearn.neighbors import NearestNeighbors

In [2]:
video_root = './v'
data_root = './data'

if not os.path.exists(data_root):
    os.makedirs(data_root)

In [3]:
# shape of library img
W = 1280
H = 720

# load library
with open('./avr_rgb.json', 'r', encoding='utf-8') as f:
    avr_RGB_data = json.load(f)
    
lib_RGB = list(avr_RGB_data.values())
lib_serials = list(avr_RGB_data.keys())

In [4]:
nbrs = NearestNeighbors(n_neighbors=5, algorithm='ball_tree').fit(lib_RGB)

In [5]:
def get_fitted_target(serial: int):
	target = cv2.imread(f'./target/{serial}.jpg')
	
	h, w, _ = target.shape
	h_prime = round(H / W * w)
	return cv2.resize(target, (w, h_prime))

In [6]:
W_SIZE = 200
H_SIZE = 200

def subdivide(t):
	subs = []

	height, width, channels = t.shape

	w_sub = width / W_SIZE
	h_sub = height / H_SIZE

	for ih in range(H_SIZE):
		for iw in range(W_SIZE):
			x = w_sub * iw 
			y = h_sub * ih

			sub = t[int(y):int(y+h_sub), int(x):int(x+w_sub)]
			subs.append(sub)

	return subs

In [7]:
def get_subdivide_RGB(subs):
	data = {}
	for i, img in enumerate(subs):
		data[i] = [round(np.mean(c)) for c in cv2.split(img)]

	return data

In [8]:
def select_candidates(t_RGB):
	_, indices = nbrs.kneighbors(t_RGB)

	selected_serial = []
	for ind in indices:
		ind = ind.tolist()
		fit_num = random.sample(ind, 1)[0]
		fit_serial = lib_serials[fit_num]
		selected_serial.append(fit_serial)

	return selected_serial

In [9]:
# thumbnail & output shape settings

thumb_width, thumb_height = W / W_SIZE * 5, H / H_SIZE * 5
grid_width = round(W_SIZE * thumb_width)
grid_height = round(H_SIZE * thumb_height)

print(thumb_width, thumb_height)
print(grid_width, grid_height)

thumb_width, thumb_height = round(thumb_width), round(thumb_height)

32.0 18.0
6400 3600


In [10]:
def load_image(serial):
	img_path = f"{data_root}/{serial}.jpg"
	
	try:
		img = Image.open(img_path).convert("RGB")
		return img.resize((thumb_width, thumb_height), Image.Resampling.LANCZOS)
	except Exception as e:
		print(img_path)
		print(e)

In [11]:
img_buffer = {}

def get_buffer(serial):
	img = img_buffer.get(serial, None)
	if img is None:
		img_buffer[serial] = load_image(serial)
		return img_buffer[serial]
	else:
		return img
	
def clear_buffer():
	for k, v in img_buffer.items():
		v.close()

In [12]:
def gen_result(candidates):
	composite_image = Image.new("RGB", (grid_width, grid_height))

	for i, serial in enumerate(candidates):
		x = (i % W_SIZE) * thumb_width
		y = (i // W_SIZE) * thumb_height

		composite_image.paste(get_buffer(serial), (round(x), round(y)))

	return composite_image

In [13]:
def target_workflow(target):
	t = get_fitted_target(target)
	t = cv2.cvtColor(t, cv2.COLOR_BGR2RGB)

	subs = subdivide(t)

	t_RGB_data = get_subdivide_RGB(subs)
	t_RGB = list(t_RGB_data.values())
	# t_serials = list(t_RGB_data.keys())

	candidates = select_candidates(t_RGB)
	result = gen_result(candidates)

	with open(f'./result/{target}.jpg', 'w+') as f:
		result.save(f, "JPEG")

	result.close()

In [17]:
for i in range(1, 12 + 1):
    print(f"processing target {i}")
    target_workflow(i)

processing target 1
processing target 2
processing target 3
processing target 4
processing target 5
processing target 6
processing target 7
processing target 8
processing target 9
processing target 10
processing target 11
processing target 12


In [15]:
# img_buffer = {}