In [21]:
import tensorflow as tf
import os

# 各フォルダパス
skins_dir = 'C:/Users/Owner/Desktop/archive/Skins'
missing_dir = 'C:/Users/Owner/Desktop/archive/Missing'
masks_dir = 'C:/Users/Owner/Desktop/archive/Masks'

def load_image(path, channels=4):
    image = tf.io.read_file(path)
    image = tf.image.decode_png(image, channels=channels)
    image = tf.image.convert_image_dtype(image, tf.float32)  # [0,1]に正規化
    return image

def load_sample(file_name):
    # file_name は Tensor で、ファイル名のみが入っている前提
    skin_path = tf.strings.join([skins_dir, file_name], separator='/')
    
    missing_file = tf.strings.join(["missing_", file_name])
    missing_path = tf.strings.join([missing_dir, missing_file], separator='/')
    
    mask_file = tf.strings.join(["mask_", file_name])
    mask_path = tf.strings.join([masks_dir, mask_file], separator='/')
    
    skin = load_image(skin_path, channels=4)
    missing = load_image(missing_path, channels=4)
    mask = load_image(mask_path, channels=1)  # マスクは1チャネル
    
    # 入力は欠損画像とマスクをチャネル方向に連結 (shape: (64, 64, 5))
    input_image = tf.concat([missing, mask], axis=-1)
    return input_image, skin

# データセットのサイズ
total = 10000

# 全てのスキン画像のファイルリストを取得
file_names = tf.data.Dataset.list_files(os.path.join(skins_dir, '*.png')).shuffle(buffer_size=total)
file_names = file_names.take(total)

# ファイル名の抽出とサンプル作成
dataset = file_names.map(lambda fn: load_sample(tf.strings.split(fn, os.sep)[-1]))
dataset = dataset.shuffle(buffer_size=total)

# 全体の件数が1000件の場合、80%をトレーニング、20%を検証に利用
train_size = int(total * 0.8)
# トレーニングデータと検証データに分割
train_dataset = dataset.take(train_size).batch(32).prefetch(tf.data.experimental.AUTOTUNE)
val_dataset   = dataset.skip(train_size).batch(32).prefetch(tf.data.experimental.AUTOTUNE)

In [22]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
import numpy as np

# Generator model
def build_generator():
    # Input: Missing image (64x64x4) concatenated with mask (64x64x1)
    inp = layers.Input(shape=[64, 64, 5], name='input_combined')
    
    # Encoder (downsampling)
    down_stack = [
        downsample(64, 4, apply_batchnorm=False),  # (32x32)
        downsample(128, 4),  # (16x16)
        downsample(256, 4),  # (8x8)
        downsample(512, 4),  # (4x4)
        downsample(512, 4),  # (2x2)
    ]
    
    # Decoder (upsampling)
    up_stack = [
        upsample(512, 4, apply_dropout=True),  # (4x4)
        upsample(256, 4, apply_dropout=True),  # (8x8)
        upsample(128, 4),  # (16x16)
        upsample(64, 4),  # (32x32)
    ]
    
    # Final output layer
    last = layers.Conv2DTranspose(4, 4, strides=2, padding='same',
                                 activation='tanh')  # (64x64)
    
    # Downsampling
    skips = []
    x = inp
    for down in down_stack:
        x = down(x)
        skips.append(x)
    
    skips = reversed(skips[:-1])
    
    # Upsampling and skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])
    
    # Final output
    x = last(x)
    
    return models.Model(inputs=inp, outputs=x, name='generator')

# Discriminator model
def build_discriminator():
    # Input: Generated/real image (64x64x4) and conditional input (64x64x5)
    inp_conditional = layers.Input(shape=[64, 64, 5], name='input_conditional')
    inp_target = layers.Input(shape=[64, 64, 4], name='target_image')
    
    x = layers.Concatenate()([inp_conditional, inp_target])
    
    # PatchGAN discriminator
    x = layers.Conv2D(64, 4, strides=2, padding='same')(x)  # (32x32)
    x = layers.LeakyReLU(0.2)(x)
    
    x = layers.Conv2D(128, 4, strides=2, padding='same')(x)  # (16x16)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)
    
    x = layers.Conv2D(256, 4, strides=2, padding='same')(x)  # (8x8)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)
    
    # Patch output
    x = layers.Conv2D(1, 4, strides=1, padding='same')(x)
    
    return models.Model(inputs=[inp_conditional, inp_target], outputs=x, name='discriminator')

