Загрузка данных:

In [None]:
#from google.colab import files
#files.upload()

#!mkdir -p ~/.kaggle
#!cp kaggle.json ~/.kaggle/
#!pip install kaggle
#!chmod 600 /root/.kaggle/kaggle.json
#!kaggle competitions download -c gan-getting-started
#!unzip /content/gan-getting-started.zip

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [None]:
import os
from PIL import Image

import random

class PairImageDataset(Dataset):
    def __init__(self, transforms):
        super().__init__()
        self.path_monet = '../input/gan-getting-started/monet_jpg'
        self.path_photoes = '../input/gan-getting-started/photo_jpg'
        self.transforms = transforms

        filenames_monet = []
        for root, dirs, filenames in os.walk(self.path_monet):
            for file_name in filenames:
              if file_name.endswith('.jpg'):
                filenames_monet.append(os.path.join( self.path_monet, file_name))

        filenames_photoes = []
        for root, dirs, filenames in os.walk(self.path_photoes):
            for file_name in filenames:
              if file_name.endswith('.jpg'):
                filenames_photoes.append(os.path.join(self.path_photoes, file_name))
        
        self.images_monet = []
        self.images_photoes = []
        for file_name in filenames_monet:
            with Image.open(file_name) as img:
                self.images_monet.append(img.copy())
        for file_name in filenames_photoes:
            with Image.open(file_name) as img:
                self.images_photoes.append(img.copy())

    def __len__(self):
        return min(len(self.images_monet) , len(self.images_photoes))

    def __getitem__(self, index):
        index_photo = index
        index_photo = random.randint(0, len(self.images_photoes) - 1)
        index_monet = index
        #print(self.images_photoes[index_photo])
        #print(self.transforms)
        return self.transforms(self.images_photoes[index_photo]), self.transforms(self.images_monet[index_monet])
        

In [None]:
transformer = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
dataset = PairImageDataset(transformer)

In [None]:
import matplotlib.pyplot as plt

image_photo, image_monet = dataset[0]
plt.imshow(image_photo.permute(1, 2, 0))
plt.show()
plt.imshow(image_monet.permute(1, 2, 0))
plt.show()

In [None]:
dataloader = DataLoader(dataset, batch_size = 5, num_workers = 0)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
print(len(dataset.images_monet))
print(len(dataset.images_photoes))

In [None]:
import itertools

