In [12]:
import tensorflow as tf
import os

# 各フォルダパス
skins_dir = 'C:/Users/Owner/Desktop/archive/Skins'
missing_dir = 'C:/Users/Owner/Desktop/archive/Missing'
masks_dir = 'C:/Users/Owner/Desktop/archive/Masks'

def load_image(path, channels=4):
    image = tf.io.read_file(path)
    image = tf.image.decode_png(image, channels=channels)
    image = tf.image.convert_image_dtype(image, tf.float32)  # [0,1]に正規化
    return image

def load_sample(file_name):
    # file_name は Tensor で、ファイル名のみが入っている前提
    skin_path = tf.strings.join([skins_dir, file_name], separator='/')
    
    missing_file = tf.strings.join(["missing_", file_name])
    missing_path = tf.strings.join([missing_dir, missing_file], separator='/')
    
    mask_file = tf.strings.join(["mask_", file_name])
    mask_path = tf.strings.join([masks_dir, mask_file], separator='/')
    
    skin = load_image(skin_path, channels=4)
    missing = load_image(missing_path, channels=4)
    mask = load_image(mask_path, channels=1)  # マスクは1チャネル
    
    # 入力は欠損画像とマスクをチャネル方向に連結 (shape: (64, 64, 5))
    input_image = tf.concat([missing, mask], axis=-1)
    return input_image, skin

# データセットのサイズ
total = 1000

# 全てのスキン画像のファイルリストを取得
file_names = tf.data.Dataset.list_files(os.path.join(skins_dir, '*.png')).shuffle(buffer_size=total)
file_names = file_names.take(total)

# ファイル名の抽出とサンプル作成
dataset = file_names.map(lambda fn: load_sample(tf.strings.split(fn, os.sep)[-1]))
dataset = dataset.shuffle(buffer_size=total)

# 全体の件数が1000件の場合、80%をトレーニング、20%を検証に利用
train_size = int(total * 0.8)
# トレーニングデータと検証データに分割
train_dataset = dataset.take(train_size).batch(32).prefetch(tf.data.experimental.AUTOTUNE)
val_dataset   = dataset.skip(train_size).batch(32).prefetch(tf.data.experimental.AUTOTUNE)

In [13]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
import numpy as np

# Generator model
def build_generator():
    # Input: Missing image (64x64x4) concatenated with mask (64x64x1)
    inp = layers.Input(shape=[64, 64, 5], name='input_combined')
    
    # Encoder (downsampling)
    down_stack = [
        downsample(64, 4, apply_batchnorm=False),  # (32x32)
        downsample(128, 4),  # (16x16)
        downsample(256, 4),  # (8x8)
        downsample(512, 4),  # (4x4)
        downsample(512, 4),  # (2x2)
    ]
    
    # Decoder (upsampling)
    up_stack = [
        upsample(512, 4, apply_dropout=True),  # (4x4)
        upsample(256, 4, apply_dropout=True),  # (8x8)
        upsample(128, 4),  # (16x16)
        upsample(64, 4),  # (32x32)
    ]
    
    # Final output layer
    last = layers.Conv2DTranspose(4, 4, strides=2, padding='same',
                                 activation='tanh')  # (64x64)
    
    # Downsampling
    skips = []
    x = inp
    for down in down_stack:
        x = down(x)
        skips.append(x)
    
    skips = reversed(skips[:-1])
    
    # Upsampling and skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])
    
    # Final output
    x = last(x)
    
    return models.Model(inputs=inp, outputs=x, name='generator')