# Downsampling layer
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    
    result = models.Sequential()
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                            kernel_initializer=initializer, use_bias=False))
    
    if apply_batchnorm:
        result.add(layers.BatchNormalization())
    
    result.add(layers.LeakyReLU(0.2))
    
    return result

# Upsampling layer
def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    
    result = models.Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2, padding='same',
                                     kernel_initializer=initializer, use_bias=False))
    
    result.add(layers.BatchNormalization())
    
    if apply_dropout:
        result.add(layers.Dropout(0.5))
    
    result.add(layers.ReLU())
    
    return result

In [23]:
import tensorflow as tf
import numpy as np

def l1_loss(y_true, y_pred):
    return tf.reduce_mean(tf.abs(y_true - y_pred))

def l2_loss(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_true - y_pred))

def edge_loss(y_true, y_pred):
    sobel_true = tf.image.sobel_edges(y_true)
    sobel_pred = tf.image.sobel_edges(y_pred)

    # X方向のエッジ差分
    edge_x_loss = tf.abs(sobel_true[..., 0] - sobel_pred[..., 0])
    # Y方向のエッジ差分
    edge_y_loss = tf.abs(sobel_true[..., 1] - sobel_pred[..., 1])

    # 両方のエッジの差分の平均を損失として使う
    return tf.reduce_mean(edge_x_loss + edge_y_loss)

def mae_loss(y_true, y_pred):
    return tf.reduce_mean(tf.abs(y_true - y_pred))

def mse_loss(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_true - y_pred))

def ssim_loss(y_true, y_pred):
    # SSIMは[0, 1]範囲の画像に適用されるため、出力を[0, 1]に正規化
    y_true = (y_true + 1.0) / 2.0  # RGBA画像などの場合、[-1, 1]の範囲から[0, 1]に変換
    y_pred = (y_pred + 1.0) / 2.0  # 同様に出力を[0, 1]に正規化
    
    # SSIMを計算
    return 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, max_val=1.0))

def laplacian_filter(image):
    """RGBAの各チャンネルにラプラシアンフィルタを適用"""
    laplacian_kernel = tf.constant([
        [0,  1,  0],
        [1, -4,  1],
        [0,  1,  0]
    ], dtype=tf.float32)
    
    laplacian_kernel = tf.reshape(laplacian_kernel, [3, 3, 1, 1])  # (高さ, 幅, 入力チャンネル, 出力チャンネル)
    
    # RGBAの各チャンネルに適用するためのフィルタを作成
    filters = tf.tile(laplacian_kernel, [1, 1, 4, 1])  # (3, 3, 4, 4) に拡張

    # 4次元テンソル (バッチ, 高さ, 幅, チャンネル) の形状を維持
    image = tf.expand_dims(image, axis=0)  # バッチ次元を追加 (None, H, W, 4)
    edges = tf.nn.conv2d(image, filters, strides=[1, 1, 1, 1], padding="SAME")

    return tf.squeeze(edges)  # バッチ次元を削除

def laplacian_loss(y_true, y_pred):
    edge_true = laplacian_filter(y_true)
    edge_pred = laplacian_filter(y_pred)

    # L1 損失
    loss = tf.reduce_mean(tf.abs(edge_true - edge_pred))
    return loss

def total_loss(y_true, y_pred):
    loss_l1 = l1_loss(y_true, y_pred)
    # loss_edge = edge_loss(y_true, y_pred)
    # loss_mae = mae_loss(y_true, y_pred)
    # loss_mse = mse_loss(y_true, y_pred)
    # loss_ssim = ssim_loss(y_true, y_pred)
    loss_laplacian = laplacian_loss(y_true, y_pred)

    return loss_l1 + 0.2 * loss_laplacian

In [24]:
# Loss functions
def generator_loss(disc_generated_output, gen_output, target, lambda_l1=100):
    # Adversarial loss
    gan_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(
        tf.ones_like(disc_generated_output), disc_generated_output)
    
    # L1 loss (pixel-wise difference)
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    loss_ssim = ssim_loss(target, gen_output)
    
    # Total loss
    total_loss = gan_loss + (lambda_l1 * l1_loss) + loss_ssim
    
    return total_loss, gan_loss, l1_loss

