In [None]:
import torchvision
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_tensor
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
from google.colab import files
import torchvision.models as models
from torchvision.transforms import Resize
import torch

In [None]:
#Upload the dataset.zip or the image
uploaded = files.upload()

Saving original_water.jpg to original_water.jpg


In [None]:
!pwd

/content


In [None]:
!ls

my_data  original_water.jpg  sample_data


In [None]:
#d_driven = True if data driven approach want to be used
d_driven = False

In [None]:
#unzip the dataset
if(d_driven):
  !unzip dotted.zip -d /content

In [None]:
#Create a folder for the results
!mkdir -p /content/my_data

In [None]:
#configurable parameters
lr = 0.0002
batch_size = 8
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 100

In [None]:
if(d_driven):
    textureName = "data_driven_dotted"

In [None]:
#Open the original image and convert it in np.array and define the number of random crops
if(not d_driven):
    imagePath = "original_water.jpg"
    image = Image.open(imagePath).convert('RGB')
    image = np.array(image)
    textureName = "water-texture"
    size = 50

In [None]:
#The dataset class. Based on d_driven decides to proced with data-free or data-driven approach
class Textures(Dataset):
    def __init__(self, principal_image=None, size=0, directory=""):
        super(Textures, self).__init__()
        self.principal_image = principal_image
        self.data = size

        self.trans = transforms.Compose([
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            transforms.Resize((1024, 1024)),
            transforms.RandomCrop((512, 512))
        ])

        if(d_driven):
            self.image_paths = [os.path.join(directory, filename) for filename in os.listdir(directory) if filename.endswith('.jpg') ]
        else:
            self.principal_image = principal_image
            self.data = size
            self.principal_image = to_tensor(self.principal_image)

    def __len__(self):
        if(d_driven):
            return len(self.image_paths)
        return self.data

    def __getitem__(self, index):

        if(d_driven):
            image_path = self.image_paths[index]
            image = Image.open(image_path).convert('RGB')
            image = np.array(image)
            image = to_tensor(image)
            image = self.trans(image)
            return image

        image = self.trans(self.principal_image)

        return image

In [None]:
#Defines the Discriminator class
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1024, 4, 1, 0),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1024, 1, 4, 1, 0),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)


In [None]:
#Defines the Generator class
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        #strato iniziale
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(3, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
        ]


        #downSampling
        model += [
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True),
        ]

        #blocchi residuali
        model += [self._make_residual_block(256) for _ in range(9)]

        model += [
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 256, 3, stride=1, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True),
        ]

        #upSampling
        model += [
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, 3, stride=1, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
        ]

        #strato finale
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, 3, 7),
            nn.Tanh(),
        ]

        self.model = nn.Sequential(*model)


    #La funzione dei blocchi residuali
    def _make_residual_block(self, input):
        block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(input, input, 3),
            nn.InstanceNorm2d(input),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(input, input, 3),
            nn.InstanceNorm2d(input),
        )
        return block

    def forward(self, x):
        return self.model(x)



In [None]:
#Defines the Perceptual Loss class
class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        self.vgg = models.vgg16(pretrained=True).features[:16].eval()
        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, generated, target):
        generated_features = self.vgg(generated)
        target_features = self.vgg(target)
        loss = nn.functional.mse_loss(generated_features, target_features)
        return loss

In [None]:
#initialize the Losses
criterion_GAN = nn.BCELoss()
criterion_perceptual = PerceptualLoss().to(device)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:05<00:00, 110MB/s] 


In [None]:
#The main function used for the training
def train(discriminator, generator, dataloader, optim_D, optim_G):

    for epoch in range(epochs):
        loop = tqdm(dataloader, leave=True)

        for i, real_images in enumerate(loop):

            original_shapes = real_images.shape

            new_shape = (original_shapes[0], original_shapes[1], original_shapes[2] // 2, original_shapes[3] // 2)

            real_images = real_images.to(device)

            noise = torch.normal(0, 1, size=new_shape).to(device)
            fake_images = generator(noise)

            # Allenamento del discriminatore
            optim_D.zero_grad()
            real_preds = discriminator(real_images).squeeze()
            fake_preds = discriminator(fake_images.detach()).squeeze()
            loss_D_real = criterion_GAN(real_preds, torch.ones_like(real_preds))
            loss_D_fake = criterion_GAN(fake_preds, torch.zeros_like(fake_preds))
            loss_D = (loss_D_real + loss_D_fake) / 2
            loss_D.backward()
            optim_D.step()

            # Allenamento del generatore
            optim_G.zero_grad()
            fake_preds = discriminator(fake_images).squeeze()
            loss_G_GAN = criterion_GAN(fake_preds, torch.ones_like(fake_preds))
            loss_G_perceptual = criterion_perceptual(fake_images, real_images)
            loss_G = loss_G_GAN + loss_G_perceptual
            loss_G.backward()
            optim_G.step()

            loop.set_postfix(Disc_Loss=loss_D.item(), Gen_Loss=loss_G.item(), Epoch=epoch)

    # Salvataggio del modello del Generatore
    torch.save(generator.state_dict(), f'generator_{textureName}.pth')
    files.download(f'generator_{textureName}.pth')

In [None]:
#initialize the Discriminator, the Generator, Adam optimizers and the DataLoader
discriminator = Discriminator().to("cuda")
generator = Generator().to("cuda")

optim_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
optim_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))

