In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets
import numpy as np
import pickle
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

In [3]:
device=('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
class SketchToFaceDataset(Dataset):
    def __init__(self, sketch_images, attributes, real_images, transform=None):
        self.sketch_images = sketch_images  
        self.attributes = attributes  
        self.real_images = real_images  
        self.transform = transform

    def __len__(self):
        return len(self.sketch_images)

    def __getitem__(self, idx):
        sketch = self.sketch_images[idx]
        attr = self.attributes[idx]
        real_img = self.real_images[idx]
        
#         if self.transform:
#             sketch = self.transform(sketch)
#             real_img = self.transform(real_img)
        return sketch, attr, real_img



In [5]:
with open('/kaggle/input/image-dict-3/images_dict_3.pkl', 'rb') as f:
    dict_image = pickle.load(f)
    

df_atrb=pd.read_csv('/kaggle/input/atributes-list-40/df_atributes_40_columns.csv')

root_dir='/kaggle/input/celeba/img_align_celeba'
real_images=[]
for idx in range(len(df_atrb)):
    image_name=df_atrb.iloc[idx,1]
    image_path=os.path.join(root_dir,image_name)
    image = cv2.imread(image_path)
    resized_image = cv2.resize(image, (128, 128))
    reshaped_image = np.transpose(resized_image, (2, 0, 1))
    real_images.append(reshaped_image)
    

In [6]:
sketch_images=dict_image.values()
sketch_images=list(sketch_images)
attributes=np.array(df_atrb.iloc[:,2:20])

sketch_images = torch.tensor(sketch_images, dtype=torch.float32).to(device).unsqueeze(1)  
attributes = torch.tensor(attributes, dtype=torch.float32).to(device)  
real_images = torch.tensor(real_images, dtype=torch.float32).to(device) 


In [7]:

# Transform
transform = transforms.Compose([transforms.ToTensor()])

# Create dataset
dataset = SketchToFaceDataset(sketch_images, attributes, real_images, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)


In [8]:
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        #  Sketch processing
        self.conv1_A = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)  
        self.conv2_A = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 
        self.conv3_A = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 
        self.prelu_A = nn.PReLU()

        # Attribute processing
        self.fc = nn.Linear(18, 128 * 128)  
        self.conv1_B = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)  
        self.conv2_B = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)  
        self.conv3_B = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)  
        self.prelu_B = nn.PReLU()

        # Downsampling
        self.conv1_down = nn.Conv2d(384, 64, kernel_size=3, stride=2, padding=1) 
        self.conv2_down = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1) 
        self.conv3_down = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1) 
        self.conv4_down = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1) 
        self.prelu_down = nn.PReLU()

        # Upsampling
        self.deconv1_up = nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1)  
        self.deconv2_up = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1) 
        self.deconv3_up = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1) 
        self.deconv4_up = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)  
        self.conv_out = nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)  
        self.prelu_up = nn.PReLU()

    def forward(self, sketch_image, attributes):
        
        out1_A = self.prelu_A(self.conv1_A(sketch_image)) 
        out2_A = self.prelu_A(self.conv2_A(out1_A)) 
        out3_A = self.prelu_A(self.conv3_A(out2_A))  
        
       
        out_A = torch.cat((out1_A, out2_A, out3_A), dim=1) 
        
        
        attr = self.fc(attributes)
        attr = attr.view(attributes.size(0), 1, 128, 128)  
        out1_B = self.prelu_B(self.conv1_B(attr)) 
        out2_B = self.prelu_B(self.conv2_B(out1_B))
        out3_B = self.prelu_B(self.conv3_B(out2_B)) 
        
        
        out_B = torch.cat((out1_B, out2_B, out3_B), dim=1) 

        
        combined = torch.cat((out_A, out_B), dim=1) 

        
        down1 = self.prelu_down(self.conv1_down(combined)) 
        down2 = self.prelu_down(self.conv2_down(down1)) 
        down3 = self.prelu_down(self.conv3_down(down2))  
        down4 = self.prelu_down(self.conv4_down(down3))  

        
        up1 = self.prelu_up(self.deconv1_up(down4))  
        up1 = torch.cat((up1, down3), dim=1) 

        up2 = self.prelu_up(self.deconv2_up(up1))  
        up2 = torch.cat((up2, down2), dim=1)  

        up3 = self.prelu_up(self.deconv3_up(up2)) 
        up3 = torch.cat((up3, down1), dim=1) 

        up4 = self.prelu_up(self.deconv4_up(up3)) 
        
       
        out = self.conv_out(up4)  
        
        return out


