## Repurposing Hao L, Shen P, Pan Z and Xu Y (2023) ACGAN algorithm for synthetic data generation for few-shot classification
The main contributions of this paper (doi: 10.3389/fphy.2023.1208781) are as follows:

• The residual module is introduced into the network structure of adversarial learning to contribute to the feature extraction. Moreover, multiple convolutional layers are employed in the model architecture to replace the original classification layer, further boosting the classification performance of the model.

• A multi-level semantic feature extractor (MSFE) which effectively extracts features at different levels is designed, fully capturing diverse semantic information of images to guide the generator in sample generation and improve the quality of generated samples.

• The proposed method can generate high-quality samples to compensate for the deficiencies under few-shot conditions, further improving the classification performance.

> Gerador
- Camada totalmente conectada para mapear o vetor z e o rótulo c para uma dimensão que possa ser redimensionada para uma forma que funcione como entrada para as convoluções transpostas.
- Operações de upsampling antes das duas primeiras camadas convolucionais.
- Camadas convolucionais transpostas para aumentar gradativamente a resolução da imagem gerada.
- Módulos residuais inseridos entre as camadas convolucionais.
- Normalização em lote e Leaky-ReLU após cada convolução.
- Camada de saída para produzir a imagem com a resolução desejada (96x96x3 neste caso).

> Discriminador
- Camadas convolucionais para extrair características das imagens de entrada.
- Módulos residuais após a primeira camada convolucional.
- Dropout após as convoluções para prevenir sobreajuste.
- Leaky-ReLU após cada convolução.
- Uma camada final para classificar as imagens como reais ou falsas, e possivelmente outra camada para a classificação auxiliar (contendo ou não parasitas).

In [2]:
from CNN_Net import Net as CustomNet
import os
from pathlib import Path
import datetime
CUDA_LAUNCH_BLOCKING = "1"
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
import seaborn as sns
import torchvision.models as models


In [3]:
IMG_SIZE = (96, 96)
CHANNELS = 1
Z_DIM = 128  # dimensão do espaço latente
NUM_CLASSES = C_DIM = 2   # dimensão do vetor de condição (rótulos)
NUM_EPOCHS = 10000
LEARNING_RATE = 0.0002
B1, B2 = 0.5, 0.999
BATCH_SIZE = 64
N_CRITIC = 5 # num vezes discriminator é treinado por iteração do generator
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

EXP_NAME = 'v1-acgan'
now = datetime.datetime.now()
TIMESTAMP = now.strftime("%Y-%m-%d")
MODEL_OUT_DIR = Path(f'models/{EXP_NAME}_{TIMESTAMP}')
# MODEL_OUT_DIR.mkdir(parents=True, exist_ok=False)

IMAGE_PATH = Path('../data/all-patches/') # 2139 leish | 2803 no-leish

#### Feature extractor from previous researches (CNN tri-set with different losses)

In [4]:
# class CustomFeatureExtractor(nn.Module):
#     '''
#     Use previously trained CNN feature extractor
#     '''
#     def __init__(self, pretrained_model_path):
#         super(CustomFeatureExtractor, self).__init__()
#         self.pretrained_model = CustomNet(Z_DIM, (1, *IMG_SIZE)).to(DEVICE)
#         self.pretrained_model.load_state_dict(torch.load(pretrained_model_path))

#         for param in self.pretrained_model.parameters():
#             param.requires_grad = False

#         self.pretrained_model.fc1 = nn.Identity() # remove a fc e a operação de achatamento

#     def forward(self, x):
#         # passa x apenas pelas convlayers e de pooling
#         x = self.pretrained_model.conv1(x)
#         x = F.relu(x)
#         x = self.pretrained_model.conv2(x)
#         x = F.relu(x)
#         x = F.max_pool2d(x, 2)
#         x = self.pretrained_model.dropout1(x)
#         return x

#### Class for residual blocks

In [5]:
# class GeneratorResidualBlock(nn.Module):
#     def __init__(self, channels, feature_extractor):
#         super(GeneratorResidualBlock, self).__init__()
#         self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
#         self.bn1 = nn.BatchNorm2d(channels)
#         self.leaky_relu = nn.LeakyReLU(negative_slope=0.01, inplace=True)

#         self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
#         self.bn2 = nn.BatchNorm2d(channels)

#         self.feature_extractor = feature_extractor
#         self.feature_extractor.eval()

#     def forward(self, x):
#         with torch.no_grad():
#             extracted_features = self.feature_extractor(x)

#         residual = x
#         out = self.leaky_relu(self.bn1(self.conv1(x)))
#         out = self.bn2(self.conv2(out))
#         out += residual + extracted_features
#         out = self.leaky_relu(out)
#         return out

class DiscriminatorResidualBlock(nn.Module):
    def __init__(self, channels):
        super(DiscriminatorResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.01, inplace=True)

        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        out = self.leaky_relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        out = self.leaky_relu(out)
        return out

#### Class for generator and discriminator

In [6]:
class Generator(nn.Module):
    def __init__(self, z_dim, c_dim, img_size, channels):
        super(Generator, self).__init__()
        self.img_size = img_size
        self.img_channels = channels

        self.fc = nn.Linear(z_dim + c_dim, 128 * (self.img_size[0] // 4) * (self.img_size[1] // 4))
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv1 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, channels, kernel_size=3, stride=1, padding=1)
        self.res_blocks = nn.Sequential(
            DiscriminatorResidualBlock(64),
            DiscriminatorResidualBlock(64),
            DiscriminatorResidualBlock(64)
        )
        self.bn1 = nn.BatchNorm2d(128)
        self.bn2 = nn.BatchNorm2d(64)
        self.tanh = nn.Tanh()

    def forward(self, z, c):
        y_one_hot = torch.nn.functional.one_hot(c, num_classes=2).float()  # [batch_size, c_dim]
        y_one_hot = y_one_hot.view(64, 2, 1, 1)
        x = torch.cat([z, y_one_hot], 1)
        x = x.view(64, -1)
        x = self.fc(x)
        x = x.view(-1, 128, self.img_size[0]//4, self.img_size[1]//4)
        x = self.upsample(x)
        x = F.leaky_relu(self.bn1(self.conv1(x)))
        x = self.upsample(x)
        x = F.leaky_relu(self.bn2(self.conv2(x)))
        x = self.res_blocks(x)
        x = self.tanh(self.conv3(x))
        return x

class Discriminator(nn.Module):
    def __init__(self, channels, num_classes):
        super(Discriminator, self).__init__()

        self.conv1 = nn.Conv2d(channels, 64, kernel_size=3, stride=2, padding=1)
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.01)
        self.dropout = nn.Dropout(0.3)
        self.res_blocks = nn.Sequential(
            DiscriminatorResidualBlock(64),
            DiscriminatorResidualBlock(64),
            DiscriminatorResidualBlock(64)
        )
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
        self.conv5 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1)
        self.conv6 = nn.Conv2d(1024, 2048, kernel_size=3, stride=2, padding=1)

        self.fc_real_fake = nn.Linear(2048 * 2 * 2, 1)
        self.fc_clf = nn.Linear(2048 * 2 * 2, 1)

    def forward(self, x):
        x = self.dropout(self.leaky_relu(self.conv1(x)))
        x = self.res_blocks(x)
        x = self.dropout(F.leaky_relu(self.conv2(x)))
        x = self.dropout(F.leaky_relu(self.conv3(x)))
        x = self.dropout(F.leaky_relu(self.conv4(x)))
        x = self.dropout(F.leaky_relu(self.conv5(x)))
        x = self.dropout(F.leaky_relu(self.conv6(x)))
        x = x.view(-1, 2048 * 2 * 2)
        real_fake_logits = self.fc_real_fake(x)
        clf_logits = self.fc_clf(x)

        return real_fake_logits, clf_logits

#### Loss functions

