In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np
import os
import math

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
plt.switch_backend('agg')
plt.style.use('ggplot')

# Define utility functions and classes
class Prior(object):
    def __init__(self, prior_type):
        self.type = prior_type

    def sample(self, shape):
        if self.type == "uniform":
            return torch.rand(shape) * 2 - 1  # Uniform between -1 and 1
        else:
            return torch.randn(shape)

def conv_out_size_same(size, stride):
    return int(math.ceil(float(size) / float(stride)))

def create_image_grid(x, img_size, tile_shape):
    # x is a tensor of shape [N, C, H, W], values in [0, 1]
    assert (x.size(0) == tile_shape[0] * tile_shape[1])
    x = x.permute(0, 2, 3, 1)  # N, H, W, C
    x = x.cpu().numpy()
    if img_size[2] == 1:
        x = x.squeeze(-1)  # Remove last channel dimension for grayscale
        img = np.zeros((img_size[0] * tile_shape[0] + tile_shape[0] - 1,
                        img_size[1] * tile_shape[1] + tile_shape[1] - 1))
    else:
        img = np.zeros((img_size[0] * tile_shape[0] + tile_shape[0] - 1,
                        img_size[1] * tile_shape[1] + tile_shape[1] - 1,
                        3))

    for t in range(x.shape[0]):
        i, j = t // tile_shape[1], t % tile_shape[1]
        img_i_start = i * img_size[0] + i
        img_i_end = (i + 1) * img_size[0] + i
        img_j_start = j * img_size[1] + j
        img_j_end = (j + 1) * img_size[1] + j
        img[img_i_start:img_i_end, img_j_start:img_j_end] = x[t]

    return img

# Define the Generator class
class Generator(nn.Module):
    def __init__(self, num_z, num_gen_feature_maps, img_size, num_conv_layers, num_gens, g_batch_size):
        super(Generator, self).__init__()
        self.num_z = num_z
        self.num_gen_feature_maps = num_gen_feature_maps
        self.img_size = img_size
        self.num_conv_layers = num_conv_layers
        self.num_gens = num_gens
        self.g_batch_size = g_batch_size

        # Compute the size after the first linear layer
        out_size = [(int(img_size[0] / (2 ** num_conv_layers)),
                     int(img_size[1] / (2 ** num_conv_layers)),
                     num_gen_feature_maps * (2 ** (num_conv_layers - 1)))]
        for i in range(1, num_conv_layers):
            out_size = [(out_size[0][0] * 2,
                         out_size[0][1] * 2,
                         out_size[0][2] // 2)] + out_size

        self.out_size = out_size

        # Create linear layers for each generator
        self.linears = nn.ModuleList()
        for i in range(num_gens):
            self.linears.append(
                nn.Sequential(
                    nn.Linear(num_z, out_size[0][0] * out_size[0][1] * out_size[0][2]),
                    nn.BatchNorm1d(out_size[0][0] * out_size[0][1] * out_size[0][2]),
                    nn.ReLU(True)
                )
            )

        # Deconvolutional layers
        deconv_modules = []
        in_channels = out_size[0][2]
        for i in range(num_conv_layers):
            if i < num_conv_layers - 1:
                out_channels = out_size[i+1][2]
                deconv_modules.append(
                    nn.Sequential(
                        nn.ConvTranspose2d(in_channels, out_channels, kernel_size=5, stride=2, padding=2, output_padding=1),
                        nn.BatchNorm2d(out_channels),
                        nn.ReLU(True)
                    )
                )
                in_channels = out_channels
            else:
                deconv_modules.append(
                    nn.Sequential(
                        nn.ConvTranspose2d(in_channels, img_size[2], kernel_size=5, stride=2, padding=2, output_padding=1),
                        nn.Tanh()
                    )
                )

        self.deconvs = nn.Sequential(*deconv_modules)

    def forward(self, z):
        z_split = torch.split(z, self.g_batch_size, dim=0)
        h0_list = []
        for i in range(self.num_gens):
            h0 = self.linears[i](z_split[i])
            h0 = h0.view(self.g_batch_size, self.out_size[0][2], self.out_size[0][0], self.out_size[0][1])
            h0_list.append(h0)
        h0 = torch.cat(h0_list, dim=0)
        output = self.deconvs(h0)
        return output

# Define the Discriminator class
class Discriminator(nn.Module):
    def __init__(self, num_dis_feature_maps, img_size, num_conv_layers, num_gens):
        super(Discriminator, self).__init__()
        self.num_dis_feature_maps = num_dis_feature_maps
        self.img_size = img_size
        self.num_conv_layers = num_conv_layers
        self.num_gens = num_gens

        # Build the discriminator network
        modules = []
        in_channels = img_size[2]
        feature_map_sizes = []
        for i in range(num_conv_layers):
            out_channels = num_dis_feature_maps * (2 ** i)
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=2, padding=2),
                    nn.BatchNorm2d(out_channels) if i != 0 else nn.Identity(),
                    nn.LeakyReLU(0.2, inplace=True)
                )
            )
            in_channels = out_channels

            # Compute output size
            if i == 0:
                h_out = conv_out_size_same(self.img_size[0], stride=2)
            else:
                h_out = conv_out_size_same(h_out, stride=2)
            feature_map_sizes.append(h_out)

        self.convs = nn.Sequential(*modules)

        # Manually compute the fc_input_dim
        final_feature_map_size = h_out
        self.fc_input_dim = in_channels * final_feature_map_size * final_feature_map_size

        # Output layers
        self.fc = nn.Linear(self.fc_input_dim, 1)
        self.fc_multi = nn.Linear(self.fc_input_dim, num_gens)

    def forward(self, x):
        h = self.convs(x)
        h = h.view(h.size(0), -1)
        d_bin_logits = self.fc(h)
        d_mul_logits = self.fc_multi(h)
        return d_bin_logits, d_mul_logits

# Define the MGAN class
class MGAN(object):
    """Mixture Generative Adversarial Nets implemented in PyTorch"""
    def __init__(self,
                 num_z=128,
                 beta=1.0,
                 num_gens=4,
                 d_batch_size=64,
                 g_batch_size=32,
                 z_prior="uniform",
                 learning_rate=0.0002,
                 img_size=(32, 32, 1),  # Changed to (32, 32, 1) for MNIST
                 num_conv_layers=3,
                 num_gen_feature_maps=128,
                 num_dis_feature_maps=128,
                 num_epochs=25000,
                 sample_fp=None,
                 sample_by_gen_fp=None,
                 random_seed=6789):

        self.beta = beta
        self.num_z = num_z
        self.num_gens = num_gens
        self.d_batch_size = d_batch_size
        self.g_batch_size = g_batch_size
        self.z_prior = Prior(z_prior)
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs
        self.img_size = img_size
        self.num_conv_layers = num_conv_layers
        self.num_gen_feature_maps = num_gen_feature_maps
        self.num_dis_feature_maps = num_dis_feature_maps
        self.sample_fp = sample_fp
        self.sample_by_gen_fp = sample_by_gen_fp
        self.random_seed = random_seed

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self._init()

    def _init(self):
        self.epoch = 0

        # Initialize Generator and Discriminator
        self.G = Generator(self.num_z, self.num_gen_feature_maps, self.img_size,
                           self.num_conv_layers, self.num_gens, self.g_batch_size).to(self.device)
        self.D = Discriminator(self.num_dis_feature_maps, self.img_size,
                               self.num_conv_layers, self.num_gens).to(self.device)

        # Optimizers
        self.optimizerD = optim.Adam(self.D.parameters(), lr=self.learning_rate, betas=(0.5, 0.999))
        self.optimizerG = optim.Adam(self.G.parameters(), lr=self.learning_rate, betas=(0.5, 0.999))

        # Loss functions
        self.criterion = nn.BCEWithLogitsLoss()
        self.criterion_multi = nn.CrossEntropyLoss()

    def fit(self, trainloader):
        real_label = 1
        fake_label = 0

        for epoch in range(self.num_epochs):
            for i, data in enumerate(trainloader, 0):
                ############################
                # (1) Update D network
                ###########################
                self.D.zero_grad()
                real_cpu = data[0].to(self.device)
                batch_size = real_cpu.size(0)
                label = torch.full((batch_size,), real_label, dtype=torch.float, device=self.device)

                d_bin_x_logits, d_mul_x_logits = self.D(real_cpu)
                d_bin_x_loss = self.criterion(d_bin_x_logits.view(-1), label)

                # Generate fake data
                if self.z_prior.type == "uniform":
                    z = torch.rand(self.g_batch_size * self.num_gens, self.num_z, device=self.device) * 2 - 1
                else:
                    z = torch.randn(self.g_batch_size * self.num_gens, self.num_z, device=self.device)

                fake = self.G(z)
                label_fake = torch.full((fake.size(0),), fake_label, dtype=torch.float, device=self.device)
                d_bin_g_logits, d_mul_g_logits = self.D(fake.detach())
                d_bin_g_loss = self.criterion(d_bin_g_logits.view(-1), label_fake)

                # Binary loss
                d_bin_loss = d_bin_x_loss + d_bin_g_loss

                # Multiclass loss
                arr = np.array([i // self.g_batch_size for i in range(self.g_batch_size * self.num_gens)])
                d_mul_labels = torch.from_numpy(arr).long().to(self.device)
                d_mul_loss = self.criterion_multi(d_mul_g_logits, d_mul_labels)

                # Total discriminator loss
                d_loss = d_bin_loss + d_mul_loss

                d_loss.backward()
                self.optimizerD.step()

                ############################
                # (2) Update G network
                ###########################
                self.G.zero_grad()
                label.fill_(real_label)
                d_bin_g_logits, d_mul_g_logits = self.D(fake)
                g_bin_loss = self.criterion(d_bin_g_logits.view(-1), label_fake.fill_(real_label))
                g_mul_loss = self.beta * self.criterion_multi(d_mul_g_logits, d_mul_labels)

                g_loss = g_bin_loss + g_mul_loss
                g_loss.backward()
                self.optimizerG.step()

            print('[%d/%d] D_loss: %.4f G_loss: %.4f' % (epoch+1, self.num_epochs, d_loss.item(), g_loss.item()))

            # Save samples
            if (epoch+1) % 10 == 0:
                self._samples(self.sample_fp.format(epoch=epoch+1))
                self._samples_by_gen(self.sample_by_gen_fp.format(epoch=epoch+1))

    def _generate(self, num_samples=100):
        with torch.no_grad():
            batch_size = self.g_batch_size * self.num_gens
            num = ((num_samples - 1) // batch_size + 1) * batch_size
            if self.z_prior.type == "uniform":
                z = torch.rand(num, self.num_z, device=self.device) * 2 - 1
            else:
                z = torch.randn(num, self.num_z, device=self.device)
            x = []
            for i in range(0, num, batch_size):
                z_batch = z[i:i+batch_size]
                x_batch = self.G(z_batch).cpu()
                x.append(x_batch)
            x = torch.cat(x, dim=0)
            idx = np.random.permutation(num)[:num_samples]
            x = x[idx]
            x = (x + 1) / 2
            return x

    def _samples(self, filepath, tile_shape=(10, 10)):
        if not os.path.exists(os.path.dirname(filepath)):
            os.makedirs(os.path.dirname(filepath))

        num_samples = tile_shape[0] * tile_shape[1]
        x = self._generate(num_samples)
        imgs = create_image_grid(x, img_size=self.img_size, tile_shape=tile_shape)
        if self.img_size[2] == 1:
            plt.imsave(filepath, imgs, cmap='gray')
        else:
            plt.imsave(filepath, imgs)

    def _samples_by_gen(self, filepath):
        if not os.path.exists(os.path.dirname(filepath)):
            os.makedirs(os.path.dirname(filepath))

        num_samples = self.num_gens * 10
        tile_shape = (self.num_gens, 10)

        x = []
        for _ in range(10):
            if self.z_prior.type == "uniform":
                z = torch.rand(self.g_batch_size * self.num_gens, self.num_z, device=self.device) * 2 - 1
            else:
                z = torch.randn(self.g_batch_size * self.num_z, device=self.device)
            x_batch = self.G(z).cpu()
            x.append(x_batch)
        x = torch.cat(x, dim=0)
        x = (x + 1) / 2

        imgs = create_image_grid(x, img_size=self.img_size, tile_shape=tile_shape)
        if self.img_size[2] == 1:
            plt.imsave(filepath, imgs, cmap='gray')
        else:
            plt.imsave(filepath, imgs)

# Main function to run the MGAN model
if __name__ == '__main__':
    # Set parameters
    num_z = 100
    beta = 0.01
    num_gens = 10
    d_batch_size = 64
    g_batch_size = 12
    z_prior = "uniform"
    learning_rate = 0.0002
    num_conv_layers = 3
    num_gen_feature_maps = 128
    num_dis_feature_maps = 128
    num_epochs = 50  # Reduced for testing purposes
    sample_fp = "samples/samples_{epoch:04d}.png"
    sample_by_gen_fp = "samples_by_gen/samples_{epoch:04d}.png"

    # Load data
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                          download=True, transform=transform)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=d_batch_size,
                                              shuffle=True, num_workers=2)

    # Initialize model
    model = MGAN(
        num_z=num_z,
        beta=beta,
        num_gens=num_gens,
        d_batch_size=d_batch_size,
        g_batch_size=g_batch_size,
        z_prior=z_prior,
        learning_rate=learning_rate,
        img_size=(32, 32, 1),  # Updated for MNIST
        num_conv_layers=num_conv_layers,
        num_gen_feature_maps=num_gen_feature_maps,
        num_dis_feature_maps=num_dis_feature_maps,
        num_epochs=num_epochs,
        sample_fp=sample_fp,
        sample_by_gen_fp=sample_by_gen_fp,
        random_seed=6789)

    # Fit model
    model.fit(trainloader)


RuntimeError: mat1 and mat2 shapes cannot be multiplied (120x131072 and 8192x1)