In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --apt-packages libomp5 libopenblas-dev

In [2]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import matplotlib.pyplot as plt

import os
from glob import glob

# import torch_xla.core.xla_model as xm
# import torch_xla.distributed.parallel_loader as pl
# import torch_xla.distributed.xla_multiprocessing as xmp

import warnings
warnings.filterwarnings("ignore")


In [3]:
class ImageDataset(Dataset):
    def __init__(self, monet_path, photo_path, transform=None):
        super().__init__()
        self.monet_path = monet_path
        self.photo_path = photo_path
        self.monet_files = glob(self.monet_path+'/*')
        self.photo_files = glob(self.photo_path+'/*')
        self.transform = transform

    def __getitem__(self, idx):
        idx2 = np.random.randint(len(self.photo_files))
        x1 = Image.open(self.monet_files[idx])
        x2 = Image.open(self.photo_files[idx2])
        if self.transform is not None:
            x1 = self.transform(x1)
            x2 = self.transform(x2)
        return x1, x2
    
    def __len__(self):
        return len(self.monet_files)

In [4]:
feature_transform = T.Compose([T.Resize((256,256)),T.ToTensor(),T.RandomVerticalFlip()])

ds = ImageDataset('/kaggle/input/gan-getting-started/monet_jpg','/kaggle/input/gan-getting-started/photo_jpg', transform=feature_transform)

In [None]:
!export XLA_USE_BF16=1

In [5]:
class G_up_Block(nn.Module):
    def __init__(self, in_channels, out_channels, k, stride=1, padding=0):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, k, stride=stride ,padding=padding),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        x = self.layers(x)
        return x

class Residual_Block(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.layers = nn.Sequential(
            G_up_Block(channels, channels, 3, 1, 1),
            G_up_Block(channels, channels, 3, 1, 1)
        )
    def forward(self, x):
        return x + self.layers(x)

class G_down_Block(nn.Module):
    def __init__(self, in_channels, out_channels, k, stride=1, padding=0, output_padding=1):
        super().__init__()
        self.layers = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, k, stride, padding, output_padding),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        x = self.layers(x)
        return x

class Generator(nn.Module):
    def __init__(self, num_residual_blocks):
        super().__init__()
        self.num_residual_blocks = num_residual_blocks
        self.up = nn.Sequential(
            G_up_Block(3,64,7,1,3),
            G_up_Block(64,128,3,2,1),
            G_up_Block(128,256,3,2,1)
        )
        self.residual = nn.Sequential(*[Residual_Block(256) for _ in range(self.num_residual_blocks)])
        self.down = nn.Sequential(
            G_down_Block(256,128,3,2,1,1),
            G_down_Block(128,64,3,2,1,1),
            G_up_Block(64,3,7,1,3)
        )
    def forward(self, x):
        x = self.up(x)
        x = self.residual(x)
        x = self.down(x)
        return x

class D_Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride=2, padding=1, norm=True):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, padding),
            nn.InstanceNorm2d(out_channels) if norm else nn.Identity(),
            nn.LeakyReLU(0.2, inplace=True)
        )
    def forward(self, x):
        x = self.layers(x)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            D_Block(3,64,2,1,norm=False),
            D_Block(64,128,2,1),
            D_Block(128,256,2,1),
            D_Block(256,512,1,1),
            D_Block(512,1,1,1),
        )
    def forward(self, x):
        x = self.layers(x)
        x = torch.sigmoid(x)
        return x