class CycleGAN:
    def __init__(self, 
                 generator_monet2photo,
                 generator_photo2monet,
                 descriminator_photo,
                 descriminator_monet):
        self.generator_monet2photo = generator_monet2photo
        self.generator_photo2monet = generator_photo2monet
        self.descriminator_monet = descriminator_monet
        self.descriminator_photo = descriminator_photo

        self.descriminator_loss = nn.MSELoss()
        self.image_loss = nn.L1Loss()

        self.optimizer_generator_monet = torch.optim.Adam(
            self.generator_monet2photo.parameters(), 
            lr = 1e-3, betas = (0.5, 0.996))
        self.optimizer_generator_photo = torch.optim.Adam(
            self.generator_photo2monet.parameters(), 
            lr = 1e-3, betas = (0.5, 0.996))

        self.optimizer_descriminator_monet = torch.optim.Adam(self.descriminator_monet.parameters(), lr = 1e-3, betas = (0.5, 0.996))
        self.optimizer_descriminator_photo = torch.optim.Adam(self.descriminator_photo.parameters(), lr = 1e-3, betas = (0.5, 0.996))

    def fit_one_batch(self, photoes, monets, alpha):

        self.optimizer_generator_photo.zero_grad()
        self.optimizer_generator_monet.zero_grad()
        #generator_photo2monet loss

        ones = torch.ones((monets.shape[0], 1 * 16 * 16)).to(device)
        zeros = torch.zeros((monets.shape[0], 1 * 16 * 16)).to(device)
        
        fake_monet = self.generator_photo2monet(photoes)
        fake_photo = self.generator_monet2photo(monets)
        
        reconstructed_photo = self.generator_monet2photo(fake_monet)
        reconstructed_monet = self.generator_photo2monet(fake_photo)
        
        loss_gen = 2 * (self.image_loss(fake_monet, photoes) + self.image_loss(fake_photo, monets)) / 2 
        
        loss_gen += (self.descriminator_loss(self.descriminator_monet(fake_monet), ones) + \
                     self.descriminator_loss(self.descriminator_photo(fake_photo), ones)) / 2
        
        
        loss_gen += 9 * (self.image_loss(reconstructed_photo, photoes) +\
                     self.image_loss(reconstructed_monet, monets)) / 2 
        
        loss_gen.backward()
        self.optimizer_generator_photo.step()
        self.optimizer_generator_monet.step()

        #descriminator_monet loss

        self.optimizer_descriminator_monet.zero_grad()

        real_monets_score = self.descriminator_monet(monets)
        target_real_monets_score = torch.ones((monets.shape[0], 1 * 16 * 16)).to(device)

        fake_monets_score = self.descriminator_monet(fake_monet.detach())
        target_fake_monets_score = torch.zeros((fake_monet.shape[0], 1 * 16 * 16)).to(device)

        loss_descriminator_monet = self.descriminator_loss(real_monets_score, target_real_monets_score) + \
                                   self.descriminator_loss(fake_monets_score, target_fake_monets_score)
        loss_descriminator_monet /= 2
        loss_descriminator_monet.backward()
        self.optimizer_descriminator_monet.step()

        #descriminator_photos loss
        self.optimizer_descriminator_photo.zero_grad()

        real_photos_score = self.descriminator_photo(photoes)
        target_real_photos_score = torch.ones((photoes.shape[0], 1 * 16 * 16)).to(device)

        fake_photos_score = self.descriminator_photo(fake_photo.detach())
        target_fake_photos_score = torch.zeros((fake_photo.shape[0],1 * 16 * 16)).to(device)

        loss_descriminator_photos = self.descriminator_loss(real_photos_score, target_real_photos_score) + \
                                    self.descriminator_loss(fake_photos_score, target_fake_photos_score)
        loss_descriminator_photos /= 2

        loss_descriminator_photos.backward()
        self.optimizer_descriminator_photo.step()

    def predict(self, photo):
        return self.generator_photo2monet(photo)
    
    def predict_monet(self, photo):
        return self.generator_monet2photo(photo)

    def save(self, epoch):
        torch.save(self.generator_photo2monet, f'generator_photo2monet_{epoch}.bin')
        #torch.save(self.generator_monet2photo, f'generator_monet2photo_{epoch}.bin')
        #torch.save(self.descriminator_monet, f'descriminator_monet_{epoch}.bin')
        #torch.save(self.descriminator_photo, f'descriminator_photo_{epoch}.bin')
        
    def pretrain_generators(self, photoes, monets):
        
        self.pretrain_optimizer_p2m.zero_grad()
        fake_photoes = self.generator_photo2monet(photoes)
        loss_photoes = self.pretrain_loss(fake_photoes, photoes)
        loss_photoes.backward()
        self.pretrain_optimizer_p2m.step()

        self.pretrain_optimizer_m2p.zero_grad()
        fake_monets = self.generator_monet2photo(monets)
        loss_monets = self.pretrain_loss(fake_monets, monets)
        loss_monets.backward()
        self.pretrain_optimizer_m2p.step()

    def load(self, path):
        self.generator_photo2monet = torch.load(path)
        
    def train(self):
        self.generator_photo2monet.train()
        self.generator_monet2photo.train()
        self.descriminator_monet.train()
        self.descriminator_photo.train()
        
    def eval(self):
        self.generator_photo2monet.eval()
        self.generator_monet2photo.eval()
        self.descriminator_monet.eval()
        self.descriminator_photo.eval()

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.module = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.01),
        )

    def forward(self, X):
        return self.module(X)

    
class BlockConv2(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.01),
        )

    def forward(self, X):
        return self.model(X)

In [None]:
class UpNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.model = nn.Sequential(
            nn.Upsample(scale_factor = 2, mode='nearest'),
            Block(in_channels, out_channels),
            Block(out_channels, out_channels),
        )
    def forward(self, X):
        return self.model(X)

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.model = nn.Sequential(
            Block(in_channels, out_channels),
            Block(out_channels, out_channels)
        )
    def forward(self, X):
        return self.model(X) + X
    
class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            Block(3, 16),
            BlockConv2(16, 64),
            BlockConv2(64, 128),
            #BlockConv2(128, 128),
            
            ResBlock(128, 128),
            ResBlock(128, 128),
            ResBlock(128, 128),
            ResBlock(128, 128),
            ResBlock(128, 128),
            ResBlock(128, 128),
            ResBlock(128, 128),
            
            #UpNet(128, 128),
            UpNet(128, 64),
            UpNet(64, 16),
            #ResBlock(64, 64),
            nn.Conv2d(16, 3, 3, 1, 1),
            nn.Tanh()
        )
    def forward(self, X):
        return self.model(X)

