In [1]:
import math
import torch


class Lambda(torch.nn.Module):
    def __init__(self, lam):
        super(Lambda, self).__init__()
        self.lam = lam

    def forward(self, x):
        return self.lam(x)

class Generator(torch.nn.Module):
    def __init__(self, resize, in_channels, hidden_size):
        super(Generator, self).__init__()
        self.features = torch.nn.Sequential(
            torch.nn.Linear(100, hidden_size * 8),
            Lambda(lambda x: x[:, :, None, None]),
            torch.nn.ConvTranspose2d(hidden_size * 8, hidden_size * 4, 4, 1, 0, bias=False),
            torch.nn.BatchNorm2d(hidden_size * 4),
            torch.nn.ReLU(True),
            torch.nn.ConvTranspose2d(hidden_size * 4, hidden_size * 2, 4, 2, 1, bias=False),
            torch.nn.BatchNorm2d(hidden_size * 2),
            torch.nn.ReLU(True),
            torch.nn.ConvTranspose2d(hidden_size * 2, hidden_size, 4, 2, 1, bias=False),
            torch.nn.BatchNorm2d(hidden_size),
            torch.nn.ReLU(True),
            torch.nn.ConvTranspose2d(hidden_size, in_channels, 4, 2, 1, bias=False),
            torch.nn.Tanh(),
        )

    def forward(self, x):
        x = torch.randn(x.size(0), 100).to(x.device)
        x = self.features(x)
        return x


class Discriminator(torch.nn.Module):
    def __init__(self, resize, in_channels, hidden_size, num_classes):
        super(Discriminator, self).__init__()
        proj_size = resize**2
        self.features = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, hidden_size, 3, 2, 1, bias=False),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Dropout(0.5, inplace=False),
            torch.nn.Conv2d(hidden_size, hidden_size * 2, 3, 2, 1, bias=False),
            torch.nn.BatchNorm2d(hidden_size * 2),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Dropout(0.5, inplace=False),
            torch.nn.Conv2d(hidden_size * 2, hidden_size * 4, 3, 2, 1, bias=False),
            torch.nn.BatchNorm2d(hidden_size * 4),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Dropout(0.5, inplace=False),
            torch.nn.Flatten(),
            torch.nn.Linear(proj_size * hidden_size * 4, 1 + num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        gen_disc, cls_disc = x[:, :1], x[:, 1:]
        return torch.sigmoid(gen_disc), torch.softmax(cls_disc, dim=-1)

class ACGAN(torch.nn.Module):
    def __init__(self, resize, in_channels, hidden_size, num_classes):
        super(ACGAN, self).__init__()
        self.generator = Generator(resize, in_channels, hidden_size)
        self.discriminator = Discriminator(resize, in_channels, hidden_size, num_classes)

    def forward(self, x, for_D=True, return_fake=False):
        x_fake = self.generator(x)
        if for_D:
            disc_res, clf_res = self.discriminator(
                torch.cat([x_fake.detach(), x], dim=0)
            )
            disc_fake, disc_real = disc_res[:x_fake.size(0)], disc_res[x_fake.size(0):]
            clf_fake, clf_real = clf_res[:x_fake.size(0)], clf_res[x_fake.size(0):]
            return disc_fake, disc_real, clf_fake, clf_real, x_fake.detach() if return_fake else None
        else:
            disc_fake, clf_fake = self.discriminator(x_fake)
            return disc_fake, clf_fake, x_fake.detach() if return_fake else None

In [6]:
Discriminator(32, 3, 16, 10)

In [2]:
torch.cuda.is_available()

True

In [11]:
import torch
from pytorch_fid.inception import InceptionV3

In [17]:
from pytorch_fid.fid_score import calculate_frechet_distance, calculate_activation_statistics