In [None]:
import numpy as np
from glob import glob
from tqdm import tqdm 
from model import VAE
import torch
import cv2
import json 
from pathlib import Path
import os 
from datetime import datetime

# Get the current date in monthdayyear format
date = datetime.now().strftime('%m%d%Y')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

latent_dim = 10

checkpoint_path = f"./weights/lat_{latent_dim}.ckpt"
checkpoint = torch.load(checkpoint_path)
model = VAE(latent_dim).eval().to(device)


In [None]:
# Read data from string 
def process_string_mech(dir, toNpy = True):
    # I do not know why but this works for windows os. You may need to change this if you are using linux/macbook
    # Zhijie: you can test using strings like: 
    # ./outputs-4bar/-0.001 2.728 5.504 -1.565 -5.632 -2.481 -8.711 9.682 1.320 -5.630 -7.171 3.601 RRRP  0.42 0.026 0.732 -0.026 0.42 2.011 0. 0. 1. .jpg
    input_string = dir.split('\\')[-1].split('.j')[0] 
    
    # Split the string by spaces
    parts = input_string.split()
    
    # Initialize lists to hold floats
    floats_before = []
    floats_after = []
    letter_string = None
    
    # Iterate over parts to separate floats and the letter string
    for part in parts:
        try:
            # Try to convert part to float
            num = float(part)
            # Add to floats_before if letter_string is not yet found
            if letter_string is None:
                floats_before.append(num)
            else:
                floats_after.append(num)
        except ValueError:
            # If conversion fails, this part is the letter string
            letter_string = part
    
    if toNpy:
        floats_before = np.array(floats_before).reshape((-1, 2))
        floats_after = np.matrix(floats_after).reshape((3, 3))
    
    #if len(floats_before) != 10: # security check... you should change this for your specific mechanism. 
    #    print('you got fucked', dir, '\n' , floats_before, '\n')
    return floats_before, letter_string, floats_after

In [None]:
# BSIdictionary update (PRPR)
file_path = './KV_468.json'

# Open and read the JSON file
with open(file_path, 'r') as file:
    KVdict = json.load(file) 

image_folder = './outputs-8bar/'

setSize = 1000 # len(imgStrings) # determine how many samples for each type. 
batchSize = 500

four_bar = ['RRRR', 'RRRP', 'RRPR', 'PRPR'] 

six_bar  = ['Watt1T1A1', 'Watt1T2A1', 'Watt1T3A1', 'Watt1T1A2', 'Watt1T2A2', 'Watt1T3A2', 
            'Watt2T1A1', 'Watt2T2A1', 'Watt2T1A2', 'Watt2T2A2', 'Steph1T1', 'Steph1T2',
            'Steph1T3', 'Steph2T1A1', 'Steph2T2A1', 'Steph2T2A2' 'Steph3T1A1', 'Steph3T2A1', 'Steph3T1A2', 
            'Steph3T2A2', 'Steph2T1A2']

