<a href="https://colab.research.google.com/github/wongdongwook/JSAC_MA-DeepSC/blob/main/Evalutation_for_4digits.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from torch.autograd import Function
from torchvision import transforms
import torchvision.utils as vutils
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import random
import time
import copy

In [None]:
from google.colab import drive
drive.mount('/content/drive')
# if you use google TPU, this source code doesn't work. Cuz TPU is not GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using ' + str(device).upper())

Mounted at /content/drive
Using CUDA


# Utils

In [None]:
def generate_imgs(img, n_dms, gen, samples_path, device, step=0): #image visulation & save function
    gen.eval()
    m = img.shape[0]

    lbl = torch.arange(start=-1, end=n_dms)
    lbl = lbl.expand(m, n_dms+1).reshape([-1])
    lbl = lbl.to(device)

    img_ = torch.repeat_interleave(img, n_dms+1, dim=0).to(device)

    real_idx = torch.arange(start=0, end=m*(n_dms+1), step=n_dms+1)
    lbl[real_idx] = 0

    display_imgs = gen(img_, lbl)
    display_imgs[real_idx] = img

    display_imgs_ = vutils.make_grid(
        display_imgs, normalize=True, nrow=n_dms+1, padding=2, pad_value=1)
    vutils.save_image(display_imgs_, os.path.join(
        samples_path, f'sample_epoch_{step}.png'))

    np_image = display_imgs_.cpu().detach().numpy()
    np_image = np.transpose(np_image, (1, 2, 0))
    # Plot using matplotlib
    plt.figure(figsize=(8, 8)) # You can adjust the figure size as needed
    plt.imshow(np_image)
    plt.axis('off')  # Optional: Remove axes for a cleaner look
    plt.show()

# Data Loader

In [None]:
# Paths to image domain datasets
MNIST_train_path = '/content/drive/My Drive/ADJSCC-V_withDA_4digit_classification/dataset/MNIST_train.pt'
MNISTM_train_path = '/content/drive/My Drive/ADJSCC-V_withDA_4digit_classification/dataset/MNISTM_train.pt'
SYN_train_path = '/content/drive/My Drive/ADJSCC-V_withDA_4digit_classification/dataset/SYN_train.pt'
USPS_train_path = '/content/drive/My Drive/ADJSCC-V_withDA_4digit_classification/dataset/USPS_train.pt'

MNIST_test_path = '/content/drive/My Drive/ADJSCC-V_withDA_4digit_classification/dataset/MNIST_test.pt'
MNISTM_test_path = '/content/drive/My Drive/ADJSCC-V_withDA_4digit_classification/dataset/MNISTM_test.pt'
SYN_test_path = '/content/drive/My Drive/ADJSCC-V_withDA_4digit_classification/dataset/SYN_test.pt'
USPS_test_path = '/content/drive/My Drive/ADJSCC-V_withDA_4digit_classification/dataset/USPS_test.pt'


train_path = [MNIST_train_path, MNISTM_train_path, SYN_train_path, USPS_train_path]
test_path = [MNIST_test_path, MNISTM_test_path, SYN_test_path, USPS_test_path]


class ImgData(torch.utils.data.Dataset): # Basic dataset class for a single domain
    def __init__(self, path, w, h):

        self.transform = transforms.Compose([transforms.Resize([w, h]),
                                             transforms.Normalize([0.5], [0.5])])

        self.data = torch.load(path)

        self.img = self.transform(self.data[0])
        self.img = self.pre_processing(self.img)

        self.label = self.data[1]

        self.len = self.label.shape[0]


    def pre_processing(self, img):
        if len(img.shape) < 4:
            img = img.unsqueeze(1).repeat(1, 3, 1, 1)

        return img

    def __len__(self):
        return self.len

    def __getitem__(self, index):

        return self.img[index], self.label[index]


class AllDomainExceptOne(torch.utils.data.Dataset):  # Combine all domains except one

    def __init__(self, path_arr, w, h):

        self.transform = transforms.Compose([transforms.Resize([w, h]),
                                             transforms.Normalize([0.5], [0.5])])

        self.data = []
        self.img = []
        self.label = []
        for path in path_arr:
            data = torch.load(path)
            self.data.append(data)
            self.img.append(self.pre_processing(self.transform(data[0])))
            self.label.append(data[1])

        self.img = torch.vstack(self.img)
        self.label = torch.hstack(self.label)

        self.len = self.label.shape[0]

    def pre_processing(self, img):
        if len(img.shape) < 4:
            img = img.unsqueeze(1).repeat(1, 3, 1, 1)

        return img

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        return self.img[index], self.label[index]


class AllDomainData(torch.utils.data.Dataset): # Dataset for all domains with optional domain labels

    def __init__(self, w, h, is_training=True, DA_task=False):
        # is_training = True if load dataset for training, False if otherwise
        # DA_task = True to also output domain label, False if otherwise

        self.transform = transforms.Compose([transforms.Resize([w, h]),
                                            transforms.Normalize([0.5], [0.5])])

        self.DA_task = DA_task

        if is_training:
            self.MNIST_data = torch.load(MNIST_train_path)
            self.MNISTM_data = torch.load(MNISTM_train_path)
            self.SYN_data = torch.load(SYN_train_path)
            self.USPS_data = torch.load(USPS_train_path)

        else:
            self.MNIST_data = torch.load(MNIST_test_path)
            self.MNISTM_data = torch.load(MNISTM_test_path)
            self.SYN_data = torch.load(SYN_test_path)
            self.USPS_data = torch.load(USPS_test_path)

        self.MNIST_img = self.transform(self.MNIST_data[0])
        self.MNISTM_img = self.transform(self.MNISTM_data[0])
        self.SYN_img = self.transform(self.SYN_data[0])
        self.USPS_img = self.transform(self.USPS_data[0])

        self.MNIST_label = self.MNIST_data[1]
        self.MNISTM_label = self.MNISTM_data[1]
        self.SYN_label = self.SYN_data[1]
        self.USPS_label = self.USPS_data[1]

        if DA_task:
            self.MNIST_img, self.MNIST_domain = self.pre_processing(self.MNIST_img, 0)
            self.MNISTM_img, self.MNISTM_domain = self.pre_processing(self.MNISTM_img, 1)
            self.SYN_img, self.SYN_domain = self.pre_processing(self.SYN_img, 2)
            self.USPS_img, self.USPS_domain = self.pre_processing(self.USPS_img, 3)
        else:
            self.MNIST_img = self.pre_processing(self.MNIST_img)
            self.MNISTM_img = self.pre_processing(self.MNISTM_img)
            self.SYN_img = self.pre_processing(self.SYN_img)
            self.USPS_img = self.pre_processing(self.USPS_img)

        self.img = torch.vstack((self.MNIST_img,
                                 self.MNISTM_img,
                                 self.SYN_img,
                                 self.USPS_img))

        self.label = torch.hstack((self.MNIST_label,
                                   self.MNISTM_label,
                                   self.SYN_label,
                                   self.USPS_label))

        if DA_task:
            self.domain = np.hstack((self.MNIST_domain,
                                     self.MNISTM_domain,
                                     self.SYN_domain,
                                     self.USPS_domain))

        self.len = self.label.shape[0]


    def pre_processing(self, img, domain=None):

        if len(img.shape) < 4:
            img = img.unsqueeze(1).repeat(1, 3, 1, 1)

        if domain is not None:
            num_img = img.shape[0]
            domain_label = np.zeros(num_img, dtype=int) + domain

            return img, domain_label

        return img

    def __len__(self):
        return self.len

    def __getitem__(self, index):

        if self.DA_task:
            return self.img[index], self.label[index], self.domain[index]

        return self.img[index], self.label[index]

# Dada Adaptation  Model

## CycleGAN based DA Module

