<a href="https://colab.research.google.com/github/wongdongwook/JSAC_MA-DeepSC/blob/main/MASCN_and_MDAN_for_CelebA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook integrates two data adaptation architectures into a single, unified implementation:

* MASCN (Multidomain Adaptive Semantic Communication Network)
* MDAN (Multidomain Data Adaptation Network)

The implementation is structured to support simultaneous execution of all architectures in a streamlined flow.

⚙️ Notes on Structure
* Shared components across the different models are marked with the suffix (s)  for clarity (e.g., shared encoders, loss modules).
* Components or steps labeled with Original are specific to the original architecture (i.e., non-shared or model-specific elements).
* The framework is designed to execute all three adaptation strategies together, allowing easy comparison and switching during experimentation.

# Download Image (Shared)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!git clone https://github.com/yunjey/StarGAN.git

In [None]:
!bash /content/StarGAN/download.sh celeba

# Channel Model (S)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
## Power normalization before transmission
# Note: if P = 1, the symbol power is 2
# If you want to set the average power as 1, please change P as P=1/np.sqrt(2)
def Power_norm(z, P = 1):
    batch_size, z_dim = z.shape
    z_power = torch.sqrt(torch.sum(z**2, 1))
    z_M = z_power.repeat(z_dim, 1)
    return np.sqrt(P*z_dim)*z/z_M.t()

def Power_norm_complex(z, P = 1):
    batch_size, z_dim = z.shape
    z_com = torch.complex(z[:, 0:z_dim:2], z[:, 1:z_dim:2])
    z_com_conj = torch.complex(z[:, 0:z_dim:2], -z[:, 1:z_dim:2])
    z_power = torch.sum(z_com*z_com_conj, 1).real
    z_M = z_power.repeat(z_dim//2, 1)
    z_nlz = np.sqrt(P*z_dim)*z_com/torch.sqrt(z_M.t())
    z_out = torch.zeros(batch_size, z_dim).cuda()
    z_out[:, 0:z_dim:2] = z_nlz.real
    z_out[:, 1:z_dim:2] = z_nlz.imag
    return z_out

# The (real) AWGN channel
def AWGN_channel(x, snr, P = 2):
    batch_size, length = x.shape
    gamma = 10 ** (snr / 10.0)
    noise = torch.sqrt(P/gamma)*torch.randn(batch_size, length).cuda()
    y = x+noise
    return y

def AWGN_complex(x, snr, Ps = 1):
    batch_size, length = x.shape
    gamma = 10 ** (snr / 10.0)
    n_I = torch.sqrt(Ps/gamma)*torch.randn(batch_size, length).cuda()
    n_R = torch.sqrt(Ps/gamma)*torch.randn(batch_size, length).cuda()
    noise = torch.complex(n_I, n_R)
    y = x + noise
    return y

# Please set the symbol power if it is not a default value
def Fading_channel(x, snr, P = 2):
    gamma = 10 ** (snr / 10.0)
    [batch_size, feature_length] = x.shape
    K = feature_length//2

    h_I = torch.randn(batch_size, K).cuda()
    h_R = torch.randn(batch_size, K).cuda()
    h_com = torch.complex(h_I, h_R)
    x_com = torch.complex(x[:, 0:feature_length:2], x[:, 1:feature_length:2])
    y_com = h_com*x_com

    n_I = torch.sqrt(P/gamma)*torch.randn(batch_size, K).cuda()
    n_R = torch.sqrt(P/gamma)*torch.randn(batch_size, K).cuda()
    noise = torch.complex(n_I, n_R)

    y_add = y_com + noise
    y = y_add/h_com

    y_out = torch.zeros(batch_size, feature_length).cuda()
    y_out[:, 0:feature_length:2] = y.real
    y_out[:, 1:feature_length:2] = y.imag
    return y_out

# Note: if P = 1, the symbol power is 2
# If you want to set the average power as 1, please change P as P=1/np.sqrt(2)
def Power_norm_VLC(z, cr, P = 1):
    batch_size, z_dim = z.shape
    Kv = torch.ceil(z_dim*cr).int()
    z_power = torch.sqrt(torch.sum(z**2, 1))
    z_M = z_power.repeat(z_dim, 1).cuda()
    return torch.sqrt(Kv*P)*z/z_M.t()

def AWGN_channel_VLC(x, snr, cr, P = 2):
    batch_size, length = x.shape
    gamma = 10 ** (snr / 10.0)
    mask = mask_gen(length, cr).cuda()
    noise = torch.sqrt(P/gamma)*torch.randn(1, length).cuda()
    noise = noise*mask
    y = x+noise
    return y

def Fading_channel_VLC(x, snr, cr, P = 2):
    gamma = 10 ** (snr / 10.0)
    [batch_size, feature_length] = x.shape
    K = feature_length//2

    mask = mask_gen(K, cr).cuda()
    h_I = torch.randn(batch_size, K).cuda()
    h_R = torch.randn(batch_size, K).cuda()
    h_com = torch.complex(h_I, h_R)
    x_com = torch.complex(x[:, 0:feature_length:2], x[:, 1:feature_length:2])
    y_com = h_com*x_com

    n_I = torch.sqrt(P/gamma)*torch.randn(batch_size, K).cuda()
    n_R = torch.sqrt(P/gamma)*torch.randn(batch_size, K).cuda()
    noise = torch.complex(n_I, n_R)*mask

    y_add = y_com + noise
    y = y_add/h_com

    y_out = torch.zeros(batch_size, feature_length).cuda()
    y_out[:, 0:feature_length:2] = y.real
    y_out[:, 1:feature_length:2] = y.imag
    return y_out

def Channel(z, snr, channel_type = 'AWGN'):
    z = Power_norm(z)
    if channel_type == 'AWGN':
        z = AWGN_channel(z, snr)
    elif channel_type == 'Fading':
        z = Fading_channel(z, snr)
    return z

def Channel_VLC(z, snr, cr, channel_type = 'AWGN'):
    z = Power_norm_VLC(z, cr)
    if channel_type == 'AWGN':
        z = AWGN_channel_VLC(z, snr, cr)
    elif channel_type == 'Fading':
        z = Fading_channel_VLC(z, snr, cr)
    return z

def mask_gen(N, cr, ch_max = 48): # channel 갯수에 따라 달라짐. 48,8,8
    MASK = torch.zeros(cr.shape[0], N).int()
    nc = N//ch_max
    #print(nc) # nc :64
    for i in range(0, cr.shape[0]):
        L_i = nc*torch.round(ch_max*cr[i]).int()
        MASK[i, 0:L_i] = 1
    return MASK

class ADJSCC(nn.Module):
    def __init__(self, enc_shape, Kernel_sz, Nc):
        super(ADJSCC, self).__init__()
        self.encoder = Encoder(enc_shape, Kernel_sz, Nc)
        self.decoder = Decoder(enc_shape, Kernel_sz, Nc)
    def forward(self, x, snr, channel_type = 'AWGN'):
        z = self.encoder(x, snr)
        z = Channel(z, snr, channel_type)
        out = self.decoder(z, snr)
        return out

# The DeepJSCC_V model, also called ADJSCC_V
class ADJSCC_V(nn.Module):
    def __init__(self, enc_shape, Kernel_sz, Nc):
        super(ADJSCC_V, self).__init__()
        self.encoder = Encoder(enc_shape, Kernel_sz, Nc)
        self.decoder = Decoder(enc_shape, Kernel_sz, Nc)
    def forward(self, x, snr, cr, channel_type = 'AWGN'):
        # x: [128, 3, 32, 32], snr, cr : [128, 1]
        z = self.encoder(x, snr) # encoder output depend on cr. if cr= 1, output : [128,3 *32*32 /cr]
        z = z*mask_gen(z.shape[1], cr).cuda()
        z = Channel_VLC(z, snr, cr, channel_type)
        out = self.decoder(z, snr)
        return out # [3, 32, 32]


def conv_block(c_in, c_out, k_size=4, stride=2, pad=1, use_bn=True, transpose=False):
    module = []
    if transpose:
        module.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, bias=not use_bn))
    else:
        module.append(nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=not use_bn))
    if use_bn:
        module.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*module)


class ResBlock(nn.Module):
    def __init__(self, channels):
        super(ResBlock, self).__init__()
        self.conv1 = conv_block(channels, channels, k_size=3, stride=1, pad=1, use_bn=True)
        self.conv2 = conv_block(channels, channels, k_size=3, stride=1, pad=1, use_bn=True)

    def __call__(self, x):
        x = F.relu(self.conv1(x))
        return x + self.conv2(x)


# Discriminator (S)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class ResidualBlock(nn.Module):
    """Residual Block with instance normalization."""
    def __init__(self, dim_in, dim_out):
        super(ResidualBlock, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True))

    def forward(self, x):
        return x + self.main(x)

class Discriminator(nn.Module):
    """Discriminator network with PatchGAN."""
    def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
        super(Discriminator, self).__init__()
        layers = []
        layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
        layers.append(nn.LeakyReLU(0.01))

        curr_dim = conv_dim
        for i in range(1, repeat_num):
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.01))
            curr_dim = curr_dim * 2

        kernel_size = int(image_size / np.power(2, repeat_num))
        self.main = nn.Sequential(*layers)
        self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, bias=False)

    def forward(self, x):
        h = self.main(x)
        out_src = self.conv1(h)
        out_cls = self.conv2(h)
        return out_src, out_cls.view(out_cls.size(0), out_cls.size(1))

# Model for MDAN (DA)

In [None]:
class Generator(nn.Module):
    """Generator network."""
    def __init__(self, conv_dim=64, c_dim=5, repeat_num=6, CR= 1): #
        super(Generator, self).__init__()
        # input: torch.Size([-1, 3, 128, 128]) + torch.Size([-1, 5])
        # input: torch.Size([-1, 8, 128, 128])
        conv_dim2= int(48 * CR)
        layers = []
        layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False)) # input [-1, 8, 128, 128] ->  [-1, 64, 128, 128]
        layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True))
        layers.append(nn.ReLU(inplace=True))

        # Down-sampling layers.
        curr_dim = conv_dim
        for i in range(2): # input:  [-1, 64, 128, 128]  process: 1) [-1, 128, 64, 64]  ->  2) [-1, 256, 32, 32]
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False)) # Size
            layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True, track_running_stats=True))
            layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim * 2


        # Bottleneck layers.
        for i in range(repeat_num//2): # input:[-1, 256, 32, 32], process: 1) [-1, 256, 32, 32] -> 2) [-1, 256, 32, 32] -> 3) [-1, 256, 32, 32]) *2
            layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))

        # Semantic Copmresiosn Layer 1
        layers.append(nn.Conv2d(curr_dim, conv_dim2, kernel_size=3, stride=1, padding=1, bias=False)) # [-1, 256, 32, 32]-> [-1, 48, 32, 32]
        layers.append(nn.InstanceNorm2d(conv_dim2, affine=True, track_running_stats=True))
        layers.append(nn.ReLU(inplace=True))

        # Semantic Copmresiosn Layer  2
        layers.append(nn.Conv2d(conv_dim2, curr_dim, kernel_size=3, stride=1, padding=1, bias=False)) # [-1, 48, 32, 32] -> [-1, 256, 32, 32]
        layers.append(nn.InstanceNorm2d(curr_dim, affine=True, track_running_stats=True))
        layers.append(nn.ReLU(inplace=True))

        # Bottleneck layers. # 256 * 32 * 32 = 262144
        for i in range(repeat_num//2):  # input:[-1, 256, 32, 32], process: 1) [-1, 256, 32, 32] -> 2) [-1, 256, 32, 32] -> 3) [-1, 256, 32, 32]) *2
            layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim)) #여기에 AFB 붙이고,

        # Up-sampling layers.
        for i in range(2):   # input:[-1, 256, 32, 32], process: 1) [-1, 128, 64, 64] ->  2)  [-1, 64, 128, 128]
            layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True))
            layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim // 2

        layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False)) #  input: [-1, 64, 128, 128] ->  [-1, 3, 128, 128]
        layers.append(nn.Tanh())
        self.main = nn.Sequential(*layers)

    def forward(self, x, c):
        # Replicate spatially and concatenate domain information.
        # Note that this type of label conditioning does not work at all if we use reflection padding in Conv2d.
        # This is because instance normalization ignores the shifting (or bias) effect.
        c = c.view(c.size(0), c.size(1), 1, 1)
        c = c.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat([x, c], dim=1)
        return self.main(x)

# Model for MASCN (SCN)


In [None]:
# Original Existing Works
class AF_block(nn.Module):
    def __init__(self, Nin, Nh, No):
        super(AF_block, self).__init__()
        self.fc1 = nn.Linear(Nin+1, Nh)
        self.fc2 = nn.Linear(Nh, No)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    def forward(self, x, snr):
        # out = F.adaptive_avg_pool2d(x, (1,1))
        # out = torch.squeeze(out)
        # out = torch.cat((out, snr), 1)
        if snr.shape[0]>1:
            snr = snr.squeeze()
        snr = snr.unsqueeze(1)
        mu = torch.mean(x, (2, 3))
        out = torch.cat((mu, snr), 1)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        out = out.unsqueeze(2)
        out = out.unsqueeze(3)
        out = out*x
        return out

class Encoder(nn.Module):
    """Encoder network that now accepts SNR for the AF_block."""
    def __init__(self, conv_dim=64, c_dim=5, repeat_num=3, CR=1):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList()
        conv_dim2 = int(48 * CR)
        self.layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
        self.layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True))
        self.layers.append(nn.ReLU(inplace=True))

        curr_dim = conv_dim
        for i in range(2):
            self.layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))
            self.layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True, track_running_stats=True))
            self.layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim * 2

        for i in range(repeat_num):
            self.layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
            self.layers.append(AF_block(curr_dim, curr_dim //2, curr_dim))

        self.layers.append(nn.Conv2d(curr_dim, conv_dim2, kernel_size=3, stride=1, padding=1, bias=False))
        self.layers.append(nn.InstanceNorm2d(conv_dim2, affine=True, track_running_stats=True))
        self.layers.append(nn.ReLU(inplace=True))

    def forward(self, x, snr):
        for layer in self.layers:
            if isinstance(layer, AF_block):
                x = layer(x, snr)
            else:
                x = layer(x)
        return x

class Decoder(nn.Module):
    """Decoder network that also processes SNR for the AF_block."""
    def __init__(self, conv_dim2=48, curr_dim=256, repeat_num=3):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.Conv2d(conv_dim2, curr_dim, kernel_size=3, stride=1, padding=1, bias=False))
        self.layers.append(nn.InstanceNorm2d(curr_dim, affine=True, track_running_stats=True))
        self.layers.append(nn.ReLU(inplace=True))

        for i in range(repeat_num):
            self.layers.append(AF_block(curr_dim, curr_dim //2, curr_dim))
            self.layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))

        for i in range(2):
            self.layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))
            self.layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True))
            self.layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim // 2

        self.layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
        self.layers.append(nn.Tanh())

    def forward(self, x, snr):
        for layer in self.layers:
            if isinstance(layer, AF_block):
                x = layer(x, snr)
            else:
                x = layer(x)
        return x


class Generator(nn.Module):
    """Generator network with separate encoder and decoder modules that handle SNR."""
    def __init__(self, conv_dim=64, c_dim=5, repeat_num=6, CR=1):
        super(Generator, self).__init__()
        self.CR = CR
        self.encoder = Encoder(conv_dim, c_dim, repeat_num//2, CR)
        self.decoder = Decoder(int(48 * CR), conv_dim*4, repeat_num//2)
        self.flatten = nn.Flatten()

    def forward(self, x, channel_type, snr, c):
        # Replicate spatially and concatenate domain information.
        c = c.view(c.size(0), c.size(1), 1, 1)
        c = c.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat([x, c], dim=1)

        # Pass through the encoder with SNR.
        x = self.encoder(x, snr)

        # Flatten for channel simulation.
        x = self.flatten(x)
        x = Channel(x, snr, channel_type)  # Assume this is a function that simulates a communication channel.

        # Reshape and pass through the decoder.
        out_size = int(48 * self.CR)
        x = x.view(-1, out_size, 32, 32)  # Ensure dimensions match those expected by the decoder.
        x = self.decoder(x, snr)

        return x


# Solver Configuration (S)

In [None]:
from torch.autograd import Variable
from torchvision.utils import save_image
import torch
import torch.nn.functional as F
import numpy as np
import os
import time
import datetime

# Model configuration.
c_dim = 5  # dimension of domain labels (1st dataset)
c2_dim = 8  # dimension of domain labels (2nd dataset)
celeba_crop_size = 178  # crop size for the CelebA dataset
rafd_crop_size = 256  # crop size for the RaFD dataset
image_size = 128  # image resolution
g_conv_dim = 64  # number of conv filters in the first layer of G
d_conv_dim = 64  # number of conv filters in the first layer of D
g_repeat_num = 6  # number of residual blocks in G
d_repeat_num = 6  # number of strided conv layers in D
lambda_cls = 1.0  # weight for domain classification loss
lambda_rec = 10.0  # weight for reconstruction loss
lambda_gp = 10.0  # weight for gradient penalty


# Training configuration
dataset = 'CelebA'  # Dataset type: 'CelebA', 'RaFD', or 'Both'
batch_size = 16  # Mini-batch size
num_iters = 200000  # Number of total iterations for training D
num_iters_decay = 100000  # Number of iterations for decaying learning rate
g_lr = 0.0001  # Learning rate for G
d_lr = 0.0001  # Learning rate for D
n_critic = 5  # Number of D updates per each G update
beta1 = 0.5  # Beta1 for Adam optimizer
beta2 = 0.999  # Beta2 for Adam optimizer
resume_iters = None  # Resume training from this step
selected_attrs = ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']  # Selected attributes for the CelebA dataset

# Test configuration
test_iters = 200000  # Test model from this step

# Miscellaneous
num_workers = 1
mode = 'train'  # Mode: 'train' or 'test'
use_tensorboard = True  # Whether to use TensorBoard

# Directories
celeba_image_dir = 'data/celeba/images'
attr_path = 'data/celeba/list_attr_celeba.txt'
rafd_image_dir = 'data/RaFD/train'
log_dir = '/content/drive/MyDrive/StarGAN/StarGAN_256/logs'
model_save_dir = '/content/drive/MyDrive/StarGAN/StarGAN_256/models'
sample_dir = '/content/drive/MyDrive/StarGAN/StarGAN_256/samples'
result_dir = '/content/drive/MyDrive/StarGAN/StarGAN_256/results'

# Step size
log_step = 10
sample_step = 1000
model_save_step = 10000
lr_update_step = 1000

# Create directories if they do not exist
os.makedirs(log_dir, exist_ok=True)
os.makedirs(model_save_dir, exist_ok=True)
os.makedirs(sample_dir, exist_ok=True)
os.makedirs(result_dir, exist_ok=True)

print("Directories created or verified:")
print(f"Log directory: {log_dir}")
print(f"Model save directory: {model_save_dir}")
print(f"Sample directory: {sample_dir}")
print(f"Result directory: {result_dir}")

# if you use google TPU, this source code doesn't work. Cuz TPU is not GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using ' + str(device).upper())

# Build Model

## Build model for MDAN (DA)

In [None]:
G, D, g_optimizer, d_optimizer = None, None, None, None
if dataset in ['CelebA', 'RaFD']:
    G = Generator(g_conv_dim, c_dim, g_repeat_num)
    D = Discriminator(image_size, d_conv_dim, c_dim, d_repeat_num)
elif dataset in ['Both']:
    G = Generator(g_conv_dim, c_dim + c2_dim + 2, g_repeat_num)  # 2 for mask vector.
    D = Discriminator(image_size, d_conv_dim, c_dim + c2_dim, d_repeat_num)

g_optimizer = torch.optim.Adam(G.parameters(), g_lr, [beta1, beta2])
d_optimizer = torch.optim.Adam(D.parameters(), d_lr, [beta1, beta2])
print_network(G, 'G', device)
print_network(D, 'D', device)

G.to(device)
D.to(device)


## Build model for MASCN (SCN)

In [None]:
G, D, g_optimizer, d_optimizer = None, None, None, None

if dataset in ['CelebA', 'RaFD']:
    G = Generator(g_conv_dim, c_dim, g_repeat_num, CR= 0.3)
    D = Discriminator(image_size, d_conv_dim, c_dim, d_repeat_num)
elif dataset in ['Both']:
    G = Generator(g_conv_dim, c_dim + c2_dim + 2, g_repeat_num)  # 2 for mask vector.
    D = Discriminator(image_size, d_conv_dim, c_dim + c2_dim, d_repeat_num)

g_optimizer = torch.optim.Adam(G.parameters(), g_lr, [beta1, beta2])
d_optimizer = torch.optim.Adam(D.parameters(), d_lr, [beta1, beta2])
print_network(G, 'G', device)
print_network(D, 'D', device)

G.to(device)
D.to(device)


# Dataset Preparation (S)

In [None]:
from torch.utils import data
from torchvision import transforms as T
from torchvision.datasets import ImageFolder
from PIL import Image
import torch
import os
import random


class CelebA(data.Dataset):
    """Dataset class for the CelebA dataset."""

    def __init__(self, image_dir, attr_path, selected_attrs, transform, mode):
        """Initialize and preprocess the CelebA dataset."""
        self.image_dir = image_dir
        self.attr_path = attr_path
        self.selected_attrs = selected_attrs
        self.transform = transform
        self.mode = mode
        self.train_dataset = []
        self.test_dataset = []
        self.attr2idx = {}
        self.idx2attr = {}
        self.preprocess()

        if mode == 'train':
            self.num_images = len(self.train_dataset)
        else:
            self.num_images = len(self.test_dataset)

    def preprocess(self):
        """Preprocess the CelebA attribute file."""
        lines = [line.rstrip() for line in open(self.attr_path, 'r')]
        all_attr_names = lines[1].split()
        for i, attr_name in enumerate(all_attr_names):
            self.attr2idx[attr_name] = i
            self.idx2attr[i] = attr_name

        lines = lines[2:]
        random.seed(1234)
        random.shuffle(lines)
        for i, line in enumerate(lines):
            split = line.split()
            filename = split[0]
            values = split[1:]

            label = []
            for attr_name in self.selected_attrs:
                idx = self.attr2idx[attr_name]
                label.append(values[idx] == '1')

            if (i+1) < 2000:
                self.test_dataset.append([filename, label])
            else:
                self.train_dataset.append([filename, label])

        print('Finished preprocessing the CelebA dataset...')

    def __getitem__(self, index):
        """Return one image and its corresponding attribute label."""
        dataset = self.train_dataset if self.mode == 'train' else self.test_dataset
        filename, label = dataset[index]
        image = Image.open(os.path.join(self.image_dir, filename))
        return self.transform(image), torch.FloatTensor(label)

    def __len__(self):
        """Return the number of images."""
        return self.num_images


def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128,
               batch_size=16, dataset='CelebA', mode='train', num_workers=1):
    """Build and return a data loader."""
    transform = []
    if mode == 'train':
        transform.append(T.RandomHorizontalFlip())
    transform.append(T.CenterCrop(crop_size))
    transform.append(T.Resize(image_size))
    transform.append(T.ToTensor())
    transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
    transform = T.Compose(transform)

    if dataset == 'CelebA':
        dataset = CelebA(image_dir, attr_path, selected_attrs, transform, mode)
    elif dataset == 'RaFD':
        dataset = ImageFolder(image_dir, transform)

    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=batch_size,
                                  shuffle=(mode=='train'),
                                  num_workers=num_workers,
                                  drop_last=True)
    return data_loader

In [None]:
# Data loader.
celeba_loader = None
rafd_loader = None

if dataset in ['CelebA', 'Both']:
        celeba_loader = get_loader(celeba_image_dir, attr_path, selected_attrs,
                                   celeba_crop_size, image_size, batch_size,
                                   'CelebA', mode, num_workers)

if dataset in ['RaFD', 'Both']:
    rafd_loader = get_loader(rafd_image_dir, None, None,
                              rafd_crop_size, image_size, batch_size,
                              'RaFD', mode, num_workers)



# Training

## MDAN (DA)

In [None]:
# Initialize data loader
data_loader = celeba_loader  # Assuming celeba_loader is defined somewhere

# Fetch fixed inputs for debugging
data_iter = iter(data_loader)
x_fixed, c_org = next(data_iter)
x_fixed = x_fixed.to(device)
c_fixed_list = create_labels(c_org, c_dim, dataset, selected_attrs, device)

# Learning rate cache for decaying
g_lr = g_lr
d_lr = d_lr

# Start training from scratch or resume training
start_iters = 0
if resume_iters:
    start_iters = resume_iters
    restore_model(resume_iters)


In [None]:
print('Start training...')
start_time = time.time()

# Define the frequency and scope of validation
validation_step = 10  # Adjust this value as needed
validation_batches = 100 # Number of batches to use for validation

