In [94]:
# general imports
import porespy as ps
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
from tqdm import tqdm

# torch imports
import torch
from torch import nn
import import_ipynb
from model import VariationalAutoEncoder
from torchvision import transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, RandomSampler
from torchvision.utils import save_image 

# GUI imports
import tkinter as tk
from tkinter import *
from PIL import Image, ImageTk

### loading the model

In [11]:
# configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
INPUT_DIM = 256
INIT_DIM = 8
LATENT_DIM = 3
BATCH_SIZE = 1
LR_RATE = 3e-4
KERNEL_SIZE = 4

In [12]:
model = VariationalAutoEncoder(init_dim=INIT_DIM, latent_dim=LATENT_DIM, kernel_size=KERNEL_SIZE)
model.load_state_dict(torch.load('models/model'))
model.eval()

VariationalAutoEncoder(
  (enc1): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (enc2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (enc3): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (enc4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
  (fc1): Linear(in_features=14400, out_features=128, bias=True)
  (fc_mu): Linear(in_features=128, out_features=3, bias=True)
  (fc_log_var): Linear(in_features=128, out_features=3, bias=True)
  (fc2): Linear(in_features=3, out_features=256, bias=True)
  (dec1): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(1, 1))
  (dec2): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (dec3): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (dec4): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (dec5): ConvTranspose2d(32, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (dec6): ConvTranspose2d

In [36]:
def decode_sample(model,sample):
    x,y,z = sample
    out = model.decode(torch.tensor([x, y, z]).float())
    out = out.view(-1, 1, 256, 256)
    img = out.detach().numpy()[0][0]
    return img

In [130]:
class VAE_GUI:
    def __init__(self, root):
        self.model = model
        self.canvas_size = 256
        self.latent_dim = 3
        
        self.root = root
        self.root.geometry("552x552")  # Set window size
        self.root.resizable(False, False)  # Disable resizing
        
        self.canvas = Canvas(root, width=self.canvas_size, height=self.canvas_size, bg='black')
        self.canvas.pack(pady=10)
        self.canvas.bind("<Button-1>", self.on_click)
        self.canvas.bind("<B1-Motion>", self.on_drag)
        
        self.image_label = tk.Label(root, width = self.canvas_size, height = self.canvas_size)
        self.image_label.pack(pady=10)
        
    def on_click(self, event):
        self.update_image(event)

    def on_drag(self, event):
        self.update_image(event)

    def update_image(self, event):
        # Normalize click coordinates to [-3, 3]
        x = 6 * (event.x / self.canvas_size - 0.5)
        y = 6 * (event.y / self.canvas_size - 0.5)
        
        # Assuming the third dimension of latent space is zero for 2D click
        point = [x, y, 0]
        
        decoded_img = decode_sample(self.model, point)
        img = Image.fromarray((decoded_img * 255).astype(np.uint8))
        img = img.resize((256, 256), Image.ANTIALIAS)
        
        img_tk = ImageTk.PhotoImage(img)
        self.image_label.configure(image=img_tk)
        self.image_label.image = img_tk


In [None]:
root = tk.Tk()
gui = VAE_GUI(root)
root.mainloop()