In [None]:
import numpy as np
from glob import glob
from tqdm import tqdm 
from vae import VAE
import torch
import cv2
import json 
import os
from datetime import datetime

# Get the current date in monthdayyear format
date = datetime.now().strftime('%m%d%Y')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

latent_dim = 10

checkpoint_path = f"./weights/lat_{latent_dim}.ckpt"
checkpoint = torch.load(checkpoint_path)
model = VAE(latent_dim).eval().to(device)


In [None]:
# Read data from string 
def process_string_mech(dir, toNpy = True):
    # I do not know why but this works for windows os. You may need to change this if you are using linux/macbook
    # Zhijie: you can test using strings like: 
    # ./outputs-4bar/-0.001 2.728 5.504 -1.565 -5.632 -2.481 -8.711 9.682 1.320 -5.630 -7.171 3.601 RRRP  0.42 0.026 0.732 -0.026 0.42 2.011 0. 0. 1. .jpg
    input_string = dir.split('\\')[-1].split('.j')[0] 
    
    # Split the string by spaces
    parts = input_string.split()
    
    # Initialize lists to hold floats
    floats_before = []
    floats_after = []
    letter_string = None
    
    # Iterate over parts to separate floats and the letter string
    for part in parts:
        try:
            # Try to convert part to float
            num = float(part)
            # Add to floats_before if letter_string is not yet found
            if letter_string is None:
                floats_before.append(num)
            else:
                floats_after.append(num)
        except ValueError:
            # If conversion fails, this part is the letter string
            letter_string = part
    
    if toNpy:
        floats_before = np.array(floats_before).reshape((-1, 2))
        floats_after = np.matrix(floats_after).reshape((3, 3))
    
    return floats_before, letter_string, floats_after

In [None]:
# BSIdictionary update (PRPR)
file_path = './KV_468.json'

# Open and read the JSON file
with open(file_path, 'r') as file:
    KVdict = json.load(file) 

image_folder = './outputs-6bar/'

six_bar  = ['Watt1T1A1', 'Watt1T2A1', 'Watt1T3A1', 'Watt1T1A2', 'Watt1T2A2', 'Watt1T3A2', 
            'Watt2T1A1', 'Watt2T2A1', 'Watt2T1A2', 'Watt2T2A2', 'Steph1T1', 'Steph1T2',
            'Steph1T3', 'Steph2T1A1', 'Steph2T2A1', 'Steph2T2A2', 'Steph3T1A1', 'Steph3T2A1', 
            'Steph3T1A2', 'Steph3T2A2', 'Steph2T1A2']

z_folder = './outputs-z/'
e_folder = './outputs-encoded/'

os.makedirs(z_folder, exist_ok=True)
os.makedirs(e_folder, exist_ok=True)

setSize = 3000 # len(imgStrings) # determine how many samples for each type. 
batchSize = 1000

for mechType in six_bar:
    batchImg = []
    result_zSet = []
    result_featSet = []
    value = KVdict[mechType]
    imgStrings = glob(image_folder + mechType + '/*')
    for i in tqdm(range(min(setSize, len(imgStrings)))): 
        batchImg.append(cv2.imread(imgStrings[i], cv2.IMREAD_GRAYSCALE)/ 255) # This /255 works better than not doing it 
        floats_before, letter_string, floats_after = process_string_mech(imgStrings[i], toNpy = False)
        if len(floats_after) == 6:
            floats_after = floats_after + [0, 0, 1]
        elif len(floats_after) != 9: 
            break
        result_featSet.append(np.array(floats_before + [KVdict[letter_string]] + floats_after, dtype= float).flatten().tolist())
        if len(batchImg) >= batchSize:
            images = torch.from_numpy(np.array([batchImg])).swapaxes(0,1).float().to(device)
            x = model.encoder(images)
            mean, logvar = x[:, : model.latent_dim], x[:, model.latent_dim :]
            z = model.reparameterize(mean, logvar)
            z = z.cpu().detach().numpy()
            result_zSet.append(z)
            batchImg = []

    if len(batchImg) > 0:
        images = torch.from_numpy(np.array([batchImg])).swapaxes(0,1).float().to(device)
        x = model.encoder(images)
        mean, logvar = x[:, : model.latent_dim], x[:, model.latent_dim :]
        z = model.reparameterize(mean, logvar)
        z = z.cpu().detach().numpy()
        result_zSet.append(z)
        batchImg = []

    result_zSet = np.concatenate(result_zSet)

    batchZname = z_folder + date + '-z-' + str(int(KVdict[mechType]))
    batchEname = e_folder + date + '-encoded-' + str(int(KVdict[mechType]))
    np.save(batchZname, np.array(result_zSet))
    np.save(batchEname, np.array(result_featSet))