In [1]:
%load_ext autoreload
%autoreload 2


from PIL import Image
import numpy as np

import tensorflow as tf

from dataloader.data_generator import DataGenerator
from models.projector import save_embeddings_for_tf_projector
from utils.helper import load_config


DATA_DIRPATH = "tiny-imagenet-200/train/"
BATCH_SIZE = 256
IMAGE_SIZE = (64, 64)
VALIDATION_SPLIT = 0.2

In [2]:
config_filepath = "config.yaml"
model_dirpath = "saved_models/efficientnet/"

In [3]:
# ---- Load config ---
config = load_config(config_filepath)

# --- Load data ---
data_generator = DataGenerator(
    directory=DATA_DIRPATH,
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    shuffle=True,
    seed=config["seed"],
    validation_split=VALIDATION_SPLIT
)

# --- Visualization ---
model = tf.keras.models.load_model(model_dirpath, compile=False)

Found 100000 files belonging to 200 classes.
Using 80000 files for training.
Using 20000 files for validation.


2022-12-28 19:20:06.636025: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-12-28 19:20:06.636152: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


Metal device set to: Apple M1


In [4]:
ds_val = data_generator.val_raw

In [5]:
images_pil = []

for x, y in ds_val: 
    img_pil = Image.fromarray(x.numpy().astype(np.uint8))
    images_pil.append(img_pil)

2022-12-28 19:20:28.164757: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


In [6]:
one_square_size = int(np.ceil(np.sqrt(len(images_pil))))
master_width = IMAGE_SIZE[0] * one_square_size
master_height = IMAGE_SIZE[1] * one_square_size
spriteimage = Image.new(
    mode="RGBA",
    size=(master_width, master_height),
    color=(0,0,0,0) # fully transparent
)

for count, image in enumerate(images_pil):
    div, mod = divmod(count, one_square_size)
    w_loc = IMAGE_SIZE[0] * mod
    h_loc = IMAGE_SIZE[1] * div
    spriteimage.paste(image, (w_loc, h_loc))

spriteimage.convert("RGB").save("sprite.jpg", transparency=0)