In [None]:
def CycleGAN_conv_block(c_in, c_out, k_size=4, stride=2, pad=1, use_bn=True, transpose=False):
    module = []
    if transpose:
        module.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, output_padding=pad, bias=not use_bn))
    else:
        module.append(nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=not use_bn))
    if use_bn:
        module.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*module)


class CycleGAN_ResBlock(nn.Module):
    def __init__(self, channels):
        super(CycleGAN_ResBlock, self).__init__()
        self.conv1 = CycleGAN_conv_block(channels, channels, k_size=3, stride=1, pad=1, use_bn=True)
        self.conv2 = CycleGAN_conv_block(channels, channels, k_size=3, stride=1, pad=1, use_bn=True)

    def __call__(self, x):
        x = F.relu(self.conv1(x))
        return x + self.conv2(x)

class CycleGAN_Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, conv_dim=64):
        super(CycleGAN_Generator, self).__init__()
        self.conv1 = CycleGAN_conv_block(in_channels, conv_dim, k_size=5, stride=1, pad=2, use_bn=True)
        self.conv2 = CycleGAN_conv_block(conv_dim, conv_dim * 2, k_size=3, stride=2, pad=1, use_bn=True)
        self.conv3 = CycleGAN_conv_block(conv_dim * 2, conv_dim * 4, k_size=3, stride=2, pad=1, use_bn=True)
        self.res4 = CycleGAN_ResBlock(conv_dim * 4)
        self.tconv5 = CycleGAN_conv_block(conv_dim * 4, conv_dim * 2, k_size=3, stride=2, pad=1, use_bn=True, transpose=True)
        self.tconv6 = CycleGAN_conv_block(conv_dim * 2, conv_dim, k_size=3, stride=2, pad=1, use_bn=True, transpose=True)
        self.conv7 = CycleGAN_conv_block(conv_dim, out_channels, k_size=5, stride=1, pad=2, use_bn=False)

        # Initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.res4(x))
        x = F.relu(self.tconv5(x))
        x = F.relu(self.tconv6(x))
        x = torch.tanh(self.conv7(x))
        return x

## StarGAN based DA Module

In [None]:
def starGAN_conv_block(c_in, c_out, k_size=4, stride=2, pad=1, use_bn=True, transpose=False):
    module = []
    if transpose:
        module.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, bias=not use_bn))
    else:
        module.append(nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=not use_bn))
    if use_bn:
        module.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*module)

class starGAN_ResBlock(nn.Module):
    def __init__(self, channels):
        super(starGAN_ResBlock, self).__init__()
        self.conv1 = starGAN_conv_block(channels, channels, k_size=3, stride=1, pad=1, use_bn=True)
        self.conv2 = starGAN_conv_block(channels, channels, k_size=3, stride=1, pad=1, use_bn=True)

    def __call__(self, x):
        x = F.relu(self.conv1(x))
        return x + self.conv2(x)

# Modified starGAN_Discriminator
class starGAN_Discriminator(nn.Module):
    def __init__(self, channels=3, num_domains=5, num_classes=10, image_size=32, conv_dim=64):
        super(starGAN_Discriminator, self).__init__()
        self.conv1 = starGAN_conv_block(channels, conv_dim, use_bn=False)
        self.conv2 = starGAN_conv_block(conv_dim, conv_dim * 2, use_bn=False)
        self.conv3 = starGAN_conv_block(conv_dim * 2, conv_dim * 4, use_bn=False)
        self.conv4 = starGAN_conv_block(conv_dim * 4, conv_dim * 8, use_bn=False)

        self.gan = starGAN_conv_block(conv_dim * 8, 1, k_size=3, stride=1, pad=1, use_bn=False)
        self.cls = starGAN_conv_block(conv_dim * 8, num_domains, k_size=image_size//16, stride=1, pad=0, use_bn=False)
        self.label_cls = starGAN_conv_block(conv_dim * 8, num_classes, k_size=image_size//16, stride=1, pad=0, use_bn=False)

        # Initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        alpha = 0.01
        x = F.leaky_relu(self.conv1(x), alpha)
        x = F.leaky_relu(self.conv2(x), alpha)
        x = F.leaky_relu(self.conv3(x), alpha)
        x = F.leaky_relu(self.conv4(x), alpha)
        gan_out = self.gan(x)
        cls_out = self.cls(x)
        label_cls_out = self.label_cls(x)

        return gan_out, cls_out.squeeze(), label_cls_out.squeeze()

class starGAN_Generator(nn.Module): # 원본_64
    def __init__(self, in_channels=3, num_domains=5, image_size=32, out_channels=3, conv_dim=64):
        super(starGAN_Generator, self).__init__()
        self.image_size = image_size
        self.embed_layer = nn.Embedding(num_domains, image_size**2)

        self.conv1 = starGAN_conv_block(in_channels+1, conv_dim, k_size=5, stride=1, pad=2, use_bn=True) # 64
        self.conv2 = starGAN_conv_block(conv_dim, conv_dim * 2, k_size=4, stride=2, pad=1, use_bn=True) #  64 * 2
        self.conv3 = starGAN_conv_block(conv_dim * 2, conv_dim * 4, k_size=4, stride=2, pad=1, use_bn=True) # 64 *4 = 256
        self.res4 = starGAN_ResBlock(conv_dim * 4)
        self.res5 = starGAN_ResBlock(conv_dim * 4)
        self.res6 = starGAN_ResBlock(conv_dim * 4)
        self.tconv7 = starGAN_conv_block(conv_dim * 4, conv_dim * 2, k_size=4, stride=2, pad=1, use_bn=True, transpose=True)
        self.tconv8 = starGAN_conv_block(conv_dim * 2, conv_dim, k_size=4, stride=2, pad=1, use_bn=True, transpose=True)
        self.conv9 = starGAN_conv_block(conv_dim, out_channels, k_size=5, stride=1, pad=2, use_bn=False)

        # Initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, target_dm=None):
        if target_dm is None:
            target_dm = torch.ones(x.shape[0])
        target_dm = target_dm.long()
        embed = self.embed_layer(target_dm).reshape([-1, 1, self.image_size, self.image_size])
        x = torch.cat((x, embed), dim=1)
        #[3, 32, 32] + [1, 32, 32] = [4, 32, 32]
        x = F.relu(self.conv1(x)) # [64,32,32]
        x = F.relu(self.conv2(x))  # [128,16,16]
        x = F.relu(self.conv3(x))# [256, 8, 8]
        x = F.relu(self.res4(x))  # [256, 8, 8]
        x = F.relu(self.res5(x))  # [256, 8, 8]
        x = F.relu(self.res6(x))  # [256, 8, 8]
        x = F.relu(self.tconv7(x)) # [128,16,16]
        x = F.relu(self.tconv8(x))  # [64,32,32]
        x = torch.tanh(self.conv9(x))  # [3,32,32]
        return x

# MASCN Transceiver

## DeepJSCC-V Model

In [None]:
class LowerBound(Function):
    @staticmethod
    def forward(ctx, inputs, bound):
        b = torch.ones_like(inputs) * bound
        ctx.save_for_backward(inputs, b)
        return torch.max(inputs, b)

    @staticmethod
    def backward(ctx, grad_output):
        inputs, b = ctx.saved_tensors
        pass_through_1 = inputs >= b
        pass_through_2 = grad_output < 0

        pass_through = pass_through_1 | pass_through_2
        return pass_through.type(grad_output.dtype) * grad_output, None

class GDN(nn.Module):
    """Generalized divisive normalization layer.
    y[i] = x[i] / sqrt(beta[i] + sum_j(gamma[j, i] * x[j]))
    """

    def __init__(self,
                 ch,
                 inverse=False,
                 beta_min=1e-6,
                 gamma_init=0.1,
                 reparam_offset=2**-18):
        super(GDN, self).__init__()
        self.inverse = inverse
        self.beta_min = beta_min
        self.gamma_init = gamma_init
        self.reparam_offset = reparam_offset

        self.build(ch)

    def build(self, ch):
        self.pedestal = self.reparam_offset**2
        self.beta_bound = ((self.beta_min + self.reparam_offset**2)**0.5)
        self.gamma_bound = self.reparam_offset

        # Create beta param
        beta = torch.sqrt(torch.ones(ch)+self.pedestal)
        self.beta = nn.Parameter(beta)

        # Create gamma param
        eye = torch.eye(ch)
        g = self.gamma_init*eye
        g = g + self.pedestal
        gamma = torch.sqrt(g)

        self.gamma = nn.Parameter(gamma)
        self.pedestal = self.pedestal

    def forward(self, inputs):
        unfold = False
        if inputs.dim() == 5:
            unfold = True
            bs, ch, d, w, h = inputs.size()
            inputs = inputs.view(bs, ch, d*w, h)

        _, ch, _, _ = inputs.size()

        # Beta bound and reparam
        beta = LowerBound.apply(self.beta, self.beta_bound)
        beta = beta**2 - self.pedestal

        # Gamma bound and reparam
        gamma = LowerBound.apply(self.gamma, self.gamma_bound)
        gamma = gamma**2 - self.pedestal
        gamma = gamma.view(ch, ch, 1, 1)

        # Norm pool calc
        norm_ = nn.functional.conv2d(inputs**2, gamma, beta)
        norm_ = torch.sqrt(norm_)

        # Apply norm
        if self.inverse:
            outputs = inputs * norm_
        else:
            outputs = inputs / norm_

        if unfold:
            outputs = outputs.view(bs, ch, d, w, h)
        return outputs

def conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)

def deconv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, output_padding = 0):
    return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding = output_padding,bias=False)