In [28]:
def train(loader, monet_G, photo_G, monet_D, photo_D, L1, mse, G_optim, D_optim, device, lambda_cycle=10):
    total_G_loss = 0.0 
    total_D_loss = 0.0 
    running_D_loss = 0.0  
    running_G_loss = 0.0
    print_interval = 100
    for batch_id, (monet, photo) in enumerate(loader):        
        monet = monet.to(device)
        photo = photo.to(device)

        fake_monet = monet_G(photo)
        fake_photo = photo_G(monet)
        
        #train discriminator
        critic_monet_real = monet_D(monet)
        critic_monet_fake = monet_D(fake_monet.detach())
        critic_monet_real_loss = mse(critic_monet_real, torch.ones_like(critic_monet_real))
        critic_monet_fake_loss = mse(critic_monet_fake, torch.zeros_like(critic_monet_fake))
        critic_monet_loss = critic_monet_real_loss + critic_monet_fake_loss
        
        critic_photo_real = photo_D(photo)
        critic_photo_fake = photo_D(fake_photo.detach())
        critic_photo_real_loss = mse(critic_photo_real, torch.ones_like(critic_photo_real))
        critic_photo_fake_loss = mse(critic_photo_fake, torch.zeros_like(critic_photo_fake))
        critic_photo_loss = critic_photo_real_loss + critic_photo_fake_loss
        
        D_loss = critic_monet_loss + critic_photo_loss
        
        D_optim.zero_grad()
        D_loss.backward()
        #xm.optimizer_step(D_optim)
        D_optim.step()
        
        running_D_loss += D_loss.item()
        total_D_loss += D_loss.item()
        
        #train generator
        critic_monet_fake = monet_D(fake_monet)
        critic_photo_fake = photo_D(fake_photo)
        
        #adverserial loss
        gen_monet_loss = mse(critic_monet_fake, torch.ones_like(critic_monet_fake))
        gen_photo_loss = mse(critic_photo_fake, torch.ones_like(critic_photo_fake))
        
        #cycle loss
        cycle_monet = monet_G(fake_photo)
        cycle_photo = photo_G(fake_monet)
        cycle_monet_loss = L1(monet, cycle_monet)
        cycle_photo_loss = L1(photo, cycle_photo)
        
        G_loss = (gen_monet_loss + gen_photo_loss) + (cycle_monet_loss + cycle_photo_loss)*lambda_cycle #+ (identity_monet_loss + identity_photo_loss)

        G_optim.zero_grad()
        G_loss.backward(retain_graph=True)
        #xm.optimizer_step(G_optim)
        G_optim.step()
        
        running_G_loss += G_loss.item()
        total_G_loss += G_loss.item()
        
        if batch_id % print_interval == print_interval-1:
            print(f'[{step:5d}] G_loss: {running_G_loss / print_interval:.3f}, D_loss:{running_G_loss / print_interval:.3f}')
            running_G_loss = 0.0
            running_D_loss = 0.0
        
        return total_G_loss/len(loader), total_D_loss/len(loader)

In [29]:
G_losses, D_losses = [], []

In [36]:
ds.__len__()

300

In [34]:
loader = DataLoader(ds, batch_size=1, num_workers=2, drop_last=True, shuffle=True)
print(len(loader))

300


In [30]:
def _run():
    global G_losses
    global D_losses
    
    batch_size = 32
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    #device = xm.xla_device()
    
    #train_sampler = torch.utils.data.distributed.DistributedSampler(ds,num_replicas=xm.xrt_world_size(),rank=xm.get_ordinal(),shuffle=True)
    #loader = DataLoader(ds, batch_size=batch_size, num_workers=4, drop_last=True, sampler=train_sampler)
    
    loader = DataLoader(ds, batch_size=batch_size, num_workers=2, drop_last=True, shuffle=True)
    monet_G = Generator(9).to(device)
    photo_G = Generator(9).to(device)
    monet_D = Discriminator().to(device)
    photo_D = Discriminator().to(device)
    
    #lr = 0.4 * 1e-3 * xm.xrt_world_size()
    lr = 1e-3
    betas = (0.5,0.999)

    G_optim = optim.Adam(list(monet_G.parameters())+list(photo_G.parameters()), lr=lr, betas=betas)
    D_optim = optim.Adam(list(monet_D.parameters())+list(photo_D.parameters()), lr=lr, betas=betas)
    
    L1 = nn.L1Loss()
    mse = nn.MSELoss()
    
    EPOCHS = 10
    
    for epoch in range(EPOCHS):
        #para_loader = pl.ParallelLoader(loader, [device])
        print(f'[{epoch+1:2d}/{EPOCHS}]')
        #g_loss, d_loss = train(para_loader.per_device_loader(device), monet_G, photo_G, monet_D, photo_D, L1, mse, D_optim, G_optim)
        g_loss, d_loss = train(loader, monet_G, photo_G, monet_D, photo_D, L1, mse, G_optim, D_optim, device)
        G_losses.append(g_loss)
        D_losses.append(d_loss)

In [31]:
_run()

[ 1/10]


OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB (GPU 0; 15.90 GiB total capacity; 13.99 GiB already allocated; 137.75 MiB free; 14.89 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
def _mp_fn(rank, flags):
    torch.set_default_tensor_type('torch.FloatTensor')
    a = _run()

In [None]:
FLAGS = {}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')