In [28]:
import os
import numpy as np
import cv2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

# Folders
bw_folder = "./training_data/filtered_data/Input_2"
color_input_folder = "./training_data/filtered_data/Input_1"
color_output_folder = "./training_data/filtered_data/Output"

# List files
bw_files = sorted(os.listdir(bw_folder))
color_input_files = sorted(os.listdir(color_input_folder))
color_output_files = sorted(os.listdir(color_output_folder))

# Ensure filenames match
assert bw_files == color_input_files == color_output_files, "Filenames do not match!"

# Split files into training and validation set
train_ratio = 0.8
bw_train_files, bw_val_files, _, _ = train_test_split(bw_files, bw_files, test_size=1-train_ratio, random_state=42)

# Ensure filenames match for training and validation
assert set(bw_train_files) & set(bw_val_files) == set(), "Overlap between training and validation files!"

# Load a batch of images from a list of filenames and folder
def load_batch_from_folder(folder, files, start_idx, batch_size, color=True):
    images = []
    end_idx = min(start_idx + batch_size, len(files))
    
    for i in range(start_idx, end_idx):
        img_path = os.path.join(folder, files[i])
        if color:
            img = cv2.imread(img_path, cv2.IMREAD_COLOR)
            #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        else:
            img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            img = np.expand_dims(img, axis=-1)  # Add channel dimension
        images.append(img)
    return np.array(images)

# Separate generators for training and validation
def image_data_generator(files, batch_size):
    total_images = len(files)
    while True:
        for start_idx in range(0, total_images, batch_size):
            bw_batch = load_batch_from_folder(bw_folder, files, start_idx, batch_size, color=False)
            color_input_batch = load_batch_from_folder(color_input_folder, files, start_idx, batch_size)
            color_output_batch = load_batch_from_folder(color_output_folder, files, start_idx, batch_size)
            yield ([bw_batch, color_input_batch], color_output_batch)

# Example of usage
batch_size = 1
train_generator = image_data_generator(bw_train_files, batch_size)
val_generator = image_data_generator(bw_val_files, batch_size)

# for bw_batch, color_input_batch, color_output_batch in train_generator:
#     print('Train:', bw_batch.shape, color_input_batch.shape, color_output_batch.shape)

# for bw_batch, color_input_batch, color_output_batch in val_generator:
#     print('Validation:', bw_batch.shape, color_input_batch.shape, color_output_batch.shape)


In [29]:
# from tensorflow.keras.preprocessing.image import ImageDataGenerator

# batch_size = 32

# # Data Augmentation
# datagen = ImageDataGenerator(
#     rotation_range=20,
#     zoom_range=0.15,
#     width_shift_range=0.2,
#     height_shift_range=0.2,
#     shear_range=0.15,
#     horizontal_flip=True,
#     fill_mode="nearest"
# )

# # Generator for augmented data
# def augmented_data_generator(batch_size):
#     base_generator = image_data_generator(batch_size)
    
#     for bw_batch, color_input_batch, color_output_batch in base_generator:
#         # Augment each batch
#         # Note: We're using the same seed for both black & white and color input images
#         # to ensure they undergo the same transformations.
        
#         # Augmenting BW images
#         bw_gen = datagen.flow(bw_batch, batch_size=batch_size, shuffle=False, seed=42)
        
#         # Augmenting color input images
#         color_input_gen = datagen.flow(color_input_batch, batch_size=batch_size, shuffle=False, seed=42)
        
#         # Augmenting color output images. Since we need to match outputs with inputs, 
#         # we're not shuffling and using a consistent seed.
#         color_output_gen = datagen.flow(color_output_batch, batch_size=batch_size, shuffle=False, seed=42)
        
#         yield [next(bw_gen), next(color_input_gen)], next(color_output_gen)

# # Example of usage
# aug_gen = augmented_data_generator(batch_size)
# for (bw_aug_batch, color_input_aug_batch), color_output_aug_batch in aug_gen:
#     print(bw_aug_batch.shape, color_input_aug_batch.shape, color_output_aug_batch.shape)


In [40]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Concatenate

def SimpleDualInput():
    # Black & white input
    bw_input = Input(shape=(2400, 1400, 1))
    # Colored reference input
    color_input = Input(shape=(2400, 1400, 3))

    # Merge inputs
    merge_layer = Concatenate()([bw_input, color_input]) # This will have 4 channels

    # Output layer to produce a 3-channel image
    outputs = Conv2D(3, (1, 1), activation='sigmoid')(merge_layer)

    return tf.keras.Model(inputs=[bw_input, color_input], outputs=outputs)

model = SimpleDualInput()
model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])


In [41]:
import tensorflow as tf
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))


Num GPUs Available:  1


In [43]:
import math
steps_per_epoch = math.ceil(len(bw_train_files) / batch_size)
validation_steps = math.ceil(len(bw_val_files) / batch_size)

# 3. Train the model
with tf.device('/GPU:0'):
    history = model.fit(
        train_generator,
        validation_data=val_generator,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps,
        epochs=1,
        verbose=1
    )

ResourceExhaustedError: Graph execution error:

Detected at node 'model_8/conv2d_51/Relu' defined at (most recent call last):
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\traitlets\config\application.py", line 992, in launch_instance
      app.start()
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\ipykernel\kernelapp.py", line 736, in start
      self.io_loop.start()
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\tornado\platform\asyncio.py", line 195, in start
      self.asyncio_loop.run_forever()
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\asyncio\base_events.py", line 601, in run_forever
      self._run_once()
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\asyncio\base_events.py", line 1905, in _run_once
      handle._run()
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\asyncio\events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\ipykernel\kernelbase.py", line 516, in dispatch_queue
      await self.process_one()
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\ipykernel\kernelbase.py", line 505, in process_one
      await dispatch(*args)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\ipykernel\kernelbase.py", line 412, in dispatch_shell
      await result
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\ipykernel\kernelbase.py", line 740, in execute_request
      reply_content = await reply_content
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\ipykernel\ipkernel.py", line 422, in do_execute
      res = shell.run_cell(
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\ipykernel\zmqshell.py", line 546, in run_cell
      return super().run_cell(*args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\IPython\core\interactiveshell.py", line 3009, in run_cell
      result = self._run_cell(
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\IPython\core\interactiveshell.py", line 3064, in _run_cell
      result = runner(coro)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\IPython\core\async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\IPython\core\interactiveshell.py", line 3269, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\IPython\core\interactiveshell.py", line 3448, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\IPython\core\interactiveshell.py", line 3508, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "C:\Users\Persephone\AppData\Local\Temp\ipykernel_9496\3903460651.py", line 7, in <module>
      history = model.fit(
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\engine\training.py", line 1564, in fit
      tmp_logs = self.train_function(iterator)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\engine\training.py", line 1160, in train_function
      return step_function(self, iterator)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\engine\training.py", line 1146, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\engine\training.py", line 1135, in run_step
      outputs = model.train_step(data)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\engine\training.py", line 993, in train_step
      y_pred = self(x, training=True)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\engine\training.py", line 557, in __call__
      return super().__call__(*args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\engine\functional.py", line 510, in call
      return self._run_internal_graph(inputs, training=training, mask=mask)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\engine\functional.py", line 667, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\layers\convolutional\base_conv.py", line 314, in call
      return self.activation(outputs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\activations.py", line 317, in relu
      return backend.relu(
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\backend.py", line 5366, in relu
      x = tf.nn.relu(x)
Node: 'model_8/conv2d_51/Relu'
Detected at node 'model_8/conv2d_51/Relu' defined at (most recent call last):
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\traitlets\config\application.py", line 992, in launch_instance
      app.start()
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\ipykernel\kernelapp.py", line 736, in start
      self.io_loop.start()
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\tornado\platform\asyncio.py", line 195, in start
      self.asyncio_loop.run_forever()
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\asyncio\base_events.py", line 601, in run_forever
      self._run_once()
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\asyncio\base_events.py", line 1905, in _run_once
      handle._run()
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\asyncio\events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\ipykernel\kernelbase.py", line 516, in dispatch_queue
      await self.process_one()
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\ipykernel\kernelbase.py", line 505, in process_one
      await dispatch(*args)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\ipykernel\kernelbase.py", line 412, in dispatch_shell
      await result
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\ipykernel\kernelbase.py", line 740, in execute_request
      reply_content = await reply_content
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\ipykernel\ipkernel.py", line 422, in do_execute
      res = shell.run_cell(
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\ipykernel\zmqshell.py", line 546, in run_cell
      return super().run_cell(*args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\IPython\core\interactiveshell.py", line 3009, in run_cell
      result = self._run_cell(
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\IPython\core\interactiveshell.py", line 3064, in _run_cell
      result = runner(coro)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\IPython\core\async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\IPython\core\interactiveshell.py", line 3269, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\IPython\core\interactiveshell.py", line 3448, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\IPython\core\interactiveshell.py", line 3508, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "C:\Users\Persephone\AppData\Local\Temp\ipykernel_9496\3903460651.py", line 7, in <module>
      history = model.fit(
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\engine\training.py", line 1564, in fit
      tmp_logs = self.train_function(iterator)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\engine\training.py", line 1160, in train_function
      return step_function(self, iterator)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\engine\training.py", line 1146, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\engine\training.py", line 1135, in run_step
      outputs = model.train_step(data)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\engine\training.py", line 993, in train_step
      y_pred = self(x, training=True)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\engine\training.py", line 557, in __call__
      return super().__call__(*args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\engine\functional.py", line 510, in call
      return self._run_internal_graph(inputs, training=training, mask=mask)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\engine\functional.py", line 667, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\layers\convolutional\base_conv.py", line 314, in call
      return self.activation(outputs)
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\activations.py", line 317, in relu
      return backend.relu(
    File "C:\Python\anaconda3\envs\pipWindows_ml_color\lib\site-packages\keras\backend.py", line 5366, in relu
      x = tf.nn.relu(x)
Node: 'model_8/conv2d_51/Relu'
2 root error(s) found.
  (0) RESOURCE_EXHAUSTED:  OOM when allocating tensor with shape[1,64,2048,1400] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[{{node model_8/conv2d_51/Relu}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.

	 [[Cast_3/_30]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.

  (1) RESOURCE_EXHAUSTED:  OOM when allocating tensor with shape[1,64,2048,1400] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[{{node model_8/conv2d_51/Relu}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.

0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_8247]