<a href="https://colab.research.google.com/github/wongdongwook/JSAC_MA-DeepSC/blob/main/MDAN_for_CelebA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!git clone https://github.com/yunjey/StarGAN.git

In [None]:
!bash /content/StarGAN/download.sh celeba

In [None]:
!pip install Logger

# MDAN Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class ResidualBlock(nn.Module):
    """Residual Block with instance normalization."""
    def __init__(self, dim_in, dim_out):
        super(ResidualBlock, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True))

    def forward(self, x):
        return x + self.main(x)


class Generator(nn.Module):
    """Generator network."""
    def __init__(self, conv_dim=64, c_dim=5, repeat_num=6):
        super(Generator, self).__init__()

        layers = []
        layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
        layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True))
        layers.append(nn.ReLU(inplace=True))

        # Down-sampling layers.
        curr_dim = conv_dim
        for i in range(2):
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True, track_running_stats=True))
            layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim * 2

        # Bottleneck layers.
        for i in range(repeat_num):
            layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim)) #여기에 AFB 붙이고,

        # Up-sampling layers.
        for i in range(2):
            layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True))
            layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim // 2

        layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
        layers.append(nn.Tanh())
        self.main = nn.Sequential(*layers)

    def forward(self, x, c):
        # Replicate spatially and concatenate domain information.
        # Note that this type of label conditioning does not work at all if we use reflection padding in Conv2d.
        # This is because instance normalization ignores the shifting (or bias) effect.
        c = c.view(c.size(0), c.size(1), 1, 1)
        c = c.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat([x, c], dim=1)
        return self.main(x)


class Discriminator(nn.Module):
    """Discriminator network with PatchGAN."""
    def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
        super(Discriminator, self).__init__()
        layers = []
        layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
        layers.append(nn.LeakyReLU(0.01))

        curr_dim = conv_dim
        for i in range(1, repeat_num):
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.01))
            curr_dim = curr_dim * 2

        kernel_size = int(image_size / np.power(2, repeat_num))
        self.main = nn.Sequential(*layers)
        self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, bias=False)

    def forward(self, x):
        h = self.main(x)
        out_src = self.conv1(h)
        out_cls = self.conv2(h)
        return out_src, out_cls.view(out_cls.size(0), out_cls.size(1))

# Solver

In [None]:
from torch.autograd import Variable
from torchvision.utils import save_image
import torch
import torch.nn.functional as F
import numpy as np
import os
import time
import datetime

# Model configuration.
c_dim = 5  # dimension of domain labels (1st dataset)
c2_dim = 8  # dimension of domain labels (2nd dataset)
celeba_crop_size = 178  # crop size for the CelebA dataset
rafd_crop_size = 256  # crop size for the RaFD dataset
image_size = 128  # image resolution
g_conv_dim = 64  # number of conv filters in the first layer of G
d_conv_dim = 64  # number of conv filters in the first layer of D
g_repeat_num = 6  # number of residual blocks in G
d_repeat_num = 6  # number of strided conv layers in D
lambda_cls = 1.0  # weight for domain classification loss
lambda_rec = 10.0  # weight for reconstruction loss
lambda_gp = 10.0  # weight for gradient penalty


# Training configuration
dataset = 'CelebA'  # Dataset type: 'CelebA', 'RaFD', or 'Both'
batch_size = 16  # Mini-batch size
num_iters = 200000  # Number of total iterations for training D
num_iters_decay = 100000  # Number of iterations for decaying learning rate
g_lr = 0.0001  # Learning rate for G
d_lr = 0.0001  # Learning rate for D
n_critic = 5  # Number of D updates per each G update
beta1 = 0.5  # Beta1 for Adam optimizer
beta2 = 0.999  # Beta2 for Adam optimizer
resume_iters = None  # Resume training from this step
selected_attrs = ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']  # Selected attributes for the CelebA dataset

# Test configuration
test_iters = 200000  # Test model from this step

# Miscellaneous
num_workers = 1
mode = 'train'  # Mode: 'train' or 'test'
use_tensorboard = True  # Whether to use TensorBoard

# Directories
celeba_image_dir = 'data/celeba/images'
attr_path = 'data/celeba/list_attr_celeba.txt'
rafd_image_dir = 'data/RaFD/train'
log_dir = 'stargan/logs'
model_save_dir = 'stargan/models'
sample_dir = 'stargan/samples'
result_dir = 'stargan/results'

# Step size
log_step = 10
sample_step = 1000
model_save_step = 10000
lr_update_step = 1000

# Build model

In [None]:
# if you use google TPU, this source code doesn't work. Cuz TPU is not GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using ' + str(device).upper())

def build_model(dataset, device, model_save_dir, g_conv_dim, c_dim, g_repeat_num, image_size, d_conv_dim, d_repeat_num, g_lr, beta1, beta2, d_lr):
    G, D, g_optimizer, d_optimizer = None, None, None, None
    if dataset in ['CelebA', 'RaFD']:
        G = Generator(g_conv_dim, c_dim, g_repeat_num)
        D = Discriminator(image_size, d_conv_dim, c_dim, d_repeat_num)
    elif dataset in ['Both']:
        G = Generator(g_conv_dim, c_dim + c2_dim + 2, g_repeat_num)  # 2 for mask vector.
        D = Discriminator(image_size, d_conv_dim, c_dim + c2_dim, d_repeat_num)

    g_optimizer = torch.optim.Adam(G.parameters(), g_lr, [beta1, beta2])
    d_optimizer = torch.optim.Adam(D.parameters(), d_lr, [beta1, beta2])
    print_network(G, 'G', device)
    print_network(D, 'D', device)

    G.to(device)
    D.to(device)
    return G, D, g_optimizer, d_optimizer


def print_network(model, name, device):
    num_params = 0
    for p in model.parameters():
        num_params += p.numel()
    print(model)
    print(name + " The number of parameters: {}".format(num_params))


def restore_model(G, D, resume_iters, model_save_dir, device):
    print('Loading the trained models from step {}...'.format(resume_iters))
    G_path = os.path.join(model_save_dir, '{}-G.ckpt'.format(resume_iters))
    D_path = os.path.join(model_save_dir, '{}-D.ckpt'.format(resume_iters))
    G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
    D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))

def update_lr(g_optimizer, d_optimizer, g_lr, d_lr):
    for param_group in g_optimizer.param_groups:
        param_group['lr'] = g_lr
    for param_group in d_optimizer.param_groups:
        param_group['lr'] = d_lr

def reset_grad(g_optimizer, d_optimizer):
    g_optimizer.zero_grad()
    d_optimizer.zero_grad()

def denorm(x):
    out = (x + 1) / 2
    return out.clamp_(0, 1)

def gradient_penalty(y, x, device):
    weight = torch.ones(y.size()).to(device)
    dydx = torch.autograd.grad(outputs=y,
                                inputs=x,
                                grad_outputs=weight,
                                retain_graph=True,
                                create_graph=True,
                                only_inputs=True)[0]

    dydx = dydx.view(dydx.size(0), -1)
    dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
    return torch.mean((dydx_l2norm - 1)**2)


def label2onehot(labels, dim):
    batch_size = labels.size(0)
    out = torch.zeros(batch_size, dim)
    out[np.arange(batch_size), labels.long()] = 1
    return out

def create_labels(c_org, c_dim, dataset, selected_attrs, device):
    c_trg_list = []
    if dataset == 'CelebA':
        hair_color_indices = [i for i, attr_name in enumerate(selected_attrs) if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']]
        for i in range(c_dim):
            c_trg = c_org.clone()
            if i in hair_color_indices:  # Set one hair color to 1 and the rest to 0.
                c_trg[:, i] = 1
                for j in hair_color_indices:
                    if j != i:
                        c_trg[:, j] = 0
            else:
                c_trg[:, i] = (c_trg[:, i] == 0)  # Reverse attribute value.
            c_trg_list.append(c_trg.to(device))
    elif dataset == 'RaFD':
        for i in range(c_dim):
            c_trg = label2onehot(torch.ones(c_org.size(0)) * i, c_dim).to(device)
            c_trg_list.append(c_trg)
    return c_trg_list

def classification_loss(logit, target, dataset):
    if dataset == 'CelebA':
        return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0)
    elif dataset == 'RaFD':
        return F.cross_entropy(logit, target)


G, D, g_optimizer, d_optimizer = None, None, None, None
if dataset in ['CelebA', 'RaFD']:
    G = Generator(g_conv_dim, c_dim, g_repeat_num)
    D = Discriminator(image_size, d_conv_dim, c_dim, d_repeat_num)
elif dataset in ['Both']:
    G = Generator(g_conv_dim, c_dim + c2_dim + 2, g_repeat_num)  # 2 for mask vector.
    D = Discriminator(image_size, d_conv_dim, c_dim + c2_dim, d_repeat_num)

g_optimizer = torch.optim.Adam(G.parameters(), g_lr, [beta1, beta2])
d_optimizer = torch.optim.Adam(D.parameters(), d_lr, [beta1, beta2])
print_network(G, 'G', device)
print_network(D, 'D', device)

G.to(device)
D.to(device)


In [None]:
import logging
logger = logging.getLogger('example_logger')
logger.setLevel(logging.DEBUG)  # 디버그 및 그 이상의 모든 로그를 캡처