# Discriminator model
def build_discriminator():
    # Input: Generated/real image (64x64x4) and conditional input (64x64x5)
    inp_conditional = layers.Input(shape=[64, 64, 5], name='input_conditional')
    inp_target = layers.Input(shape=[64, 64, 4], name='target_image')
    
    x = layers.Concatenate()([inp_conditional, inp_target])
    
    # PatchGAN discriminator
    x = layers.Conv2D(64, 4, strides=2, padding='same')(x)  # (32x32)
    x = layers.LeakyReLU(0.2)(x)
    
    x = layers.Conv2D(128, 4, strides=2, padding='same')(x)  # (16x16)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)
    
    x = layers.Conv2D(256, 4, strides=2, padding='same')(x)  # (8x8)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)
    
    # Patch output
    x = layers.Conv2D(1, 4, strides=1, padding='same')(x)
    
    return models.Model(inputs=[inp_conditional, inp_target], outputs=x, name='discriminator')

# Downsampling layer
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    
    result = models.Sequential()
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                            kernel_initializer=initializer, use_bias=False))
    
    if apply_batchnorm:
        result.add(layers.BatchNormalization())
    
    result.add(layers.LeakyReLU(0.2))
    
    return result

# Upsampling layer
def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    
    result = models.Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2, padding='same',
                                     kernel_initializer=initializer, use_bias=False))
    
    result.add(layers.BatchNormalization())
    
    if apply_dropout:
        result.add(layers.Dropout(0.5))
    
    result.add(layers.ReLU())
    
    return result

In [14]:
# Loss functions
def generator_loss(disc_generated_output, gen_output, target, lambda_l1=100):
    # Adversarial loss
    gan_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(
        tf.ones_like(disc_generated_output), disc_generated_output)
    
    # L1 loss (pixel-wise difference)
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    
    # Total loss
    total_loss = gan_loss + (lambda_l1 * l1_loss)
    
    return total_loss, gan_loss, l1_loss

def discriminator_loss(disc_real_output, disc_generated_output):
    # Real image loss
    real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(
        tf.ones_like(disc_real_output), disc_real_output)
    
    # Generated image loss
    generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(
        tf.zeros_like(disc_generated_output), disc_generated_output)
    
    # Total loss
    total_loss = real_loss + generated_loss
    
    return total_loss

In [22]:
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output

