In [23]:
import numpy as np
import tensorflow as tf
import io
import binvox_rw
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import Sequence
from tensorflow.keras.layers import Input, Conv2D, Conv3DTranspose, Concatenate
from tensorflow.keras.layers import MaxPool2D, UpSampling3D, Reshape
from tensorflow.keras.models import Sequential, Model
import zipfile

In [24]:
class voxel_gen(Sequence):
    def __init__(self, x_set, batch_size=32, dim=(128,128,128)):
        self.x = x_set # path for each dataset : models/models-binvox-solid
        self.batch_size = batch_size
        self.dim = dim

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.x) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        x_set_temp = self.x[index*self.batch_size:(index+1)*self.batch_size]
        X = self.__data_generation(x_set_temp)
        return X
    
    def __data_generation(self, x_set_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((self.batch_size, *self.dim, 1))

        for i, x_set in enumerate(x_set_temp):
            # Store sample
            with open(str(x_set), 'rb') as f:
                model = binvox_rw.read_as_3d_array(f)
            X[i,] = np.expand_dims(model.data, axis=-1)
        return X

In [25]:
def customGen(batch_size=32):
    gen = ImageDataGenerator(rescale=1/255.)
    view0 = gen.flow_from_directory(
        "models/models-screenshots/view0",
        target_size=(512, 512),
        shuffle=False,
        batch_size=batch_size,
        class_mode=None)

    view1 = gen.flow_from_directory(
        "models/models-screenshots/view1",
        target_size=(512, 512),
        shuffle=False,
        batch_size=batch_size,
        class_mode=None)

    view2 = gen.flow_from_directory(
        "models/models-screenshots/view2",
        target_size=(512, 512),
        shuffle=False,
        batch_size=batch_size,
        class_mode=None)

    view3 = gen.flow_from_directory(
        "models/models-screenshots/view3",
        target_size=(512, 512),
        shuffle=False,
        batch_size=batch_size,
        class_mode=None)
    
    voxels = Path("models/models-binvox-solid/data")
    fvoxels = [f for f in voxels.iterdir() if f.is_file()]
    
    out = voxel_gen(fvoxels, batch_size)
    while True:
        for x1, x2, x3, x4, y in zip(view0, view1, view2, view3, out):
            yield [x1,x2,x3,x4], y

In [26]:
def encode_decode():
    inview0 = Input(shape=(512, 512, 3), name='view0')
    inview1 = Input(shape=(512, 512, 3), name='view1')
    inview2 = Input(shape=(512, 512, 3), name='view2')
    inview3 = Input(shape=(512, 512, 3), name='view3')
    
    encoder = Sequential([
            Conv2D(32, 7, (2, 2), activation='relu'),
            MaxPool2D((2, 2)),
            Conv2D(32, 5, activation='relu'),
            MaxPool2D((2, 2)),
            Conv2D(64, 3, activation='relu'),
            MaxPool2D((2, 2)),
            Conv2D(64, 3, activation='relu'),
            MaxPool2D((2, 2)),
            Conv2D(128, 3, activation='relu'),
            MaxPool2D((2, 2)),
            Conv2D(32, 5, activation='relu'),
            Reshape((1, 1, 1, -1))
    ], name='encoder')
    
    combined = Concatenate(axis=4)([encoder(inview0), encoder(inview1), encoder(inview2), encoder(inview3)])
    
    decoder = Sequential([
            Conv3DTranspose(128, 3, activation='relu'),
            UpSampling3D(2),
            Conv3DTranspose(64, 3, activation='relu'),
            UpSampling3D(2),
            Conv3DTranspose(32, 3, activation='relu', padding='same'),
            UpSampling3D(2),
            Conv3DTranspose(16, 3, activation='relu', padding='same'),
            UpSampling3D(2),
            Conv3DTranspose(8, 3, activation='relu', padding='same'),
            UpSampling3D(2)
    ], name='decoder')(combined)
    
    out = Conv3DTranspose(1, 3, activation='sigmoid', padding='same')(decoder)
    
    return Model(inputs=[inview0, inview1, inview2, inview3], outputs=out)

In [27]:
tf.keras.backend.clear_session()
model = encode_decode()
model.summary(line_length=118, positions=[.38, .66, .75, 1.], expand_nested=True)

Model: "model"
______________________________________________________________________________________________________________________
 Layer (type)                               Output Shape                     Param #    Connected to                  
 view0 (InputLayer)                         [(None, 512, 512, 3)]            0          []                            
                                                                                                                      
 view1 (InputLayer)                         [(None, 512, 512, 3)]            0          []                            
                                                                                                                      
 view2 (InputLayer)                         [(None, 512, 512, 3)]            0          []                            
                                                                                                                      
 view3 (InputLayer)              

Total params: 998,673
Trainable params: 998,673
Non-trainable params: 0
______________________________________________________________________________________________________________________


In [28]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
loss = tf.keras.losses.BinaryCrossentropy()

In [29]:
model.compile(optimizer, loss)

In [8]:
model.fit(customGen(batch_size=32),epochs=10, batch_size=8)

Found 11694 images belonging to 1 classes.
Found 11694 images belonging to 1 classes.
Found 11694 images belonging to 1 classes.
Found 11694 images belonging to 1 classes.
Epoch 1/10


ResourceExhaustedError: Graph execution error:

Detected at node 'model/encoder/conv2d_1/Conv2D_2' defined at (most recent call last):
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\runpy.py", line 196, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\runpy.py", line 86, in _run_code
      exec(code, run_globals)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\ipykernel_launcher.py", line 16, in <module>
      app.launch_new_instance()
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\traitlets\config\application.py", line 846, in launch_instance
      app.start()
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\ipykernel\kernelapp.py", line 677, in start
      self.io_loop.start()
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\tornado\platform\asyncio.py", line 199, in start
      self.asyncio_loop.run_forever()
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\asyncio\base_events.py", line 595, in run_forever
      self._run_once()
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\asyncio\base_events.py", line 1881, in _run_once
      handle._run()
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\asyncio\events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\ipykernel\kernelbase.py", line 471, in dispatch_queue
      await self.process_one()
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\ipykernel\kernelbase.py", line 460, in process_one
      await dispatch(*args)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\ipykernel\kernelbase.py", line 367, in dispatch_shell
      await result
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\ipykernel\kernelbase.py", line 662, in execute_request
      reply_content = await reply_content
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\ipykernel\ipkernel.py", line 360, in do_execute
      res = shell.run_cell(code, store_history=store_history, silent=silent)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\ipykernel\zmqshell.py", line 532, in run_cell
      return super().run_cell(*args, **kwargs)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\IPython\core\interactiveshell.py", line 2833, in run_cell
      result = self._run_cell(
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\IPython\core\interactiveshell.py", line 2879, in _run_cell
      return runner(coro)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\IPython\core\async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\IPython\core\interactiveshell.py", line 3077, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\IPython\core\interactiveshell.py", line 3280, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\IPython\core\interactiveshell.py", line 3340, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "C:\Users\prade\AppData\Local\Temp\ipykernel_10724\2723389152.py", line 1, in <cell line: 1>
      model.fit(customGen(batch_size=32),epochs=10, batch_size=8)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\engine\training.py", line 1384, in fit
      tmp_logs = self.train_function(iterator)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\engine\training.py", line 1021, in train_function
      return step_function(self, iterator)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\engine\training.py", line 1010, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\engine\training.py", line 1000, in run_step
      outputs = model.train_step(data)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\engine\training.py", line 859, in train_step
      y_pred = self(x, training=True)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\engine\base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\utils\traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\engine\functional.py", line 451, in call
      return self._run_internal_graph(
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\engine\functional.py", line 589, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\engine\base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\utils\traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\engine\sequential.py", line 374, in call
      return super(Sequential, self).call(inputs, training=training, mask=mask)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\engine\functional.py", line 451, in call
      return self._run_internal_graph(
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\engine\functional.py", line 589, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\engine\base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\utils\traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\layers\convolutional.py", line 248, in call
      outputs = self.convolution_op(inputs, self.kernel)
    File "C:\Users\prade\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\layers\convolutional.py", line 233, in convolution_op
      return tf.nn.convolution(
Node: 'model/encoder/conv2d_1/Conv2D_2'
OOM when allocating tensor with shape[32,32,122,122] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[{{node model/encoder/conv2d_1/Conv2D_2}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_train_function_8858]