# Dataset preparation

In [None]:
from torch.utils import data
from torchvision import transforms as T
from torchvision.datasets import ImageFolder
from PIL import Image
import torch
import os
import random


class CelebA(data.Dataset):
    """Dataset class for the CelebA dataset."""

    def __init__(self, image_dir, attr_path, selected_attrs, transform, mode):
        """Initialize and preprocess the CelebA dataset."""
        self.image_dir = image_dir
        self.attr_path = attr_path
        self.selected_attrs = selected_attrs
        self.transform = transform
        self.mode = mode
        self.train_dataset = []
        self.test_dataset = []
        self.attr2idx = {}
        self.idx2attr = {}
        self.preprocess()

        if mode == 'train':
            self.num_images = len(self.train_dataset)
        else:
            self.num_images = len(self.test_dataset)

    def preprocess(self):
        """Preprocess the CelebA attribute file."""
        lines = [line.rstrip() for line in open(self.attr_path, 'r')]
        all_attr_names = lines[1].split()
        for i, attr_name in enumerate(all_attr_names):
            self.attr2idx[attr_name] = i
            self.idx2attr[i] = attr_name

        lines = lines[2:]
        random.seed(1234)
        random.shuffle(lines)
        for i, line in enumerate(lines):
            split = line.split()
            filename = split[0]
            values = split[1:]

            label = []
            for attr_name in self.selected_attrs:
                idx = self.attr2idx[attr_name]
                label.append(values[idx] == '1')

            if (i+1) < 2000:
                self.test_dataset.append([filename, label])
            else:
                self.train_dataset.append([filename, label])

        print('Finished preprocessing the CelebA dataset...')

    def __getitem__(self, index):
        """Return one image and its corresponding attribute label."""
        dataset = self.train_dataset if self.mode == 'train' else self.test_dataset
        filename, label = dataset[index]
        image = Image.open(os.path.join(self.image_dir, filename))
        return self.transform(image), torch.FloatTensor(label)

    def __len__(self):
        """Return the number of images."""
        return self.num_images


def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128,
               batch_size=16, dataset='CelebA', mode='train', num_workers=1):
    """Build and return a data loader."""
    transform = []
    if mode == 'train':
        transform.append(T.RandomHorizontalFlip())
    transform.append(T.CenterCrop(crop_size))
    transform.append(T.Resize(image_size))
    transform.append(T.ToTensor())
    transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
    transform = T.Compose(transform)

    if dataset == 'CelebA':
        dataset = CelebA(image_dir, attr_path, selected_attrs, transform, mode)
    elif dataset == 'RaFD':
        dataset = ImageFolder(image_dir, transform)

    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=batch_size,
                                  shuffle=(mode=='train'),
                                  num_workers=num_workers)
    return data_loader


# Data loader.
celeba_loader = None
rafd_loader = None

if dataset in ['CelebA', 'Both']:
        celeba_loader = get_loader(celeba_image_dir, attr_path, selected_attrs,
                                   celeba_crop_size, image_size, batch_size,
                                   'CelebA', mode, num_workers)

if dataset in ['RaFD', 'Both']:
    rafd_loader = get_loader(rafd_image_dir, None, None,
                              rafd_crop_size, image_size, batch_size,
                              'RaFD', mode, num_workers)



# Training

In [None]:
data_loader = celeba_loader  # Initialize data loader for CelebA dataset (use rafd_loader for RaFD dataset)

# Fetch fixed inputs for debugging
data_iter = iter(data_loader)  # Create an iterator from the data loader
x_fixed, c_org = next(data_iter)  # Get the first batch from the data loader
x_fixed = x_fixed.to(device)
c_fixed_list = create_labels(c_org, c_dim, dataset, selected_attrs, device)  # Generate target domain labels from original labels

# Learning rate cache for decay
g_lr = g_lr  # Generator learning rate
d_lr = d_lr  # Discriminator learning rate

# Start training from scratch or resume
start_iters = 0  # Initialize starting iteration
if resume_iters:
    start_iters = resume_iters  # Resume from saved iteration if available
    restore_model(resume_iters)  # Load model checkpoint if resuming