def discriminator_loss(disc_real_output, disc_generated_output):
    # Real image loss
    real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(
        tf.ones_like(disc_real_output), disc_real_output)
    
    # Generated image loss
    generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(
        tf.zeros_like(disc_generated_output), disc_generated_output)
    
    # Total loss
    total_loss = real_loss + generated_loss
    
    return total_loss

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output

# Pix2Pix GAN model class (改善版)
class Pix2PixGAN:
    def __init__(self, lambda_l1=100):
        self.generator = build_generator()
        self.discriminator = build_discriminator()
        
        self.generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        self.discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        self.lambda_l1 = lambda_l1

        # 統計用のメトリクス
        self.gen_total_loss_metric = tf.keras.metrics.Mean(name='gen_total_loss')
        self.gen_gan_loss_metric = tf.keras.metrics.Mean(name='gen_gan_loss')
        self.gen_l1_loss_metric = tf.keras.metrics.Mean(name='gen_l1_loss')
        self.disc_loss_metric = tf.keras.metrics.Mean(name='disc_loss')
    
    @tf.function
    def train_step(self, input_conditional, target):
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            # 生成画像の生成
            gen_output = self.generator(input_conditional, training=True)
            
            # 識別器の予測
            disc_real_output = self.discriminator([input_conditional, target], training=True)
            disc_generated_output = self.discriminator([input_conditional, gen_output], training=True)
            
            # 損失計算
            gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(
                disc_generated_output, gen_output, target, self.lambda_l1)
            disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
        
        # 勾配計算
        generator_gradients = gen_tape.gradient(gen_total_loss, self.generator.trainable_variables)
        discriminator_gradients = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
        
        # 勾配適用
        self.generator_optimizer.apply_gradients(zip(generator_gradients, self.generator.trainable_variables))
        self.discriminator_optimizer.apply_gradients(zip(discriminator_gradients, self.discriminator.trainable_variables))
        
        # メトリクス更新
        self.gen_total_loss_metric(gen_total_loss)
        self.gen_gan_loss_metric(gen_gan_loss)
        self.gen_l1_loss_metric(gen_l1_loss)
        self.disc_loss_metric(disc_loss)
        
        return gen_total_loss, disc_loss

    def fit(self, train_dataset, val_dataset=None, epochs=50):
        # 事前に tf.data パイプラインの最適化（cache, prefetchなど）をしておくと良い
        train_dataset = train_dataset.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
        if val_dataset is not None:
            val_dataset = val_dataset.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
        
        for epoch in range(epochs):
            print_txt = ""
            print_txt += f'Epoch {epoch+1}/{epochs} [{int((epoch/epochs)*100) * "=" + int(((epochs-epoch)/epochs)*100) * " "}]\n'
            
            # メトリクスのリセット
            self.gen_total_loss_metric.reset_states()
            self.gen_gan_loss_metric.reset_states()
            self.gen_l1_loss_metric.reset_states()
            self.disc_loss_metric.reset_states()
            
            # トレーニングループ
            for batch, (input_conditional, target) in enumerate(train_dataset):
                gen_loss, disc_loss = self.train_step(input_conditional, target)
                
                if batch % 10 == 0:
                    print_txt += f'Batch {batch}: Gen Loss: {gen_loss:.4f}, Disc Loss: {disc_loss:.4f}\n'
            
            # エポックごとのメトリクス出力
            print_txt += f'Epoch {epoch+1} Training Losses:\n'
            print_txt += f'  gen_total_loss: {self.gen_total_loss_metric.result():.4f}\n'
            print_txt += f'  gen_gan_loss  : {self.gen_gan_loss_metric.result():.4f}\n'
            print_txt += f'  gen_l1_loss   : {self.gen_l1_loss_metric.result():.4f}\n'
            print_txt += f'  disc_loss     : {self.disc_loss_metric.result():.4f}\n'
            
            # バリデーション (ある場合)
            if val_dataset is not None:
                val_l1_metric = tf.keras.metrics.Mean(name='val_gen_l1_loss')
                for input_conditional, target in val_dataset:
                    gen_output = self.generator(input_conditional, training=False)
                    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
                    val_l1_metric(l1_loss)
                print_txt += f'Epoch {epoch+1} Validation Loss: gen_l1_loss: {val_l1_metric.result():.4f}\n'
            
            # 10エポックごとのモデル保存とサンプル生成
            if (epoch + 1) % 10 == 0:
                self.generator.save(f'../../models/GAN/checkpoint/pixelart_inpainting_generator_epoch_{epoch+1}.h5')
                if val_dataset is not None:
                    self.generate_samples(val_dataset, epoch + 1)
            clear_output(wait=True)
            print(print_txt)
    
    def generate_samples(self, dataset, epoch, num_samples=4):
        # サンプル可視化部分はEagerモードで十分
        shuffled_dataset = dataset.shuffle(buffer_size=1000)
        for input_conditional, target in shuffled_dataset.take(1):
            input_samples = input_conditional[:num_samples]
            target_samples = target[:num_samples]
            predicted_samples = self.generator(input_samples, training=False)
            
            # conditional input から missing image と mask を分割
            missing_samples = input_samples[:, :, :, :4]  # 最初の4チャンネル
            mask_samples = input_samples[:, :, :, 4:5]    # 最後のチャンネル
            
            plt.figure(figsize=(15, 4 * num_samples))
            for i in range(num_samples):
                # Missing image
                plt.subplot(num_samples, 4, i * 4 + 1)
                plt.imshow(missing_samples[i])
                plt.title("Missing")
                plt.axis("off")
                
                # Mask
                plt.subplot(num_samples, 4, i * 4 + 2)
                plt.imshow(tf.squeeze(mask_samples[i]))
                plt.title("Mask")
                plt.axis("off")
                
                # Generated image
                plt.subplot(num_samples, 4, i * 4 + 3)
                plt.imshow(predicted_samples[i] * 0.5 + 0.5)  # Denormalize
                plt.title("Generated")
                plt.axis("off")
                
                # Target image
                plt.subplot(num_samples, 4, i * 4 + 4)
                plt.imshow(target_samples[i])
                plt.title("Target")
                plt.axis("off")
            
            plt.savefig(f'../../Output/GAN/samples_epoch_{epoch}.png')
            plt.close()
            break  # 一度のバッチのみ処理
    
    def inpaint(self, missing_image, mask):
        # 入力画像にバッチ次元を追加
        if len(missing_image.shape) == 3:
            missing_image = tf.expand_dims(missing_image, 0)
        if len(mask.shape) == 3:
            mask = tf.expand_dims(mask, 0)
        
        # 条件入力作成
        input_conditional = tf.concat([missing_image, mask], axis=-1)
        generated = self.generator(input_conditional, training=False)
        return tf.squeeze(generated, 0)