for i in range(start_iters, num_iters):
    # Fetch real images and labels
    try:
        x_real, label_org = next(data_iter)
    except StopIteration:
        data_iter = iter(data_loader)
        x_real, label_org = next(data_iter)

    # Prepare labels for domain transfer
    rand_idx = torch.randperm(label_org.size(0))
    label_trg = label_org[rand_idx]

    if dataset == 'CelebA':
        c_org = label_org.clone()
        c_trg = label_trg.clone()
    elif dataset == 'RaFD':
        c_org = label2onehot(label_org, c_dim)
        c_trg = label2onehot(label_trg, c_dim)

    x_real = x_real.to(device)           # Input images.
    c_org = c_org.to(device)             # Original domain labels.
    c_trg = c_trg.to(device)             # Target domain labels.
    label_org = label_org.to(device)     # Labels for computing classification loss.
    label_trg = label_trg.to(device)     # Labels for computing classification loss.

    # Train the discriminator
    out_src, out_cls = D(x_real)
    d_loss_real = -torch.mean(out_src)
    d_loss_cls = classification_loss(out_cls, label_org, dataset)

    # Compute loss with fake images.
    x_fake = G(x_real, c_trg)
    out_src, out_cls = D(x_fake.detach())
    d_loss_fake = torch.mean(out_src)

    # Compute loss for gradient penalty.
    alpha = torch.rand(x_real.size(0), 1, 1, 1).to(device)
    x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
    out_src, _ = D(x_hat)
    d_loss_gp = gradient_penalty(out_src, x_hat, device)

    # Backward and optimize.
    d_loss = d_loss_real + d_loss_fake + lambda_cls * d_loss_cls + lambda_gp * d_loss_gp
    reset_grad(g_optimizer, d_optimizer)
    d_loss.backward()
    d_optimizer.step()

    # Logging.
    loss = {}
    loss['D/loss_real'] = d_loss_real.item()
    loss['D/loss_fake'] = d_loss_fake.item()
    loss['D/loss_cls'] = d_loss_cls.item()
    loss['D/loss_gp'] = d_loss_gp.item()

    # Train the generator less frequently
    if (i + 1) % n_critic == 0:
        x_fake = G(x_real, c_trg)
        out_src, out_cls = D(x_fake)
        g_loss_fake = -torch.mean(out_src)
        g_loss_cls = classification_loss(out_cls, label_trg, dataset)

        x_reconst = G(x_fake, c_org)
        g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))

        g_loss = g_loss_fake + lambda_rec * g_loss_rec + lambda_cls * g_loss_cls
        reset_grad(g_optimizer, d_optimizer)
        g_loss.backward()
        g_optimizer.step()

        # Logging.
        loss['G/loss_fake'] = g_loss_fake.item()
        loss['G/loss_rec'] = g_loss_rec.item()
        loss['G/loss_cls'] = g_loss_cls.item()

    # Print out training information and log
    if (i + 1) % log_step == 0:
        et = time.time() - start_time
        et = str(datetime.timedelta(seconds=et))[:-7]
        log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, num_iters)
        for tag, value in loss.items():
            log += ", {}: {:.4f}".format(tag, value)
        print(log)

    # Evaluate classification accuracy periodically
    if (i + 1) % validation_step == 0:
        data_iter = iter(data_loader)
        total_correct = 0
        total_num = 0
        for j in range(validation_batches):  # Limit the number of batches for validation
            try:
                x_real, label_org = next(data_iter)
            except StopIteration:
                break  # If there aren't enough batches in the data loader, break early

            if dataset == 'CelebA':
                label_org = label_org.clone()
            x_real = x_real.to(device)
            label_org = label_org.to(device)
            _, out_cls = D(x_real)
            probs = torch.sigmoid(out_cls)
            preds = (probs > 0.5).float()
            correct = (preds == label_org).float().sum()
            total_correct += correct
            total_num += label_org.numel()

        if total_num > 0:
            accuracy = total_correct / total_num
            print(f"Validation at Iteration {i+1}: Classification Accuracy = {accuracy:.4f}")
        else:
            print(f"Validation at Iteration {i+1}: No data to validate.")

    # Translate fixed images for debugging
    if (i + 1) % sample_step == 0:
        with torch.no_grad():
            x_fake_list = [x_fixed]
            for c_fixed in c_fixed_list:
                x_fake_list.append(G(x_fixed, c_fixed))
            x_concat = torch.cat(x_fake_list, dim=3)
            sample_path = os.path.join(sample_dir, '{}-images.jpg'.format(i+1))
            save_image(denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
            print('Saved real and fake images into {}...'.format(sample_path))

    # Save model checkpoints
    if (i + 1) % model_save_step == 0:
        G_path = os.path.join(model_save_dir, '{}-G.ckpt'.format(i+1))
        D_path = os.path.join(model_save_dir, '{}-D.ckpt'.format(i+1))
        torch.save(G.state_dict(), G_path)
        torch.save(D.state_dict(), D_path)
        print('Saved model checkpoints into {}...'.format(model_save_dir))

    # Decay learning rates
    if (i + 1) % lr_update_step == 0 and (i + 1) > (num_iters - num_iters_decay):
        g_lr -= (g_lr / float(num_iters_decay))
        d_lr -= (d_lr / float(num_iters_decay))
        update_lr(g_lr, d_lr)
        print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))

## MASCN (SCN)

e.g., CR=0.1

In [None]:
import torch
import os
import datetime
import time
from torch.nn.functional import binary_cross_entropy_with_logits as bce_logits_loss

# Initialize data loader
data_loader = celeba_loader  # Assuming celeba_loader is defined somewhere

# Fetch fixed inputs for debugging
data_iter = iter(data_loader)
x_fixed, c_org = next(data_iter)
x_fixed = x_fixed.to(device)
c_fixed_list = create_labels(c_org, c_dim, dataset, selected_attrs, device)

# Learning rate cache for decaying
g_lr = g_lr
d_lr = d_lr

start_iters= 0
if resume_iters:
    start_iters = resume_iters
    restore_model(resume_iters)

print('Start training...')
start_time = time.time()
ch = 'AWGN'
# Define the frequency and scope of validation
validation_step = 10  # Adjust this value as needed
validation_batches = 100  # Number of batches to use for validation