In [None]:
class DiscriminatorRes(nn.Module):
    def __init__(self):
        super(DiscriminatorRes, self).__init__()
        
        self.model = nn.Sequential(
            Block(3, 16),
            BlockConv2(16, 32),#128
            Block(32, 64), 
            BlockConv2(64, 64),#64
            Block(64, 128),
            BlockConv2(128, 128), # 32
            BlockConv2(128, 128), # 16
            nn.Conv2d(128, 1, 3, 1, 1),
            nn.Flatten()
        )
        
        self.scale_factor = 16
    
    def forward(self, x):
        return self.model(x)

In [None]:
model = CycleGAN(ResNet().to(device), 
                 ResNet().to(device),
                 DiscriminatorRes().to(device), 
                 DiscriminatorRes().to(device))


model.load('../input/last-120/generator_photo2monet_50 (2).bin')

In [None]:
from tqdm.auto import tqdm
from torchvision.utils import make_grid

def convert_image(x):
    return make_grid(x, nrows = x.shape[0], normalize=True)

num_epochs = 0
i = 0

for epoch in tqdm(range(num_epochs)):
    for photo_batch, monet_batch in tqdm(dataloader):
        i += 1
        alpha = 10
        if epoch > 10:
            alpha = 4
        model.fit_one_batch(photo_batch.to(device), monet_batch.to(device), alpha)
        
        if i % 150 == 1:
 
            fig = plt.figure(figsize=(15, 15))
            predicted_photo = model.predict(photo_batch.to(device))
            predicted_monet = model.predict_monet(monet_batch.to(device))
            fig.add_subplot(1, 4, 1)
            
            plt.imshow(convert_image(photo_batch[0]).permute(1, 2, 0).cpu())

            fig.add_subplot(1, 4, 2)
            plt.imshow(convert_image(monet_batch[0]).permute(1, 2, 0).cpu())

            fig.add_subplot(1, 4, 3)
            plt.imshow(convert_image(predicted_photo[0]).permute(1, 2, 0).cpu().detach().numpy())
            
            fig.add_subplot(1, 4, 4)
            plt.imshow(convert_image(predicted_monet[0]).permute(1, 2, 0).cpu().detach().numpy())
            
            plt.title(f"epoch {epoch + 1}")
            plt.show()
    if epoch % 5 == 4:
        model.save(epoch + 1)

In [None]:
import os
from PIL import Image

class ImageDataset(Dataset):
    def __init__(self, transforms):
        super().__init__()
        self.path_monet = '../input/gan-getting-started/monet_jpg'
        self.path_photoes = '../input/gan-getting-started/photo_jpg'
        self.transforms = transforms

        filenames_monet = []
        for root, dirs, filenames in os.walk(self.path_monet):
            for file_name in filenames:
                if file_name.endswith('.jpg'):
                    filenames_monet.append(os.path.join( self.path_monet, file_name))

        filenames_photoes = []
        for root, dirs, filenames in os.walk(self.path_photoes):
            for file_name in filenames:
                if file_name.endswith('.jpg'):
                    filenames_photoes.append(os.path.join(self.path_photoes, file_name))
        
        self.images_monet = []
        self.images_photoes = []
        for file_name in filenames_monet:
            with Image.open(file_name) as img:
                self.images_monet.append(img.copy())
        for file_name in filenames_photoes:
            with Image.open(file_name) as img:
                self.images_photoes.append(img.copy())

    def __len__(self):
        return len(self.images_photoes)

    def __getitem__(self, index):
        index_monet = index 
        index_photo = index
        #print(self.images_photoes[index_photo])
        #print(self.transforms)
        return self.transforms(self.images_photoes[index_photo])
        

In [None]:
transformer_ans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
dataset_for_ans = ImageDataset(transformer_ans)

In [None]:
!mkdir ../images

In [None]:
ph_dl = DataLoader(dataset_for_ans, batch_size=1, pin_memory=True)

In [None]:
from tqdm.auto import tqdm

trans = transforms.ToPILImage()


t = tqdm(ph_dl, leave=False, total=ph_dl.__len__())
for i, photo in enumerate(t):
    with torch.no_grad():
        output = model.predict(photo.to(device))
        output = convert_image(output)
        pred_monet = output.reshape(3, 256, 256)
        pred_monet = pred_monet.cpu().detach()
        if i % 400 == 0:
            fig = plt.figure(figsize=(15, 15))
            plt.imshow(pred_monet.permute(1,2,0))
            plt.show()
    #print(pred_monet.shape)
    img = trans(pred_monet).convert("RGB")
    img.save("../images/" + str(i+1) + ".jpg")

In [None]:
import shutil

shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")
