In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import numpy as np
from torch.optim.lr_scheduler import StepLR
import torchvision.utils as vutils
from torch.utils.data import DataLoader, TensorDataset
from scipy import linalg
from scipy.stats import entropy
import tqdm
import cv2


In [None]:
batch_size = 128//5
image_size = 64
nc = 3
nz = 100
ngf = 64
ndf = 32

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
device

device(type='cuda', index=0)

In [None]:
# !pip3 install datasets
from datasets import load_from_disk
portrait_data = load_from_disk('/content/drive/MyDrive/cv2/wikiart_portrait')

Collecting datasets
  Downloading datasets-2.18.0-py3-none-any.whl (510 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: xxhash, dill, multiprocess, datasets
Successfully installed datasets

In [None]:
from torchvision.transforms.functional import five_crop
from torchvision.transforms.functional import resize
from PIL import Image

transform = transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Lambda(lambda x: five_crop(x, (int(x.shape[1]*0.9), int(x.shape[2]*0.9)))),
                               transforms.Lambda(lambda crops: [resize(crop, (64, 64)) for crop in crops]),
                               transforms.Lambda(lambda crops: torch.stack([crop for crop in crops])),
                           ])

def apply_transform(examples):
    examples['image'] = [transform(image) for image in examples["image"]]
    return examples

In [None]:
ds = portrait_data.with_format("torch")
ds = portrait_data.with_transform(apply_transform)

In [None]:
import matplotlib.pyplot as plt
import torchvision

batch = ds[0]['image']

b = batch.size(0)

for i in range(b):
    image = batch[i].permute(1, 2, 0)

    image_np = image.cpu().numpy()

    plt.imshow(image_np)
    plt.axis('off')
    plt.title(f"Image {i+1}")
    plt.show()

In [None]:
class Generator(nn.Module):
    def __init__(self, nz=100, ngf=64, nc=3):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)