for i in range(start_iters, num_iters):
    # Fetch real images and labels
    try:
        x_real, label_org = next(data_iter)
    except StopIteration:
        data_iter = iter(data_loader)
        x_real, label_org = next(data_iter)

    # Prepare labels for domain transfer
    rand_idx = torch.randperm(label_org.size(0))
    label_trg = label_org[rand_idx]
    snr = torch.randint(0, 28, (x_real.shape[0], 1)).cuda()

    if dataset == 'CelebA':
        c_org = label_org.clone()
        c_trg = label_trg.clone()
    elif dataset == 'RaFD':
        c_org = label2onehot(label_org, c_dim)
        c_trg = label2onehot(label_trg, c_dim)

    x_real = x_real.to(device)           # Input images.
    c_org = c_org.to(device)             # Original domain labels.
    c_trg = c_trg.to(device)             # Target domain labels.
    label_org = label_org.to(device)     # Labels for computing classification loss.
    label_trg = label_trg.to(device)     # Labels for computing classification loss.


    # Train the discriminator
    # real image
    out_src, out_cls, out_bin  = D(x_real)
    d_loss_real = -torch.mean(out_src)
    d_loss_cls = classification_loss(out_cls, label_org, dataset)
    d_loss_bin_r = bce_logits_loss(out_bin, label_org)

    # Compute loss with fake images (target image).
    x_fake = G(x_real, ch, snr, c_trg)
    out_src, out_cls, out_bin = D(x_fake.detach())
    d_loss_fake = torch.mean(out_src)
    d_loss_bin_f = bce_logits_loss(out_bin, label_trg)

    # Compute loss for gradient penalty.
    alpha = torch.rand(x_real.size(0), 1, 1, 1).to(device)
    x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
    out_src, _, _ = D(x_hat)
    d_loss_gp = gradient_penalty(out_src, x_hat, device)
    d_loss_prg= d_loss_bin_f +d_loss_bin_r

    # Backward and optimize.
    d_loss = d_loss_real + d_loss_fake + lambda_cls * d_loss_cls + lambda_gp * d_loss_gp +d_loss_prg
    reset_grad(g_optimizer, d_optimizer)
    d_loss.backward()
    d_optimizer.step()

    # Logging.
    loss = {}
    loss['D/loss_real'] = d_loss_real.item()
    loss['D/loss_fake'] = d_loss_fake.item()
    loss['D/loss_cls'] = d_loss_cls.item()
    loss['D/loss_gp'] = d_loss_gp.item()
    loss['D/loss_prg'] = d_loss_prg.item()


    # Train the generator less frequently
    if (i + 1) % n_critic == 0:
        x_fake = G(x_real, ch, snr, c_trg)
        out_src, out_cls, out_bin = D(x_fake)
        g_loss_fake = -torch.mean(out_src)
        g_loss_cls = classification_loss(out_cls, label_trg, dataset)
        d_loss_bin_f = bce_logits_loss(out_bin, label_trg)

        x_reconst = G(x_fake, ch, snr, c_org)
        g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))
        d_loss_bin_r = bce_logits_loss(out_bin, label_org)
        d_loss_prg= d_loss_bin_f +d_loss_bin_r


        g_loss = g_loss_fake + lambda_rec * g_loss_rec + lambda_cls * g_loss_cls + d_loss_prg
        reset_grad(g_optimizer, d_optimizer)
        g_loss.backward()
        g_optimizer.step()

        # Logging.
        loss['G/loss_fake'] = g_loss_fake.item()
        loss['G/loss_rec'] = g_loss_rec.item()
        loss['G/loss_cls'] = g_loss_cls.item()
        loss['D/loss_prg'] = d_loss_prg.item()


    # Print out training information and log
    if (i + 1) % log_step == 0:
        et = time.time() - start_time
        et = str(datetime.timedelta(seconds=et))[:-7]
        log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, num_iters)
        for tag, value in loss.items():
            log += ", {}: {:.4f}".format(tag, value)
        print(log)

    # Evaluate binary classification accuracy periodically
    if (i + 1) % 10 == 0:
        total_correct = 0
        total_images = 0
        for _ in range(validation_batches):
            try:
                x_val, labels_val = next(data_iter)
            except StopIteration:
                data_iter = iter(data_loader)
                x_val, labels_val = next(data_iter)

            x_val = x_val.to(device)
            labels_val = labels_val.to(device)
            # Assuming labels_val are in the same format expected by out_bin (binary classification labels)
            _, _, out_bin_val = D(x_val)
            preds = torch.sigmoid(out_bin_val) > 0.5  # Convert probabilities to binary predictions
            correct = (preds == labels_val).all(dim=1).float().sum().item()  # Correct predictions
            total_correct += correct
            total_images += x_val.size(0)

        accuracy = total_correct / total_images
        print(f"Validation at Iteration {i+1}: Binary Classification Accuracy = {accuracy:.4f} for batch size {total_images}")
    # Translate fixed images for debugging
    if (i + 1) % 300 == 0:
        with torch.no_grad():
            x_fake_list = [x_fixed]
            for c_fixed in c_fixed_list:
                x_fake_list.append(G(x_fixed, ch, snr, c_fixed))
            x_concat = torch.cat(x_fake_list, dim=3)
            sample_path = os.path.join(sample_dir, '{}-images.jpg'.format(i+1))
            save_image(denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
            print('Saved real and fake images into {}...'.format(sample_path))

    # Save model checkpoints
    if (i + 1) % model_save_step == 0:
        G_path = os.path.join(model_save_dir, '{}-G.ckpt'.format(i+1))
        D_path = os.path.join(model_save_dir, '{}-D.ckpt'.format(i+1))
        torch.save(G.state_dict(), G_path)
        torch.save(D.state_dict(), D_path)
        print('Saved model checkpoints into {}...'.format(model_save_dir))

    # Decay learning rates
    if (i + 1) % lr_update_step == 0 and (i + 1) > (num_iters - num_iters_decay):
        g_lr -= (g_lr / float(num_iters_decay))
        d_lr -= (d_lr / float(num_iters_decay))
        update_lr(g_lr, d_lr)
        print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))

