<a href="https://colab.research.google.com/github/ysinjab/latent-glitch/blob/main/Michelangelo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Implementation based and inspired on original paper https://arxiv.org/abs/1707.09557

# Install Dependencies


In [2]:
!pip install voxelfuse
!pip install tensorflow-gan

import os
import scipy
import scipy.io
import numpy as np
import tensorflow as tf
from tensorflow.keras import Sequential, Model
from tensorflow.keras.layers import Input, Conv3D, Conv3DTranspose, Activation, BatchNormalization, LeakyReLU, Flatten, Reshape, Dense, Dropout, UpSampling3D
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from tensorflow.keras import backend 
from tensorflow.keras.constraints import Constraint
import tensorflow_gan as tfgan
from google.colab import output
from voxelfuse.voxel_model import VoxelModel
from voxelfuse.mesh import Mesh
from voxelfuse.primitives import generateMaterials

output.clear()

# Generator & Discrimniator

In [3]:
def make_Generator(kernel_size=5, strides=2, latent_dim=32):
     model = Sequential()
     model.add(Input(shape=(latent_dim,)))

     model.add(Dense(units=2048, input_shape=(latent_dim,), use_bias=False))
     model.add(BatchNormalization(momentum=0.9))
     model.add(Reshape((2, 2, 2, 256)))
     model.add(Activation('relu'))
     model.add(Dropout(0.4))

     model.add(UpSampling3D())
     model.add(Conv3DTranspose(filters=512, kernel_size=kernel_size, padding='same', use_bias=False))
     model.add(BatchNormalization(momentum=0.9))
     model.add(Activation('relu'))

     model.add(UpSampling3D())
     model.add(Conv3DTranspose(filters=256, kernel_size=kernel_size, padding='same', use_bias=False))
     model.add(BatchNormalization(momentum=0.9))
     model.add(Activation('relu'))

     model.add(UpSampling3D())
     model.add(Conv3DTranspose(filters=128, kernel_size=kernel_size, padding='same', use_bias=False))
     model.add(BatchNormalization(momentum=0.9))
     model.add(Activation('relu'))

     model.add(UpSampling3D())
     model.add(Conv3DTranspose(filters=64, kernel_size=kernel_size, padding='same', use_bias=False))
     model.add(BatchNormalization(momentum=0.9))
     model.add(Activation('relu'))

     model.add(UpSampling3D())
     model.add(Conv3DTranspose(filters=1, kernel_size=kernel_size, padding='same', use_bias=False))
     model.add(Activation('sigmoid'))

     return model

In [4]:
def make_Discriminator(kernel_size=5, strides=2, im_dim=64):
    model = Sequential()
    model.add(Input(shape=(im_dim, im_dim, im_dim, 1)))
    
    model.add(Conv3D(filters=64, kernel_size=kernel_size, strides=strides, padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.35))

    model.add(Conv3D(filters=128,kernel_size=kernel_size, strides=strides, padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.35))

    model.add(Conv3D(filters=256,kernel_size=kernel_size, strides=strides, padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.35))

    model.add(Conv3D(filters=512,kernel_size=kernel_size, strides=strides, padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.35))

    model.add(Flatten())
    model.add(Dense(1))

    return model

# Prepare training dataset

In [5]:
!mkdir -p /content/datasets/meshes
# https://drive.google.com/file/d//view?usp=sharing
!gdown --id 1T9ZG0BKIFkotve7zvxbgPOicRQs1sreh

Downloading...
From: https://drive.google.com/uc?id=1T9ZG0BKIFkotve7zvxbgPOicRQs1sreh
To: /content/64.zip
  0% 0.00/78.9k [00:00<?, ?B/s]100% 78.9k/78.9k [00:00<00:00, 29.6MB/s]


In [6]:
!unzip /content/64.zip -d  /content/datasets/meshes
!rm /content/datasets/meshes/64/.DS_Store
output.clear()

In [7]:
!rm -rf /content/datasets/meshes/64/__MACOSX