# Pix2Pix GAN model class
class Pix2PixGAN:
    def __init__(self, lambda_l1=100):
        self.generator = build_generator()
        self.discriminator = build_discriminator()
        
        self.generator_optimizer = optimizers.Adam(2e-4, beta_1=0.5)
        self.discriminator_optimizer = optimizers.Adam(2e-4, beta_1=0.5)
        self.lambda_l1 = lambda_l1
    
    # Training step
    @tf.function
    def train_step(self, input_conditional, target):
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            # Generate output
            gen_output = self.generator(input_conditional, training=True)
            
            # Discriminator predictions
            disc_real_output = self.discriminator([input_conditional, target], training=True)
            disc_generated_output = self.discriminator([input_conditional, gen_output], training=True)
            
            # Calculate losses
            gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(
                disc_generated_output, gen_output, target, self.lambda_l1)
            disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
        
        # Calculate gradients
        generator_gradients = gen_tape.gradient(
            gen_total_loss, self.generator.trainable_variables)
        discriminator_gradients = disc_tape.gradient(
            disc_loss, self.discriminator.trainable_variables)
        
        # Apply gradients
        self.generator_optimizer.apply_gradients(
            zip(generator_gradients, self.generator.trainable_variables))
        self.discriminator_optimizer.apply_gradients(
            zip(discriminator_gradients, self.discriminator.trainable_variables))
        
        return {
            'gen_total_loss': gen_total_loss,
            'gen_gan_loss': gen_gan_loss,
            'gen_l1_loss': gen_l1_loss,
            'disc_loss': disc_loss
        }
    
    # Model training
    def fit(self, train_dataset, val_dataset=None, epochs=50):
        for epoch in range(epochs):
            clear_output()
            print(f'Epoch {epoch+1}/{epochs}')
            
            # Track losses for each epoch
            train_losses = {
                'gen_total_loss': [],
                'gen_gan_loss': [],
                'gen_l1_loss': [],
                'disc_loss': []
            }
            
            # Training loop
            for batch, (input_conditional, target) in enumerate(train_dataset):
                losses = self.train_step(input_conditional, target)
                
                for k, v in losses.items():
                    train_losses[k].append(v.numpy())
                
                if batch % 10 == 0:
                    print(f'Batch {batch}: Gen Loss: {losses["gen_total_loss"]:.4f}, Disc Loss: {losses["disc_loss"]:.4f}')
            
            # Print average losses for the epoch
            print(f'Epoch {epoch+1} Training Losses:')
            for k, v in train_losses.items():
                print(f'  {k}: {np.mean(v):.4f}')
            
            # Validation (if provided)
            if val_dataset is not None:
                val_losses = {
                    'gen_total_loss': [],
                    'gen_l1_loss': []
                }
                
                for input_conditional, target in val_dataset:
                    # Generate predictions
                    gen_output = self.generator(input_conditional, training=False)
                    
                    # Calculate L1 loss for validation
                    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
                    val_losses['gen_l1_loss'].append(l1_loss.numpy())
                
                # Print validation losses
                print(f'Epoch {epoch+1} Validation Losses:')
                for k, v in val_losses.items():
                    print(f'  {k}: {np.mean(v):.4f}')
            
            # Save model every 10 epochs
            if (epoch + 1) % 10 == 0:
                self.generator.save(f'../../models/GAN/checkpoint/pixelart_inpainting_generator_epoch_{epoch+1}.h5')
                
                # Generate and save sample images
                if val_dataset is not None:
                    self.generate_samples(val_dataset, epoch + 1)
    
    # Generate and save sample images
    def generate_samples(self, dataset, epoch, num_samples=4):
        for input_conditional, target in dataset.take(1):
            # Only use a few samples for visualization
            input_samples = input_conditional[:num_samples]
            target_samples = target[:num_samples]
            
            # Generate predictions
            predicted_samples = self.generator(input_samples, training=False)
            
            # Extract missing images and masks from the conditional input
            missing_samples = input_samples[:, :, :, :4]  # First 4 channels
            mask_samples = input_samples[:, :, :, 4:5]    # Last channel
            
            # Create figure
            plt.figure(figsize=(15, 4 * num_samples))
            
            for i in range(num_samples):
                # Missing image
                plt.subplot(num_samples, 4, i * 4 + 1)
                plt.imshow(missing_samples[i])
                plt.title("Missing")
                plt.axis("off")
                
                # Mask
                plt.subplot(num_samples, 4, i * 4 + 2)
                plt.imshow(tf.squeeze(mask_samples[i]), cmap='gray')
                plt.title("Mask")
                plt.axis("off")
                
                # Generated image
                plt.subplot(num_samples, 4, i * 4 + 3)
                plt.imshow(predicted_samples[i] * 0.5 + 0.5)  # Denormalize
                plt.title("Generated")
                plt.axis("off")
                
                # Target image
                plt.subplot(num_samples, 4, i * 4 + 4)
                plt.imshow(target_samples[i])
                plt.title("Target")
                plt.axis("off")
            
            plt.savefig(f'../../Output/GAN/samples_epoch_{epoch}.png')
            plt.close()
            break  # Only process one batch
    
    # Perform inpainting on new images
    def inpaint(self, missing_image, mask):
        # Ensure correct shapes and types
        if len(missing_image.shape) == 3:  # Add batch dimension if needed
            missing_image = tf.expand_dims(missing_image, 0)
        if len(mask.shape) == 3:
            mask = tf.expand_dims(mask, 0)
        
        # Create conditional input
        input_conditional = tf.concat([missing_image, mask], axis=-1)
        
        # Generate inpainted image
        generated = self.generator(input_conditional, training=False)
        
        # Remove batch dimension
        generated = tf.squeeze(generated, 0)
        
        return generated

In [23]:
# Initialize the model
model = Pix2PixGAN(lambda_l1=100)

# Train the model
model.fit(train_dataset, val_dataset, epochs=500)

Epoch 321/500


KeyboardInterrupt: 

In [None]:
model.generator.save('../../models/GAN/GAN.h5')