In [7]:
def compute_gradient_penalty(discriminator, real_samples, fake_samples):
    '''
    Calcula a penalidade de gradiente para o termo de regularização. Garante que os gradientes não cresçam
    demais.
    '''
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=real_samples.device)
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates, _ = discriminator(interpolates)
    fake = torch.ones(d_interpolates.size(), device=real_samples.device, requires_grad=False)

    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients = gradients.view(gradients.size(0), -1)

    return ((gradients.norm(2, dim=1) - 1) ** 2).mean() # norma L2 dos gradientes

def discriminator_loss(real_images, fake_images, real_labels, discriminator, gp_weight=10):
    real_validity, real_class_logits = discriminator(real_images)
    fake_validity, _ = discriminator(fake_images.detach())

    # Wasserstein = diferença entre as médias das saídas para imagens reais e falsas
    wasserstein_loss = fake_validity.mean() - real_validity.mean()
    gradient_penalty = compute_gradient_penalty(discriminator, real_images.data, fake_images.data)
    clf_loss = F.cross_entropy(real_class_logits, real_labels)

    return wasserstein_loss + gp_weight * gradient_penalty + clf_loss

In [8]:
def generator_loss(discriminator, fake_images, labels, criterion):
    validity, predicted_labels = discriminator(fake_images)
    ls_loss = -torch.mean(validity) # negando para incentivar a maximizar o erro do discrim.
    labels = labels.unsqueeze(1)
    lc_loss = criterion(predicted_labels, labels.float()) # cross entropy loss

    return lc_loss + ls_loss # L(G) = L_c - (- L_s(G)) fórmula 15

#### Data loading and class instanciation

Read data function

In [9]:
def read_data():
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((96,96)),
        transforms.ToTensor(),
    ])
    dataset = datasets.ImageFolder(root=IMAGE_PATH, transform=transform)

    train_size = int(2/3 * len(dataset))
    test_size = len(dataset) - train_size

    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    leish_train = sum(label == 0 for _, label in train_dataset)
    leish_test = sum(label == 0 for _, label in test_dataset)

    print(f'label format = {dataset.class_to_idx}')
    print(f'train test split proportion = train[{len(train_dataset)}], test[{len(test_dataset)}]')
    print(f'leish in training set = {leish_train}')
    print(f'leish in testing set = {leish_test}')

    return train_loader, test_loader

In [10]:
# triplet = CustomFeatureExtractor(
#     pretrained_model_path='./models/v1-ensble_2023-11-07\model_Triplet_v1-ensble_2023-11-07.pth'
# )
# cosface = CustomFeatureExtractor(
#     pretrained_model_path='./models/v1-ensble_2023-11-07\model_CosFace_v1-ensble_2023-11-07.pth'
# )
# multisim = CustomFeatureExtractor(
#     pretrained_model_path='./models/v1-ensble_2023-11-07\model_MultiSimilarity_v1-ensble_2023-11-07.pth'
# )

# feature_extractors = nn.ModuleList([triplet, cosface, multisim])
generator = Generator(Z_DIM, C_DIM, IMG_SIZE, CHANNELS).to(DEVICE)
discriminator = Discriminator(CHANNELS, NUM_CLASSES).to(DEVICE)

GEN_OPTIM = torch.optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(B1, B2))
DISC_OPTIM = torch.optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(B1, B2))
CRITERION = torch.nn.BCELoss()

train_loader, test_loader = read_data()

label format = {'leish': 0, 'no-leish': 1}
train test split proportion = train[3294], test[1648]
leish in training set = 1444
leish in testing set = 695


#### Train function
<ul>
    <li>Iterar sobre as épocas.</li>
    <li>Dentro de cada época, iterar sobre os dados no DataLoader.</li>
    <li>Gerar amostras de ruído, labels, e produzir imagens falsas com o gerador.</li>
    <li>Alimentar imagens reais e falsas no discriminador para obter os scores.</li>
    <li>Calcular a perda do discriminador, incluindo a Wasserstein loss, a penalidade de gradiente e a perda de classificação.</li>
    <li>Atualizar os parâmetros do discriminador.</li>
    <li>Calcular a perda do gerador e atualizar seus parâmetros.</li>
    <li>Registrar as perdas para acompanhamento e salvar as imagens geradas periodicamente.</li>
</ul>

In [11]:
def train(
        generator,
        discriminator,
        data_loader,
        gen_optim=GEN_OPTIM,
        disc_optim=DISC_OPTIM,
        criterion=CRITERION,
        z_dim=Z_DIM,
        c_dim=C_DIM,
        num_epochs=NUM_EPOCHS,
        device=DEVICE,
        save_image_interval=1000):
    gen_losses, disc_losses = [], []
    size_dl = len(data_loader)
    generator.train()
    discriminator.train()

    for epoch in tqdm(range(num_epochs), desc='EPOCHS'):
        for i, (real_images, labels) in enumerate(data_loader):
            real_images, labels = real_images.to(device), labels.to(device)
            batch_size = real_images.size(0)

            z = torch.randn(batch_size, z_dim, 1, 1, device=device)
            fake_labels = torch.randint(0, c_dim, (batch_size,), device=device)

            fake_images = generator(z, fake_labels)

            disc_optim.zero_grad()
            d_loss = discriminator_loss(real_images, fake_images, labels, discriminator)
            d_loss.backward()
            disc_optim.step()
            disc_losses.append(d_loss.item())

            gen_optim.zero_grad()
            g_loss = generator_loss(discriminator, fake_images, fake_labels, criterion)
            g_loss.backward()
            gen_optim.step()
            gen_losses.append(g_loss.item())

            if i % 50 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}] : Step [{i+1}/{size_dl}] '
                      f'>> discriminator loss = {d_loss.item()}'
                      f'| generator loss = {g_loss.item()}')

            if epoch % save_image_interval==0 and i == size_dl-1:
                with torch.no_grad():
                    generator.eval()
                    test_z = torch.randn(5, z_dim, 1, 1, device=device)
                    test_labels = torch.randint(0, c_dim, (5,), device=device)
                    test_images = generator(test_z, test_labels)
                    test_pred_labels = discriminator(test_images)[1]
                    generator.train()

                    for j, image in enumerate(test_images):
                        predicted_classes = torch.argmax(test_pred_labels, dim=1)
                        save_image(image, f'./generated_images/label_{predicted_classes}_epoch_{epoch}_image_{j}.png', normalize=True)

    torch.save(generator.state_dict(), MODEL_OUT_DIR.joinpath(f'generator_{EXP_NAME}_{TIMESTAMP}.pth'))
    torch.save(discriminator.state_dict(), MODEL_OUT_DIR.joinpath(f'discriminator_{EXP_NAME}_{TIMESTAMP}.pth'))

    return gen_losses, disc_losses

Start trainning

In [12]:
gen_losses, disc_losses = train(generator, discriminator, train_loader)

EPOCHS:   0%|          | 0/10000 [00:06<?, ?it/s]


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


#### Evaluation

Generated image quality: visual evaluation, analize generated_images folder

Classification quality:

In [None]:
def evaluate_acgan(discriminator, dataloader, device):
    discriminator.eval()
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            _, class_logits = discriminator(images)
            preds = torch.argmax(class_logits, dim=1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    # accuracy = accuracy_score(all_labels, all_preds)
    # precision = precision_score(all_labels, all_preds)
    # recall = recall_score(all_labels, all_preds)
    # f1 = f1_score(all_labels, all_preds)
    conf_matrix = confusion_matrix(all_labels, all_preds)
    class_report = classification_report(all_labels, all_preds, target_names=['Leish', 'No-leish'])

    # print(f"Acurácia: {accuracy}")
    # print(f"Precisão: {precision}")
    # print(f"Recall: {recall}")
    # print(f"F1-Score: {f1}")
    print(class_report)

    sns.heatmap(conf_matrix, annot=True, fmt='d')
    plt.xlabel('Predições')
    plt.ylabel('Verdadeiros')
    plt.show()

    discriminator.train()

In [None]:
evaluate_acgan(discriminator, dataloader, DEVICE)