# Start training
print('Start training...')
start_time = time.time()
for i in range(start_iters, num_iters):  # Main training loop

    # ========================================== #
    #             1. Preprocess Input Data        #
    # ========================================== #
    try:
        x_real, label_org = next(data_iter)  # Fetch a batch
    except:
        data_iter = iter(data_loader)        # If iterator is exhausted, restart
        x_real, label_org = next(data_iter)

    # Generate random target domain labels
    rand_idx = torch.randperm(label_org.size(0))
    label_trg = label_org[rand_idx]

    if dataset == 'CelebA':
        c_org = label_org.clone()
        c_trg = label_trg.clone()
    elif dataset == 'RaFD':
        c_org = label2onehot(label_org, c_dim)
        c_trg = label2onehot(label_trg, c_dim)

    x_real = x_real.to(device)
    c_org = c_org.to(device)
    c_trg = c_trg.to(device)
    label_org = label_org.to(device)
    label_trg = label_trg.to(device)

    # ========================================== #
    #         2. Train the Discriminator          #
    # ========================================== #
    out_src, out_cls = D(x_real)
    d_loss_real = - torch.mean(out_src)
    d_loss_cls = classification_loss(out_cls, label_org, dataset)

    x_fake = G(x_real, c_trg)
    out_src, out_cls = D(x_fake.detach())
    d_loss_fake = torch.mean(out_src)

    alpha = torch.rand(x_real.size(0), 1, 1, 1).to(device)
    x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
    out_src, _ = D(x_hat)
    d_loss_gp = gradient_penalty(out_src, x_hat, device)

    d_loss = d_loss_real + d_loss_fake + lambda_cls * d_loss_cls + lambda_gp * d_loss_gp
    reset_grad(g_optimizer, d_optimizer)
    d_loss.backward()
    d_optimizer.step()

    loss = {}
    loss['D/loss_real'] = d_loss_real.item()
    loss['D/loss_fake'] = d_loss_fake.item()
    loss['D/loss_cls'] = d_loss_cls.item()
    loss['D/loss_gp'] = d_loss_gp.item()

    # ========================================== #
    #            3. Train the Generator           #
    # ========================================== #
    if (i+1) % n_critic == 0:
        x_fake = G(x_real, c_trg)
        out_src, out_cls = D(x_fake)
        g_loss_fake = - torch.mean(out_src)
        g_loss_cls = classification_loss(out_cls, label_trg, dataset)

        x_reconst = G(x_fake, c_org)
        g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))

        g_loss = g_loss_fake + lambda_rec * g_loss_rec + lambda_cls * g_loss_cls
        reset_grad(g_optimizer, d_optimizer)
        g_loss.backward()
        g_optimizer.step()

        loss['G/loss_fake'] = g_loss_fake.item()
        loss['G/loss_rec'] = g_loss_rec.item()
        loss['G/loss_cls'] = g_loss_cls.item()

    # ========================================== #
    #               4. Miscellaneous              #
    # ========================================== #

    # Print training status
    if (i+1) % log_step == 0:
        et = time.time() - start_time
        et = str(datetime.timedelta(seconds=et))[:-7]
        log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, num_iters)
        for tag, value in loss.items():
            log += ", {}: {:.4f}".format(tag, value)
        print(log)

    # Save sample images for debugging
    if (i+1) % sample_step == 0:
        with torch.no_grad():
            x_fake_list = [x_fixed]
            for c_fixed in c_fixed_list:
                x_fake_list.append(G(x_fixed, c_fixed))
            x_concat = torch.cat(x_fake_list, dim=3)
            sample_path = os.path.join(sample_dir, '{}-images.jpg'.format(i+1))
            save_image(denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
            print('Saved real and fake images into {}...'.format(sample_path))

    # Save model checkpoints
    if (i+1) % model_save_step == 0:
        G_path = os.path.join(model_save_dir, '{}-G.ckpt'.format(i+1))
        D_path = os.path.join(model_save_dir, '{}-D.ckpt'.format(i+1))
        torch.save(G.state_dict(), G_path)
        torch.save(D.state_dict(), D_path)
        print('Saved model checkpoints into {}...'.format(model_save_dir))

    # Decay learning rate
    if (i+1) % lr_update_step == 0 and (i+1) > (num_iters - num_iters_decay):
        g_lr -= (g_lr / float(num_iters_decay))
        d_lr -= (d_lr / float(num_iters_decay))
        update_lr(g_lr, d_lr)
        print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))


# Printing Image

In [None]:
# Load the trained generator and discriminator
restore_model(G, D, resume_iters, model_save_dir, device)

with torch.no_grad():
    for i, (x_real, c_org) in enumerate(data_loader):

        # Prepare input images and target domain labels.
        x_real = x_real.to(device)
        c_trg_list = create_labels(c_org, c_dim, dataset, selected_attrs, device)

        # Translate images.
        x_fake_list = [x_real]
        for c_trg in c_trg_list:
            x_fake_list.append(G(x_real, c_trg))

        # Save the translated images.
        x_concat = torch.cat(x_fake_list, dim=3)
        result_path = os.path.join(result_dir, '{}-images.jpg'.format(i+1))
        save_image(denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
        print('Saved real and fake images into {}...'.format(result_path))