In [9]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # Convolutional layers to process the face image
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
        )

        
        self.fc = nn.Linear(18, 256)

       
        self.out = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0)
        
        self.linear=nn.Linear(169,13)
        self.linear2=nn.Linear(13,1)
        self.sigmoid=nn.Sigmoid()

    def forward(self, face_image, attributes):
       
        out_img = self.conv(face_image)

        
        attr = self.fc(attributes)
        attr = attr.unsqueeze(2).unsqueeze(3) 
        attr = attr.expand(attr.size(0), attr.size(1), out_img.size(2), out_img.size(3))

        
        combined = torch.cat([out_img, attr], dim=1)

        
        validity = self.out(combined)
        validity=validity.view(-1,169)
        l1=self.linear(validity)
        final=self.sigmoid(self.linear2(l1))
        return final


In [10]:
# # Reconstruction loss (pixel-wise loss)
# def reconstruction_loss(gen_images, real_images):
#     return F.mse_loss(gen_images, real_images)

# Adversarial loss (for real/fake classification)
adversarial_loss = nn.BCELoss()


In [11]:
CHECKPOINT_DIR = "/kaggle/working/"
checkpoint_interval=1
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
def train(generator, discriminator, dataloader, num_epochs=50, lr=0.0002):
    optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

    for epoch in range(num_epochs):
        for i, (sketches, attributes, real_faces) in enumerate(dataloader):
            batch_size = sketches.size(0)
            sketches=sketches.to(device)
            attributes=attributes.to(device)
            real_faces=real_faces.to(device)
            
            valid = torch.ones((batch_size, 1), requires_grad=False).to(device)
            fake = torch.zeros((batch_size, 1), requires_grad=False).to(device)

           
            #  Train Generator
     
            optimizer_G.zero_grad()

            
            gen_faces = generator(sketches, attributes)

            
            g_loss_adv = adversarial_loss(discriminator(gen_faces, attributes), valid)
            
#             g_loss_recon = reconstruction_loss(gen_faces, real_faces)

           
            g_loss = g_loss_adv #+ g_loss_recon
            g_loss.backward()
            optimizer_G.step()

        
            
         
            optimizer_D.zero_grad()

            
            real_loss = adversarial_loss(discriminator(real_faces, attributes), valid)
            
            fake_loss = adversarial_loss(discriminator(gen_faces.detach(), attributes), fake)
            
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()

            print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")

        if (epoch + 1) % checkpoint_interval == 0:
            torch.save(generator.state_dict(), f"{CHECKPOINT_DIR}/generator_epoch_{epoch+1}.pth")
            torch.save(discriminator.state_dict(), f"{CHECKPOINT_DIR}/discriminator_epoch_{epoch+1}.pth")
            print(f"Checkpoint saved at epoch {epoch+1}.")    

In [12]:
def inference(generator_checkpoint, sketch_image, attributes, device):
    generator = Generator().to(device)
    generator.load_state_dict(torch.load(generator_checkpoint, map_location=device))
    generator.eval()

    with torch.no_grad():
        sketch_image = sketch_image.to(device)
        attributes = attributes.to(device)
        generated_image = generator(sketch_image, attributes)
    
    return generated_image

In [13]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)


In [14]:
train(generator, discriminator, dataloader, num_epochs=2)


[Epoch 0/2] [Batch 0/2] [D loss: 0.6769] [G loss: 0.6577]
[Epoch 0/2] [Batch 1/2] [D loss: 0.3946] [G loss: 0.6345]
Checkpoint saved at epoch 1.
[Epoch 1/2] [Batch 0/2] [D loss: 0.3531] [G loss: 0.7073]
[Epoch 1/2] [Batch 1/2] [D loss: 0.3094] [G loss: 0.7761]
Checkpoint saved at epoch 2.


In [15]:
example_sketch = torch.randn(1, 1, 128, 128)  
example_attributes = torch.randn(1, 18)       
generated_image = inference(f"{CHECKPOINT_DIR}/generator_epoch_2.pth", example_sketch, example_attributes, device)

# Save the generated image
from torchvision.utils import save_image
save_image(generated_image, "generated_image.png")
print("Generated image saved as 'generated_image.png'.")

Generated image saved as 'generated_image.png'.
