In [107]:
from google.colab import drive
drive.mount('/gdrive')

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


In [108]:
import tarfile
import os

In [109]:
path_to_data = r'/gdrive/MyDrive/Dataset/StylesGan.tar'
with tarfile.open(path_to_data, 'r') as t:
    t.extractall('./data')

In [110]:
path_to_folder = r'/content/data/StylesGan'

In [111]:
from torchvision.datasets import ImageFolder
import torchvision.transforms as tt

In [112]:
std = (0.5,0.5,0.5)
mean = (0.5,0.5,0.5)

In [113]:
dataset = ImageFolder(root = path_to_folder, transform= tt.Compose([tt.ToTensor(), tt.Normalize(mean, std)]))

In [114]:
def denorm(i):
    i = (i*std[0]) + mean[0]
    return i

In [115]:
image,_ = dataset[0]
image.shape

torch.Size([3, 128, 128])

In [116]:
from torch.utils.data import DataLoader
batch_size = 200

In [117]:
data_dl = DataLoader(dataset=dataset, batch_size=batch_size, pin_memory= True, num_workers= 2)

In [118]:
import torch
import torch.nn as nn

In [119]:
device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [120]:
D = nn.Sequential(
    nn.Conv2d(3,16, kernel_size=3, stride=1, padding=1), #input 3X128X128
    nn.BatchNorm2d(16),
    nn.LeakyReLU(),
    nn.MaxPool2d(2,2), #output 16X64X64
    
    nn.Conv2d(16,64, kernel_size=3, stride=1, padding=1),
    nn.BatchNorm2d(64),
    nn.LeakyReLU(),
    nn.MaxPool2d(2,2), #output 64X32X32
    
    nn.Conv2d(64,128, kernel_size=3, stride=1, padding=1),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(),
    nn.MaxPool2d(2,2), #output 128X16X16
    
    nn.Conv2d(128,256, kernel_size=3, stride=1, padding=1),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(),
    nn.MaxPool2d(2,2), #output 256X8X8
    
    nn.Conv2d(256,512, kernel_size=3, stride=1, padding=1),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(),
    nn.MaxPool2d(2,2), #output 512X4X4
    
    nn.AdaptiveAvgPool2d(1),
    nn.Flatten(),
    
    nn.Linear(512, 128),
    nn.LeakyReLU(),
    nn.Linear(128,1),
    nn.Sigmoid()
)

In [121]:
latent_size = 64

In [122]:
G = nn.Sequential(
    nn.ConvTranspose2d(latent_size, 128, kernel_size=4, stride=1, padding=1),
    nn.BatchNorm2d(128),
    nn.ReLU(), #output 128X4X4
    
    nn.ConvTranspose2d(128,512, kernel_size=4, stride=2, padding=1),
    nn.BatchNorm2d(512),
    nn.ReLU(), #output 512X8X8
    
    nn.ConvTranspose2d(512,256, kernel_size=4, stride=2, padding=1),
    nn.BatchNorm2d(256),
    nn.ReLU(), #output 256X16X16
    
    nn.ConvTranspose2d(256,128, kernel_size=4, stride=2, padding=1),
    nn.BatchNorm2d(128),
    nn.ReLU(), #output 128X32X32
    
    nn.ConvTranspose2d(128,64, kernel_size=4, stride=2, padding = 1),
    nn.BatchNorm2d(64),
    nn.ReLU(), #64X64X64
    
    nn.ConvTranspose2d(64,3, kernel_size=4, stride=2, padding=1),
    nn.Tanh(), #3X128X128
)

In [123]:
D.to(device)
G.to(device)

Sequential(
  (0): ConvTranspose2d(64, 128, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): ConvTranspose2d(128, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU()
  (6): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (8): ReLU()
  (9): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (10): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (11): ReLU()
  (12): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (14): ReLU()
  (15): ConvTranspose2d(64, 3, kernel_size=(4, 4), strid

In [124]:
loss_func = nn.BCELoss()
d_opt_func = torch.optim.SGD(D.parameters(), lr=3e-5)
g_opt_func = torch.optim.Adam(G.parameters(), lr=3e-4)

In [125]:
import torch
from torchvision.utils import save_image

In [126]:
random_vectors = torch.randn(batch_size, latent_size,1,1).to(device)
save_image_path = r'/gdrive/MyDrive/Dataset/StylesSavedImages'

def save_fake(i):
    fake_image = G(random_vectors)
    name = f'fake_image_{i}.png'
    
    save_image(denorm(fake_image), os.path.join(save_image_path, name), nrow=20)

In [127]:
def dFit(x):
    fake_labels = torch.zeros(batch_size, 1).to(device)
    real_labels = torch.ones(batch_size, 1).to(device)
    
    output = D(x)
    d_loss_real = loss_func(output, real_labels)
    
    y = torch.randn(batch_size, latent_size, 1,1).to(device)
    fake_output = D(G(y))
    d_loss_fake = loss_func(fake_output, fake_labels)
    
    d_loss = d_loss_fake + d_loss_real
    
    d_opt_func.zero_grad()
    g_opt_func.zero_grad()
    
    d_loss.backward()
    
    d_opt_func.step()
    
    return d_loss, output, fake_output

In [128]:
def gFit():
    x = torch.randn(batch_size, latent_size,1,1).to(device)
    label = torch.ones(batch_size, 1).to(device)
    
    gen_output = G(x)
    gen_output_d = D(gen_output)
    
    g_loss = loss_func(gen_output_d, label)
    
    d_opt_func.zero_grad()
    g_opt_func.zero_grad()
    
    g_loss.backward()
    
    g_opt_func.step()
    
    return g_loss

In [129]:
def fit(num_epochs):
    for epoch in range(num_epochs):
        for i,_ in data_dl:
            i = i.to(device)
            d_loss, real_score, fake_score = dFit(i)
            g_loss = gFit()
            if real_score.mean()>0.8:
                gFit()
                gFit()
            if fake_score.mean()>0.7:
                dFit(i)
            
        print(f'epoch_num: {epoch+1} d_loss: {d_loss:.4f}, g_loss: {g_loss:.4f}, real_score: {real_score.mean()}, fake_score: {fake_score.mean()}')
        
        save_fake(epoch)

In [None]:
fit(100)

epoch_num: 1 d_loss: 1.9930, g_loss: 0.3584, real_score: 0.4533725380897522, fake_score: 0.6992700099945068
epoch_num: 2 d_loss: 1.8466, g_loss: 0.4249, real_score: 0.4563266634941101, fake_score: 0.654172956943512
epoch_num: 3 d_loss: 1.7899, g_loss: 0.4508, real_score: 0.4606822729110718, fake_score: 0.637472927570343
epoch_num: 4 d_loss: 1.7502, g_loss: 0.4753, real_score: 0.45973289012908936, fake_score: 0.6220251321792603
epoch_num: 5 d_loss: 1.6793, g_loss: 0.5216, real_score: 0.4591852128505707, fake_score: 0.5937653183937073
epoch_num: 6 d_loss: 1.6241, g_loss: 0.5600, real_score: 0.4599149525165558, fake_score: 0.5713907480239868
epoch_num: 7 d_loss: 1.6014, g_loss: 0.5739, real_score: 0.4619671106338501, fake_score: 0.5634881258010864
epoch_num: 8 d_loss: 1.5607, g_loss: 0.6008, real_score: 0.46517813205718994, fake_score: 0.5485193133354187
epoch_num: 9 d_loss: 1.5244, g_loss: 0.6240, real_score: 0.4692873954772949, fake_score: 0.5359088778495789
