In [None]:
class Upsample(nn.Module):
    def __init__(self, scale_factor):
        super(Upsample, self).__init__()
        self.scale_factor = scale_factor
        
    def forward(self, x):
        return F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)

In [None]:
class Generator(nn.Module):
    def __init__(self, classes, channels, img_size, latent_dim, code_dim):
        super(Generator, self).__init__()
        self.classes = classes
        self.channels = channels
        self.img_size = img_size
        self.img_init_size = self.img_size // 4
        self.latent_dim = latent_dim
        self.code_dim = code_dim
        self.img_init_shape = (128, self.img_init_size, self.img_init_size)
        self.img_shape = (self.channels, self.img_size, self.img_size)
        self.stem_linear = nn.Sequential(
            nn.Linear(latent_dim+classes+code_dim, int(np.prod(self.img_init_size)))
        )
        self.model = nn.Sequential(
            nn.BatchNorm2d(128),
            *self._create_deconv_layer(128, 128, upsample=True),
            *self._create_deconv_layer(128, 64, upsample=True),
            *self._create_deconv_layer(64, self.channels, upsample=False, normalize=False),
            nn.Tanh()
        )
        
    def _create_deconv_layer(self, size_in, size_out, upsample=True, normalize=True):
        layers = []
        if upsample:
            layers.append(Upsample(scale_factor=2))
        layers.append(nn.Conv2d(size_in, size_out, 3, stride=1, padding=1))
        if normalize:
            layers.append(nn.BatchNorm2d(size_out, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers
    
    def forward(self, noise, labels, code):
        z = torch.cat((noise, labels, code), -1)
        z_vec = self.stem_linear(z)
        z_img = z_vec.view(z_vec.shape[0], *self.img_init_shape)
        x = self.model(z_img)
        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self, classes, channels, img_size, latent_dim, code_dim):
        super(Discriminator, self).__init__()
        self.classes = classes
        self.channels = channels
        self.img_size = img_size
        self.latent_dim = latent_dim
        self.code_dim = code_dim
        self.img_shape = (self.channels, self.img_size, self.img_size)
        self.model = nn.Sequential(
            *self._create_conv_layer(self.channels, 16, True, False),
            *self._create_conv_layer(16, 32, True, True),
            *self._create_conv_layer(32, 64, True, True),
            *self._create_conv_layer(64, 128, True, True),
        )
        
        out_linear_dim = 128 * (self.img_size // 16) * (self.img_size // 16)
        self.adv_linear = nn.Linear(out_linear_dim, 1)
        self.class_linear = nn.Sequential(
            nn.Linear(out_linear_dim, 128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, self.classes)
        )
        self.code_linear = nn.Sequential(
            nn.Linear(out_linear_dim, 128), 
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, self.code_dim)
        )
        self.adv_loss = torch.nn.MSELoss()
        self.class_loss = torch.nn.CrossEntropyLoss()
        self.style_loss = torch.nn.MSELoss()
        
    def _create_conv_layer(self, size_in, size_out, drop_out=True, normalize=True):
        layers = [nn.Conv2d(size_in, size_out, 3, 2, 1)]
        if drop_out:
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            layers.append(nn.Dropout(0.4))
        if normalize:
            layers.append(nn.BatchNorm2d(size_out, 0.8))
        return layers
            
    def forward(self, image):
        y_img = self.model(image)
        y_vec = y_img.view(y_img.shape[0], -1)
        y = self.adv_linear(y_vec)
        label = F.softmax(self.class_linear(y_vec), dim=1)
        code = self.code_linear(y_vec)
        return y, label, code

In [None]:
def _weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [None]:
import itertools

In [None]:
for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(self.data_loader):
        data, target = data.to(self.device), target.to(self.device)
        batch_size = data.size(0)
        real_label = torch.full((batch_size, 1), 1., device=self.device)
        fake_label = torch.full((batch_size, 1), 0., device=self.device)
        
        # train G
        self.netG.zero_grad()
        z_noise = torch.randn(batch_size, self.latent_dim, device=self.device)
        x_fake_labels = torch.randint(0, self.classes, (batch_size,), device=self.device)
        labels_onehot = self._to_onehot(x_fake_labels, dim=self.classes)
        z_style = torch.zeros((batch_size, self.style_dim), device=self.device).normal()
        x_fake = self.netG(z_noise, labels_onehot, z_style)
        y_fake_g, _, _ = self.netD(x_fake)
        g_loss = self.netD.adv_loss(y_fake_g, real_label)
        g_loss.backward()
        self.optim_G.step()
        
        # train D
        self.netD.zero_grad()
        y_real, _, _ = self.netD(data)
        d_real_loss = self.netD.adv_loss(y_real, real_label)
        y_fake_d, _, _ = self.netD(x_fake.detach())
        d_fake_loss = self.netD.adv_loss(y_fake_d, fake_label)
        d_loss = (d_real_loss + d_fake_loss) /2
        d_loss.backward()
        self.optim_D.step()
        
        # update mutualinformation
        self.optim_info.zero_grad()
        z_noise.normal_()
        x_fake_labels = torch.randint(0, self.classes, (batch_size,), device=self.device)
        labels_onehot = self._to_onehot(x_fake_labels, dim=self.classes)
        z_style.normal_()
        x_fake = self.netG(z_noise, labels_onehot, z_style)
        _, label_fake, style_fake = self.netD(x_fake)
        info_loss = self.netD.class_loss(label_fake, x_fake_labels) + self.netD.style_loss(style_fake, z_style)
        info_loss.backward()
        self.optim_info.step()