<a href="https://colab.research.google.com/github/wongdongwook/JSAC_MA-DeepSC/blob/main/CyclegAN_(DA)_for_4digits.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 4-Digits CycleGAN: Data Domain Adaptation via CycleGAN  
This Colab notebook implements domain adaptation experiments on 4 digits datasets—**MNIST, MNIST-M, SYN, USPS**—using a CycleGAN-based approach.  

The implementation is inspired by and aligns with the method described in:  
📄 **Deep Learning-Enabled Semantic Communication Systems with Task-Unaware Transmitter and Dynamic Data**
> [arXiv:2205.00271](https://arxiv.org/abs/2205.00271)

### 🔍 Key Features:
- **Unpaired image-to-image translation** between domains using CycleGAN.
- **Domain combinations**: MNIST ↔ SYN, MNIST ↔ USPS, SYN ↔ USPS.
- **Cycle-consistency loss** to preserve image semantics.
- **PatchGAN Discriminator** for high-frequency detail preservation.
- **Digit classification** model to evaluate semantic preservation after translation.

This notebook is designed for adaptation studies in low-resource or visually distinct digit domains.


# Configuration

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils
from torch import optim
from torchvision.transforms.functional import to_pil_image
from torchvision import datasets
from torchvision import transforms
import numpy as np
import os
import math
from tqdm import tqdm
import datetime
from sklearn.metrics import accuracy_score, confusion_matrix
import pandas as pd
from google.colab import drive

In [None]:
# GPU Compatibility
is_cuda = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using ' + str(device).upper())
drive.mount("/content/drive")

Using CUDA
Mounted at /content/drive


# Data Loader

In [None]:
class ImgDomainAdaptationData(torch.utils.data.Dataset):
    def __init__(self, path_A, path_B, id_A, id_B, w, h):

        self.transform = transforms.Compose([transforms.Resize([w, h]),
                                            transforms.Normalize([0.5], [0.5])])

        self.data_A = torch.load(path_A)
        self.data_B = torch.load(path_B)

        self.img_A = self.transform(self.data_A[0])
        self.img_B = self.transform(self.data_B[0])

        self.label_A = self.data_A[1]
        self.label_B = self.data_B[1]

        self.img_A, self.domain_A = self.pre_processing(self.img_A, id_A)
        self.img_B, self.domain_B = self.pre_processing(self.img_B, id_B)

        self.len = min(self.label_A.shape[0], self.label_B.shape[0])


    def pre_processing(self, img, domain):
        num_img = img.shape[0]

        if len(img.shape) < 4:
            img = img.unsqueeze(1).repeat(1, 3, 1, 1)

        domain_label = np.zeros(num_img, dtype=int) + domain

        return img, domain_label

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        return (self.img_A[index], self.label_A[index]), (self.img_B[index], self.label_B[index])

# Utils

In [None]:
def generate_imgs(a, b, ab_gen, ba_gen, samples_path, a_name, b_name, epoch=0):
    ab_gen.eval()
    ba_gen.eval()

    b_fake = ab_gen(a)
    a_fake = ba_gen(b)

    a_imgs = torch.zeros((a.shape[0] * 2, 3, a.shape[2], a.shape[3]))
    b_imgs = torch.zeros((b.shape[0] * 2, 3, b.shape[2], b.shape[3]))

    even_idx = torch.arange(start=0, end=a.shape[0] * 2, step=2)
    odd_idx = torch.arange(start=1, end=a.shape[0] * 2, step=2)

    a_imgs[even_idx] = a.cpu()
    a_imgs[odd_idx] = b_fake.cpu()

    b_imgs[even_idx] = b.cpu()
    b_imgs[odd_idx] = a_fake.cpu()

    rows = math.ceil((a.shape[0] * 2) ** 0.5)
    a_imgs_ = vutils.make_grid(a_imgs, normalize=True, nrow=rows)
    b_imgs_ = vutils.make_grid(b_imgs, normalize=True, nrow=rows)

    vutils.save_image(a_imgs_, os.path.join(samples_path[0], f'{a_name}_to_{b_name}_ep_{epoch}.png'))
    vutils.save_image(b_imgs_, os.path.join(samples_path[1], f'{b_name}_to_{a_name}_ep_{epoch}.png'))

# CycleGAN Model

In [None]:
def conv_block(c_in, c_out, k_size=4, stride=2, pad=1, use_bn=True, transpose=False):
    module = []
    if transpose:
        module.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, output_padding=pad, bias=not use_bn))
    else:
        module.append(nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=not use_bn))

    if use_bn:
        module.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*module)