if(d_driven):
    dataset = Textures(directory = "dotted")
else:
    dataset = Textures(principal_image=image, size = size)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
#call the training function
train(discriminator, generator, dataloader, optim_D, optim_G)

100%|██████████| 7/7 [00:21<00:00,  3.01s/it, Disc_Loss=0.575, Epoch=0, Gen_Loss=3.79]
100%|██████████| 7/7 [00:19<00:00,  2.72s/it, Disc_Loss=0.609, Epoch=1, Gen_Loss=4.45]
100%|██████████| 7/7 [00:19<00:00,  2.76s/it, Disc_Loss=0.0243, Epoch=2, Gen_Loss=7.51]
100%|██████████| 7/7 [00:19<00:00,  2.81s/it, Disc_Loss=0.0173, Epoch=3, Gen_Loss=7.82]
100%|██████████| 7/7 [00:19<00:00,  2.82s/it, Disc_Loss=0.0127, Epoch=4, Gen_Loss=8.22]
100%|██████████| 7/7 [00:19<00:00,  2.80s/it, Disc_Loss=1.03, Epoch=5, Gen_Loss=8.39]
100%|██████████| 7/7 [00:19<00:00,  2.82s/it, Disc_Loss=0.515, Epoch=6, Gen_Loss=3.37]
100%|██████████| 7/7 [00:19<00:00,  2.80s/it, Disc_Loss=0.22, Epoch=7, Gen_Loss=4.75]
100%|██████████| 7/7 [00:19<00:00,  2.83s/it, Disc_Loss=0.387, Epoch=8, Gen_Loss=6.12]
100%|██████████| 7/7 [00:19<00:00,  2.80s/it, Disc_Loss=0.0836, Epoch=9, Gen_Loss=5.65]
100%|██████████| 7/7 [00:19<00:00,  2.81s/it, Disc_Loss=0.0578, Epoch=10, Gen_Loss=6.85]
100%|██████████| 7/7 [00:19<00:00,  2.8

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
#In the case of data-free approach open the image and convert it to a tensor
if(not d_driven):
    image = Image.open(imagePath).convert('RGB')
    image = np.array(image)
    image = to_tensor(image)

In [None]:
#Define a transformation
trans = transforms.Compose([
    transforms.RandomCrop((512, 512)),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [None]:
#loading the Generator and putting it on eval mode
gen = Generator().to(device)
gen.load_state_dict(torch.load(f'generator_{textureName}.pth'))
gen.eval()

Generator(
  (model): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
    (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): ReLU(inplace=True)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (9): ReLU(inplace=True)
    (10): Sequential(
      (0): ReflectionPad2d((1, 1, 1, 1))
      (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      (2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (3): ReLU(inplace=True)
      (4): ReflectionPad2d((1, 1, 1, 1))
      (5): Conv2d(256, 256, kernel_size=(3, 3), stride

In [None]:
#In the case of data-free approach create 10 random crop images from the original image.
if(not d_driven):
    images = [trans(image).unsqueeze(0).to(device) for i in range(10)]
    resize_transform = Resize((512, 512))

In [None]:
#function used to save and download the result both in the data-free and data-driven approach
def save_and_download(output, i):
    output = output * 0.5 + 0.5
    save_image(output, f"my_data/{textureName}_{i}.png")
    files.download(f"my_data/{textureName}_{i}.png")

In [None]:
#generate images using the newly trained model

image = images[0]

for i in range(7):
    if(d_driven):
        noise = torch.normal(0, 1, size=(8, 3, 256, 256)).to(device)
        output = gen(noise)
    else:
        noise = torch.normal(0, 1, size=image.shape).to(device)
        output = gen(noise)
        output = resize_transform(output)

    save_and_download(output, i)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>