<a href="https://colab.research.google.com/github/wongdongwook/JSAC_MA-DeepSC/blob/main/MASCN_for_4Digits.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torch import optim
import pdb  # Debugging tool
from torch import autograd  # For gradient penalty (WGAN-GP)
import torchvision.utils as vutils  # For saving images
import torch.nn.functional as F
from torchvision.transforms.functional import to_pil_image
from torchvision import datasets
from torchvision import transforms
import numpy as np
import os
import math
from tqdm import tqdm  # Progress bar
import datetime
from sklearn.metrics import accuracy_score, confusion_matrix  # Evaluation metrics

In [None]:
# Mount Google Drive to access .pt files or saved models
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Set computing device (GPU if available, otherwise CPU)
# Note: This code does not support Google TPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using ' + str(device).upper())

Using CUDA


# Data Loader

In [None]:
# Dataset paths for different digit domains
MNIST_path = '/content/drive/My Drive/StarGAN/dataset/MNIST_train.pt'
MNISTM_path = '/content/drive/My Drive/StarGAN/dataset/MNISTM_train.pt'
SynDigits_path = '/content/drive/My Drive/StarGAN/dataset/SYN_train.pt'
USPS_path = '/content/drive/My Drive/StarGAN/dataset/USPS_train.pt'
# Each .pt file is assumed to contain [image_tensor, label_tensor]

class ImgDomainAdaptationData(torch.utils.data.Dataset):
    """
    Custom dataset class for multi-domain digit datasets.
    Assigns domain indices as follows: MNIST=0, MNISTM=1, SYN=2, USPS=3
    """

    def __init__(self, w, h):
        # Define image preprocessing (resize and normalize)
        self.transform = transforms.Compose([transforms.Resize([w, h]),
                                             transforms.Normalize([0.5], [0.5])])

        # Load data from .pt files
        self.MNIST_data = torch.load(MNIST_path)
        self.MNISTM_data = torch.load(MNISTM_path)
        self.SynDigits_data = torch.load(SynDigits_path)
        self.USPS_data = torch.load(USPS_path)

        # Apply preprocessing to image tensors
        self.MNIST_img = self.transform(self.MNIST_data[0])
        self.MNISTM_img = self.transform(self.MNISTM_data[0])
        self.SynDigits_img = self.transform(self.SynDigits_data[0])
        self.USPS_img = self.transform(self.USPS_data[0])

        # Extract labels
        self.MNIST_label = self.MNIST_data[1]
        self.MNISTM_label = self.MNISTM_data[1]
        self.SynDigits_label = self.SynDigits_data[1]
        self.USPS_label = self.USPS_data[1]

        # Apply additional preprocessing (e.g., channel repeat, domain indexing)
        self.MNIST_img, self.MNIST_domain = self.pre_processing(self.MNIST_img, 0)
        self.MNISTM_img, self.MNISTM_domain = self.pre_processing(self.MNISTM_img, 1)
        self.SynDigits_img, self.SynDigits_domain = self.pre_processing(self.SynDigits_img, 2)
        self.USPS_img, self.USPS_domain = self.pre_processing(self.USPS_img, 3)

        # Concatenate all images, labels, and domain labels
        self.img = torch.vstack((self.MNIST_img,
                                  self.MNISTM_img,
                                  self.SynDigits_img,
                                  self.USPS_img))

        self.label = torch.hstack((self.MNIST_label,
                                    self.MNISTM_label,
                                    self.SynDigits_label,
                                    self.USPS_label))

        self.domain = np.hstack((self.MNIST_domain,
                                 self.MNISTM_domain,
                                 self.SynDigits_domain,
                                 self.USPS_domain))

    def pre_processing(self, img, domain):
        num_img = img.shape[0]

        # Convert single-channel to 3-channel if needed
        if len(img.shape) < 4:
            img = img.unsqueeze(1).repeat(1, 3, 1, 1)

        # Generate domain label array
        domain_label = np.zeros(num_img, dtype=int) + domain

        return img, domain_label

    def __len__(self):
        return self.label.shape[0]

    def __getitem__(self, index):
        return self.img[index], self.label[index].item(), self.domain[index]


# Utils

In [None]:
def generate_imgs(img, n_dms, gen, samples_path, channel_type, step=0, is_cuda=True):
	"""
    Generate and save a grid of images showing original and domain-adapted results.

    Args:
        img: Input image batch (Tensor)
        n_dms: Number of domains
        gen: Generator model
        samples_path: Directory to save image grids
        channel_type: Channel type (e.g., 'AWGN')
        step: Iteration step for file naming
        is_cuda: Whether to use CUDA
  """
	gen.eval()
	m = img.shape[0]

  # Create target domain labels: -1 for real, [0, ..., n_dms-1] for fake
	lbl = torch.arange(start=-1, end=n_dms)
	lbl = lbl.expand(m, n_dms+1).reshape([-1])

	if is_cuda:
		lbl = lbl.cuda()
	# Repeat input images for each domain label
	img_ = torch.repeat_interleave(img, n_dms+1, dim=0)
	SNR_TRAIN = torch.randint(20, 28, (img_.shape[0], 1)).cuda() # Random SNR values (simulating channel noise)
	real_idx = torch.arange(start=0, end=m*(n_dms+1), step=n_dms+1) # Mark real images (used as ground truth)
	lbl[real_idx] = 0

	# Generate images via semantic encoder+channel+decoder
	display_imgs = gen(img_, channel_type,SNR_TRAIN, lbl)
	display_imgs[real_idx] = img

	# Create grid and save to file
	display_imgs_ = vutils.make_grid(display_imgs, normalize=True, nrow=n_dms+1, padding=2, pad_value=1)
	vutils.save_image(display_imgs_, os.path.join(samples_path, 'sample_' + str(step) + '.png'))




# WGAN-GP 구조에서 쓰이는 gradient penalty 계산
# interpolated image를 만든 후, critic (discriminator)의 출력을 기반으로 그래디언트의 L2 norm을 계산
def gradient_penalty(real, fake, critic, is_cuda=True):
	"""
    Compute the gradient penalty term for WGAN-GP.

    Args:
        real: Real images (Tensor)
        fake: Generated images (Tensor)
        critic: Discriminator (a.k.a. critic)
        is_cuda: Whether CUDA is used

    Returns:
        Gradient penalty (scalar Tensor)
	"""
	m = real.shape[0]
	epsilon = torch.rand(m, 1, 1, 1)  # Sample random weights for interpolation
	if is_cuda:
		epsilon = epsilon.cuda()

	# Interpolated images between real and fake
	interpolated_img = epsilon * real + (1-epsilon) * fake
	interpolated_out, _,_= critic(interpolated_img)

	# Compute gradients of critic outputs w.r.t. interpolated inputs
	grads = autograd.grad(outputs=interpolated_out, inputs=interpolated_img,
							   grad_outputs=torch.ones(interpolated_out.shape).cuda() if is_cuda else torch.ones(interpolated_out.shape),
							   create_graph=True, retain_graph=True)[0]
	grads = grads.reshape([m, -1])
	grad_penalty = ((grads.norm(2, dim=1) - 1) ** 2).mean()
	return grad_penalty


#Model

In [None]:
def conv_block(c_in, c_out, k_size=4, stride=2, pad=1, use_bn=True, transpose=False):
    module = []
    if transpose:
        module.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, bias=not use_bn))
    else:
        module.append(nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=not use_bn))
    if use_bn:
        module.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*module)


class ResBlock(nn.Module):
    def __init__(self, channels):
        super(ResBlock, self).__init__()
        self.conv1 = conv_block(channels, channels, k_size=3, stride=1, pad=1, use_bn=True)
        self.conv2 = conv_block(channels, channels, k_size=3, stride=1, pad=1, use_bn=True)

    def __call__(self, x):
        x = F.relu(self.conv1(x))
        return x + self.conv2(x)

# Original Existing Works
class AF_block(nn.Module):
    def __init__(self, Nin, Nh, No):
        super(AF_block, self).__init__()
        self.fc1 = nn.Linear(Nin+1, Nh)
        self.fc2 = nn.Linear(Nh, No)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    def forward(self, x, snr):
        # out = F.adaptive_avg_pool2d(x, (1,1))
        # out = torch.squeeze(out)
        # out = torch.cat((out, snr), 1)
        if snr.shape[0]>1:
            snr = snr.squeeze()
        snr = snr.unsqueeze(1)
        mu = torch.mean(x, (2, 3))
        out = torch.cat((mu, snr), 1)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        out = out.unsqueeze(2)
        out = out.unsqueeze(3)
        out = out*x
        return out


