In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [None]:
import os
from PIL import Image

class PairImageDataset(Dataset):
    def __init__(self, transforms):
        super().__init__()
        self.path_monet = '../input/gan-getting-started/monet_jpg'
        self.path_photoes = '../input/gan-getting-started/photo_jpg'
        self.transforms = transforms

        filenames_monet = []
        for root, dirs, filenames in os.walk(self.path_monet):
            for file_name in filenames:
                if file_name.endswith('.jpg'):
                    filenames_monet.append(os.path.join( self.path_monet, file_name))

        filenames_photoes = []
        for root, dirs, filenames in os.walk(self.path_photoes):
            for file_name in filenames:
                if file_name.endswith('.jpg'):
                    filenames_photoes.append(os.path.join(self.path_photoes, file_name))
        
        self.images_monet = []
        self.images_photoes = []
        for file_name in filenames_monet:
            with Image.open(file_name) as img:
                self.images_monet.append(img.copy())
        for file_name in filenames_photoes:
            with Image.open(file_name) as img:
                self.images_photoes.append(img.copy())

    def __len__(self):
        return min(len(self.images_monet) , len(self.images_photoes))

    def __getitem__(self, index):
        index_monet = index 
        index_photo = index
        #print(self.images_photoes[index_photo])
        #print(self.transforms)
        return self.transforms(self.images_photoes[index_photo]), self.transforms(self.images_monet[index_monet])
        

In [None]:
transformer = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

In [None]:
dataset = PairImageDataset(transformer)

In [None]:
import matplotlib.pyplot as plt

image_photo, image_monet = dataset[0]
plt.imshow(image_photo.permute(1, 2, 0))
plt.show()
plt.imshow(image_monet.permute(1, 2, 0))
plt.show()

In [None]:
dataloader = DataLoader(dataset, batch_size = 10, num_workers = 0)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            # 16 * 256 * 256

            nn.Conv2d(16, 8, 4, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(8),
            #8 * 128 * 128

            nn.Conv2d(8, 8, 4, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(8),
            #8 * 64 * 64,

            nn.Conv2d(8, 4, 4, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(4),
            # 4 * 32 * 32
            
            nn.Conv2d(4, 1, 4, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(1),
            # 1 * 16 * 16
            nn.Flatten(),
            nn.Linear(16 * 16, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        ).to(device)
    def forward(self, X):
        return self.model(X)

In [None]:
class GeneratorPhoto(nn.Module):
    def __init__(self, alpha = 0.01):
        super().__init__()
        self.alpha = alpha
        self.model = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1),
            nn.LeakyReLU(self.alpha),
            nn.BatchNorm2d(16),
            #16 * 256 * 256
            nn.Conv2d(16, 8, 4, 2, 1),
            nn.LeakyReLU(self.alpha),
            nn.BatchNorm2d(8),
            #8 * 128 * 128
            nn.Conv2d(8, 4, 4, 2, 1),
            nn.LeakyReLU(self.alpha), 
            nn.BatchNorm2d(4),
            # 4 * 64 * 64  
            nn.Conv2d(4, 2, 3, 1, 1),

            nn.Upsample(scale_factor = 2, mode='nearest'),
            nn.Conv2d(2, 4, 3, 1, 1),
            nn.LeakyReLU(self.alpha),
            nn.BatchNorm2d(4),
            # 4 * 128 * 128
            nn.Upsample(scale_factor = 2, mode='nearest'),
            nn.Conv2d(4, 16, 3, 1, 1),
            nn.LeakyReLU(self.alpha),
            nn.BatchNorm2d(16),
            # 16 * 256 * 256
            nn.Conv2d(16, 8, 3, 1, 1)
        ).to(device)

        self.after_concat = nn.Sequential(
            nn.Conv2d(11, 6, 3, 1, 1),
            nn.LeakyReLU(self.alpha),
            nn.BatchNorm2d(6),

            nn.Conv2d(6, 6, 3, 1, 1),
            nn.LeakyReLU(self.alpha),
            nn.BatchNorm2d(6),

            nn.Conv2d(6, 3, 3, 1, 1),
            nn.Sigmoid()
        ).to(device)
    def forward(self, X):
        hidden = self.model(X)
        hidden = torch.cat((hidden, X), 1)
        return self.after_concat(hidden)

In [None]:
class Generator(nn.Module):
    def __init__(self, alpha = 0.01):
        super().__init__()
        self.alpha = alpha
        self.model = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1),
            nn.LeakyReLU(self.alpha),
            nn.BatchNorm2d(16),
            #16 * 256 * 256
            nn.Conv2d(16, 8, 4, 2, 1),
            nn.LeakyReLU(self.alpha),
            nn.BatchNorm2d(8),
            #8 * 128 * 128
            nn.Conv2d(8, 4, 4, 2, 1),
            nn.LeakyReLU(self.alpha), 
            nn.BatchNorm2d(4),
            # 4 * 64 * 64  
            nn.Conv2d(4, 2, 3, 1, 1),

            nn.Upsample(scale_factor = 2, mode='nearest'),
            nn.Conv2d(2, 4, 3, 1, 1),
            nn.LeakyReLU(self.alpha),
            nn.BatchNorm2d(4),
            # 4 * 128 * 128
            nn.Upsample(scale_factor = 2, mode='nearest'),
            nn.Conv2d(4, 16, 3, 1, 1),
            nn.LeakyReLU(self.alpha),
            nn.BatchNorm2d(16),
            # 16 * 256 * 256
            nn.Conv2d(16, 3, 3, 1, 1)
        ).to(device)

    def forward(self, X):
        hidden = self.model(X)
        return hidden

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.module = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.LeakyReLU(0.01),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, X):
        return self.module(X)


