In [1]:
import torch
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import confusion_matrix
from torch.optim import Adam, lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import random_split, DataLoader

from config import Config
from caae import *
from load_dataset import *

In [2]:
writer = SummaryWriter()

In [3]:
def gradient_penalty(real_samples, g_samples, discriminator):
    """
    calculates the gradient penalty term used in the Wasserstein GAN with Gradient Penalty (WGAN-GP) loss
    :param real_samples: Real data samples from the dataset
    :param g_samples: Generated samples produced by the generator
    :param discriminator: The discriminator network
    :return: The calculated gradient penalty is returned as the output of the function
    """
    batch_size = real_samples.size(0)
    alpha = torch.rand(batch_size, 1)
    alpha = alpha.expand(batch_size, real_samples.size(1))  # Make alpha the same size as real_samples
    interpolates = alpha * real_samples + ((1 - alpha) * g_samples)
    # a linear combination of real_samples and g_samples using the alpha factor
    # represents the points between real and generated data in the input space

    interpolates.requires_grad_(True)  # Enable gradient computation
    d_interpolates = discriminator(interpolates)

    gradients = torch.autograd.grad(outputs=d_interpolates, inputs=interpolates,
                                    grad_outputs=torch.ones(d_interpolates.size()),
                                    create_graph=True, retain_graph=True, only_inputs=True)[0]

    slopes = torch.sqrt(torch.sum(gradients ** 2, dim=1))
    g_penalty = torch.mean((slopes - 1.) ** 2)

    return g_penalty

In [4]:
def evaluate(y_true, y_pred, epoch):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    fnr = fn / (tp + fn)
    err = (fn + fp) / (tp + tn + fp + fn)
    precision = tp / (tp + fp)
    recall = 1 - fnr
    f1score = (2 * precision * recall) / (precision + recall)

    print('---Test results-------')

    print(tp, fn)
    print(fp, tn)
    print('False negative rate: ', fnr)
    print('Error rate: ', err)
    print('Precision: ', precision)
    print('Recall: ', recall)
    print('F1 score: ', f1score)

    writer.add_scalar('test/False_negative_rate', fnr, epoch)
    writer.add_scalar('test/Error_rate', err, epoch)
    writer.add_scalar('test/Precision', precision, epoch)
    writer.add_scalar('test/Recall', recall, epoch)
    writer.add_scalar('test/F1_score', f1score, epoch)

In [None]:
encoder = Encoder()
decoder = Decoder()
discriminator_g = Discriminator('g')
discriminator_c = Discriminator('c')