# Testing

In [None]:
G = None
G = Generator(g_conv_dim, c_dim, g_repeat_num, CR= 0.7)
G.to(device)
resume_iters= 90000
start_iters = resume_iters
model_save_dir = '/content/drive/MyDrive/StarGAN/CelebA_SCN_StarGAN_0.7/models'
restore_model(G,D, resume_iters, model_save_dir, device)


In [None]:
from torch.autograd import Variable
from torchvision.utils import save_image

with torch.no_grad():
    x_fake_list = [x_fixed.to(device)]  # Move x_fixed to the specified device (e.g., GPU)
    batch_size = x_fixed.size(0)        # Get the batch size from x_fixed

    for c_fixed in c_fixed_list:
        c_fixed = c_fixed.to(device)    # Move c_fixed to the device
        snr = 3 * torch.ones(batch_size, 1, device=device)  # Generate a tensor of SNR=3 for the entire batch
        x_fake_list.append(G(x_fixed, 'AWGN', snr, c_fixed).to(device))  # Generate fake images and move to device

    x_concat = torch.cat(x_fake_list, dim=3)  # Concatenate all real and fake images horizontally
    sample_path = os.path.join(sample_dir, 'test_images.jpg')  # Define path for saving the image
    save_image(denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)  # Denormalize and save to disk
    print('Saved real and fake images into {}...'.format(sample_path))  # Print save location


In [None]:
from torchvision.transforms.functional import to_pil_image
import matplotlib.pyplot as plt

# List of target domain attributes (used for visualization)
attr_list = ['Black Hair', 'Blond Hair', 'Gender', 'Age']

def plot_domain_adaptation_results(src_img, c_fixed_list, CR_test, SNR_test, dst_domains, device):
    if src_img.dim() == 3:
        src_img = src_img.unsqueeze(0)  # Add batch dimension if necessary

    src_img = src_img.to(device)
    SNR = (SNR_test * torch.ones(src_img.shape[0], 1)).to(device)  # Generate SNR tensor
    CR = (CR_test * torch.ones(src_img.shape[0], 1)).to(device)    # Generate CR tensor

    fig, axs = plt.subplots(1, len(dst_domains) + 1, figsize=(18, 6))  # One row, multiple columns

    # Normalize source image for visualization ([−1,1] → [0,1])
    src_img_display = (src_img + 1) / 2

    # Show the input image in the first column
    axs[0].imshow(to_pil_image(src_img_display[0].cpu()))
    axs[0].set_title(f'Input (SNR = {SNR_test}, CR = {CR_test})')
    axs[0].axis('off')

    # Load the generator (MASCN in this case)
    MSCN = G
    MSCN.eval()

    # Iterate over each target domain label
    for i, c_fixed in enumerate(c_fixed_list):
        c_fixed = c_fixed.to(device)  # Move c_fixed to the device
        DA_img_starGAN = MSCN(src_img, 'AWGN', SNR, c_fixed)  # No need to normalize again

        axs[i + 1].imshow(to_pil_image(denorm(DA_img_starGAN[0].data.cpu())))  # Display the adapted image
        axs[i + 1].set_title(f'{attr_list[i]}')
        axs[i + 1].axis('off')

    plt.tight_layout()
    plt.show()

# Example function call
plot_domain_adaptation_results(x_fixed, c_fixed_list, 0.6, 3, dst_domains=['Black_Hair', 'Blond_Hair', 'Gender', 'Age'], device=device)


In [None]:
from torch.autograd import Variable
from torchvision.utils import save_image
import torch
import os

# Define denormalization function: convert image range from [-1, 1] to [0, 1]
def denorm(x):
    out = (x + 1) / 2
    return out.clamp_(0, 1)

# Generate and save fake images
with torch.no_grad():
    x_fake_list = [x_fixed.to(device)]  # Move x_fixed to the designated device (e.g., GPU)
    batch_size = x_fixed.size(0)        # Get the batch size

    for c_fixed in c_fixed_list:
        c_fixed = c_fixed.to(device)    # Move c_fixed to device
        snr = 3 * torch.ones(batch_size, 1, device=device)  # Generate a constant SNR tensor
        x_fake = G(x_fixed, 'AWGN', snr, c_fixed)           # Generate fake image
        x_fake_list.append(x_fake)

    # Concatenate real and fake images horizontally, then denormalize
    x_concat = torch.cat(x_fake_list, dim=3)
    x_concat = denorm(x_concat.data.cpu())  # Move to CPU before saving

    # Save image to file
    sample_path = os.path.join(sample_dir, 'test_images.jpg')
    save_image(x_concat, sample_path, nrow=1, padding=0)
    print('Saved real and fake images into {}...'.format(sample_path))

# Load CelebA dataset
celeba_loader = get_loader(celeba_image_dir, attr_path, selected_attrs,
                           celeba_crop_size, image_size, batch_size,
                           'CelebA', mode, num_workers)

# Skip to the 20th batch
i_stop = 20
for i, (x_fixed, c_org) in enumerate(celeba_loader):
    if i < i_stop:
        continue
    x_fixed = x_fixed.to(device)
    c_org = c_org.to(device)
    break

# Create modified target domain labels from original attributes
c_fixed_list = create_labels(c_org, c_dim, dataset, selected_attrs, device)

print(c_org[0])        # Print original domain label for the first sample
c_fixed_list.pop(2)    # Remove the third attribute for visualization
print()