In [None]:
class BlockConv2(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1),
            nn.LeakyReLU(0.01),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, X):
        return self.model(X)

In [None]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
                  BlockConv2(3, 16), 
                  Block(16, 16)
        )
        self.pool1 = nn.MaxPool2d(4, 2, 1, return_indices=True)
        # 16 * 64 * 64
        self.conv2 = nn.Sequential(
                BlockConv2(16, 8), 
                Block(8, 8)
        )     
        # 8 * 32 * 32 
        self.pool2 = nn.MaxPool2d(4, 2, 1, return_indices=True)
        # 8 * 16 * 16
        self.conv3 = Block(8, 8)
        # 4 * 16 * 16
        # 8 * 32 * 32
        self.unpool1 = nn.MaxUnpool2d(4, 2, 1)

        self.up1 = nn.Sequential(
                nn.Upsample(scale_factor = 2, mode='nearest'),
                Block(8, 8),
                Block(8, 16)
        )
        # 8 * 64 * 64
        self.unpool2 = nn.MaxUnpool2d(4, 2, 1)

        self.up2 = nn.Sequential(
                nn.Upsample(scale_factor= 2, mode='nearest'),
                Block(16, 16),
                Block(16, 3)
        )
        # 3 * 256 * 256
        self.last = nn.Sequential(
            Block(6, 9),
            Block(9, 6),
            nn.Conv2d(6, 3, 3, 1, 1)
        )
        self.sigmoid = nn.Sigmoid()
    def forward(self, X):
        hidden = self.conv1(X)
        hidden, ind1 = self.pool1(hidden) 
        hidden = self.conv2(hidden) 
        hidden, ind2 = self.pool2(hidden) 
        hidden = self.conv3(hidden) 
        #print(hidden.shape, ind2.shape)
        hidden = self.unpool1(hidden, ind2)
        hidden = self.up1(hidden) 
        hidden = self.unpool2(hidden, ind1)
        hidden = self.up2(hidden)
        hidden = torch.cat((hidden, X), 1)
        return self.sigmoid(self.last(hidden))

In [None]:
print(len(dataset.images_monet))
print(len(dataset.images_photoes))