class conv_block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(conv_block, self).__init__()
        self.conv = conv(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.gdn = nn.GDN(out_channels)
        self.prelu = nn.PReLU()
    def forward(self, x):
        out = self.conv(x)
        out = self.gdn(out)
        out = self.prelu(out)
        return out

class deconv_block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, output_padding = 0):
        super(deconv_block, self).__init__()
        self.deconv = deconv(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,  output_padding = output_padding)
        self.gdn = nn.GDN(out_channels)
        self.prelu = nn.PReLU()
        self.sigmoid = nn.Sigmoid()
    def forward(self, x, activate_func='prelu'):
        out = self.deconv(x)
        out = self.gdn(out)
        if activate_func=='prelu':
            out = self.prelu(out)
        elif activate_func=='sigmoid':
            out = self.sigmoid(out)
        return out

class conv_ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_conv1x1=False, kernel_size=3, stride=1, padding=1):
        super(conv_ResBlock, self).__init__()
        self.conv1 = conv(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.conv2 = conv(out_channels, out_channels, kernel_size=1, stride = 1, padding=0)
        self.gdn1 = GDN(out_channels)
        self.gdn2 = GDN(out_channels)
        self.prelu = nn.PReLU()
        self.use_conv1x1 = use_conv1x1
        if use_conv1x1 == True:
            self.conv3 = conv(in_channels, out_channels, kernel_size=1, stride=stride, padding=0)
    def forward(self, x):
        out = self.conv1(x)
        out = self.gdn1(out)
        out = self.prelu(out)
        out = self.conv2(out)
        out = self.gdn2(out)
        if self.use_conv1x1 == True:
            x = self.conv3(x)
        out = out+x
        out = self.prelu(out)
        return out

class deconv_ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_deconv1x1=False, kernel_size=3, stride=1, padding=1, output_padding=0):
        super(deconv_ResBlock, self).__init__()
        self.deconv1 = deconv(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding)
        self.deconv2 = deconv(out_channels, out_channels, kernel_size=1, stride = 1, padding=0, output_padding=0)
        self.gdn1 = GDN(out_channels)
        self.gdn2 = GDN(out_channels)
        self.prelu = nn.PReLU()
        self.sigmoid = nn.Sigmoid()
        self.use_deconv1x1 = use_deconv1x1
        if use_deconv1x1 == True:
            self.deconv3 = deconv(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, output_padding=output_padding)
    def forward(self, x, activate_func='prelu'):
        out = self.deconv1(x)
        out = self.gdn1(out)
        out = self.prelu(out)
        out = self.deconv2(out)
        out = self.gdn2(out)
        if self.use_deconv1x1 == True:
            x = self.deconv3(x)
        out = out+x
        if activate_func=='prelu':
            out = self.prelu(out)
        elif activate_func=='sigmoid':
            out = self.sigmoid(out)
        return out

# Original Existing Works
class AF_block(nn.Module):
    def __init__(self, Nin, Nh, No):
        super(AF_block, self).__init__()
        self.fc1 = nn.Linear(Nin+1, Nh)
        self.fc2 = nn.Linear(Nh, No)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    def forward(self, x, snr):
        # out = F.adaptive_avg_pool2d(x, (1,1))
        # out = torch.squeeze(out)
        # out = torch.cat((out, snr), 1)
        if snr.shape[0]>1:
            snr = snr.squeeze()
        snr = snr.unsqueeze(1)
        mu = torch.mean(x, (2, 3))
        out = torch.cat((mu, snr), 1)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        out = out.unsqueeze(2)
        out = out.unsqueeze(3)
        out = out*x
        return out

# The Encoder model with attention feature blocks
class Encoder(nn.Module):
    def __init__(self, enc_shape, kernel_sz, Nc_conv):
        super(Encoder, self).__init__()
        enc_N = enc_shape[0]
        Nh_AF = Nc_conv//2
        padding_L = (kernel_sz-1)//2
        self.conv1 = conv_ResBlock(3, Nc_conv, use_conv1x1=True, kernel_size = kernel_sz, stride = 2, padding=padding_L) # 데이터셋의 채널수에따라서 1이 3으로 바뀌고 3이 1로 바뀔 수 있음
        self.conv2 = conv_ResBlock(Nc_conv, Nc_conv, use_conv1x1=True, kernel_size = kernel_sz, stride = 2, padding=padding_L)
        self.conv3 = conv_ResBlock(Nc_conv, Nc_conv, kernel_size = kernel_sz, stride = 1, padding=padding_L)
        self.conv4 = conv_ResBlock(Nc_conv, Nc_conv, kernel_size = kernel_sz, stride = 1, padding=padding_L)
        self.conv5 = conv_ResBlock(Nc_conv, enc_N, use_conv1x1=True, kernel_size = kernel_sz, stride = 1, padding=padding_L)
        self.AF1 = AF_block(Nc_conv, Nh_AF, Nc_conv)
        self.AF2 = AF_block(Nc_conv, Nh_AF, Nc_conv)
        self.AF3 = AF_block(Nc_conv, Nh_AF, Nc_conv)
        self.AF4 = AF_block(Nc_conv, Nh_AF, Nc_conv)
        self.AF5 = AF_block(enc_N, enc_N//2, enc_N)
        self.flatten = nn.Flatten()
    def forward(self, x, snr):
        #snr = snr.view(-1, 1)
        out = self.conv1(x)
        out = self.AF1(out, snr)
        out = self.conv2(out)
        out = self.AF2(out, snr)
        out = self.conv3(out)
        out = self.AF3(out, snr)
        out = self.conv4(out)
        out = self.AF4(out, snr)
        out = self.conv5(out)
        out = self.AF5(out, snr)
        out = self.flatten(out)
        return out

# The Decoder model with attention feature blocks
class Decoder(nn.Module):
    def __init__(self, enc_shape, kernel_sz, Nc_deconv):
        super(Decoder, self).__init__()
        self.enc_shape = enc_shape
        Nh_AF1 = enc_shape[0]//2
        Nh_AF = Nc_deconv//2
        padding_L = (kernel_sz-1)//2
        self.deconv1 = deconv_ResBlock(self.enc_shape[0], Nc_deconv, use_deconv1x1=True, kernel_size = kernel_sz, stride = 2,  padding=padding_L, output_padding = 1)
        self.deconv2 = deconv_ResBlock(Nc_deconv, Nc_deconv, use_deconv1x1=True, kernel_size = kernel_sz, stride = 2,  padding=padding_L, output_padding = 1)
        self.deconv3 = deconv_ResBlock(Nc_deconv, Nc_deconv, kernel_size=kernel_sz, stride=1, padding=padding_L)
        self.deconv4 = deconv_ResBlock(Nc_deconv, Nc_deconv, kernel_size=kernel_sz, stride=1, padding=padding_L)
        self.deconv5 = deconv_ResBlock(Nc_deconv, 3, use_deconv1x1=True, kernel_size=kernel_sz, stride=1, padding=padding_L)

        self.AF1 = AF_block(self.enc_shape[0], Nh_AF1, self.enc_shape[0])
        self.AF2 = AF_block(Nc_deconv, Nh_AF, Nc_deconv)
        self.AF3 = AF_block(Nc_deconv, Nh_AF, Nc_deconv)
        self.AF4 = AF_block(Nc_deconv, Nh_AF, Nc_deconv)
        self.AF5 = AF_block(Nc_deconv, Nh_AF, Nc_deconv)
    def forward(self, x, snr):
        #snr = snr.view(-1, 1)
        out = x.view(-1, self.enc_shape[0], self.enc_shape[1], self.enc_shape[2])
        out = self.AF1(out, snr)
        out = self.deconv1(out)
        out = self.AF2(out, snr)
        out = self.deconv2(out)
        out = self.AF3(out, snr)
        out = self.deconv3(out)
        out = self.AF4(out, snr)
        out = self.deconv4(out)
        out = self.AF5(out, snr)
        out = self.deconv5(out, 'sigmoid')
        return out

# Power normalization before transmission
# Note: if P = 1, the symbol power is 2
# If you want to set the average power as 1, please change P as P=1/np.sqrt(2)
def Power_norm(z, P = 1):
    batch_size, z_dim = z.shape
    z_power = torch.sqrt(torch.sum(z**2, 1))
    z_M = z_power.repeat(z_dim, 1)
    return np.sqrt(P*z_dim)*z/z_M.t()

def Power_norm_complex(z, P = 1):
    batch_size, z_dim = z.shape
    z_com = torch.complex(z[:, 0:z_dim:2], z[:, 1:z_dim:2])
    z_com_conj = torch.complex(z[:, 0:z_dim:2], -z[:, 1:z_dim:2])
    z_power = torch.sum(z_com*z_com_conj, 1).real
    z_M = z_power.repeat(z_dim//2, 1)
    z_nlz = np.sqrt(P*z_dim)*z_com/torch.sqrt(z_M.t())
    z_out = torch.zeros(batch_size, z_dim).cuda()
    z_out[:, 0:z_dim:2] = z_nlz.real
    z_out[:, 1:z_dim:2] = z_nlz.imag
    return z_out

# The (real) AWGN channel
def AWGN_channel(x, snr, P = 2):
    batch_size, length = x.shape
    gamma = 10 ** (snr / 10.0)
    noise = torch.sqrt(P/gamma)*torch.randn(batch_size, length).cuda()
    y = x+noise
    return y

def AWGN_complex(x, snr, Ps = 1):
    batch_size, length = x.shape
    gamma = 10 ** (snr / 10.0)
    n_I = torch.sqrt(Ps/gamma)*torch.randn(batch_size, length).cuda()
    n_R = torch.sqrt(Ps/gamma)*torch.randn(batch_size, length).cuda()
    noise = torch.complex(n_I, n_R)
    y = x + noise
    return y

# Please set the symbol power if it is not a default value
def Fading_channel(x, snr, P = 2):
    gamma = 10 ** (snr / 10.0)
    [batch_size, feature_length] = x.shape
    K = feature_length//2

    h_I = torch.randn(batch_size, K).cuda()
    h_R = torch.randn(batch_size, K).cuda()
    h_com = torch.complex(h_I, h_R)
    x_com = torch.complex(x[:, 0:feature_length:2], x[:, 1:feature_length:2])
    y_com = h_com*x_com

    n_I = torch.sqrt(P/gamma)*torch.randn(batch_size, K).cuda()
    n_R = torch.sqrt(P/gamma)*torch.randn(batch_size, K).cuda()
    noise = torch.complex(n_I, n_R)

    y_add = y_com + noise
    y = y_add/h_com

    y_out = torch.zeros(batch_size, feature_length).cuda()
    y_out[:, 0:feature_length:2] = y.real
    y_out[:, 1:feature_length:2] = y.imag
    return y_out

# Note: if P = 1, the symbol power is 2
# If you want to set the average power as 1, please change P as P=1/np.sqrt(2)
def Power_norm_VLC(z, cr, P = 1):
    batch_size, z_dim = z.shape
    Kv = torch.ceil(z_dim*cr).int()
    z_power = torch.sqrt(torch.sum(z**2, 1))
    z_M = z_power.repeat(z_dim, 1).cuda()
    return torch.sqrt(Kv*P)*z/z_M.t()

def AWGN_channel_VLC(x, snr, cr, P = 2):
    batch_size, length = x.shape
    gamma = 10 ** (snr / 10.0)
    mask = mask_gen(length, cr).cuda()
    noise = torch.sqrt(P/gamma)*torch.randn(1, length).cuda()
    noise = noise*mask
    y = x+noise
    return y

def Fading_channel_VLC(x, snr, cr, P = 2):
    gamma = 10 ** (snr / 10.0)
    [batch_size, feature_length] = x.shape
    K = feature_length//2

    mask = mask_gen(K, cr).cuda()
    h_I = torch.randn(batch_size, K).cuda()
    h_R = torch.randn(batch_size, K).cuda()
    h_com = torch.complex(h_I, h_R)
    x_com = torch.complex(x[:, 0:feature_length:2], x[:, 1:feature_length:2])
    y_com = h_com*x_com

    n_I = torch.sqrt(P/gamma)*torch.randn(batch_size, K).cuda()
    n_R = torch.sqrt(P/gamma)*torch.randn(batch_size, K).cuda()
    noise = torch.complex(n_I, n_R)*mask

    y_add = y_com + noise
    y = y_add/h_com

    y_out = torch.zeros(batch_size, feature_length).cuda()
    y_out[:, 0:feature_length:2] = y.real
    y_out[:, 1:feature_length:2] = y.imag
    return y_out

def Channel(z, snr, channel_type = 'AWGN'):
    z = Power_norm(z)
    if channel_type == 'AWGN':
        z = AWGN_channel(z, snr)
    elif channel_type == 'Fading':
        z = Fading_channel(z, snr)
    return z

def Channel_VLC(z, snr, cr, channel_type = 'AWGN'):
    z = Power_norm_VLC(z, cr)
    if channel_type == 'AWGN':
        z = AWGN_channel_VLC(z, snr, cr)
    elif channel_type == 'Fading':
        z = Fading_channel_VLC(z, snr, cr)
    return z

def mask_gen(N, cr, ch_max = 48):
    MASK = torch.zeros(cr.shape[0], N).int()
    nc = N//ch_max
    for i in range(0, cr.shape[0]):
        L_i = nc*torch.round(ch_max*cr[i]).int()
        MASK[i, 0:L_i] = 1
    return MASK

class ADJSCC(nn.Module):
    def __init__(self, enc_shape, Kernel_sz, Nc):
        super(ADJSCC, self).__init__()
        self.encoder = Encoder(enc_shape, Kernel_sz, Nc)
        self.decoder = Decoder(enc_shape, Kernel_sz, Nc)
    def forward(self, x, snr, channel_type = 'AWGN'):
        z = self.encoder(x, snr)
        z = Channel(z, snr, channel_type)
        out = self.decoder(z, snr)
        return out

# The DeepJSCC_V model, also called ADJSCC_V
class ADJSCC_V(nn.Module):
    def __init__(self, enc_shape, Kernel_sz, Nc):
        super(ADJSCC_V, self).__init__()
        self.encoder = Encoder(enc_shape, Kernel_sz, Nc)
        self.decoder = Decoder(enc_shape, Kernel_sz, Nc)
    def forward(self, x, snr, cr, channel_type = 'AWGN'):
        z = self.encoder(x, snr)
        z = z*mask_gen(z.shape[1], cr).cuda()
        z = Channel_VLC(z, snr, cr, channel_type)
        out = self.decoder(z, snr)
        return out

## MASCN Transceiver

In [None]:
class Generator(nn.Module): # MASCN_generative DeepSC
    def __init__(self, in_channels=3, num_domains=4, image_size=32, out_channels=3, conv_dim=64, CR=1):
        super(Generator, self).__init__()
        self.image_size = image_size
        self.embed_layer = nn.Embedding(num_domains, image_size**2)
        self.CR=CR
        conv_dim_btl = int(conv_dim * 4 *CR)
        self.conv_dim=conv_dim


        # Encoder
        self.conv1 = conv_block(in_channels+1, conv_dim, k_size=5, stride=1, pad=2, use_bn=True) # 64
        self.conv2 = conv_block(conv_dim, conv_dim * 2, k_size=4, stride=2, pad=1, use_bn=True) #  64*2
        self.conv3 = conv_block(conv_dim * 2, conv_dim * 4, k_size=4, stride=2, pad=1, use_bn=True) # 64*4 = 256 밑에도 수정일어남
        self.res4 = ResBlock(conv_dim * 4)
        self.AF4= AF_block(conv_dim * 4, conv_dim*4 //2, conv_dim * 4)
        self.res5 = ResBlock(conv_dim * 4)
        self.AF5= AF_block(conv_dim * 4, conv_dim*4 //2, conv_dim * 4)
        self.conv5 = conv_block(conv_dim * 4, conv_dim_btl , k_size=3, stride=1, pad=1, use_bn=True)
        #self.AF_c_5= AF_block(conv_dim_btl, conv_dim_btl//2, conv_dim_btl)

        # Decoder
        #self.AF_c_6= AF_block(conv_dim_btl, conv_dim_btl//2, conv_dim_btl)
        self.conv6 = conv_block(conv_dim_btl, conv_dim * 4, k_size=3, stride=1, pad=1, use_bn=True)
        self.AF6= AF_block(conv_dim * 4, conv_dim*4 //2, conv_dim * 4)
        self.res6 = ResBlock(conv_dim * 4)
        self.AF7= AF_block(conv_dim * 4, conv_dim*4 //2, conv_dim * 4)
        self.res7 = ResBlock(conv_dim * 4)
        self.tconv7 = conv_block(conv_dim * 4, conv_dim * 2, k_size=4, stride=2, pad=1, use_bn=True, transpose=True)
        self.tconv8 = conv_block(conv_dim * 2, conv_dim, k_size=4, stride=2, pad=1, use_bn=True, transpose=True)
        self.conv9 = conv_block(conv_dim, out_channels, k_size=5, stride=1, pad=2, use_bn=False)
        self.flatten = nn.Flatten()

        # Initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, channel_type, snr,  target_dm=None):
        if target_dm is None:
            target_dm = torch.ones(x.shape[0])
        target_dm = target_dm.long()
        embed = self.embed_layer(target_dm).reshape([-1, 1, self.image_size, self.image_size])
        x = torch.cat((x, embed), dim=1)
        x = F.relu(self.conv1(x)) # [12,32,32]
        x = F.relu(self.conv2(x))  # [24,16,16]
        x = F.relu(self.conv3(x))# [48, 8, 8]
        x = F.relu(self.res4(x))  # [48, 8, 8]
        x = F.relu(self.AF4(x, snr))
        x = F.relu(self.res5(x))  # [48, 8, 8]
        x = F.relu(self.AF5(x, snr))
        x = F.relu(self.conv5(x))
        #x = F.relu(self.AF_c_5(x, snr))
        x = self.flatten(x)

        x = Channel(x, snr, channel_type)
        # 1. binary masking operation (It dosen't work well.. ) has some problem. but, output size of encoder(conv_dim) is lower  -> works well
        # 2. AWGN channel -> channel coding layer.
        out_size= int(self.conv_dim*4 *self.CR)
        x = x.view(-1, out_size, 8, 8) # dependent code
        #x = F.relu(self.AF_c_6(x, snr))
        x = F.relu(self.conv6(x))
        x = F.relu(self.AF6(x, snr))
        x = F.relu(self.res6(x))  # [48, 8, 8]
        x = F.relu(self.AF7(x, snr))
        x = F.relu(self.res7(x))  # [48, 8, 8]

        x = F.relu(self.tconv7(x)) # [24,16,16]
        x = F.relu(self.tconv8(x))  # [12,32,32]
        x = torch.tanh(self.conv9(x))  # [3,32,32]
        return x


class Discriminator(nn.Module):
    def __init__(self, channels=3, num_domains=4, num_classes=10, image_size=32, conv_dim=64):
        super(Discriminator, self).__init__()
        self.conv1 = conv_block(channels, conv_dim, use_bn=False)
        self.conv2 = conv_block(conv_dim, conv_dim * 2, use_bn=False)
        self.conv3 = conv_block(conv_dim * 2, conv_dim * 4, use_bn=False)
        self.conv4 = conv_block(conv_dim * 4, conv_dim * 8, use_bn=False)

        self.gan = conv_block(conv_dim * 8, 1, k_size=3, stride=1, pad=1, use_bn=False)
        self.cls = conv_block(conv_dim * 8, num_domains, k_size=image_size//16, stride=1, pad=0, use_bn=False)
        self.label_cls = conv_block(conv_dim * 8, num_classes, k_size=image_size//16, stride=1, pad=0, use_bn=False)

        # Initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        alpha = 0.01
        x = F.leaky_relu(self.conv1(x), alpha)
        x = F.leaky_relu(self.conv2(x), alpha)
        x = F.leaky_relu(self.conv3(x), alpha)
        x = F.leaky_relu(self.conv4(x), alpha)
        gan_out = self.gan(x)
        cls_out = self.cls(x)
        label_cls_out = self.label_cls(x)

        return gan_out, cls_out.squeeze(), label_cls_out.squeeze()

# Evaluation

## General Settings

In [None]:
root = '/content/drive/My Drive/ADJSCC-V_withDA_4digit_classification'
model_path = os.path.join(root, 'model')
classifier_path = os.path.join(model_path, 'classifier')

DS_NAME = ["MNIST", "MNISTM", "SYN", "USPS"]

BATCH_SIZE = 128
IMGS_TO_DISPLAY = 25

CHANNEL = 'AWGN'  # Choose AWGN or Fading
N_CHANNELS = 256
KERNEL_SIZE = 5

IMAGE_SIZE = 32

enc_out_shape = [48, IMAGE_SIZE//4, IMAGE_SIZE//4]

## Upper Bound (DeepJSSC-V with retrained)


In [None]:
# --- Upper Bound Setting ---
# DeepJSCC-V retrained with full supervision on the target domain (e.g., SYN).
# This simulates the optimal case where domain-specific knowledge is fully available.
# Used to define the upper bound for performance comparison.

import csv
import os
import torch
import numpy as np
from tqdm import tqdm

# Evaluation settings
CR_test_range = np.arange(0.1, 1, 0.1)
SNR_test_range = np.array([3, 10, 18])

dataset = []
ds_loader = []

# Load datasets for each domain
for domain_id in range(len(DS_NAME)):
    dataset.append(
        ImgData(f'/content/drive/My Drive/ADJSCC-V_withDA_4digit_classification/dataset/{DS_NAME[domain_id]}_test.pt',
                IMAGE_SIZE, IMAGE_SIZE)
    )
    ds_loader.append(torch.utils.data.DataLoader(dataset[domain_id],
                                                 batch_size=BATCH_SIZE,
                                                 shuffle=False))

total_iter = 0
for SNR_test in SNR_test_range:
    for CR_test in CR_test_range:
        for domain_id in range(len(DS_NAME)):

            current_setting = f'upper bound eval with SNR = {SNR_test} and CR = {CR_test} for {DS_NAME[domain_id]}'
            print(f'\nCurrent eval setting: {current_setting}')
            print("\n=================================================================")

            # Load ADJSCC-V model trained specifically on the preperation stage domain
            ADJSCCV_path = os.path.join(model_path, 'ADJSCCV', f'{DS_NAME[domain_id]}')
            ADJSCCV_name = f'JSCC-V_{DS_NAME[domain_id]}.pt'
            ADJSCCV = ADJSCC_V(enc_out_shape, KERNEL_SIZE, N_CHANNELS)
            ADJSCCV.load_state_dict(torch.load(os.path.join(ADJSCCV_path, ADJSCCV_name)))
            ADJSCCV.to(device).eval()

            reconstructed_images = []

            for _, (src_img, target_label) in tqdm(enumerate(ds_loader[domain_id]), total=len(ds_loader[domain_id])):
                total_iter += 1

                src_img = src_img.to(device)

                SNR = (SNR_test * torch.ones(src_img.shape[0], 1)).to(device)
                CR = (CR_test * torch.ones(src_img.shape[0], 1)).to(device)

                # Reconstruct image using semantic encoder-decoder
                recon_img = (src_img + 1) / 2  # [-1,1] → [0,1]
                recon_img = ADJSCCV(recon_img, SNR, CR, CHANNEL)  # [0,1]
                recon_img = (recon_img - 0.5) / 0.5  # 다시 [-1,1]

                reconstructed_images.append(recon_img.detach().cpu())

            # Optional: Merge all tensors for saving or post-processing
            reconstructed_images = torch.cat(reconstructed_images, dim=0)

            # Optional: save the first 25 images as a grid (disabled)
            # torchvision.utils.save_image(reconstructed_images[:25], f'recon_{DS_NAME[domain_id]}_snr{SNR_test}_cr{CR_test}.png', nrow=5, normalize=True)

print('All done! (Classifier removed version)')


## Lower Bound (DeepJSSC-V w/o retrained)


In [None]:
# --- Lower Bound Setting ---
# DeepJSCC-V trained only on the source domain (e.g., MNIST).
# No domain adaptation is applied to the target domain (e.g., SYN).
# This represents the lower bound performance baseline.

import csv
import os
import torch
import numpy as np
from tqdm import tqdm

# === Lower Bound Evaluation ===
# DeepJSCC-V trained only on other domains (source domains),
# and evaluated on the target domain without domain adaptation.

CR_test_range = np.arange(0.1, 1, 0.1)
SNR_test_range = np.array([3, 10, 18])

# Optional backup log (can be removed if image saving is not needed)
with open(os.path.join(root, 'eval_info', 'Lower_bound', 'lower_bound.csv'), 'w') as f:
    writer = csv.writer(f)
    writer.writerow(['Domain', 'SNR', 'CR'])

dataset = []
ds_loader = []

# For each target domain, exclude it from the training set and build dataset from other domains
for domain_id in range(len(train_path)):
    subset = [path for i, path in enumerate(train_path) if i != domain_id]
    dataset.append(AllDomainExceptOne(subset, IMAGE_SIZE, IMAGE_SIZE))
    ds_loader.append(torch.utils.data.DataLoader(dataset[domain_id],
                                                 batch_size=BATCH_SIZE,
                                                 shuffle=False))

total_iter = 0
for SNR_test in SNR_test_range:
    for CR_test in CR_test_range:
        for domain_id in range(len(DS_NAME)):

            current_setting = f'lower bound eval with SNR = {SNR_test} and CR = {CR_test} for {DS_NAME[domain_id]}'
            print(f'\nCurrent eval setting: {current_setting}')
            print("\n=================================================================")

            # Load ADJSCC-V model trained specifically on the preperation stage domain
            ADJSCCV_path = os.path.join(model_path, 'ADJSCCV', f'{DS_NAME[domain_id]}')
            ADJSCCV_name = f'JSCC-V_{DS_NAME[domain_id]}.pt'
            ADJSCCV = ADJSCC_V(enc_out_shape, KERNEL_SIZE, N_CHANNELS)
            ADJSCCV.load_state_dict(torch.load(os.path.join(ADJSCCV_path, ADJSCCV_name)))
            ADJSCCV.to(device).eval()

            reconstructed_images = []

            for _, (src_img, target_label) in tqdm(enumerate(ds_loader[domain_id]), total=len(ds_loader[domain_id])):
                total_iter += 1
                src_img = src_img.to(device)

                SNR = (SNR_test * torch.ones(src_img.shape[0], 1)).to(device)
                CR = (CR_test * torch.ones(src_img.shape[0], 1)).to(device)

                # Image reconstruction via the semantic transceiver
                recon_img = (src_img + 1) / 2  # [-1,1] → [0,1]
                recon_img = ADJSCCV(recon_img, SNR, CR, CHANNEL)  # [0,1]
                recon_img = (recon_img - 0.5) / 0.5  # [0,1] → [-1,1]

                reconstructed_images.append(recon_img.detach().cpu())

            reconstructed_images = torch.cat(reconstructed_images, dim=0)

            # Optional: save images
            # torchvision.utils.save_image(reconstructed_images[:25], f'recon_LB_{DS_NAME[domain_id]}_snr{SNR_test}_cr{CR_test}.png', nrow=5, normalize=True)

            # Log the evaluation configuration (no accuracy logged)
            with open(os.path.join(root, 'eval_info', 'Lower_bound', 'lower_bound.csv'), 'a') as f:
                writer = csv.writer(f)
                writer.writerow([f'{DS_NAME[domain_id]}',
                                 SNR_test,
                                 CR_test])

print('All done! (Lower Bound Classifier Removed)')



## MDAN Model

In [None]:
import csv
import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm

# === Path Configuration ===
starGAN_path = os.path.join(model_path, 'starGAN_w_clf')
samples_path = os.path.join(root, 'samples', 'starGAN_w_clf')
os.makedirs(samples_path, exist_ok=True)

# === Evaluation Settings ===
CR_test_range = np.arange(0.1, 1, 0.1)
SNR_test_range = np.array([3, 10, 18])

total_iter = 0
for SNR_test in SNR_test_range:
    for CR_test in CR_test_range:
        for domain_id in range(len(DS_NAME)):

            current_setting = f'starGAN eval with SNR = {SNR_test} and CR = {CR_test} for {DS_NAME[domain_id]}'
            print(f'\nCurrent eval setting: {current_setting}')
            print("\n=================================================================")

            # Load StarGAN Generator
            starGAN_name = 'gen_100.pkl'
            starGAN = starGAN_Generator(num_domains=len(DS_NAME), image_size=IMAGE_SIZE)
            starGAN.load_state_dict(torch.load(os.path.join(starGAN_path, starGAN_name)))
            starGAN.to(device).eval()

            # Load ADJSCC-V for Semantic Communication
            ADJSCCV_path = os.path.join(model_path, 'ADJSCCV', f'{DS_NAME[domain_id]}')
            ADJSCCV_name = f'JSCC-V_{DS_NAME[domain_id]}.pt'
            ADJSCCV = ADJSCC_V(enc_out_shape, KERNEL_SIZE, N_CHANNELS)
            ADJSCCV.load_state_dict(torch.load(os.path.join(ADJSCCV_path, ADJSCCV_name)))
            ADJSCCV.to(device).eval()

            # Load evaluation dataset (without domain labels)
            dataset = AllDomainData(IMAGE_SIZE, IMAGE_SIZE, is_training=False, DA_task=False)
            ds_loader = torch.utils.data.DataLoader(dataset,
                                                    batch_size=BATCH_SIZE,
                                                    shuffle=True)

            for _, (src_img, _) in tqdm(enumerate(ds_loader), total=len(ds_loader)):
                total_iter += 1

                src_img = src_img.to(device)
                target_domain = torch.full((src_img.shape[0],), domain_id).long().to(device)

                # StarGAN domain adaptation
                DA_img = starGAN(src_img, target_domain)  # [-1, 1]

                # Semantic transmission (ADJSCC-V)
                SNR = torch.full((DA_img.shape[0], 1), SNR_test).to(device)
                CR = torch.full((DA_img.shape[0], 1), CR_test).to(device)

                DA_img_rec = (DA_img + 1) / 2   # [-1,1] → [0,1]
                DA_img_rec = ADJSCCV(DA_img_rec, SNR, CR, CHANNEL)  # [0,1]
                DA_img_rec = (DA_img_rec - 0.5) / 0.5  # [0,1] → [-1,1]

                # Save visualization every 100 steps
                if total_iter % 100 == 0:
                    generate_imgs(src_img[:IMGS_TO_DISPLAY], len(DS_NAME), starGAN, samples_path, device, total_iter)

print('StarGAN evaluation complete (classifier removed).')


## CycleGAN based DA Model

In [None]:
import os
import torch
import numpy as np
from tqdm import tqdm

# === Path Configuration ===
cycleGAN_path = os.path.join(model_path, 'cycleGAN')

# === Evaluation Settings ===
CR_test_range = np.arange(0.1, 1, 0.1)
SNR_test_range = np.array([3, 10, 18])

for SNR_test in SNR_test_range:
    for CR_test in CR_test_range:
        for dst_domain in range(len(DS_NAME)):

            current_setting = f'cycleGAN eval with SNR = {SNR_test} and CR = {CR_test} for {DS_NAME[dst_domain]}'
            print(f'\nCurrent eval setting: {current_setting}')
            print("\n=================================================================")

            # Load ADJSCC-V model
            ADJSCCV_path = os.path.join(model_path, 'ADJSCCV', f'{DS_NAME[dst_domain]}')
            ADJSCCV_name = f'JSCC-V_{DS_NAME[dst_domain]}.pt'
            ADJSCCV = ADJSCC_V(enc_out_shape, KERNEL_SIZE, N_CHANNELS)
            ADJSCCV.load_state_dict(torch.load(os.path.join(ADJSCCV_path, ADJSCCV_name)))
            ADJSCCV.to(device).eval()

            # Loop over all source domains except the current target
            for src_domain in [i for i in range(len(DS_NAME)) if i != dst_domain]:

                print(f'Source domain: {DS_NAME[src_domain]} --> Destination domain: {DS_NAME[dst_domain]}')

                # Load CycleGAN Generator
                cycleGAN_name = f'gen_{DS_NAME[src_domain]}_{DS_NAME[dst_domain]}.pkl'
                cycleGAN = CycleGAN_Generator(in_channels=3, out_channels=3, conv_dim=12)
                cycleGAN.load_state_dict(torch.load(os.path.join(cycleGAN_path, cycleGAN_name)))
                cycleGAN.to(device).eval()

                # Load dataset
                src_dataset = ImgData(os.path.join(root, 'dataset', f'{DS_NAME[src_domain]}_test.pt'),
                                      IMAGE_SIZE, IMAGE_SIZE)
                ds_loader = torch.utils.data.DataLoader(src_dataset,
                                                        batch_size=BATCH_SIZE,
                                                        shuffle=False)

                for _, (src_img, _) in tqdm(enumerate(ds_loader), total=len(ds_loader)):
                    src_img = src_img.to(device)

                    # Domain adaptation (CycleGAN)
                    DA_img = cycleGAN(src_img)

                    # Semantic reconstruction (ADJSCC-V)
                    SNR = torch.full((DA_img.shape[0], 1), SNR_test).to(device)
                    CR = torch.full((DA_img.shape[0], 1), CR_test).to(device)

                    DA_img = (DA_img + 1) / 2  # [-1,1] → [0,1]
                    DA_img_rec = ADJSCCV(DA_img, SNR, CR, CHANNEL)
                    DA_img_rec = (DA_img_rec - 0.5) / 0.5  # [0,1] → [-1,1]

                    # At this point, reconstructed images (DA_img_rec) are available
                    # You can save or evaluate them here
                    # Example: torchvision.utils.save_image(DA_img_rec, ...)

print('CycleGAN evaluation complete (classifier removed).')


## MASCN Transceiver

Since all models are executed sequentially in the Colab environment, some parts of this code conflict with the existing implementation and do not function properly.  
This section will be separated and uploaded later as an independent module.

# Making Test Dataset and Evaluation

In [None]:
root = '/content/drive/My Drive/ADJSCC-V_withDA_4digit_classification'
model_path = os.path.join(root, 'model')
classifier_path = os.path.join(model_path, 'classifier')

DS_NAME = ["MNIST", "MNISTM", "SYN", "USPS"]

BATCH_SIZE = 2
IMGS_TO_DISPLAY = 25

CHANNEL = 'AWGN'  # Choose AWGN or Fading
N_CHANNELS = 256
KERNEL_SIZE = 5

IMAGE_SIZE = 32

enc_out_shape = [48, IMAGE_SIZE//4, IMAGE_SIZE//4]

# MNSIT-2, MNISTM-2, USPS-2, SYN-2

MNIST_ds = ImgData(os.path.join(root, 'dataset', 'MNIST_test.pt'), IMAGE_SIZE, IMAGE_SIZE)
MNIST_ds_loader = torch.utils.data.DataLoader(MNIST_ds,
                                        batch_size=BATCH_SIZE,
                                        shuffle=True)

MNISTM_ds = ImgData(os.path.join(root, 'dataset', 'MNISTM_test.pt'), IMAGE_SIZE, IMAGE_SIZE)
MNISTM_ds_loader = torch.utils.data.DataLoader(MNISTM_ds,
                                        batch_size=BATCH_SIZE,
                                        shuffle=True)

USPS_ds = ImgData(os.path.join(root, 'dataset', 'USPS_test.pt'), IMAGE_SIZE, IMAGE_SIZE)
USPS_ds_loader = torch.utils.data.DataLoader(USPS_ds,
                                        batch_size=BATCH_SIZE,
                                        shuffle=True)


SYN_ds = ImgData(os.path.join(root, 'dataset', 'SYN_test.pt'), IMAGE_SIZE, IMAGE_SIZE)
SYN_ds_loader = torch.utils.data.DataLoader(SYN_ds,
                                        batch_size=BATCH_SIZE,
                                        shuffle=True)


In [None]:
from torchvision.transforms.functional import to_pil_image

DS_NAME_dict = {
    'MNIST': 0,
    'MNISTM': 1,
    'SYN': 2,
    'USPS': 3
}

def plot_domain_adaptation_results(src_img, CR_test, SNR_test, src_domain, dst_domains, model_path, device):
    if src_img.dim() == 3:
        src_img = src_img.unsqueeze(0)

    src_img = src_img.to(device)
    SNR = (SNR_test * torch.ones(src_img.shape[0], 1)).to(device)
    CR = (CR_test * torch.ones(src_img.shape[0], 1)).to(device)

    # Initialize figure with an extra column for algorithm names
    fig, axs = plt.subplots(3, len(dst_domains) + 2, figsize=(18, 6))  # Added one more column

    # Normalize source image for display
    src_img_display = (src_img + 1) / 2

    # Set labels for the algorithms in the new left-most column
    axs[0, 0].text(0.5, 0.5, 'CycleGAN', horizontalalignment='center', verticalalignment='center', fontsize=24, transform=axs[0, 0].transAxes)
    axs[0, 0].axis('off')
    axs[1, 0].text(0.5, 0.5, 'StarGAN', horizontalalignment='center', verticalalignment='center', fontsize=24, transform=axs[1, 0].transAxes)
    axs[1, 0].axis('off')
    axs[2, 0].text(0.5, 0.5, 'SCN', horizontalalignment='center', verticalalignment='center', fontsize=24, transform=axs[2, 0].transAxes)
    axs[2, 0].axis('off')

    # Display source image in both rows
    axs[0, 1].imshow(to_pil_image(src_img_display[0].cpu().squeeze()))
    axs[0, 1].set_title(f'Input ({src_domain}, SNR = {SNR_test}, CR = {CR_test})')
    axs[0, 1].axis('off')

    axs[1, 1].imshow(to_pil_image(src_img_display[0].cpu().squeeze()))
    axs[1, 1].set_title(f'Input ({src_domain}, SNR = {SNR_test}, CR = {CR_test})')
    axs[1, 1].axis('off')

    axs[2, 1].imshow(to_pil_image(src_img_display[0].cpu().squeeze()))
    axs[2, 1].set_title(f'Input ({src_domain}, SNR = {SNR_test}, CR = {CR_test})')
    axs[2, 1].axis('off')

    # Iterate over all destination domains
    for i, dst_domain in enumerate(dst_domains):

        # Load CycleGAN
        cycleGAN_path = os.path.join(model_path, 'cycleGAN')
        cycleGAN_name = f'gen_{src_domain}_{dst_domain}.pkl'
        cycleGAN = CycleGAN_Generator(in_channels=3, out_channels=3, conv_dim=12)
        cycleGAN.load_state_dict(torch.load(os.path.join(cycleGAN_path, cycleGAN_name)))
        cycleGAN.to(device).eval()

        # Load StarGAN
        starGAN_path = os.path.join(model_path, 'starGAN_w_clf')
        starGAN_name = 'gen_100.pkl'
        starGAN = starGAN_Generator(num_domains=len(DS_NAME_dict), image_size=IMAGE_SIZE)
        starGAN.load_state_dict(torch.load(os.path.join(starGAN_path, starGAN_name)))
        starGAN.to(device).eval()

        # Load SCN
        SCN_model_path = os.path.join('/content/drive/My Drive/StarGAN/', f'model_256_AWGN_CR_{CR_test:.1f}_AFB')
        SCN = Generator(in_channels=3, num_domains=len(DS_NAME_dict), image_size=IMAGE_SIZE, out_channels=3, conv_dim=64, CR=CR_test)
        SCN.load_state_dict(torch.load(os.path.join(SCN_model_path, f'gen_AFB_100_{CR_test:.1f}.pkl')))
        SCN.to(device).eval()

        target_domain = (torch.zeros(src_img.shape[0]) + DS_NAME_dict[dst_domain]).long().to(device)

        # Load ADJSCCV
        ADJSCCV_path = os.path.join(model_path, 'ADJSCCV', dst_domain)
        ADJSCCV_name = f'JSCC-V_{dst_domain}.pt'
        ADJSCCV = ADJSCC_V(enc_out_shape, KERNEL_SIZE, N_CHANNELS)
        ADJSCCV.load_state_dict(torch.load(os.path.join(ADJSCCV_path, ADJSCCV_name)))
        ADJSCCV.to(device).eval()

        # Domain adaptation using CycleGAN
        DA_img_cycleGAN = cycleGAN(src_img)
        DA_img_rec_cycleGAN = (DA_img_cycleGAN + 1) / 2
        DA_img_rec_cycleGAN = ADJSCCV(DA_img_rec_cycleGAN, SNR, CR, 'AWGN')
        axs[0, i + 2].imshow(to_pil_image(DA_img_rec_cycleGAN[0].cpu().squeeze()))
        axs[0, i + 2].set_title(f'{src_domain} to {dst_domain}')
        axs[0, i + 2].axis('off')

        # Domain adaptation using StarGAN
        DA_img_starGAN = starGAN(src_img, target_domain)
        DA_img_rec_starGAN = (DA_img_starGAN + 1) / 2
        DA_img_rec_starGAN = ADJSCCV(DA_img_rec_starGAN, SNR, CR, 'AWGN')
        axs[1, i + 2].imshow(to_pil_image(DA_img_rec_starGAN[0].cpu().squeeze()))
        axs[1, i + 2].set_title(f'{src_domain} to {dst_domain}')
        axs[1, i + 2].axis('off')

        # Domain adaptation using SCN
        DA_img_rec_SCN = SCN(src_img, 'AWGN', SNR, target_domain) # [-1, 1]
        DA_img_rec_SCN = (DA_img_rec_SCN + 1) / 2 # [0, 1]

        axs[2, i + 2].imshow(to_pil_image(DA_img_rec_SCN[0].cpu().squeeze()))
        axs[2, i + 2].set_title(f'{src_domain} to {dst_domain}')
        axs[2, i + 2].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
SNR_test = 3
CR_test = 0.2

for _, (src_img, _) in tqdm(enumerate(MNIST_ds_loader), total=len(MNIST_ds_loader)):
    break

for _, (src_img2, _) in tqdm(enumerate(MNISTM_ds_loader), total=len(MNISTM_ds_loader)):
    break

for _, (src_img3, _) in tqdm(enumerate(SYN_ds_loader), total=len(SYN_ds_loader)):
    break

plot_domain_adaptation_results(src_img=src_img,
                               CR_test=CR_test,
                               SNR_test=SNR_test,
                               src_domain='MNIST',
                               dst_domains=['MNISTM', 'SYN', 'USPS'],
                               model_path=model_path,
                               device=device)

plot_domain_adaptation_results(src_img=src_img2,
                               CR_test=CR_test,
                               SNR_test=SNR_test,
                               src_domain='MNISTM',
                               dst_domains=['MNIST', 'SYN', 'USPS'],
                               model_path=model_path,
                               device=device)

plot_domain_adaptation_results(src_img=src_img3,
                               CR_test=CR_test,
                               SNR_test=SNR_test,
                               src_domain='SYN',
                               dst_domains=['MNIST', 'MNISTM', 'USPS'],
                               model_path=model_path,
                               device=device)

In [None]:
import os
import torchvision.transforms.functional as TF

def save_images(images, labels, dataset_name, root_path):
    # Create directory if it does not exist
    save_path = os.path.join(root_path, 'qualitative_comparison_test_data', dataset_name)
    os.makedirs(save_path, exist_ok=True)

    for idx, (img, lbl) in enumerate(zip(images, labels)):
        # Ensure the image tensor is in the correct shape (C, H, W)
        if img.dim() == 3 and img.shape[0] != 3:
            img = img.squeeze(0)  # For grayscale, remove the channel if it's 1xHxW
        elif img.dim() == 3 and img.shape[0] == 3:
            img = img.permute(1, 2, 0)  # For RGB, change from CxHxW to HxWxC

        # Convert back to CxHxW for saving
        if img.dim() == 3 and img.shape[2] == 3:
            img = img.permute(2, 0, 1)  # Only if previously permuted to HxWxC

        # Convert tensor to PIL Image
        img_pil = TF.to_pil_image(img)
        img_path = os.path.join(save_path, f'image_{idx}_label_{lbl}.png')
        img_pil.save(img_path)  # Save the image
        print(f'Saved {img_path}')


# Define the root directory

# Dataset names for identification
dataset_names = ['MNIST', 'MNISTM', 'USPS', 'SYN']

# Save two samples from each dataset
for (images, labels), name in zip(zip(sample_images, sample_labels), dataset_names):
    save_images(images, labels, name, root)