class ResBlock(nn.Module):
    def __init__(self, channels):
        super(ResBlock, self).__init__()
        self.conv1 = conv_block(channels, channels, k_size=3, stride=1, pad=1, use_bn=True)
        self.conv2 = conv_block(channels, channels, k_size=3, stride=1, pad=1, use_bn=True)

    def __call__(self, x):
        x = F.relu(self.conv1(x))
        return x + self.conv2(x)


class Discriminator(nn.Module):
    def __init__(self, channels=3, conv_dim=64):
        super(Discriminator, self).__init__()
        self.conv1 = conv_block(channels, conv_dim, use_bn=False)
        self.conv2 = conv_block(conv_dim, conv_dim * 2)
        self.conv3 = conv_block(conv_dim * 2, conv_dim * 4)
        self.conv4 = conv_block(conv_dim * 4, 1, k_size=3, stride=1, pad=1, use_bn=False)

        # Initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        alpha = 0.2
        x = F.leaky_relu(self.conv1(x), alpha)
        x = F.leaky_relu(self.conv2(x), alpha)
        x = F.leaky_relu(self.conv3(x), alpha)
        x = self.conv4(x)
        x = x.reshape([x.shape[0], -1]).mean(1)
        return x


class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, conv_dim=64):
        super(Generator, self).__init__()
        self.conv1 = conv_block(in_channels, conv_dim, k_size=5, stride=1, pad=2, use_bn=True)
        self.conv2 = conv_block(conv_dim, conv_dim * 2, k_size=3, stride=2, pad=1, use_bn=True)
        self.conv3 = conv_block(conv_dim * 2, conv_dim * 4, k_size=3, stride=2, pad=1, use_bn=True)
        self.res4 = ResBlock(conv_dim * 4)
        self.tconv5 = conv_block(conv_dim * 4, conv_dim * 2, k_size=3, stride=2, pad=1, use_bn=True, transpose=True)
        self.tconv6 = conv_block(conv_dim * 2, conv_dim, k_size=3, stride=2, pad=1, use_bn=True, transpose=True)
        self.conv7 = conv_block(conv_dim, out_channels, k_size=5, stride=1, pad=2, use_bn=False)

        # Initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.res4(x))
        x = F.relu(self.tconv5(x))
        x = F.relu(self.tconv6(x))
        x = torch.tanh(self.conv7(x))
        return x

# Training Settings

In [None]:
MNIST_train_path = '/content/drive/My Drive/CycleGAN/4-digit dataset/MNIST_train.pt'
SYN_train_path = '/content/drive/My Drive/CycleGAN/4-digit dataset/SYN_train.pt'
USPS_train_path = '/content/drive/My Drive/CycleGAN/4-digit dataset/USPS_train.pt'

ds_path = [MNIST_train_path, SYN_train_path, USPS_train_path]

MNIST_domain_id = 0
SYN_domain_id = 1
USPS_domain_id = 2

DS_NAME = ["MNIST", "SYN", "USPS"]


EPOCHS = 50  # 50-300
N_CRITIC = 5
BATCH_SIZE = 128
IMGS_TO_DISPLAY = 32

IMAGE_SIZE = 32
NUM_DOMAINS = 2

GRADIENT_PENALTY = 10
CONV_DIM = 12

model_path = '/content/drive/My Drive/CycleGAN/model'
samples_path = model_path = '/content/drive/My Drive/CycleGAN/samples'
os.makedirs(model_path, exist_ok=True)
os.makedirs(samples_path, exist_ok=True)

# Training Loop

