In [1]:
import os
import sys

import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import DataLoader, Dataset
import albumentations
import albumentations.pytorch
import cv2
import matplotlib.pyplot as plt

In [2]:
np.random.seed(777)
torch.cuda.manual_seed(777)
device = torch.device('cuda')

In [3]:
class AnimeDataset(Dataset):
    def __init__(self, transform = None):
        super().__init__()
        
        self.img_dir = "/home/temp_1/kangsanha/AnimeGan/For_pixel2style2pixel/AnimeGan/data/"
        self.filenames = os.listdir(self.img_dir)

        self.transform = transform

    def __getitem__(self, index):
        img_path = os.path.join(self.img_dir, self.filenames[index])
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR -> RGB
        
        
        if img.dtype == np.uint8:
            img = img / 255.0
        
        if self.transform:
            img_transform = self.transform(image=img)
            img = img_transform['image']
            
       
        return img
        

    def __len__(self):
        
        return len(self.filenames)

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, input_channels, filters, upsample = True, upsample_rgb = True,
                rgba = False):
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
        
        self.to_style1 = nn.Linear(latent_dim, input_channels)
        self.to_noise1 = nn.Linear(1, filters)
        self.conv1 = Conv2DMod(input_channels, filters, 3)
        
        self.to_style2 = nn.Linear(latent_dim, filters)
        self.to_noise2 = nn.Linear(1, filters)
        self.conv2 = Conv2DMod(filters, filters, 3)

        self.activation = leaky_relu()
        self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba)