In [157]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import time
import os
#import PIL
import glob

#import imageio
from torch.autograd import Variable
from torchvision.utils import make_grid,save_image
from IPython import display
from torchvision import datasets, transforms
device = 'cuda' if torch.cuda.is_available() else 'cpu'
img_width = 28
img_height = 28
img_channels = 1
num_classes = 10
noise_dim = 100
clip_value = 0.01
train_ratio = 5
batch_size = 128
learning_rate = 0.0005

In [147]:
class Generator(nn.Module) :
    
    def __init__(self):
        super(Generator, self).__init__()

        self.dense1 = nn.Sequential(nn.Linear(noise_dim + num_classes, 4*4*512),
                                    nn.BatchNorm1d(4*4*512), nn.LeakyReLU(0.2))
        self.conv_trans1 = nn.Sequential(nn.ConvTranspose2d(512, 256, (3,3), stride=(2,2), padding=1),
                                  nn.BatchNorm2d(256),
                                  nn.LeakyReLU(0.2),)
        self.conv_trans2 = nn.Sequential(nn.ConvTranspose2d(256, 128, (4,4), stride=(2,2), padding=1),
                                  nn.BatchNorm2d(128),
                                  nn.LeakyReLU(0.2),)
        self.conv_trans3 = nn.Sequential(nn.ConvTranspose2d(128, 1, (4,4), stride=(2,2), padding=1),
                                        nn.Tanh(),)
#         noise = Input(shape=(self.noise_dim,))
#         labels = Input(shape=(self.num_classes,))
#         img_generated = self.generator([noise, labels])
        
#         self.discriminator.trainable = False
        
#         gan_output = self.discriminator([img_generated, labels])
#         self.gan_model = Model(inputs=[noise, labels], outputs=gan_output)
#         self.gan_model.compile(optimizer=RMSprop(lr=0.0001), loss=self.wasserstein_loss)
    def forward(self, input_noise, labels):
        labels = labels.float()
        x_ = self.dense1(torch.cat((input_noise, labels),1))
        x_ = self.conv_trans1(x_.view(x_.shape[0], 512, 4, 4))
        x_ = self.conv_trans2(x_)
        x_ = self.conv_trans3(x_)
        return x_
        
        
        

class Discriminator(nn.Module) :
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.conv1 = nn.Sequential(nn.Conv2d(1, 64, (4,4), stride=(2,2), padding=1),
                                  nn.LeakyReLU(0.2),
                                  nn.Dropout2d(0.3),)
        self.conv2 = nn.Sequential(nn.Conv2d(64, 128, (3,3), stride=(2,2), padding=1),
                                   nn.BatchNorm2d(128),
                                  nn.LeakyReLU(0.2),
                                  nn.Dropout(0.3),)
        self.dense1 = nn.Sequential(nn.Linear(6282, 256),
                                   nn.LeakyReLU(0.2),
                                   nn.Linear(256,1),)
        
    
    def forward(self, input_images, labels):
        labels = labels.float()
        x_ = self.conv1(input_images)
       # print(x_.shape)
        x_ = self.conv2(x_)
        #print(x_.shape)
        x_ = x_.view(x_.shape[0],-1)
        #print(x_.shape)
        x_ = torch.cat((x_, labels), 1)
        #print(x_.shape)
        x_ = self.dense1(x_)
        return x_
        
    
   

In [148]:
def generate_random_noise(noise_size, batch_size) :
    return torch.randn(batch_size, noise_size)

def wasserstein_loss(pred, truth) :
    return torch.mean(truth * pred)

def clip_gradients(model, clip_limit):
    for p in model.parameters():
                    p.data.clamp_(-clip_limit, clip_limit)
            

def get_sample_image(G, n_noise=100):
    """
        save sample 100 images
    """
    img = np.zeros([280, 280])
    for j in range(10):
        c = torch.zeros([10, 10]).to(device)
        c[:, j] = 1
        z = torch.randn(10, n_noise).to(device)
        y_hat = G(z,c).view(10, 28, 28)
        result = y_hat.cpu().data.numpy()
        img[j*28:(j+1)*28] = np.concatenate([x for x in result], axis=-1)
    return img

In [160]:
discriminator = Discriminator().to(device)
generator = Generator().to(device)

discriminator_optimizer = torch.optim.RMSprop(discriminator.parameters(), lr=learning_rate)
generator_optimizer = torch.optim.RMSprop(generator.parameters(), lr=learning_rate)

transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                              ])
trainset = datasets.MNIST(root='../data/', train=True, download=True, transform = transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True)

true_image = torch.FloatTensor([1])
fake_image = true_image * -1

#dataiter = iter(trainloader)
epochs = 10
for epoch in range(epochs) :
    for index, (images_mini_batch, labels) in enumerate(trainloader) :
        images_mini_batch.to(device)
        true_image.to(device)
        fake_image.to(device)
        labels = F.one_hot(labels, num_classes)
        labels.to(device)
        if(index == 0) :
            print(labels[0].shape)
        
        noise = generate_random_noise(noise_dim, batch_size)
        noise.to(device)
        
        discriminator_optimizer.zero_grad()
        disc_real_outputs = discriminator(images_mini_batch, labels)
        disc_fake_outputs = discriminator(generator(noise, labels), labels)
        total_disc_loss = wasserstein_loss(disc_real_outputs, true_image) + wasserstein_loss(disc_fake_outputs, fake_image)
        
        total_disc_loss.backward()
        
        discriminator_optimizer.step()
        clip_gradients(discriminator, clip_value)
        generator_optimizer.zero_grad()
        
        if(index % 5 == 0) :
            gen_noise = generate_random_noise(noise_dim, batch_size)
            gen_outputs = generator(gen_noise, labels)
            gen_loss = wasserstein_loss(discriminator(gen_outputs, labels), true_image)
            gen_loss.backward()
            generator_optimizer.step()
    print("Discriminator loss : {} , Generator loss : {}".format(total_disc_loss, gen_loss))
    save_image(gen_outputs.data[0], "./images/%d.png" % epoch, nrow=1, normalize=True)


        
    



torch.Size([10])
Discriminator loss : -0.3004552721977234 , Generator loss : -0.00905001163482666
torch.Size([10])
Discriminator loss : -0.19726793467998505 , Generator loss : 0.13129985332489014
torch.Size([10])
Discriminator loss : -0.15004754066467285 , Generator loss : 0.27197736501693726
torch.Size([10])
Discriminator loss : -0.09402923285961151 , Generator loss : 0.035656217485666275
torch.Size([10])
Discriminator loss : -0.04907269775867462 , Generator loss : -0.05336248129606247
torch.Size([10])
Discriminator loss : -0.059706300497055054 , Generator loss : -0.02963740937411785
torch.Size([10])
Discriminator loss : -0.03410840034484863 , Generator loss : 0.03056771121919155
torch.Size([10])
Discriminator loss : -0.050747767090797424 , Generator loss : -0.18606796860694885
torch.Size([10])
Discriminator loss : -0.00853574275970459 , Generator loss : 0.17838357388973236
torch.Size([10])
Discriminator loss : -0.035808928310871124 , Generator loss : -0.06547760963439941


In [None]:
 def wasserstein_loss(self, pred, truth):
        return K.mean(truth * pred)
    
    def clip_gradients(self, model, clip_value):
        for layer in model.layers :
                    weights = layer.get_weights()
                    weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
                    layer.set_weights(weights)
        return model
    
    # Generates a tensor sampled from a random normal distribution
    def generate_uniform_noise(self, batch_size) :
        return tf.random.normal([batch_size, self.noise_dim],dtype=tf.dtypes.float32)
    
    def one_hot_encode(self, y):
        return tf.reshape(tf.one_hot(y, self.num_classes), (1,self.num_classes))
    
    def generate_img(self, input_noise, input_label, epoch) :
        label_vector = self.one_hot_encode(input_label)
        predictions = self.generator.predict([input_noise,label_vector])
        fig = plt.figure(figsize=(4,4))

        for i in range(predictions.shape[0]):
            plt.subplot(4, 4, i+1)
            plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
            plt.axis('off')
        plt.savefig("model1_img_at_epoch_{}.png".format(epoch))
        plt.show()
        
    def generate_sample_labels(self, batch_size):
        sampled_labels = np.random.randint(0, self.num_classes, batch_size)
        return tf.convert_to_tensor(np.array([self.one_hot_encode(x) for x in sampled_labels]).reshape(-1,self.num_classes,))
    
    def train(self, images, labels, epochs, batch_size) :
        buffer_size = images.shape[0]
        random_fixed_noise = self.generate_uniform_noise(1)
        # Reshape to account for greyscales and normalize RGB to [-1,1] as per GoodFellow 2016
        x_train = images.reshape(images.shape[0], 28, 28, 1).astype('float32')
        x_train = (x_train - 127.5)/127.5
        y_train = to_categorical(labels, self.num_classes)
        x_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(buffer_size).batch(batch_size, drop_remainder=True)
        for t in range(epochs):
            start = time.time()
       #     .take(self.disc_train_count)
            for images_mini_batch, labels in x_train:
                for _ in range(self.disc_train_count):
                    noise = self.generate_uniform_noise(batch_size)
                    generated_images = self.generator.predict([noise, labels])

                    discr_real_loss = self.discriminator.train_on_batch([images_mini_batch, labels], np.ones((batch_size,1)))
                    discr_fake_loss = self.discriminator.train_on_batch([generated_images, labels], -1 * np.ones((batch_size,1)))
                    self.discriminator = self.clip_gradients(self.discriminator, self.clip_value)
                    
                random_labels = self.generate_sample_labels(batch_size)
                gen_loss = self.gan_model.train_on_batch([noise, random_labels], np.ones((batch_size,1)))
            print("Discriminator loss : {} , Generator loss : {}".format(discr_real_loss + discr_fake_loss, gen_loss))
            self.generate_img(random_fixed_noise, 1, t)

            
    
                
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
img_width, img_height = x_train[0].shape
num_classes = 10 
w_cond_gan = W_Cond_GAN(img_width, img_height, 1, num_classes)
w_cond_gan.train(x_train, y_train, 100, 512)   
    

