# Auxiliary Classifier GANs (ACGANs)

Auxiliary Classifier GANs (ACGANs) [1] are a variant of Generative Adversarial Networks (GANs) that incorporate auxiliary classifiers to enhance the image synthesis process by providing additional class information to the generator. This additional class information helps direct the generator towards specific classes, thereby improving the quality and relevance of the generated images, especially in scenarios where the synthesis is targeted towards specific diseases or conditions.

### Key Features and Applications

1. **Enhancing Image Diversity**:
   - ACGANs address challenges related to diversity in generated images by using auxiliary classifiers. This allows sampling noise vectors from different distributions, such as heavy-tailed student t-distributions, enhancing the realism and variety of the generated images.

2. **Conditional Image Synthesis**:
   - In medical imaging, ACGANs are used to generate synthetic images conditioned on specific parameters. For example, in synthesizing MR knee images, ACGANs condition the generation process on acquisition parameters like repetition time, echo time, and image orientation.

3. **Broader Applications**:
   - Beyond medical imaging, ACGANs are employed in various fields:
     - **SAR Target Image Generation**: Integrated into multi-task learning methods to combine pose estimation and class information.
     - **Text-to-Image Synthesis**: Recover side information about generated images, such as class labels, enhancing interpretability and relevance.

4. **Addressing Dataset Imbalance**:
   - ACGANs help in classification tasks with imbalanced data. By leveraging auxiliary classifiers, they improve classification performance on skewed datasets. For instance, generating minority class samples like iris images using conditional Wasserstein GANs with gradient penalty.

5. **Medical Imaging Applications**:
   - **Lung Cancer Diagnosis and Neuroimaging**: ACGANs differentiate between real and fake samples while performing classification tasks, enhancing diagnostic capabilities.
   - **Breast Cancer Detection**: Used to generate synthetic mammograms, aiding convolutional neural networks in accurately classifying breast cancer cases.

### Conclusion

Auxiliary Classifier GANs play a vital role in conditional image synthesis by incorporating auxiliary classifiers to provide additional class information to the generator. This improves the quality, diversity, and relevance of generated images. From medical imaging to text-to-image synthesis and addressing dataset imbalances, ACGANs have demonstrated versatility and effectiveness in various applications, making them a valuable tool in the field of image synthesis and classification.

Intuition:
The core idea of ​​ACGAN is to add category information to the standard GAN to improve the diversity and quality of generated images. The generator generates images by inputting noise and category labels, and the discriminator judges the authenticity and category labels of the images at the same time. This design makes the optimization goals of the generator and discriminator clearer and helps to generate more meaningful images.

Here are the code snippets that reflect the core idea:

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)

        self.init_size = opt.img_size // 4
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim + opt.n_classes, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise, labels):
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        out = self.l1(gen_input)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)

        self.model = nn.Sequential(
            nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

        self.aux_layer = nn.Sequential(nn.Linear(512, opt.n_classes), nn.Softmax())

    def forward(self, img, labels):
        d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
        validity = self.model(d_in)
        label = self.aux_layer(d_in)
        return validity, label

Explaining with details

In [None]:
# Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)

        self.init_size = opt.img_size // 4
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim + opt.n_classes, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise, labels):
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        out = self.l1(gen_input)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

1. 'self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)':

- This layer embeds the class labels into a continuous vector space so that the labels can be used as input to the generator.

2. 'self.l1 = nn.Sequential(nn.Linear(opt.latent_dim + opt.n_classes, 128 * self.init_size ** 2))':

- The first layer of the generator is a fully connected layer, the input is the concatenation of the noise vector and the class label, and the output is a high-dimensional vector, which is then reshaped into a feature map.

3. 'self.conv_blocks = nn.Sequential(...)':

These convolution blocks gradually transform the feature map into the final image through upsampling and convolution operations.

4. 'def forward(self, noise, labels):':
In the forward propagation, the noise and label embedding vectors are concatenated, and then the image is generated through the fully connected layer and convolution blocks.

In [None]:
# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)

        self.model = nn.Sequential(
            nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

        self.aux_layer = nn.Sequential(nn.Linear(512, opt.n_classes), nn.Softmax())

    def forward(self, img, labels):
        d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
        validity = self.model(d_in)
        label = self.aux_layer(d_in)
        return validity, label


1. 'self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)':

Embed the category label into a continuous vector space for concatenation with the image features.

2. 'self.model = nn.Sequential(...)':

The main part of the discriminator is a fully connected layer network used to judge the authenticity of the image. The input is the concatenation of the flattened image vector and the label embedding vector.

3. 'self.aux_layer = nn.Sequential(nn.Linear(512, opt.n_classes), nn.Softmax())':

Auxiliary classifier, used to predict the category label of the image. The input is the feature vector from the main network.

4. 'def forward(self, img, labels):':

In the forward propagation, the flattened image vector is concatenated with the label embedding vector, and then the authenticity of the image is judged by the main network, and the category label of the image is predicted by the auxiliary classifier.

In [None]:
# Training
for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))
        labels = Variable(labels.type(LongTensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise and labels as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
        gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, imgs.shape[0])))

        # Generate a batch of images
        gen_imgs = generator(z, gen_labels)

        # Loss measures generator's ability to fool the discriminator
        validity, pred_label = discriminator(gen_imgs, gen_labels)
        g_loss = 0.5 * adversarial_loss(validity, valid) + 0.5 * auxiliary_loss(pred_label, gen_labels)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Loss for real images
        real_pred, real_aux = discriminator(real_imgs, labels)
        d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2

        # Loss for fake images
        fake_pred, fake_aux = discriminator(gen_imgs.detach(), gen_labels)
        d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

1. Generator loss (g_loss):
The generator loss consists of the loss of deceiving the discriminator and the loss of generating the correct label.

2. Discriminator loss (d_loss):
The discriminator loss consists of the authenticity loss of the real image and the generated image, as well as their label classification loss.