In [1]:
from google.colab import files
files.upload()


Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"redwanafarbin","key":"90f46590b3e3160b40c236a5c9efef59"}'}

In [2]:
# STEP 1: Setup and Imports
!pip install -q kaggle
import os
import numpy as np
import matplotlib.pyplot as plt
from zipfile import ZipFile
from tqdm import tqdm
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, Embedding, Concatenate
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D, Conv2DTranspose, Conv2D, LeakyReLU
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
from tensorflow.keras.preprocessing.image import load_img, img_to_array

# Move kaggle.json to proper location
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# STEP 2: Download CelebA from Kaggle
!kaggle datasets download -d jessicali9530/celeba-dataset
!unzip -qq celeba-dataset.zip -d celeba


Dataset URL: https://www.kaggle.com/datasets/jessicali9530/celeba-dataset
License(s): other
Downloading celeba-dataset.zip to /content
 99% 1.31G/1.33G [00:14<00:00, 147MB/s]
100% 1.33G/1.33G [00:15<00:00, 95.0MB/s]


In [3]:

from __future__ import print_function
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.nn import utils
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
from torch.utils.tensorboard import SummaryWriter

In [4]:

def getDataLoader(input_dir, batch_size, image_size, num_workers, labels_number):
    dataset = MyCustomDataset(input_dir, csv_name, image_size, labels_number)
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, drop_last=True
    )
    return data_loader

In [5]:
class MyCustomDataset(torch.utils.data.Dataset):
    def __init__(self, input_dir, csv_name, image_size, labels_number):
        self.transform_to_apply = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        ])
        self.data_info = pd.read_csv(os.path.join(input_dir, csv_name))
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        temp_labels = np.random.rand(self.image_arr.shape[0], len(labels_number))
        for i in range(len(labels_number)):
            temp_labels[0:-1,i] = (self.data_info.iloc[0:-1, labels_number[i]] )
        self.label_arr = np.asarray(temp_labels)
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        single_image_name = self.image_arr[index]
        img_as_img = Image.open(os.path.join(input_dir, single_image_name))
        img = self.transform_to_apply(img_as_img)
        single_image_label = self.label_arr[index]
        for i in range(single_image_label.size):
            if single_image_label[i] < 0:
                single_image_label[i] = 0
        return (img, single_image_label)

    def __len__(self):
        return self.data_len