class Discriminator(nn.Module):
    def __init__(self, nc=3, ndf=64, output_size=1, num_style_classes=3):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.real_fake_head = nn.Sequential(
            nn.Conv2d(ndf * 8, output_size, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        # Common body
        x = self.main(x)
        # Flatten the feature map
        # x = x.view(x.size(0), -1)
        # Real/fake head
        real_fake_output = self.real_fake_head(x)
        # Multi-label head
        # multi_label_output = self.multi_label_head(x)
        return real_fake_output.view(-1, 1)



latent_dim = 100
g = Generator()
batchsize = 2
z = torch.randn(batchsize, latent_dim, 1, 1)
out = g(z)
print(out.size())

d = Discriminator()
x = torch.randn((batchsize, 3, 64, 64))
out = d(x)
print(out.size())

In [None]:

def collate_fn(examples):
    images = []
    labels = []
    for example in examples:
        images.append((example["image"]))
        labels.append(example["style"])
    pixel_values = torch.stack(images)
    labels = torch.tensor(labels)
    return {"image": pixel_values, "style": labels}

dataloader = DataLoader(ds, collate_fn=collate_fn, batch_size=batch_size)

In [None]:
num_epochs = 1000
lr = 0.0001
beta1 = 0.5

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
criterion = nn.BCELoss()

fixed_noise = torch.randn(256, nz, 1, 1, device=device)

real_label = 1.
fake_label = 0.

netD = Discriminator().to(device)
netD.apply(weights_init)
netG = Generator().to(device)
netG.apply(weights_init)

optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
from torchsummary import summary
print(summary(netG, (100, 1, 1)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ConvTranspose2d-1            [-1, 512, 4, 4]         819,200
       BatchNorm2d-2            [-1, 512, 4, 4]           1,024
              ReLU-3            [-1, 512, 4, 4]               0
   ConvTranspose2d-4            [-1, 256, 8, 8]       2,097,152
       BatchNorm2d-5            [-1, 256, 8, 8]             512
              ReLU-6            [-1, 256, 8, 8]               0
   ConvTranspose2d-7          [-1, 128, 16, 16]         524,288
       BatchNorm2d-8          [-1, 128, 16, 16]             256
              ReLU-9          [-1, 128, 16, 16]               0
  ConvTranspose2d-10           [-1, 64, 32, 32]         131,072
      BatchNorm2d-11           [-1, 64, 32, 32]             128
             ReLU-12           [-1, 64, 32, 32]               0
  ConvTranspose2d-13            [-1, 3, 64, 64]           3,072
             Tanh-14            [-1, 3,

In [None]:
print(summary(netD, (3, 256, 256)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           3,072
         LeakyReLU-2         [-1, 64, 128, 128]               0
            Conv2d-3          [-1, 128, 64, 64]         131,072
       BatchNorm2d-4          [-1, 128, 64, 64]             256
         LeakyReLU-5          [-1, 128, 64, 64]               0
            Conv2d-6          [-1, 256, 32, 32]         524,288
       BatchNorm2d-7          [-1, 256, 32, 32]             512
         LeakyReLU-8          [-1, 256, 32, 32]               0
            Conv2d-9          [-1, 512, 16, 16]       2,097,152
      BatchNorm2d-10          [-1, 512, 16, 16]           1,024
        LeakyReLU-11          [-1, 512, 16, 16]               0
           Conv2d-12            [-1, 1, 13, 13]           8,192
          Sigmoid-13            [-1, 1, 13, 13]               0
Total params: 2,765,568
Trainable param

In [None]:
# G_losses_file = open('/content/drive/MyDrive/cv2/G_losses_scaled_baseline_bs_128.txt', 'a+')
# D_losses_file = open('/content/drive/MyDrive/cv2/G_losses_scaled_baseline_bs_128..txt', 'a+')

In [None]:
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")

for epoch in range(num_epochs):

    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch

        # real = data['images'].to(device)
        real = data['image'].to(device)
        real = real.view(-1, real.size(2), real.size(3), real.size(4))

        b_size = real.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 10 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            G_losses_file.write(str(errG.item()) + '\n')
            D_losses_file.write(str(errD.item()) + '\n')

    # Check if it's time to save the models
    if (epoch + 1) % 10 == 0:
        # Create a checkpoint dictionary containing the state dictionaries of the models
        checkpoint = {'netG': netG.state_dict(),
                      'netD': netD.state_dict()}

        # Save the checkpoint to a file with a name indicating the epoch
        torch.save(checkpoint, '/content/drive/MyDrive/cv2/scaled_baseline_bs_128_epoch_{}.pt'.format(epoch + 1))

Starting Training Loop...
[120/1000][0/565]	Loss_D: 0.0640	Loss_G: 3.1265	D(x): 0.9637	D(G(z)): 0.0255 / 0.1051
[120/1000][10/565]	Loss_D: 0.0527	Loss_G: 5.5977	D(x): 0.9697	D(G(z)): 0.0202 / 0.0082
[120/1000][20/565]	Loss_D: 0.0742	Loss_G: 5.2739	D(x): 0.9742	D(G(z)): 0.0449 / 0.0103
[120/1000][30/565]	Loss_D: 0.0738	Loss_G: 5.3046	D(x): 0.9404	D(G(z)): 0.0104 / 0.0175
[120/1000][40/565]	Loss_D: 0.0657	Loss_G: 4.5307	D(x): 0.9685	D(G(z)): 0.0313 / 0.0216
[120/1000][50/565]	Loss_D: 0.0487	Loss_G: 5.5019	D(x): 0.9743	D(G(z)): 0.0212 / 0.0076
[120/1000][60/565]	Loss_D: 0.0850	Loss_G: 4.6021	D(x): 0.9817	D(G(z)): 0.0584 / 0.0224
[120/1000][70/565]	Loss_D: 0.3057	Loss_G: 6.1265	D(x): 0.9957	D(G(z)): 0.2170 / 0.0060
[120/1000][80/565]	Loss_D: 0.1072	Loss_G: 5.4146	D(x): 0.9377	D(G(z)): 0.0342 / 0.0140
[120/1000][90/565]	Loss_D: 0.4531	Loss_G: 2.5875	D(x): 0.7629	D(G(z)): 0.1161 / 0.1492
[120/1000][100/565]	Loss_D: 0.0952	Loss_G: 4.3422	D(x): 0.9431	D(G(z)): 0.0312 / 0.0297
[120/1000][110/56