In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
from PIL import Image
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D
from tensorflow.keras.models import Model, load_model

In [None]:
def load_images(directory):
    images = []
    filenames = sorted(os.listdir(directory))
    
    for filename in filenames:
        if filename.endswith('.DS_Store'):
            continue
        
        img = Image.open(os.path.join(directory, filename))
        
        width, height = img.size
        img = img.crop((0, height // 2, width, height))
        img = img.resize((640, 240))
        
        img = np.array(img) / 255.0 
        
        images.append(img)
    
    return np.array(images)

def black_out(image):
    image_copy = np.copy(image)
    r_channel = image_copy[:, :, 0]
    g_channel = image_copy[:, :, 1]
    b_channel = image_copy[:, :, 2]

    red_pixels = (r_channel >= 220) & (g_channel < 50) & (b_channel < 50)
    
    image_copy[red_pixels] = [255, 0, 0]
    image_copy[~red_pixels] = [0, 0, 0]
    
    return image_copy

def load_mask_images(directory):
    images = []
    filenames = sorted(os.listdir(directory))
    
    for filename in filenames:
        if filename.endswith('.DS_Store'):
            continue
        
        img = Image.open(os.path.join(directory, filename))
    
        width, height = img.size
        img = img.crop((0, height // 2, width, height))
        img = img.resize((640, 240))
        
        img = black_out(img)

        img = np.array(img) / 255.0 
        
        images.append(img)
           
    return np.array(images)

In [None]:
# 데이터 로드
original_dir = '/images/original'
mask_dir = '/images/mask'

original_images = load_images(original_dir)
mask_images = load_mask_images(mask_dir)

print(original_images.shape)
print(mask_images.shape)

random_indices = np.random.choice(original_images.shape[0], 20, replace=False)

for i in random_indices:
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    
    axes[0].imshow(original_images[i])
    axes[0].set_title(f"Original Image {i+1}")
    axes[0].axis('off')
    
    axes[1].imshow(mask_images[i])
    axes[1].set_title(f"Mask Image {i+1}")
    axes[1].axis('off')
    
    plt.show()

In [None]:
input_shape = (240, 640, 3)

input_img = Input(shape=input_shape)

# Encoder
x = Conv2D(16, (3, 3), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2, 2), padding='same')(x)

# Decoder
x = Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)

autoencoder = Model(input_img, decoded)

autoencoder.compile(optimizer='adam', loss='binary_crossentropy')

autoencoder.fit(original_images, mask_images, epochs=50, batch_size=16, shuffle=True)

autoencoder.save('autoencoder_16.h5')

In [None]:
input_shape = (240, 640, 3)

input_img = Input(shape=input_shape)

# Encoder
x = Conv2D(32, (3, 3), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2, 2), padding='same')(x)

# Decoder
x = Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)

decoded = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)

autoencoder = Model(input_img, decoded)

autoencoder.compile(optimizer='adam', loss='binary_crossentropy')

autoencoder.fit(original_images, mask_images, epochs=50, batch_size=16, shuffle=True)

autoencoder.save('autoencoder_32.h5')

In [None]:
input_shape = (240, 640, 3)

input_img = Input(shape=input_shape)

# Encoder
x = Conv2D(64, (3, 3), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2, 2), padding='same')(x)

# Decoder
x = Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)

decoded = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)

autoencoder = Model(input_img, decoded)

autoencoder.compile(optimizer='adam', loss='binary_crossentropy')

autoencoder.fit(original_images, mask_images, epochs=50, batch_size=16, shuffle=True)

autoencoder.save('autoencoder_64.h5')

In [None]:
test_dir = '/images/test'
test_images = load_images(test_dir)

autoencoder16 = load_model('autoencoder_16.h5')
autoencoder32 = load_model('autoencoder_32.h5')
autoencoder64 = load_model('autoencoder_64.h5')

predict_images1 = autoencoder16.predict(test_images)
predict_images2 = autoencoder32.predict(test_images)
predict_images3 = autoencoder64.predict(test_images)

for i in range(5):
    fig, axes = plt.subplots(1, 4, figsize=(32, 8))
    
    axes[0].imshow(test_images[i])
    axes[0].set_title(f"Test Image {i+1}")
    axes[0].axis('off')

    axes[1].imshow(predict_images1[i])
    axes[1].set_title(f"autoencoder_16")
    axes[1].axis('off')
    
    axes[2].imshow(predict_images2[i])
    axes[2].set_title(f"autoencoder_32")
    axes[2].axis('off')
    
    axes[3].imshow(predict_images3[i])
    axes[3].set_title(f"autoencoder_64")
    axes[3].axis('off')
    
    plt.show()