In [None]:
class CycleGAN:
    def __init__(self, 
                 generator_monet2photo,
                 generator_photo2monet,
                 descriminator_photo,
                 descriminator_monet):
        self.generator_monet2photo = generator_monet2photo
        self.generator_photo2monet = generator_photo2monet
        self.descriminator_monet = descriminator_monet
        self.descriminator_photo = descriminator_photo

        self.descriminator_loss = nn.BCELoss()
        self.image_loss = nn.MSELoss()

        self.optimizer_generator_monet2photo = torch.optim.Adam(self.generator_monet2photo.parameters(), lr = 1e-3)
        self.optimizer_generator_photo2monet = torch.optim.Adam(self.generator_photo2monet.parameters(), lr = 1e-3)

        self.optimizer_descriminator_monet = torch.optim.Adam(self.descriminator_monet.parameters(), lr = 3e-5)
        self.optimizer_descriminator_photo = torch.optim.Adam(self.descriminator_photo.parameters(), lr = 3e-5)

        self.pretrain_optimizer_p2m = torch.optim.Adam(self.generator_photo2monet.parameters(), lr = 1e-3)
        self.pretrain_optimizer_m2p = torch.optim.Adam(self.generator_monet2photo.parameters(), lr = 1e-3)

        self.pretrain_loss = nn.MSELoss()


    def fit_one_batch(self, photoes, monets, alpha):

        #generator_photo2monet loss

        fake_photo = self.generator_monet2photo(self.generator_photo2monet(photoes))
        
        fake_monet = self.generator_photo2monet(photoes)
        score_fake_monet = self.descriminator_monet(fake_monet)
        expected_score_fake_monet = torch.ones((fake_monet.shape[0], 1)).to(device) 
        
        loss_photo2monet = alpha * self.image_loss(fake_photo, photoes) + self.descriminator_loss(score_fake_monet, expected_score_fake_monet)

        #generator_monet2photo loss

        fake_monet = self.generator_photo2monet(self.generator_monet2photo(monets))

        fake_photo = self.generator_monet2photo(monets)
        score_fake_photo = self.descriminator_photo(fake_photo)
        excpected_score_fake_photo = torch.ones((fake_photo.shape[0], 1)).to(device) 

        loss_monet2photo = alpha * self.image_loss(fake_monet, monets) + self.descriminator_loss(score_fake_photo, excpected_score_fake_photo)

        self.optimizer_generator_monet2photo.zero_grad()
        self.optimizer_generator_photo2monet.zero_grad()
        loss_photo2monet.backward()
        loss_monet2photo.backward()
        self.optimizer_generator_monet2photo.step()
        self.optimizer_generator_photo2monet.step()

        #descriminator_monet loss

        real_monets_score = self.descriminator_monet(monets)
        target_real_monets_score = torch.ones((monets.shape[0], 1)).to(device)

        fake_monets = self.generator_photo2monet(photoes)
        fake_monets_score = self.descriminator_monet(fake_monets)
        target_fake_monets_score = torch.zeros((fake_monets.shape[0], 1)).to(device)

        loss_descriminator_monet = self.descriminator_loss(real_monets_score, target_real_monets_score) + \
                                   self.descriminator_loss(fake_monets_score, target_fake_monets_score)

        #descriminator_photos loss

        real_photos_score = self.descriminator_photo(photoes)
        target_real_photos_score = torch.ones((photoes.shape[0], 1)).to(device)

        fake_photos = self.generator_monet2photo(monets)
        fake_photos_score = self.descriminator_photo(fake_photos)
        target_fake_photos_score = torch.zeros((fake_photos.shape[0], 1)).to(device)

        loss_descriminator_photos = self.descriminator_loss(real_photos_score, target_real_photos_score) + \
                                    self.descriminator_loss(fake_photos_score, target_fake_photos_score)


        self.optimizer_descriminator_monet.zero_grad()
        self.optimizer_descriminator_photo.zero_grad()
        loss_descriminator_monet.backward()
        loss_descriminator_photos.backward()
        self.optimizer_descriminator_monet.step()
        self.optimizer_descriminator_photo.step()

    def predict(self, photo):
        return self.generator_photo2monet(photo)

    def save(self, epoch):
        torch.save(self.generator_photo2monet, f'generator_photo2monet_{epoch}.bin')
        torch.save(self.generator_monet2photo, f'generator_monet2photo_{epoch}.bin')
        torch.save(self.descriminator_monet, f'descriminator_monet_{epoch}.bin')
        torch.save(self.descriminator_photo, f'descriminator_photo_{epoch}.bin')
        
    def load(self, 
             path_generator_photo2monet, 
             path_generator_monet2photo,
             path_descriminator_monet,
             path_descriminator_photo):
        self.generator_photo2monet = torch.load(path_generator_photo2monet)
        self.generator_monet2photo = torch.load(path_generator_monet2photo)
        self.descriminator_monet = torch.load(path_descriminator_monet)
        self.descriminator_photo = torch.load(path_descriminator_photo)
        
    def pretrain_generators(self, photoes, monets):
        
        self.pretrain_optimizer_p2m.zero_grad()
        fake_photoes = self.generator_photo2monet(photoes)
        loss_photoes = self.pretrain_loss(fake_photoes, photoes)
        loss_photoes.backward()
        self.pretrain_optimizer_p2m.step()

        self.pretrain_optimizer_m2p.zero_grad()
        fake_monets = self.generator_monet2photo(monets)
        loss_monets = self.pretrain_loss(fake_monets, monets)
        loss_monets.backward()
        self.pretrain_optimizer_m2p.step()