## Power normalization before transmission
# Note: if P = 1, the symbol power is 2
# If you want to set the average power as 1, please change P as P=1/np.sqrt(2)
def Power_norm(z, P = 1):
    batch_size, z_dim = z.shape
    z_power = torch.sqrt(torch.sum(z**2, 1))
    z_M = z_power.repeat(z_dim, 1)
    return np.sqrt(P*z_dim)*z/z_M.t()

def Power_norm_complex(z, P = 1):
    batch_size, z_dim = z.shape
    z_com = torch.complex(z[:, 0:z_dim:2], z[:, 1:z_dim:2])
    z_com_conj = torch.complex(z[:, 0:z_dim:2], -z[:, 1:z_dim:2])
    z_power = torch.sum(z_com*z_com_conj, 1).real
    z_M = z_power.repeat(z_dim//2, 1)
    z_nlz = np.sqrt(P*z_dim)*z_com/torch.sqrt(z_M.t())
    z_out = torch.zeros(batch_size, z_dim).cuda()
    z_out[:, 0:z_dim:2] = z_nlz.real
    z_out[:, 1:z_dim:2] = z_nlz.imag
    return z_out

# The (real) AWGN channel
def AWGN_channel(x, snr, P = 2):
    batch_size, length = x.shape
    gamma = 10 ** (snr / 10.0)
    noise = torch.sqrt(P/gamma)*torch.randn(batch_size, length).cuda()
    y = x+noise
    return y

def AWGN_complex(x, snr, Ps = 1):
    batch_size, length = x.shape
    gamma = 10 ** (snr / 10.0)
    n_I = torch.sqrt(Ps/gamma)*torch.randn(batch_size, length).cuda()
    n_R = torch.sqrt(Ps/gamma)*torch.randn(batch_size, length).cuda()
    noise = torch.complex(n_I, n_R)
    y = x + noise
    return y

# Please set the symbol power if it is not a default value
def Fading_channel(x, snr, P = 2):
    gamma = 10 ** (snr / 10.0)
    [batch_size, feature_length] = x.shape
    K = feature_length//2

    h_I = torch.randn(batch_size, K).cuda()
    h_R = torch.randn(batch_size, K).cuda()
    h_com = torch.complex(h_I, h_R)
    x_com = torch.complex(x[:, 0:feature_length:2], x[:, 1:feature_length:2])
    y_com = h_com*x_com

    n_I = torch.sqrt(P/gamma)*torch.randn(batch_size, K).cuda()
    n_R = torch.sqrt(P/gamma)*torch.randn(batch_size, K).cuda()
    noise = torch.complex(n_I, n_R)

    y_add = y_com + noise
    y = y_add/h_com

    y_out = torch.zeros(batch_size, feature_length).cuda()
    y_out[:, 0:feature_length:2] = y.real
    y_out[:, 1:feature_length:2] = y.imag
    return y_out

# Note: if P = 1, the symbol power is 2
# If you want to set the average power as 1, please change P as P=1/np.sqrt(2)
def Power_norm_VLC(z, cr, P = 1):
    batch_size, z_dim = z.shape
    Kv = torch.ceil(z_dim*cr).int()
    z_power = torch.sqrt(torch.sum(z**2, 1))
    z_M = z_power.repeat(z_dim, 1).cuda()
    return torch.sqrt(Kv*P)*z/z_M.t()

def AWGN_channel_VLC(x, snr, cr, P = 2):
    batch_size, length = x.shape
    gamma = 10 ** (snr / 10.0)
    mask = mask_gen(length, cr).cuda()
    noise = torch.sqrt(P/gamma)*torch.randn(1, length).cuda()
    noise = noise*mask
    y = x+noise
    return y

def Fading_channel_VLC(x, snr, cr, P = 2):
    gamma = 10 ** (snr / 10.0)
    [batch_size, feature_length] = x.shape
    K = feature_length//2

    mask = mask_gen(K, cr).cuda()
    h_I = torch.randn(batch_size, K).cuda()
    h_R = torch.randn(batch_size, K).cuda()
    h_com = torch.complex(h_I, h_R)
    x_com = torch.complex(x[:, 0:feature_length:2], x[:, 1:feature_length:2])
    y_com = h_com*x_com

    n_I = torch.sqrt(P/gamma)*torch.randn(batch_size, K).cuda()
    n_R = torch.sqrt(P/gamma)*torch.randn(batch_size, K).cuda()
    noise = torch.complex(n_I, n_R)*mask

    y_add = y_com + noise
    y = y_add/h_com

    y_out = torch.zeros(batch_size, feature_length).cuda()
    y_out[:, 0:feature_length:2] = y.real
    y_out[:, 1:feature_length:2] = y.imag
    return y_out

def Channel(z, snr, channel_type = 'AWGN'):
    z = Power_norm(z)
    if channel_type == 'AWGN':
        z = AWGN_channel(z, snr)
    elif channel_type == 'Fading':
        z = Fading_channel(z, snr)
    return z

def Channel_VLC(z, snr, cr, channel_type = 'AWGN'):
    z = Power_norm_VLC(z, cr)
    if channel_type == 'AWGN':
        z = AWGN_channel_VLC(z, snr, cr)
    elif channel_type == 'Fading':
        z = Fading_channel_VLC(z, snr, cr)
    return z

def mask_gen(N, cr, ch_max = 48): # channel 갯수에 따라 달라짐. 48,8,8
    MASK = torch.zeros(cr.shape[0], N).int()
    nc = N//ch_max
    #print(nc) # nc :64
    for i in range(0, cr.shape[0]):
        L_i = nc*torch.round(ch_max*cr[i]).int()
        MASK[i, 0:L_i] = 1
    return MASK

class ADJSCC(nn.Module):
    def __init__(self, enc_shape, Kernel_sz, Nc):
        super(ADJSCC, self).__init__()
        self.encoder = Encoder(enc_shape, Kernel_sz, Nc)
        self.decoder = Decoder(enc_shape, Kernel_sz, Nc)
    def forward(self, x, snr, channel_type = 'AWGN'):
        z = self.encoder(x, snr)
        z = Channel(z, snr, channel_type)
        out = self.decoder(z, snr)
        return out

# The DeepJSCC_V model, also called ADJSCC_V
class ADJSCC_V(nn.Module):
    def __init__(self, enc_shape, Kernel_sz, Nc):
        super(ADJSCC_V, self).__init__()
        self.encoder = Encoder(enc_shape, Kernel_sz, Nc)
        self.decoder = Decoder(enc_shape, Kernel_sz, Nc)
    def forward(self, x, snr, cr, channel_type = 'AWGN'):
        # x: [128, 3, 32, 32], snr, cr : [128, 1]
        z = self.encoder(x, snr) # encoder output depend on cr. if cr= 1, output : [128,3 *32*32 /cr]
        z = z*mask_gen(z.shape[1], cr).cuda()
        z = Channel_VLC(z, snr, cr, channel_type)
        out = self.decoder(z, snr)
        return out # [3, 32, 32]



## Discriminator Model (Common)
MDAN과 MASCN은 판별기 모델은 같은걸 쓴다고 가정함

In [None]:
class Discriminator(nn.Module):
    def __init__(self, channels=3, num_domains=5, num_classes=10, image_size=32, conv_dim=64):
        super(Discriminator, self).__init__()
        self.conv1 = conv_block(channels, conv_dim, use_bn=False)
        self.conv2 = conv_block(conv_dim, conv_dim * 2, use_bn=False)
        self.conv3 = conv_block(conv_dim * 2, conv_dim * 4, use_bn=False)
        self.conv4 = conv_block(conv_dim * 4, conv_dim * 8, use_bn=False)

        self.gan = conv_block(conv_dim * 8, 1, k_size=3, stride=1, pad=1, use_bn=False)
        self.cls = conv_block(conv_dim * 8, num_domains, k_size=image_size//16, stride=1, pad=0, use_bn=False)
        self.label_cls = conv_block(conv_dim * 8, num_classes, k_size=image_size//16, stride=1, pad=0, use_bn=False)

        # Initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        alpha = 0.01
        x = F.leaky_relu(self.conv1(x), alpha)
        x = F.leaky_relu(self.conv2(x), alpha)
        x = F.leaky_relu(self.conv3(x), alpha)
        x = F.leaky_relu(self.conv4(x), alpha)
        gan_out = self.gan(x)
        cls_out = self.cls(x)
        label_cls_out = self.label_cls(x)

        return gan_out, cls_out.squeeze(), label_cls_out.squeeze()

## MDAN (Multi domain data adaptaion network)
Original StarGAN 코드

In [None]:
class Generator(nn.Module):  # Original StarGAN. 인코더 앞에 위치함.
    def __init__(self, in_channels=3, num_domains=5, image_size=32, out_channels=3, conv_dim=64):
        super(Generator, self).__init__()
        self.image_size = image_size
        self.embed_layer = nn.Embedding(num_domains, image_size**2)

        self.conv1 = conv_block(in_channels+1, conv_dim, k_size=5, stride=1, pad=2, use_bn=True) # 64
        self.conv2 = conv_block(conv_dim, conv_dim * 2, k_size=4, stride=2, pad=1, use_bn=True) #  64 * 2
        self.conv3 = conv_block(conv_dim * 2, conv_dim * 4, k_size=4, stride=2, pad=1, use_bn=True) # 64 *4 = 256
        self.res4 = ResBlock(conv_dim * 4)
        self.res5 = ResBlock(conv_dim * 4)
        self.res6 = ResBlock(conv_dim * 4)
        self.tconv7 = conv_block(conv_dim * 4, conv_dim * 2, k_size=4, stride=2, pad=1, use_bn=True, transpose=True)
        self.tconv8 = conv_block(conv_dim * 2, conv_dim, k_size=4, stride=2, pad=1, use_bn=True, transpose=True)
        self.conv9 = conv_block(conv_dim, out_channels, k_size=5, stride=1, pad=2, use_bn=False)

        # Initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, target_dm=None):
        if target_dm is None:
            target_dm = torch.ones(x.shape[0])
        target_dm = target_dm.long()
        embed = self.embed_layer(target_dm).reshape([-1, 1, self.image_size, self.image_size])
        x = torch.cat((x, embed), dim=1) #[3, 32, 32] + [1, 32, 32] = [4, 32, 32]
        x = F.relu(self.conv1(x)) # [64,32,32]
        x = F.relu(self.conv2(x))  # [128,16,16]
        x = F.relu(self.conv3(x))# [256, 8, 8]
        x = F.relu(self.res4(x))  # [256, 8, 8]
        x = F.relu(self.res5(x))  # [256, 8, 8]
        x = F.relu(self.res6(x))  # [256, 8, 8]
        x = F.relu(self.tconv7(x)) # [128,16,16]
        x = F.relu(self.tconv8(x))  # [64,32,32]
        x = torch.tanh(self.conv9(x))  # [3,32,32]
        return x

## MASCN (Multi domain adaptive semantic coding network)

In [None]:
class Generator(nn.Module):  # StarGAN based Generative Semantic Transceiver
    def __init__(self, in_channels=3, num_domains=5, image_size=32, out_channels=3, conv_dim=64, CR=0.2):
        super(Generator, self).__init__()
        self.image_size = image_size
        self.embed_layer = nn.Embedding(num_domains, image_size**2)
        self.CR=CR
        conv_dim_btl = int(conv_dim * 4 *CR)
        self.conv_dim=conv_dim

        # Encoder
        self.conv1 = conv_block(in_channels+1, conv_dim, k_size=5, stride=1, pad=2, use_bn=True) # 64
        self.conv2 = conv_block(conv_dim, conv_dim * 2, k_size=4, stride=2, pad=1, use_bn=True) #  64*2
        self.conv3 = conv_block(conv_dim * 2, conv_dim * 4, k_size=4, stride=2, pad=1, use_bn=True) # 64*4
        self.res4 = ResBlock(conv_dim * 4)
        self.AF4= AF_block(conv_dim * 4, conv_dim*4 //2, conv_dim * 4)
        self.res5 = ResBlock(conv_dim * 4)
        self.AF5= AF_block(conv_dim * 4, conv_dim*4 //2, conv_dim * 4)
        self.conv5 = conv_block(conv_dim * 4, conv_dim_btl , k_size=3, stride=1, pad=1, use_bn=True)

        # Decoder
        self.conv6 = conv_block(conv_dim_btl, conv_dim * 4, k_size=3, stride=1, pad=1, use_bn=True)
        self.AF6= AF_block(conv_dim * 4, conv_dim*4 //2, conv_dim * 4)
        self.res6 = ResBlock(conv_dim * 4)
        self.AF7= AF_block(conv_dim * 4, conv_dim*4 //2, conv_dim * 4)
        self.res7 = ResBlock(conv_dim * 4)
        self.tconv7 = conv_block(conv_dim * 4, conv_dim * 2, k_size=4, stride=2, pad=1, use_bn=True, transpose=True)
        self.tconv8 = conv_block(conv_dim * 2, conv_dim, k_size=4, stride=2, pad=1, use_bn=True, transpose=True)
        self.conv9 = conv_block(conv_dim, out_channels, k_size=5, stride=1, pad=2, use_bn=False)
        self.flatten = nn.Flatten()

        # Initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, channel_type, snr,  target_dm=None):
        # Encoder
        if target_dm is None:
            target_dm = torch.ones(x.shape[0])
        target_dm = target_dm.long()
        embed = self.embed_layer(target_dm).reshape([-1, 1, self.image_size, self.image_size])
        x = torch.cat((x, embed), dim=1)
        x = F.relu(self.conv1(x)) # [12,32,32]
        x = F.relu(self.conv2(x))  # [24,16,16]
        x = F.relu(self.conv3(x))# [48, 8, 8]
        x = F.relu(self.res4(x))  # [48, 8, 8]
        x = F.relu(self.AF4(x, snr))
        x = F.relu(self.res5(x))  # [48, 8, 8]
        x = F.relu(self.AF5(x, snr))
        x = F.relu(self.conv5(x))
        x = self.flatten(x)

        # AWGN channel
        x = Channel(x, snr, channel_type)
         # Decoder
        out_size= int(self.conv_dim*4 *self.CR)
        x = x.view(-1, out_size, 8, 8)
        x = F.relu(self.conv6(x))
        x = F.relu(self.AF6(x, snr))
        x = F.relu(self.res6(x))  # [48, 8, 8]
        x = F.relu(self.AF7(x, snr))
        x = F.relu(self.res7(x))  # [48, 8, 8]
        x = F.relu(self.tconv7(x)) # [24,16,16]
        x = F.relu(self.tconv8(x))  # [12,32,32]
        x = torch.tanh(self.conv9(x))  # [3,32,32]
        return x

# Configuration

In [None]:
#하이퍼파라미터 설정
EPOCHS = 100  # 전체 학습 epoch 수 (보통 50~300에서 튜닝)
BATCH_SIZE = 128 # 한 배치당 이미지 수
IMGS_TO_DISPLAY = 25 # 시각화용 이미지 개수

IMAGE_SIZE = 32  # 입력 이미지 크기 (32x32)
NUM_DOMAINS = 4   # 도메인 수: MNIST, MNIST-M, SYN, USPS

N_CRITIC = 5  # Discriminator를 Generator보다 몇 배 더 학습할지 (WGAN-GP 기준)
GRADIENT_PENALTY = 10  # WGAN-GP의 lambda 값
CR=1   # Compression Ratio 관련 하이퍼파라미터

# 경로 설정 및 디렉토리 생성
model_path = '/content/drive/My Drive/StarGAN/model_256_AWGN_CR_0.1_AFB'
os.makedirs(model_path, exist_ok=True)
samples_path = '/content/drive/My Drive/StarGAN/samples_256_AWGN_CR_0.1_AFB'
os.makedirs(samples_path, exist_ok=True)

LOAD_MODEL = False

# Generator, Discriminator 모델 생성 및 조건부 불러오기
gen = Generator(num_domains=NUM_DOMAINS, image_size=IMAGE_SIZE, CR=0.1) # Compression Ratio = 0.1인 모델 생성 및 초기
dis = Discriminator(num_domains=NUM_DOMAINS, image_size=IMAGE_SIZE)

if LOAD_MODEL: # 학습된 모델 저장 경로 (generator.pkl, discriminator.pkl)
  gen.load_state_dict(torch.load(os.path.join(model_path, 'generator.pkl')))
  dis.load_state_dict(torch.load(os.path.join(model_path, 'discriminator.pkl')))

# gpu 디바이스 할당
gen.to(device)
dis.to(device)

# Define Optimizers
g_opt = torch.optim.Adam(gen.parameters(), lr=0.0001, betas=(0.5, 0.999))
d_opt = torch.optim.Adam(dis.parameters(), lr=0.0001, betas=(0.5, 0.999))

# Define Loss
ce = nn.CrossEntropyLoss()

# Data loaders
data1 = ImgDomainAdaptationData(32,32)
ds_loader = torch.utils.data.DataLoader(data1,
                                        batch_size=BATCH_SIZE,
                                        shuffle=True,
                                        num_workers=4,
                                        drop_last=True)
iters_per_epoch = len(ds_loader)

# Fix images for viz
loader_iter = iter(ds_loader)
img_fixed = next(loader_iter)[0][:IMGS_TO_DISPLAY]

# GPU Compatibility
is_cuda = torch.cuda.is_available()
gen, dis = gen.to(device), dis.to(device)
img_fixed = img_fixed.to(device)




# Training Loop for MASCN
e.g., compression rate =0.1

In [None]:
total_iter = 0
g_label_loss= g_gan_loss = g_clf_loss = g_rec_loss = torch.Tensor([0])
EPOCHS= 100
for epoch in tqdm(range(EPOCHS)):
  gen.train()
  dis.train()

  num_batches = 0

  for i, data in enumerate(ds_loader):
    total_iter += 1
    num_batches += 1

    # Loading data
    real, label, dm = data
    snr = torch.randint(0, 28, (real.shape[0], 1)).cuda()
    real, label, dm = real.to(device), label.to(device), dm.long().to(device)

    target_dm = dm[torch.randperm(dm.size(0))]

    # Fake Images
    fake = gen(real, 'AWGN', snr, target_dm)

    # Training discriminator
    real_gan_out, real_cls_out, real_label_out   = dis(real)
    fake_gan_out, fake_cls_out, _  = dis(fake.detach())

    d_gan_loss = -(real_gan_out.mean() - fake_gan_out.mean()) + gradient_penalty(real, fake, dis, is_cuda) * GRADIENT_PENALTY
    d_clf_loss = ce(real_cls_out, dm)
    d_label_loss = ce(real_label_out, label)  # Label classification loss for discriminator

    d_opt.zero_grad()
    d_loss = d_gan_loss + d_clf_loss + d_label_loss
    d_loss.backward()
    d_opt.step()

    # Training Generator
    if total_iter % N_CRITIC == 0:
      fake = gen(real, 'AWGN', snr, target_dm)
      fake_gan_out, fake_cls_out, fake_label_out  = dis(fake)

      g_gan_loss = - fake_gan_out.mean()
      g_clf_loss = ce(fake_cls_out, target_dm)
      g_label_loss = ce(fake_label_out, label)  # Label classification loss for generator
      g_rec_loss = (real - gen(fake, 'AWGN', snr, dm)).abs().mean()

      g_opt.zero_grad()
      g_loss = g_gan_loss + g_clf_loss + g_rec_loss +g_label_loss
      g_loss.backward()
      g_opt.step()

    if i % 50 == 0:
      print("\nEpoch: " + str(epoch + 1) + "/" + str(EPOCHS)
          + " iter: " + str(i+1) + "/" + str(iters_per_epoch)
          + " total_iters: " + str(total_iter)
          + "\td_gan_loss:" + str(round(d_gan_loss.item(), 4))
          + "\td_clf_loss:" + str(round(d_clf_loss.item(), 4))
          + "\td_label_loss:" + str(round(d_label_loss.item(), 4))
          + "\tg_gan_loss:" + str(round(g_gan_loss.item(), 4))
          + "\tg_clf_loss:" + str(round(g_clf_loss.item(), 4))
          + "\tg_rec_loss:" + str(round(g_rec_loss.item(), 4))
          + "\tg_label_loss:" + str(round(g_label_loss.item(), 4)))

    if total_iter % 100==0:
      generate_imgs(img_fixed, NUM_DOMAINS, gen, samples_path, 'AWGN', total_iter, is_cuda)

  # torch.save(gen.state_dict(), os.path.join(model_path, 'gen.pkl'))
  # torch.save(dis.state_dict(), os.path.join(model_path, 'dis.pkl'))

torch.save(gen.state_dict(), os.path.join(model_path, 'gen_AFB_100_0.1.pkl'))
torch.save(dis.state_dict(), os.path.join(model_path, 'dis_AFB_100_0.1.pkl'))