In [26]:
# Initialize the model
model = Pix2PixGAN(lambda_l1=100)

# Train the model
model.fit(train_dataset, val_dataset, epochs=500)

Batch 0: Gen Loss: 7.1000, Disc Loss: 0.2270
Batch 10: Gen Loss: 6.9916, Disc Loss: 0.1258
Batch 20: Gen Loss: 5.2306, Disc Loss: 0.3209
Batch 30: Gen Loss: 6.8139, Disc Loss: 0.5388
Batch 40: Gen Loss: 5.5575, Disc Loss: 0.3737
Batch 50: Gen Loss: 6.5359, Disc Loss: 0.2660
Batch 60: Gen Loss: 8.4132, Disc Loss: 0.2437
Batch 70: Gen Loss: 7.0274, Disc Loss: 0.2767
Batch 80: Gen Loss: 6.3262, Disc Loss: 0.4009
Batch 90: Gen Loss: 6.1261, Disc Loss: 0.2728
Batch 100: Gen Loss: 7.0177, Disc Loss: 0.1692
Batch 110: Gen Loss: 5.7119, Disc Loss: 0.2356
Batch 120: Gen Loss: 10.3580, Disc Loss: 0.2850
Batch 130: Gen Loss: 7.0180, Disc Loss: 0.1949
Batch 140: Gen Loss: 5.5990, Disc Loss: 0.4377
Batch 150: Gen Loss: 8.6369, Disc Loss: 0.5072
Batch 160: Gen Loss: 5.9492, Disc Loss: 0.2317
Batch 170: Gen Loss: 6.2289, Disc Loss: 0.2392
Batch 180: Gen Loss: 7.1347, Disc Loss: 0.1768
Batch 190: Gen Loss: 4.5477, Disc Loss: 0.6108
Batch 200: Gen Loss: 6.8387, Disc Loss: 0.1804
Batch 210: Gen Loss: 7.

In [27]:
model.generator.save('../../models/GAN/GAN.h5')