eight_bar = ['Type824-0', 'Type822-13', 'Type814-18', 'Type824-9', 'Type812-2', 'Type817-6', 'Type821-5', 'Type812-6', 'Type821-4', 'Type822-19', 'Type816-4', 'Type823-0', 'Type825-5', 'Type824-8', 'Type817-5', 'Type815-5', 'Type831-3', 'Type814-12', 'Type814-3', 'Type817-8', 'Type822-11', 'Type823-14', 'Type823-2', 'Type822-10', 'Type821-8', 'Type824-3', 'Type814-17', 'Type814-19', 'Type824-2', 'Type821-6', 'Type814-7', 'Type823-12', 'Type823-8', 'Type822-15', 'Type824-5', 'Type817-3', 'Type815-4', 'Type825-0', 'Type811-3', 'Type822-18', 'Type814-9', 'Type814-6', 'Type824-12', 'Type817-2', 'Type822-1', 'Type814-1', 'Type823-4', 'Type813-0', 'Type816-6', 'Type817-9', 'Type824-13', 'Type822-14', 'Type811-1', 'Type811-4', 'Type819-0', 'Type819-1', 'Type824-1', 'Type825-4', 'Type812-1', 'Type815-7', 'Type822-5', 'Type824-4', 'Type821-3', 'Type815-8', 'Type817-4', 'Type814-5', 'Type814-13', 'Type818-3', 'Type815-2', 'Type815-6', 'Type832-2', 'Type816-5', 'Type817-1', 'Type814-16', 'Type825-2', 'Type814-15', 'Type824-11', 'Type831-1', 'Type824-15', 'Type823-9', 'Type812-5', 'Type815-0', 'Type821-7', 'Type817-7', 'Type812-7', 'Type821-9', 'Type819-2', 'Type816-10', 'Type821-0', 'Type812-3', 'Type823-15', 'Type814-0', 'Type825-1', 'Type823-13', 'Type816-2', 'Type817-10', 'Type825-3', 'Type813-2', 'Type822-4', 'Type822-17', 'Type823-6', 'Type822-6', 'Type823-5', 'Type814-11', 'Type821-1', 'Type816-8', 'Type818-2', 'Type822-0', 'Type816-0', 'Type814-4', 'Type822-8', 'Type822-9', 'Type823-7', 'Type822-2', 'Type816-9', 'Type815-3', 'Type824-6', 'Type816-7', 'Type815-9', 'Type823-10', 'Type818-0', 'Type823-1', 'Type814-10', 'Type832-3', 'Type812-0', 'Type832-1', 'Type811-2', 'Type814-14', 'Type831-0', 'Type822-16', 'Type822-7', 'Type824-14', 'Type831-2', 'Type813-1', 'Type822-12', 'Type832-4', 'Type816-11', 'Type832-0', 'Type811-0', 'Type822-3', 'Type814-2', 'Type824-10', 'Type812-4', 'Type818-1', 'Type816-3', 'Type824-7', 'Type814-8', 'Type823-3', 'Type817-0', 'Type815-1', 'Type821-2', 'Type816-1', 'Type823-11']

z_folder = './outputs-z/'
e_folder = './outputs-encoded/'

os.makedirs(z_folder, exist_ok=True)
os.makedirs(e_folder, exist_ok=True)

for mechType in eight_bar:
    batchImg = []
    result_zSet = []
    result_featSet = []
    value = KVdict[mechType]
    z_folder = './outputs-z/'
    e_folder = './outputs-encoded/'
    imgStrings = glob(image_folder + mechType + '/*')
    print(mechType, len(imgStrings))
    for i in tqdm(range(min(setSize, len(imgStrings)))): 
        batchImg.append(cv2.imread(str(Path(imgStrings[i])), cv2.IMREAD_GRAYSCALE)/ 255) # This /255 works better than not doing it 
        floats_before, letter_string, floats_after = process_string_mech(imgStrings[i], toNpy = False)
        if len(floats_after) == 6:
            floats_after = floats_after + [0, 0, 1]
        elif len(floats_after) != 9: 
            break
        
        result_featSet.append(np.array(floats_before + [KVdict[letter_string]] + floats_after, dtype= float).flatten().tolist())
        if len(batchImg) >= batchSize:
            images = torch.from_numpy(np.array([batchImg])).swapaxes(0,1).float().to(device)
            x = model.encoder(images)
            mean, logvar = x[:, : model.latent_dim], x[:, model.latent_dim :]
            z = model.reparameterize(mean, logvar)
            z = z.cpu().detach().numpy()
            result_zSet.append(z)
            batchImg = []

    if len(batchImg) > 0:
        images = torch.from_numpy(np.array([batchImg])).swapaxes(0,1).float().to(device)
        x = model.encoder(images)
        mean, logvar = x[:, : model.latent_dim], x[:, model.latent_dim :]
        z = model.reparameterize(mean, logvar)
        z = z.cpu().detach().numpy()
        result_zSet.append(z)
        batchImg = []

    if len(result_zSet) > 0:
        result_zSet = np.concatenate(result_zSet)
        batchZname = z_folder + date + '-z-' + str(int(KVdict[mechType]))
        batchEname = e_folder + date + '-encoded-' + str(int(KVdict[mechType]))
        np.save(batchZname, np.array(result_zSet))
        np.save(batchEname, np.array(result_featSet))