In [1]:
# Generate numbers from the trained generator now
num_to_generate = 9
random_fixed_noise = w_cond_gan.generate_uniform_noise(1)
w_cond_gan.generate_img(random_fixed_noise,num_to_generate, 0)

NameError: name 'w_cond_gan' is not defined

In [44]:
#(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
sampled_labels = np.random.randint(0, 10, 128)
#y_train = to_categorical(y_train)
labels_sampled = np.array([w_cond_gan.one_hot_encode(x) for x in sampled_labels]).reshape(-1,10)
print( y_train.shape, labels_sampled.shape)

(60000,) (128, 10)


In [28]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
print(y_train[0])
print(w_cond_gan.one_hot_encode(y_train[0]))
y_train = to_categorical(y_train)
print(y_train[0])


5
tf.Tensor([[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]], shape=(1, 10), dtype=float32)
[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]


In [32]:
print(w_cond_gan.generate_sample_labels(128)[100])

[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]


In [137]:
true_image = torch.FloatTensor([1])
fake_image = true_image * -1
print(true_image.shape, fake_image.shape)

torch.Size([1]) torch.Size([1])


In [145]:
z = torch.randn(9, 100)
print(z.shape)
labels = torch.FloatTensor([np.arange(9) for x in range(9)])
print(labels.shape)
print(torch.cat((z, labels),1))

torch.Size([9, 100])
torch.Size([9, 9])
tensor([[ 1.1845e+00, -1.0773e+00, -8.7932e-01,  1.5088e+00,  9.9296e-01,
          8.6638e-01, -1.3979e+00, -5.4746e-01,  1.4257e+00, -1.1504e+00,
         -4.4333e-01, -4.2188e-01,  2.0185e-01, -5.9029e-01, -1.0240e+00,
          1.4768e+00, -1.1574e+00,  1.2716e+00, -7.6203e-01,  1.0776e+00,
          1.0588e-01,  4.7398e-01, -4.0842e-01, -1.2114e+00,  9.7274e-01,
         -1.3814e+00,  1.3243e-01,  1.1240e+00,  1.0353e-01, -7.9471e-01,
          5.3642e-01,  1.1856e+00,  2.4557e+00,  7.1689e-01,  1.4755e+00,
          1.9607e+00, -2.1436e-01, -9.4985e-02, -1.8335e+00, -1.0516e+00,
         -1.5023e+00,  1.1358e+00,  1.6882e+00, -1.8248e-01,  1.9478e+00,
         -1.9916e-01, -1.8321e+00,  4.8991e-01,  1.3902e+00, -1.3419e+00,
         -1.1729e+00, -2.9572e-01, -6.0050e-01,  4.2851e-01, -4.7968e-01,
          1.0157e+00, -5.4433e-01,  1.6908e-01,  1.9551e+00, -1.2990e+00,
          6.1643e-01,  7.9264e-01,  1.3209e+00, -1.2800e+00,  2.0515e+00

In [153]:
labels = torch.FloatTensor(np.random.rand(10,0,1))

In [154]:
print(labels)

tensor([], size=(10, 0, 1))