In [8]:
training_data = []
data_path = '/content/datasets/meshes/64'

for filename in os.listdir(data_path):
  voxel_model = np.load(f'{data_path}/{filename}', allow_pickle=True)['arr_0']
  training_data.append(voxel_model)

training_data = np.array(training_data).astype(np.float32)

In [11]:
len(training_data)

96

In [12]:
training_data[0].shape

(64, 64, 64)

# Define functions


In [13]:
def save_voxel(voxel_model, path):
    model = VoxelModel(voxel_model)  
    print('model = VoxelModel(voxel_model)  ')
    mesh = Mesh.fromVoxelModel(model)
    print('mesh = Mesh.fromVoxelModel(model)')
    mesh.export(path)

def gradient_penalty(real, fake, epsilon): 
    global discriminator
    #mixed_images = real * epsilon + fake * (1 - epsilon)
    mixed_images = fake + epsilon * (real - fake)
    with tf.GradientTape() as tape:
        tape.watch(mixed_images) 
        mixed_scores = discriminator(mixed_images)
        

    gradient = tape.gradient(mixed_scores, mixed_images)[0]
    
    gradient_norm = tf.norm(gradient)
    penalty = tf.math.reduce_mean((gradient_norm - 1)**2)
    return penalty

def generator_loss(fake_output):
    gen_loss = -1. * tf.math.reduce_mean(fake_output)
    return gen_loss

def discriminator_loss(real_output, fake_output, gradient_penalty):
    c_lambda = 10
    loss = tf.math.reduce_mean(fake_output) - tf.math.reduce_mean(real_output) + c_lambda * gradient_penalty
    return loss

# Training

In [14]:
discriminator = make_Discriminator()
generator = make_Generator()

In [None]:
generator.summary()

Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_3 (Dense)             (None, 2048)              65536     
                                                                 
 batch_normalization_5 (Batc  (None, 2048)             8192      
 hNormalization)                                                 
                                                                 
 reshape_1 (Reshape)         (None, 2, 2, 2, 256)      0         
                                                                 
 activation_6 (Activation)   (None, 2, 2, 2, 256)      0         
                                                                 
 dropout_9 (Dropout)         (None, 2, 2, 2, 256)      0         
                                                                 
 up_sampling3d_5 (UpSampling  (None, 4, 4, 4, 256)     0         
 3D)                                                  

In [None]:
discriminator.summary()

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv3d_4 (Conv3D)           (None, 32, 32, 32, 64)    8064      
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 32, 32, 32, 64)    0         
                                                                 
 dropout_5 (Dropout)         (None, 32, 32, 32, 64)    0         
                                                                 
 conv3d_5 (Conv3D)           (None, 16, 16, 16, 128)   1024128   
                                                                 
 leaky_re_lu_5 (LeakyReLU)   (None, 16, 16, 16, 128)   0         
                                                                 
 dropout_6 (Dropout)         (None, 16, 16, 16, 128)   0         
                                                                 
 conv3d_6 (Conv3D)           (None, 8, 8, 8, 256)     

In [15]:

def train(sample_path, checkpoints_path, num_epochs=1000, batch_size=8, latent_dim=32, restore_D=None, restore_G=None, restore_epoch=0):
    global discriminator
    global generator

    sample_epoch = 500
    save_epoch = 500

    dis_optim = RMSprop(lr=0.0002, decay=6e-8)
    gen_optim = RMSprop(lr=0.0001, decay=3e-8)

    if restore_D!=None:
      discriminator=tf.keras.models.load_model(restore_D)

    if restore_G!=None:
      generator=tf.keras.models.load_model(restore_G)

    generator.compile(optimizer=gen_optim)
    discriminator.compile(optimizer=dis_optim)

    dl, gl = [],[]
    for epoch in range(restore_epoch, num_epochs):
        #sample a random batch
        idx = np.random.randint(len(training_data), size=batch_size)
        real = training_data[idx]
        real = real.reshape(real.shape+(1,))

        noise = tf.random.normal([batch_size, latent_dim])
        for i in range(3):

          with tf.GradientTape() as disc_tape:
        
            generated_images = generator(noise, training=True)

            real_output = discriminator(real, training=True)
            fake_output = discriminator(generated_images, training=True)
        
            epsilon = tf.random.normal([batch_size, 1, 1, 1, 1], 0.0, 1.0)
        
            gp = gradient_penalty(real, generated_images, epsilon)
        
            disc_loss = discriminator_loss(real_output, fake_output, gp)

    
          gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
          dis_optim.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

        with tf.GradientTape() as gen_tape:
          generated_images = generator(noise, training=True)
          fake_output = discriminator(generated_images, training=True)
          gen_loss = generator_loss(fake_output)
        gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
        gen_optim.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))

        print('Training epoch {}/{}, d_loss: {},  g_loss: {}'.format(epoch+1, num_epochs, disc_loss, gen_loss))

        # sampling
        if epoch % sample_epoch == 0:
            if not os.path.exists(sample_path):
                os.makedirs(sample_path)
            print('Sampling...')
            sample_noise = np.random.uniform(-1.0, 1.0, size=[1, latent_dim]).astype(np.float64)
            voxel_model = generator.predict(sample_noise, verbose=1)
            voxel_model = voxel_model.reshape(voxel_model[0].shape[:-1])
            voxel_model = np.rint(voxel_model)
            try:
              save_voxel(voxel_model, sample_path + f'/epoch_{epoch+1}.obj')
            except Exception as ex:
              print('Could not create voxel model... Continuing training')
              print(ex)

        # save weights
        if epoch % save_epoch == 0:
            if not os.path.exists(checkpoints_path):
                os.makedirs(checkpoints_path)
            generator.save(checkpoints_path + '/generator_epoch_' + str(epoch+1))
            discriminator.save(checkpoints_path + '/discriminator_epoch_' + str(epoch+1))

In [16]:
# from google.colab import drive
# drive.mount('/content/gdrive')

In [17]:
!mkdir -p /content/training/samplings
!mkdir -p /content/training/model_checkpoints

In [18]:
sample_path = '/content/training/samplings'
checkpoints_path = '/content/training/model_checkpoints'

train(sample_path, checkpoints_path, num_epochs=1000000)

  super(RMSprop, self).__init__(name, **kwargs)


Training epoch 1/1000000, d_loss: -778.9308471679688,  g_loss: 7162.4765625
Sampling...
model = VoxelModel(voxel_model)  


Finding exterior voxels: 100%|██████████| 48/48 [00:00<00:00, 468.67it/s]
Meshing: 100%|██████████| 540/540 [00:05<00:00, 100.91it/s]


mesh = Mesh.fromVoxelModel(model)
INFO:tensorflow:Assets written to: /content/training/model_checkpoints/generator_epoch_1/assets
INFO:tensorflow:Assets written to: /content/training/model_checkpoints/discriminator_epoch_1/assets
Training epoch 2/1000000, d_loss: -707.9207763671875,  g_loss: 1071.9080810546875
Training epoch 3/1000000, d_loss: -547.8465576171875,  g_loss: 943.881103515625
Training epoch 4/1000000, d_loss: -596.46728515625,  g_loss: 1101.9921875
Training epoch 5/1000000, d_loss: -252.6448974609375,  g_loss: 387.1937255859375
Training epoch 6/1000000, d_loss: -183.7472686767578,  g_loss: 300.0902099609375
Training epoch 7/1000000, d_loss: -167.5020294189453,  g_loss: 270.99420166015625
Training epoch 8/1000000, d_loss: -168.0524139404297,  g_loss: 252.30172729492188
Training epoch 9/1000000, d_loss: -228.27256774902344,  g_loss: 370.7388916015625
Training epoch 10/1000000, d_loss: -233.05300903320312,  g_loss: 337.6827392578125
Training epoch 11/1000000, d_loss: -131.996

KeyboardInterrupt: ignored