In [None]:
model = CycleGAN(UNet().to(device), 
                 UNet().to(device),
                 Discriminator().to(device), 
                 Discriminator().to(device))

model.load('../input/unet25/generator_photo2monet_25.bin',
          '../input/unet25/generator_monet2photo_25.bin',
          '../input/unet25/descriminator_monet_25.bin',
          '../input/unet25/descriminator_photo_25.bin')

In [None]:
import os
from PIL import Image

class ImageDataset(Dataset):
    def __init__(self, transforms):
        super().__init__()
        self.path_monet = '../input/gan-getting-started/monet_jpg'
        self.path_photoes = '../input/gan-getting-started/photo_jpg'
        self.transforms = transforms

        filenames_monet = []
        for root, dirs, filenames in os.walk(self.path_monet):
            for file_name in filenames:
                if file_name.endswith('.jpg'):
                    filenames_monet.append(os.path.join( self.path_monet, file_name))

        filenames_photoes = []
        for root, dirs, filenames in os.walk(self.path_photoes):
            for file_name in filenames:
                if file_name.endswith('.jpg'):
                    filenames_photoes.append(os.path.join(self.path_photoes, file_name))
        
        self.images_monet = []
        self.images_photoes = []
        for file_name in filenames_monet:
            with Image.open(file_name) as img:
                self.images_monet.append(img.copy())
        for file_name in filenames_photoes:
            with Image.open(file_name) as img:
                self.images_photoes.append(img.copy())

    def __len__(self):
        return len(self.images_photoes)

    def __getitem__(self, index):
        index_monet = index 
        index_photo = index
        #print(self.images_photoes[index_photo])
        #print(self.transforms)
        return self.transforms(self.images_photoes[index_photo])
        

In [None]:
transformer_ans = transforms.Compose([
    transforms.ToTensor()
])

In [None]:
dataset_for_ans = ImageDataset(transformer_ans)

In [None]:
!mkdir ../images

In [None]:
ph_dl = DataLoader(dataset_for_ans, batch_size=1, pin_memory=True)

In [None]:
from tqdm.auto import tqdm

trans = transforms.ToPILImage()


t = tqdm(ph_dl, leave=False, total=ph_dl.__len__())
for i, photo in enumerate(t):
    with torch.no_grad():
        pred_monet = model.predict(photo.to(device)).reshape(3, 256, 256)
        pred_monet = pred_monet.cpu().detach()
        if i == 0:
            plt.imshow(pred_monet.permute(1,2,0))
            plt.show()
    #print(pred_monet.shape)
    img = trans(pred_monet).convert("RGB")
    img.save("../images/" + str(i+1) + ".jpg")

In [None]:
import shutil

shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")