In [6]:
class CategoricalConditionalBatchNorm(torch.nn.Module):
    def __init__(self, num_features, num_cats, eps=2e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super().__init__()
        self.num_features = num_features
        self.num_cats = num_cats
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = torch.nn.Parameter(torch.Tensor(num_cats, num_features))
            self.bias = torch.nn.Parameter(torch.Tensor(num_cats, num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)
        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)
            self.num_batches_tracked.zero_()

    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            self.weight.data.fill_(1.0)
            self.bias.data.zero_()

    def forward(self, input, cats):
        exponential_average_factor = 0.0
        if self.training and self.track_running_stats:
            self.num_batches_tracked += 1
            if self.momentum is None:
                exponential_average_factor = 1.0 / self.num_batches_tracked.item()
            else:
                exponential_average_factor = self.momentum
        out = torch.nn.functional.batch_norm(
            input, self.running_mean, self.running_var, None, None,
            self.training or not self.track_running_stats,
            exponential_average_factor, self.eps)
        if self.affine:
            shape = [input.size(0), self.num_features] + (input.dim() - 2) * [1]
            weight = self.weight.index_select(0, cats).view(shape)
            bias = self.bias.index_select(0, cats).view(shape)
            out = out * weight + bias
        return out

    def extra_repr(self):
        return '{num_features}, num_cats={num_cats}, eps={eps}, momentum={momentum}, affine={affine}, ' \
               'track_running_stats={track_running_stats}'.format(**self.__dict__)

In [7]:
class Generator(nn.Module):
    def __init__(self, num_classes, d=128):
        super(Generator, self).__init__()
        self.deconv1_1 = nn.ConvTranspose2d(100, d*4, 4, 1, 0)
        self.deconv1_1_bn = nn.BatchNorm2d(d*4)

        self.deconv00_2 = nn.Conv2d(num_classes, int(d/4), 1, 1, 0)
        self.deconv00_2_bn = nn.BatchNorm2d(int(d/4))
        self.deconv0_2 = nn.ConvTranspose2d(int(d/4), d, 4, 1, 0)
        self.deconv0_2_bn = nn.BatchNorm2d(d)
        self.deconv1_2 = nn.ConvTranspose2d(d, d*4, 3, 1, 1)
        self.deconv1_2_bn = nn.BatchNorm2d(d*4)

        self.deconv2 = nn.ConvTranspose2d(d*8, d*4, 4, 2, 1)
        self.deconv2_bn = CategoricalConditionalBatchNorm(d*4, num_classes)
        self.deconv3 = nn.ConvTranspose2d(d*4, d*2, 4, 2, 1)
        self.deconv3_bn = CategoricalConditionalBatchNorm(d*2, num_classes)
        self.deconv4 = nn.ConvTranspose2d(d*2, d, 4, 2, 1)
        self.deconv4_bn = CategoricalConditionalBatchNorm(d, num_classes)
        self.deconv5 = nn.ConvTranspose2d(d, 3, 4, 2, 1)

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    def forward(self, input, label, cat):
        x = F.leaky_relu(self.deconv1_1_bn(self.deconv1_1(input)), 0.2)

        y = F.leaky_relu(self.deconv00_2_bn(self.deconv00_2(label)), 0.2)
        y = F.leaky_relu(self.deconv0_2_bn(self.deconv0_2(y)), 0.2)
        y = F.leaky_relu(self.deconv1_2_bn(self.deconv1_2(y)), 0.2)

        x = torch.cat([x, y], 1)
        x = F.leaky_relu(self.deconv2_bn(self.deconv2(x), cat), 0.2)
        x = F.leaky_relu(self.deconv3_bn(self.deconv3(x), cat), 0.2)
        x = F.leaky_relu(self.deconv4_bn(self.deconv4(x), cat), 0.2)
        x = torch.tanh(self.deconv5(x))
        return x


In [8]:
class Discriminator(nn.Module):
    def __init__(self, d=128):
        super(Discriminator, self).__init__()

        self.conv1_1 = nn.Conv2d(3, int(d/2), 4, 2, 1)
        self.conv0_2 = nn.Conv2d(2*n_labels, int(d/4), 1, 1, 0)
        self.conv1_2 = nn.Conv2d(int(d/4), int(d/2), 4, 2, 1)

        self.conv2 = utils.spectral_norm(nn.Conv2d(d, d*2, 4, 2, 1))
        self.conv3 = utils.spectral_norm(nn.Conv2d(d*2, d*4, 4, 2, 1))
        self.conv4 = utils.spectral_norm(nn.Conv2d(d*4, d*8, 4, 2, 1))
        self.conv5 = nn.Conv2d(d*8, 1, 4, 1, 0)

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    def forward(self, input, label):
        x = F.leaky_relu(self.conv1_1(input), 0.2)

        y = F.leaky_relu(self.conv0_2(label), 0.2)
        y = F.leaky_relu(self.conv1_2(y), 0.2)

        x = torch.cat([x, y], 1)

        x = F.leaky_relu(self.conv2(x), 0.2)
        x = F.leaky_relu(self.conv3(x), 0.2)
        x = F.leaky_relu(self.conv4(x), 0.2)
        x = torch.sigmoid(self.conv5(x))
        return x

In [9]:
def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

In [10]:
def getDiscriminatorLabels(lbl, batch_size, image_size):
    if lbl.shape[1] == 1:
        a = torch.zeros([batch_size, 2, image_size, image_size])
        for i in range(batch_size):
            if lbl[i][0] == 1:
                a[i,0,:,:] = 1
            else:
                a[i,1,:,:] = 1
        return a
    elif lbl.shape[1] == 2:
        a = torch.zeros([batch_size, 4, image_size, image_size])
        for i in range(batch_size):
            if lbl[i][0] == 1:
                a[i,0,:,:] = 1
            else:
                a[i,1,:,:] = 1
            if lbl[i][1] == 1:
                a[i,2,:,:] = 1
            else:
                a[i,3,:,:] = 1
        return a
    else:
        a = torch.zeros([batch_size, 6, image_size, image_size])
        for i in range(batch_size):
            if lbl[i][0] == 1:
                a[i,0,:,:] = 1
            else:
                a[i,1,:,:] = 1
            if lbl[i][1] == 1:
                a[i,2,:,:] = 1
            else:
                a[i,3,:,:] = 1
            if lbl[i][2] == 1:
                a[i,4,:,:] = 1
            else:
                a[i,5,:,:] = 1
        return a


In [11]:
def getGeneratorLabels(lbl, batch_size):
    if lbl.shape[1] == 1:
        a = torch.zeros([batch_size, 2, 1, 1])
        for i in range(batch_size):
            if lbl[i][0] == 1:
                a[i,1,0,0] = 1
            else:
                a[i,0,0,0] = 1
        return a
    elif lbl.shape[1] == 2:
        a = torch.zeros([batch_size, 4, 1, 1])
        for i in range(batch_size):
            if lbl[i][0] == 1:
                a[i,1,0,0] = 1
            else:
                a[i,0,0,0] = 1
            if lbl[i][1] == 1:
                a[i,3,0,0] = 1  # Fixed: was == (comparison) now = (assignment)
            else:
                a[i,2,0,0] = 1  # Fixed: was == (comparison) now = (assignment)
        return a
    else:
        a = torch.zeros([batch_size, 6, 1, 1])
        for i in range(batch_size):
            if lbl[i][0] == 1:
                a[i,1,0,0] = 1
            else:
                a[i,0,0,0] = 1
            if lbl[i][1] == 1:
                a[i,3,0,0] = 1  # Fixed: was == (comparison) now = (assignment)
            else:
                a[i,2,0,0] = 1  # Fixed: was == (comparison) now = (assignment)
            if lbl[i][2] == 1:
                a[i,5,0,0] = 1
            else:
                a[i,4,0,0] = 1
        return a




In [12]:
def getGeneratorCategories(lbl, batch_size):
    if lbl.shape[1] == 1:
        a = torch.zeros([batch_size])
        for i in range(batch_size):
            if lbl[i][0] == 1:
                a[i] = 1
            else:
                a[i] = 0
        return a.long()
    elif lbl.shape[1] == 2:
        a = torch.zeros([batch_size])
        for i in range(batch_size):
            if lbl[i][0] == 0 and lbl[i][1] == 0:
                a[i] = 0
            elif lbl[i][0] == 0 and lbl[i][1] == 1:
                a[i] = 1
            elif lbl[i][0] == 1 and lbl[i][1] == 0:  # Fixed: was checking (0,1) again
                a[i] = 2
            else:  # (1,1)
                a[i] = 3
        return a.long()
    else:
        a = torch.zeros([batch_size])
        for i in range(batch_size):
            if lbl[i][0] == 0 and lbl[i][1] == 0 and lbl[i][2] == 0:
                a[i] = 0
            elif lbl[i][0] == 0 and lbl[i][1] == 0 and lbl[i][2] == 1:
                a[i] = 1
            elif lbl[i][0] == 0 and lbl[i][1] == 1 and lbl[i][2] == 0:
                a[i] = 2
            elif lbl[i][0] == 0 and lbl[i][1] == 1 and lbl[i][2] == 1:
                a[i] = 3
            elif lbl[i][0] == 1 and lbl[i][1] == 0 and lbl[i][2] == 0:
                a[i] = 4
            elif lbl[i][0] == 1 and lbl[i][1] == 0 and lbl[i][2] == 1:
                a[i] = 5
            elif lbl[i][0] == 1 and lbl[i][1] == 1 and lbl[i][2] == 0:
                a[i] = 6
            else:  # (1,1,1)
                a[i] = 7
        return a.long()


In [13]:
def getGeneratorVisualizationLabels(n_features, batch_size):
    remaining = batch_size - 64

    if n_features == 1:
        zeros = np.zeros((32,1))
        ones = np.ones((32,1))
        if remaining > 0:
            lbl_tmp = np.concatenate((zeros,ones),axis=0)
            remaining_zeros = np.zeros((remaining,1))
            lbl = np.concatenate((lbl_tmp,remaining_zeros),axis = 0)
        else:
            lbl = np.concatenate((zeros,ones),axis=0)
        return lbl
    elif n_features == 2:
        zeros = np.zeros((16,1))
        ones = np.ones((16,1))
        c1 = np.concatenate((zeros,ones,zeros,ones),axis =0)
        c2 = np.concatenate((zeros,zeros,ones,ones),axis =0)
        lbl = np.concatenate((c2,c1),axis=1)
        if remaining > 0:
            remaining_zeros = np.zeros((remaining,2))
            lbl = np.concatenate((lbl,remaining_zeros),axis = 0)
        return lbl
    else:
        zeros = np.zeros((8,1))
        ones = np.ones((8,1))
        c1 = np.concatenate((zeros,ones,zeros,ones,zeros,ones,zeros,ones),axis =0)
        c2 = np.concatenate((zeros,zeros,ones,ones,zeros,zeros,ones,ones),axis =0)
        c3 = np.concatenate((zeros,zeros,zeros,zeros,ones,ones,ones,ones),axis =0)
        lbl = np.concatenate((c3,c2,c1),axis=1)
        if remaining > 0:
            remaining_zeros = np.zeros((remaining,3))
            lbl = np.concatenate((lbl,remaining_zeros),axis = 0)
        return lbl

In [14]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.tensorboard import SummaryWriter


def train_model():
    device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
    print("Use device:", device)

    writer = SummaryWriter('/log')
    data_loader = getDataLoader(input_dir, batch_size, image_size, num_workers, labels_number)

    # Save training sample image grid
    show_images = next(iter(data_loader))
    plt.figure(figsize=(8, 8))
    plt.axis("off")
    plt.title("Training Images")
    plt.imshow(np.transpose(vutils.make_grid(show_images[0][0:64], padding=2, normalize=True), (1, 2, 0)))
    plt.savefig(os.path.join(output_dir, 'training_sample.png'))
    plt.close()

    netG = Generator(n_labels * 2).to(device)
    netD = Discriminator().to(device)

    cost_fun = nn.BCELoss()
    real_label = 1.0
    fake_label = 0.0

    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
    G_losses = []
    D_losses = []


    for current_epoch in range(num_epochs):
        for batch_index, (data, lbl) in enumerate(data_loader):
            lbl_real_disc = getDiscriminatorLabels(lbl, batch_size, image_size).to(device)
            netD.zero_grad()

            real_data = data.to(device)
            b_size = real_data.size(0)
            targets_real = torch.full((b_size,), real_label, device=device, dtype=torch.float)

            outputs_real = netD(real_data, lbl_real_disc).view(-1)
            real_loss = cost_fun(outputs_real, targets_real)
            real_loss.backward()
            D_x = outputs_real.mean().item()

            # ---- Generate Fake Images ----
            fake_in_lbl_clear = np.random.randint(2, size=(batch_size, n_labels))
            fake_in_lbl_g = getGeneratorLabels(fake_in_lbl_clear, batch_size).to(device)
            fake_in_lbl_d = getDiscriminatorLabels(fake_in_lbl_clear, batch_size, image_size).to(device)
            fake_cat = getGeneratorCategories(fake_in_lbl_clear, batch_size).to(device)

            noise = torch.randn(b_size, g_input_dim, 1, 1, device=device)
            fake = netG(noise, fake_in_lbl_g, fake_cat)

            targets_fake = torch.full((b_size,), fake_label, device=device, dtype=torch.float)
            outputs_fake = netD(fake.detach(), fake_in_lbl_d).view(-1)
            errD_fake = cost_fun(outputs_fake, targets_fake)
            errD_fake.backward()
            D_G_z1 = outputs_fake.mean().item()

            # ---- Total Discriminator Loss ----
            errD = real_loss + errD_fake
            optimizerD.step()

            # ---- Generator Training ----
            netG.zero_grad()
            targets_gen = torch.full((b_size,), real_label, device=device, dtype=torch.float)
            outputs_gen = netD(fake, fake_in_lbl_d).view(-1)
            loss_g = cost_fun(outputs_gen, targets_gen)
            loss_g.backward()
            D_G_z2 = outputs_gen.mean().item()
            optimizerG.step()

            # ✅ Log losses *after* both steps complete
            G_losses.append(loss_g.item())
            D_losses.append(errD.item())

            # ---- Visualization and Logging ----
            if batch_index % visualization_step == 0:
                print(f'📘 Epoch {current_epoch}/{num_epochs} | Batch {batch_index}/{len(data_loader)}')
                print(f'Loss_D: {errD.item():.4f} Loss_G: {loss_g.item():.4f} D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}')

                fake_vis_label_clear = getGeneratorVisualizationLabels(n_labels, batch_size)
                fake_vis_label = getGeneratorLabels(fake_vis_label_clear, batch_size).to(device)
                fake_vis_cat = getGeneratorCategories(fake_vis_label_clear, batch_size).to(device)
                noise_vis = torch.randn(b_size, g_input_dim, 1, 1, device=device)
                visualFake = netG(noise_vis, fake_vis_label, fake_vis_cat)

                plt.figure(figsize=(8, 8))
                plt.axis("off")
                plt.title("Generated Images")
                plt.imshow(np.transpose(vutils.make_grid(visualFake[0:64].detach().cpu(), padding=2, normalize=True), (1, 2, 0)))
                plt.savefig(os.path.join(output_dir, f'result_{current_epoch}_{batch_index}.png'))
                plt.close()

                current_batch = current_epoch * len(data_loader) + batch_index
                writer.add_scalar('Loss D real', D_x, current_batch)
                writer.add_scalar('Loss D fake', D_G_z1, current_batch)
                writer.add_scalar('Loss G', D_G_z2, current_batch)
                writer.add_image('Generated Images', vutils.make_grid(visualFake[0:64], padding=2, normalize=True), current_batch)

        print(f"✅ Finished Epoch {current_epoch + 1}/{num_epochs}")


    writer.close()
    np.save(os.path.join(output_dir, 'G_losses.npy'), np.array(G_losses))
    np.save(os.path.join(output_dir, 'D_losses.npy'), np.array(D_losses))




In [15]:
# 🧠 Hyperparameters and config
batch_size = 64
num_epochs = 2
input_dir = '/content/celeba/img_align_celeba/img_align_celeba'
output_dir = '/content/generated'
csv_name = '/content/celeba/list_attr_celeba.csv'
num_workers = 2
image_size = 64
ngpu = 1
g_input_dim = 100
n_channels = 3
lr = 0.0002
beta1 = 0.5
visualization_step = 200
n_labels = 2
labels_number = [34, 20]
ngf = 64
ndf = 64

os.makedirs(output_dir, exist_ok=True)

if __name__ == "__main__":
    train_model()

Use device: cuda:0
📘 Epoch 0/2 | Batch 0/3165
Loss_D: 1.3814 Loss_G: 1.0945 D(x): 0.5023 D(G(z)): 0.4998 / 0.3348
📘 Epoch 0/2 | Batch 200/3165
Loss_D: 1.3425 Loss_G: 2.3297 D(x): 0.6525 D(G(z)): 0.4714 / 0.1137
📘 Epoch 0/2 | Batch 400/3165
Loss_D: 1.2260 Loss_G: 1.0475 D(x): 0.5626 D(G(z)): 0.4461 / 0.3758
📘 Epoch 0/2 | Batch 600/3165
Loss_D: 1.2211 Loss_G: 0.8851 D(x): 0.5955 D(G(z)): 0.4752 / 0.4264
📘 Epoch 0/2 | Batch 800/3165
Loss_D: 1.2114 Loss_G: 0.9443 D(x): 0.6069 D(G(z)): 0.4793 / 0.4050
📘 Epoch 0/2 | Batch 1000/3165
Loss_D: 1.2037 Loss_G: 1.0361 D(x): 0.5186 D(G(z)): 0.3787 / 0.3638
📘 Epoch 0/2 | Batch 1200/3165
Loss_D: 1.1659 Loss_G: 0.9923 D(x): 0.5907 D(G(z)): 0.4451 / 0.3871
📘 Epoch 0/2 | Batch 1400/3165
Loss_D: 1.1195 Loss_G: 1.1929 D(x): 0.6291 D(G(z)): 0.4486 / 0.3216
📘 Epoch 0/2 | Batch 1600/3165
Loss_D: 1.1043 Loss_G: 1.2816 D(x): 0.5782 D(G(z)): 0.3761 / 0.2929
📘 Epoch 0/2 | Batch 1800/3165
Loss_D: 1.2080 Loss_G: 1.1055 D(x): 0.5895 D(G(z)): 0.4490 / 0.3589
📘 Epoch 