In [None]:
class EarlyStopping:
    """Early stops the training if cycle consistency loss doesn't improve after a given patience."""
    def __init__(self, patience=10, verbose=False, epsilon=1.001):
        """
        Args:
            patience (int): How long to wait after last time cycle consistency loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                           Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.cycle_loss_min = np.inf  # Changed from np.Inf to np.inf
        self.epsilon = epsilon

    def __call__(self, cycle_loss, gen, dis):

        score = -cycle_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(cycle_loss, gen, dis)
        elif score < self.best_score / self.epsilon:
            self.counter += 1
            print(f'\nCurrent cycle consistency loss {cycle_loss:.6f} > {self.cycle_loss_min:.6f}/{self.epsilon} = {self.cycle_loss_min/self.epsilon:.6f}')
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}\n')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(cycle_loss, gen, dis)
            self.counter = 0

    def save_checkpoint(self, cycle_loss, gen, dis):
        '''Saves model when the cycle consistency loss decrease.'''

        gen_ab, gen_ba = gen
        dis_a, dis_b = dis

        if self.verbose:
            print(f'\ncycle consistency loss decreased ({self.cycle_loss_min:.6f} --> {cycle_loss:.6f}).  Saving model ...\n')
        # Note: Here you should define how you want to save your model. For example:
        # Ensure dm_a and dm_b are accessible here (they are global in the provided context)
        torch.save(gen_ab.state_dict(), os.path.join(model_path, f'gen_{DS_NAME[dm_a]}_{DS_NAME[dm_b]}.pkl'))
        torch.save(gen_ba.state_dict(), os.path.join(model_path, f'gen_{DS_NAME[dm_b]}_{DS_NAME[dm_a]}.pkl'))

        torch.save(dis_a.state_dict(), os.path.join(model_path, f'dis_{DS_NAME[dm_a]}.pkl'))
        torch.save(dis_b.state_dict(), os.path.join(model_path, f'dis_{DS_NAME[dm_b]}.pkl'))

        self.cycle_loss_min = cycle_loss

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm

# ab_combinations: each row is [src_domain, tgt_domain]
ab_combinations = np.array([[0, 1],
                            [0, 2],
                            [1, 2]])
# Example domain name list, e.g. ["MNIST", "SYN", "USPS"]
# DS_NAME = [...]
# ds_path = [...]  # training dataset paths for each domain

# Hyperparameters & other settings (adjust as needed)
CONV_DIM = 64
BATCH_SIZE = 64
EPOCHS = 50
IMAGE_SIZE = 28
IMGS_TO_DISPLAY = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for dm_a, dm_b in ab_combinations:
    # 1) Define and initialize generators & discriminators
    gen_ab = Generator(in_channels=3, out_channels=3, conv_dim=CONV_DIM).to(device).train()
    gen_ba = Generator(in_channels=3, out_channels=3, conv_dim=CONV_DIM).to(device).train()

    dis_a = Discriminator(channels=3).to(device).train()
    dis_b = Discriminator(channels=3).to(device).train()

    # 2) Define optimizers for G & D
    g_optim = optim.Adam(
        list(gen_ab.parameters()) + list(gen_ba.parameters()),
        lr=0.0001,
        betas=(0.5, 0.999)
    )
    d_optim = optim.Adam(
        list(dis_a.parameters()) + list(dis_b.parameters()),
        lr=0.0001,
        betas=(0.5, 0.999)
    )

    # 3) Data loader for domain A and B
    # (ImgDomainAdaptationData should yield (a_real, a_label), (b_real, b_label))
    data = ImgDomainAdaptationData(ds_path[dm_a], ds_path[dm_b], dm_a, dm_b, IMAGE_SIZE, IMAGE_SIZE)
    ds_loader = torch.utils.data.DataLoader(
        data,
        batch_size=BATCH_SIZE,
        shuffle=True
    )
    iters_per_epoch = len(ds_loader)

    # 4) Prepare fixed samples for visualization
    loader_iter = iter(ds_loader)
    img_fixed = next(loader_iter)
    a_fixed, b_fixed = img_fixed
    a_fixed, _ = a_fixed
    b_fixed, _ = b_fixed
    a_fixed = a_fixed[:IMGS_TO_DISPLAY].to(device)
    b_fixed = b_fixed[:IMGS_TO_DISPLAY].to(device)

    # 5) Create naming for this experiment (e.g. "SYN_USPS_conv12_batch64")
    current_setting = f'{DS_NAME[dm_a]}_{DS_NAME[dm_b]}_conv{CONV_DIM}_batch{BATCH_SIZE}'
    print(f'Current training setting: {current_setting}')

    # 6) Create output directories
    model_path_exp = os.path.join(model_path, current_setting)
    os.makedirs(model_path_exp, exist_ok=True)

    samples_path_exp = os.path.join(samples_path, current_setting)
    samples_path_ab = os.path.join(samples_path_exp, f'{DS_NAME[dm_a]}_to_{DS_NAME[dm_b]}')
    samples_path_ba = os.path.join(samples_path_exp, f'{DS_NAME[dm_b]}_to_{DS_NAME[dm_a]}')
    os.makedirs(samples_path_exp, exist_ok=True)
    os.makedirs(samples_path_ab, exist_ok=True)
    os.makedirs(samples_path_ba, exist_ok=True)

    # 7) EarlyStopping + lists to track training info
    early_stopping = EarlyStopping(patience=12, verbose=True)

    train_info = []
    # We’ll keep track of generator adversarial loss, cycle loss, total gen loss, disc loss
    # across epochs
    g_adv_loss_per_ep = []
    g_cyc_loss_per_ep = []
    g_loss_per_ep = []
    d_loss_per_ep = []

    # 8) CycleGAN Training loop
    for epoch in range(EPOCHS):
        g_adv_losses = []
        g_cyc_losses = []
        g_losses = []
        d_losses = []

        for batch_idx, batch_data in tqdm(enumerate(ds_loader), total=iters_per_epoch, desc=f'Epoch {epoch+1}'):
            a_data, b_data = batch_data
            a_real, _ = a_data
            b_real, _ = b_data
            a_real, b_real = a_real.to(device), b_real.to(device)

            # -----------------------------
            # (A) Discriminator Training
            # -----------------------------
            b_fake = gen_ab(a_real)
            a_fake = gen_ba(b_real)

            a_real_out = dis_a(a_real)
            a_fake_out = dis_a(a_fake.detach())
            d_a_loss = (torch.mean((a_real_out - 1) ** 2) + torch.mean(a_fake_out ** 2)) / 2

            b_real_out = dis_b(b_real)
            b_fake_out = dis_b(b_fake.detach())
            d_b_loss = (torch.mean((b_real_out - 1) ** 2) + torch.mean(b_fake_out ** 2)) / 2

            d_optim.zero_grad()
            d_loss = d_a_loss + d_b_loss
            d_loss.backward()
            d_optim.step()

            # -----------------------------
            # (B) Generator Training
            # -----------------------------
            a_fake_out = dis_a(a_fake)
            b_fake_out = dis_b(b_fake)

            # Adversarial losses for G
            g_a_adv_loss = torch.mean((a_fake_out - 1) ** 2)
            g_b_adv_loss = torch.mean((b_fake_out - 1) ** 2)
            g_adv_loss = g_a_adv_loss + g_b_adv_loss

            # Cycle consistency
            a_recon = gen_ba(b_fake)
            b_recon = gen_ab(a_fake)
            g_a_cyc_loss = (a_real - a_recon).abs().mean()
            g_b_cyc_loss = (b_real - b_recon).abs().mean()
            g_cyc_loss = g_a_cyc_loss + g_b_cyc_loss

            g_optim.zero_grad()
            g_loss = g_adv_loss + 10.0 * g_cyc_loss
            g_loss.backward()
            g_optim.step()

            # Collect losses
            g_adv_losses.append(g_adv_loss.item())
            g_cyc_losses.append(g_cyc_loss.item())
            g_losses.append(g_loss.item())
            d_losses.append(d_loss.item())

        # -----------------------------
        # End of epoch: Logging & Visualization
        # -----------------------------
        generate_imgs(
            a_fixed, b_fixed,
            gen_ab, gen_ba,
            (samples_path_ab, samples_path_ba),
            DS_NAME[dm_a], DS_NAME[dm_b],
            epoch+1
        )

        avg_g_adv_loss = np.mean(g_adv_losses)
        avg_g_cyc_loss = np.mean(g_cyc_losses)
        avg_g_loss = np.mean(g_losses)
        avg_d_loss = np.mean(d_losses)

        train_info.append([avg_g_adv_loss, avg_g_cyc_loss, avg_g_loss, avg_d_loss])

        print(f"\nEpoch [{epoch+1}/{EPOCHS}] - {current_setting}")
        print(f"Generator:\n  adv_loss: {avg_g_adv_loss:.6f}, cyc_loss: {avg_g_cyc_loss:.6f}, total: {avg_g_loss:.6f}")
        print("Discriminator:")
        print(f"  total_loss: {avg_d_loss:.6f}")
        print("==========================================")

        # Early stopping on cycle loss (or whichever metric you prefer)
        early_stopping(avg_g_cyc_loss, (gen_ab, gen_ba), (dis_a, dis_b))
        if early_stopping.early_stop:
            print(f"Early stopping triggered at epoch {epoch+1}.")
            break

    # Generate final images with epoch=-1 or similar naming
    generate_imgs(
        a_fixed, b_fixed,
        gen_ab, gen_ba,
        (samples_path_ab, samples_path_ba),
        DS_NAME[dm_a], DS_NAME[dm_b],
        -1
    )

    # Save training info to CSV
    df = pd.DataFrame(train_info, columns=['Gen Adv Loss', 'Gen Cyc Loss', 'Gen Total Loss', 'Dis Total Loss'])
    train_info_path = os.path.join('.', 'train_info', current_setting)
    os.makedirs(train_info_path, exist_ok=True)
    df_path = os.path.join(train_info_path, 'train_info.csv')
    df.to_csv(df_path, index=True)
    print(f'Saved training info to {df_path}')

print("All training complete!")