# MetroGAN
This code is for 128*128 dataset.

In [None]:
import glob
import os
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import time
import datetime
import sys
from torch.utils.data import DataLoader
import torch.nn as nn
import torchvision.utils as vutils
import torch.nn.functional as F
import neptune.new as neptune
import functools
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.utils.spectral_norm as spectral_norm
from torchsummary import summary

### neptune init

In [None]:
# GPU usage configuration
cuda_id = 0
torch.cuda.set_device(cuda_id)
cuda = True if torch.cuda.is_available() else False
device = torch.device("cuda:{}".format(cuda_id) if torch.cuda.is_available() else "cpu")


# neptune block
neptune_log = False

## dataset and dataloader

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, mode="train"):
        """
        :param root: root path of the dataset
        :param transforms_: transforms of the data
        :param mode: "train"/"test". In test mode, the name of the city will be record and the inputs will be record separately.
        """
        self.mode=mode
        self.transform = transforms.Compose(transforms_)
        #self.pop_files = sorted(glob.glob(os.path.join(root, "population resize") + "/*.*"))
        self.light_files = sorted(glob.glob(os.path.join(root, "NTL") + "/*.*"))
        self.water_files = sorted(glob.glob(os.path.join(root, "Water") + "/*.*"))
        self.build_files = sorted(glob.glob(os.path.join(root, "Built-up_area") + "/*.*"))
        self.dem_files = sorted(glob.glob(os.path.join(root, "DEM") + "/*.*"))
        # print(len(self.light_files),(os.path.join(root,"population resize/population resize") + "/*.*"))
        if mode == "test":
            self.cities_name = sorted(glob.glob(os.path.join(root, "DEM") + "/*.*"))
            
    def __getitem__(self, index):
        index = index % len(self.dem_files)
        #img_pop = self.transform((torch.as_tensor(np.array(Image.open(self.pop_files[index])), dtype=torch.float))*2-1)
        img_light = (torch.as_tensor(np.array(Image.open(self.light_files[index])), dtype=torch.float)) *2 -1
        img_water = (torch.as_tensor(np.array(Image.open(self.water_files[index])), dtype=torch.float)) / 127.5-1

        img_build_ori = (torch.as_tensor(np.array(Image.open(self.build_files[index])), dtype=torch.float)) / 127.5-1
        img_build = self.transform(img_build_ori.unsqueeze(0))
        
        img_dem = (torch.as_tensor(np.array(Image.open(self.dem_files[index])), dtype=torch.float) *2-1)
        
        input_img = torch.stack([img_dem,img_light,img_water], dim=0)
        if self.mode == "test":
            city_name = self.cities_name[index]
            return {"input_img": input_img,
                    "label": img_build,
                    "mask": img_water.unsqueeze(0),
                    "dem": img_dem.unsqueeze(0),
                    "name": city_name[city_name.rindex('/')+1:city_name.rindex('.')]}
        else:
            return {"input_img": input_img,
                    "label": img_build,
                    "mask": img_water.unsqueeze(0)}

    def __len__(self):
        return len(self.dem_files)

## parameter configuration

In [None]:
# neptune type parameters
parameters = {
    'img_size':128,
    'start_epoch': 0,
    'n_epochs': 400,
    'dataset_name': "global_city_dataset",
    'batch_size':64,
    'g_lr': 0.0004,
    'd_lr':0.0004,
    'scheduler':"cosine",
    'gan_loss':"mse",
    'decay_epoch': 0,
    'optimizer': 'adam',
    'b1': 0.5,
    'b2': 0.999,
    'input_channels': 3,
    'output_channels': 1,
    'input_path': "../multi-year_dataset/train",
    'test_path': "../multi-year_dataset/validate",
    'pixel_lamda': 30,
    'ngf': 64,
    'sample_interval': 50,
    'checkpoint_interval': 50,
    'n_cpu': 8,
    'debug':False,
    'pretrain':False,
    'pretrain_save':True,
    'load_pretrain' : False,
}
if neptune_log:
    run['model/parameters'] = parameters
    os.makedirs("images/%s" % run['sys/id'].fetch(), exist_ok=True)
    os.makedirs("saved_models/%s" % run['sys/id'].fetch(), exist_ok=True)
print(parameters)



## tool functions

In [None]:
# The function to initialize the model
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

def ShowResults(generator, dataloader_test, step, mode="train"):
    generator.eval()
    for j, test_batch in enumerate(dataloader_test):
        if j == 1:
            test_input_img = test_batch["input_img"].cuda()
            test_label = 0.5* test_batch["label"]+0.5
            with torch.no_grad():
                fake = 0.5* generator(test_input_img, step, 1).detach().cpu()+0.5
                imgs = torch.cat((fake, test_label), dim=0)
                if step == 5:
                    imgs = torch.cat(
                        (imgs,
                         0.5+0.5*(test_input_img.cpu()[:, 1, :, :].unsqueeze(1)),
                         0.5+0.5*(test_input_img.cpu()[:, 0, :, :].unsqueeze(1)),
                         0.5+0.5*(test_input_img.cpu()[:, 2, :, :].unsqueeze(1))),
                        dim=0)
                fig = plt.figure(dpi=200.0)
                plt.imshow(np.transpose(vutils.make_grid(imgs, padding=2, normalize=False,pad_value=1.0), (1, 2, 0)))
                plt.axis("off")
                if neptune_log:
                    if mode=="pretrain":
                        run["city-fig/pretrain_{}".format(epoch)].upload(fig)
                    elif mode=="train":
                        run["city-fig/{}".format(epoch)].upload(fig)
                    else:
                        assert False
                    run["city-figs"].log(fig)
                plt.show()
                plt.close()
    generator.train()


def get_scheduler(optimizer, lr_policy):
    """
    The fuction to choose learning rate scheduler
    :param optimizer: the optimizer of the network
    :param lr_policy: a string that indicate lr policy
    :return: the scheduler
    """
    if lr_policy == 'linear':
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch - 100) / float(parameters["n_epochs"]-99)
            return lr_l
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    elif lr_policy == 'step':
        scheduler = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
    elif lr_policy == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
    elif lr_policy == 'cosine':
        #scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=parameters["n_epochs"], eta_min=0)
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=2,T_mult=2)
    else:
        return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
    return scheduler

## Model

In [None]:
# weighted scaled conv2d
class WSConv2d(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2
    ):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        # initialize conv layer
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)
        
        self.weight=self.conv.weight
    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_batchnorm=True):
        super(ConvBlock, self).__init__()
        self.use_bn = use_batchnorm
        #self.conv1 = nn.Conv2d(in_channels, out_channels,kernel_size=3, stride=1, padding=1)
        #self.conv2 = nn.Conv2d(out_channels, out_channels,kernel_size=3, stride=1, padding=1)
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)
        self.bn = nn.BatchNorm2d(out_channels,momentum=0.8)
        
        self.weight=self.conv1.weight
    def forward(self, in_feature, skip):
        x=torch.cat([in_feature, skip], 1)
        x = self.leaky(self.conv1(x))
        x = self.bn(x) if self.use_bn else x
        x = self.leaky(self.conv2(x))
        x = self.bn(x) if self.use_bn else x
        return x    

# convolution block in decoder
class UpConv(nn.Module):
    def __init__(self,in_channels , out_channels, use_batchnorm=True):
        super(UpConv, self).__init__()
        self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
        self.leaky = nn.LeakyReLU(0.2,inplace=True)
        self.use_bn=use_batchnorm
        self.bn = nn.BatchNorm2d(out_channels)
        
        self.weight=self.conv.weight
    def forward(self,in_feature,skip):
        x=torch.cat([in_feature, skip], 1)
        d = self.conv(x)
        if self.use_bn:
            d = self.bn(d)
        d = self.leaky(d)
        return d    

#convolution block in encoder
class DownConv(nn.Module):
    def __init__(self,in_channels , out_channels, use_batchnorm=True):
        super(DownConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
        self.leaky = nn.LeakyReLU(0.2,inplace=True)
        self.use_bn=use_batchnorm
        self.bn = nn.BatchNorm2d(out_channels)
        
        self.weight=self.conv.weight
    def forward(self,in_feature):
        d = self.conv(in_feature)
        if self.use_bn:
            d = self.bn(d)
        d = self.leaky(d)
        return d
    
##############################
# Progressive Generator
#############################
class Generator(nn.Module):
    
    def __init__(self, in_channels=3, out_channels=1, max_channels=512, ngf=64):
        super(Generator, self).__init__()
        
        self.down1=DownConv(in_channels, ngf, False)
        self.down2=DownConv(ngf, 2*ngf)
        self.down3=DownConv(2*ngf, 4*ngf)
        self.down4=DownConv(4*ngf, 8*ngf)
        self.down5=DownConv(8*ngf, 8*ngf)
        self.down6=DownConv(8*ngf, 8*ngf)
        self.down7=DownConv(8*ngf, 8*ngf, False)
        
        factors=[0.5,0.5,0.25,0.25,0.25,0.25]
        in_cns=[1024,1024,1024,512,256,128]
        self.initial = nn.Sequential(
            nn.ConvTranspose2d(max_channels, max_channels, kernel_size=4, stride=2,padding=1, bias=True),
            nn.BatchNorm2d(max_channels,momentum=0.8),
            nn.LeakyReLU(0.2),
            nn.Conv2d(max_channels, max_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(max_channels,momentum=0.8),
            nn.LeakyReLU(0.2),
        )

        self.initial_out = nn.Conv2d(
            max_channels, out_channels, kernel_size=1, stride=1, padding=0
        )
        self.prog_blocks, self.out_layers = (
            #nn.ModuleList([ConvBlock(2*max_channels, 2*max_channels)]),
            nn.ModuleList([]),
            nn.ModuleList([self.initial_out]),
        )
        
        
        for i in range(len(factors)):  
            conv_in_c = int(in_cns[i])
            conv_out_c = int(in_cns[i] * factors[i])
            #self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
            self.prog_blocks.append(UpConv(conv_in_c, conv_out_c))
            self.out_layers.append(
                nn.Conv2d(conv_out_c, out_channels, kernel_size=1, stride=1, padding=0)
            )
        
    def fade_in(self, alpha, upscaled, generated):
        # alpha should be scalar within [0, 1], and upscale.shape == generated.shape
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)
    
    def forward(self, x, steps=5, alpha=1e-5, gf=64):
        #print("steps:",steps)
        steps=int(steps)
        
        # Downsampling
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        ds = [d6, d5, d4, d3, d2, d1]   

        out = self.initial(d7)

        for step in range(int(steps+1)):
            upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
            #out = self.prog_blocks[step](upscaled,ds[step+1])
            out = self.prog_blocks[step](out,ds[step])
            
        '''if steps == len(self.prog_blocks):
            upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
            #out = Conv2D(channels, kernel_size=3, strides=1, padding=1)(upscaled)
            out = nn.ConvTranspose2d(in_cns[-1]*factors[-1], kernel_size=4, stride=2, padding=1)(out)
            d = nn.BatchNorm2d(momentum=0.8)(out)
            d = nn.LeakyReLU(alpha=0.2,inplace=True)(out)'''

        final_upscaled = self.out_layers[steps](upscaled)
        final_out = self.out_layers[steps+1](out)
        return self.fade_in(alpha, final_upscaled, final_out)

##############################
#        Discriminator
##############################

class Discriminator(nn.Module):
    def __init__(self, in_channels=3, dropout=0.5):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalization=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            # layers.append(nn.Dropout(dropout))
            return layers

        self.main_model = nn.Sequential(
            *discriminator_block(in_channels + 1, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            nn.Dropout(dropout)
        )

        self.rest_model = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0)
        )

    def forward(self, img_A, img_B, matching=False):
        # Concatenate image and condition image by channels to produce input_img
        input_img = torch.cat((img_A, img_B), 1)
        # print(input_img.size())
        D_feature = self.main_model(input_img)
        output = self.rest_model(D_feature)
        if(matching is True):
            return output, D_feature
        else:
            return output


## Initialization

In [None]:
prog_epochs=[20,40,40,100,100,parameters['n_epochs']]

# training dataloader
def get_loader_train(image_size):
    transform_resize = [transforms.Resize((image_size, image_size))]
    print(transform_resize)
    loader = DataLoader(
        ImageDataset(parameters["input_path"], transforms_=transform_resize),
        batch_size=parameters['batch_size'],
        shuffle=True,
        num_workers=parameters['n_cpu'],
    )
    return loader
def get_loader_test(image_size):
    transform_resize = [transforms.Resize((image_size, image_size))]
    print(transform_resize)
    loader = DataLoader(
        ImageDataset(parameters["test_path"], transforms_=transform_resize, mode="test"),
        batch_size=8,
        shuffle=True,
        num_workers=parameters['n_cpu'],
    )
    return loader
'''dataloader = DataLoader(
    ImageDataset(parameters["input_path"], transforms_=transforms_),
    batch_size=parameters['batch_size'],
    shuffle=True,
    num_workers=parameters['n_cpu'],
)'''
# testing dataloader
'''dataloader_test = DataLoader(
    ImageDataset(parameters["test_path"], transforms_=[],mode="test"),
    batch_size=8,
    shuffle=False,
    num_workers=parameters['n_cpu'],
)'''

# Tensor type
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# Calculate output of image discriminator (PatchGAN)
patch = (1, 10, 10)
# the output size of the patch GAN is calculated by hand,check goodnote-草稿
use_dropout = True

# Initialize generator and discriminator
# generator = GeneratorUNet()
generator = Generator().to(device)
# generator = ResnetGenerator(input_nc=1, output_nc=1).to(device)
discriminator = Discriminator().to(device)
summary(generator,(3,128,128))
#summary(discriminator,[(3,128,128),(1,128,128)])

load_id=161
# initialize parameters
if parameters['start_epoch'] != 0:
    # Load pretrained models
    generator.load_state_dict(torch.load(
        "saved_models/CIT-{}/generator_{}.pth".format(load_id,parameters['start_epoch']-1)))
    discriminator.load_state_dict(torch.load(
        "saved_models/CIT-{}/discriminator_{}.pth".format(load_id,parameters['start_epoch']-1)))
else:
    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)
    
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=parameters['g_lr'], betas=(parameters['b1'], parameters['b2']))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=parameters['d_lr'], betas=(parameters['b1'], parameters['b2']))

scheduler_G = get_scheduler(optimizer_G, parameters['scheduler'])
scheduler_D = get_scheduler(optimizer_D,parameters['scheduler'])

## loss function

In [None]:
# water mask loss function
class MaskLoss(nn.Module):
    def __init__(self):
        super(MaskLoss, self).__init__()
    def forward(self, img, mask):
        return torch.sum((img*0.5+0.5)*(mask*0.5+0.5)/(128*128))


# Loss functions
if parameters['gan_loss'] == "mse":
    criterion_GAN = torch.nn.MSELoss(reduction="mean").to(device)
elif parameters['gan_loss'] =="bce":
    criterion_GAN = torch.nn.BCEWithLogitsLoss(reduction="mean").to(device)
criterion_pixelwise = torch.nn.L1Loss().to(device)
criterion_water = MaskLoss().to(device)

## Training

In [None]:
##################
#  Training
################
prev_time = time.time()
img_list = []
G_losses = []
D_losses = []
iters = 0
torch.cuda.empty_cache()

n_steps=int(np.log2(parameters['img_size'] / 2))
step=0
for num_epochs in prog_epochs:
    alpha = 1e-5  # start with very low alpha
    dataloader = get_loader_train(4 * 2 ** step)
    dataloader_test = get_loader_test(4 * 2 ** step)
    print(f"Current image size: {4 * 2 ** step}")
    
    for epoch in range(num_epochs):
        for i, batch in enumerate(dataloader):
            # Model input_imgs
            input_img = batch["input_img"].cuda()
            label = batch["label"].cuda()
            mask = batch["mask"].cuda()
            
            # Adversarial ground truths
            valid = Tensor(np.ones((input_img.size(0), *patch))).cuda()
            fake = Tensor(np.zeros((input_img.size(0), *patch))).cuda()

            fake_img = generator(input_img,step,alpha).cuda()
            # ---------------------
            #  Train Discriminator
            # ---------------------
            if step==n_steps-1:
                optimizer_D.zero_grad()

                # Real loss
                pred_real = discriminator(label, input_img)
                loss_real = criterion_GAN(pred_real, valid)  # 1-D

                # Fake loss
                pred_fake = discriminator(fake_img.detach(), input_img)
                loss_fake = criterion_GAN(pred_fake, fake)  # D

                # Total loss
                loss_D = (loss_real + loss_fake) * 0.5

                loss_D.backward()
                optimizer_D.step()

            # ------------------
            #  Train Generators
            # ------------------

            optimizer_G.zero_grad()
            
            # Pixel-wise loss
            loss_pixel = criterion_pixelwise(fake_img, label)
            if step==n_steps-1:
                # GAN loss
                pred_fake = discriminator(fake_img, input_img)
                
                loss_GAN = criterion_GAN(pred_fake, valid)  # 1-D
                # water constrain loss
                loss_water = criterion_water(fake_img, mask)
                loss_G = (loss_GAN + loss_pixel * parameters['pixel_lamda']+100*loss_water)
            else:
                loss_G = (loss_pixel * parameters['pixel_lamda'])
            # print(loss_GAN,loss_pixel)
            loss_G.backward()
            optimizer_G.step()

            #===============
            # training debug
            #---------------
            if parameters["debug"]==True:
                fake_img = generator(input_img).cuda()
                imgs_error_test = torch.cat((fake_img.detach()*0.5+0.5, label*0.5+0.5, (fake_img.detach()-label)*0.5+0.5), dim=0).cpu()
                c=3
                r=5
                fig =plt.figure(dpi=200,figsize=(c,r),facecolor="white",linewidth=0)
                #axs = plt.subplots(r, c)
                for col in range(c):
                    for row in range(r):
                        axs=plt.subplot(r,c,c*row+col+1)
                        #axs.tick_params(axis=u'both', which=u'both',length=0,labelsize=5)
                        plt.xticks([])
                        plt.yticks([])
                        plt.imshow(np.transpose((imgs_error_test[parameters['batch_size']*col+row]),(1,2,0)))
                        #axs.set_title(titles[i])
                        #axs.axis('off')
                plt.subplots_adjust(wspace=0.0,hspace=0.05)
                plt.show()
                plt.close()
            
            # --------------
            #  Log Progress
            # --------------
            if step==n_steps-1:
                # Determine approximate time left
                batches_done = epoch * len(dataloader) + i
                batches_left = parameters["n_epochs"] * len(dataloader) - batches_done
                time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
                prev_time = time.time()
                # Print log
                sys.stdout.write(
                    "\r[Epoch %d/%d][Batch %d/%d][min: %f][D loss: %f] [G loss: %f, pixel: %f, GAN: %f, mask: %f] ETA: %s"
                    % (
                        epoch,
                        parameters["n_epochs"],
                        i,
                        len(dataloader),
                        fake_img.detach().min().item(),
                        loss_D.item(),
                        loss_G.item(),
                        loss_pixel.item(),
                        loss_GAN.item(),
                        loss_water.item(),
                        time_left,
                    )
                )
                if neptune_log:
                    run["train/Dloss"].log(loss_D.item())
                    run["train/Gloss"].log(loss_G.item())
                    run["train/pixelloss"].log(loss_pixel.item())
                    run["train/GANloss"].log(loss_GAN.item())
                    run["train/lr"].log(scheduler_G.get_last_lr()[0])
                G_losses.append(loss_G.item())
                D_losses.append(loss_D.item())
                # Show Results
                if ((epoch % 10 == 0) or (epoch == parameters["n_epochs"] - 1)) and (i == len(dataloader) - 1):
                    ShowResults(generator, dataloader_test, step, "train")

                # progress scheduler
                scheduler_G.step()
                scheduler_D.step()
                # save generated results
                if neptune_log and (((epoch == parameters["n_epochs"] - 1) and (i == len(dataloader) - 1)) or (epoch%parameters['checkpoint_interval']==5)):  
                    #save images
                    os.makedirs("images/%s/%d" % (run['sys/id'].fetch(),epoch), exist_ok=True)
                    for j, test_batch in enumerate(dataloader_test):
                        test_input_img = test_batch["input_img"].cuda()
                        test_label = 0.5* test_batch["label"]+0.5
                        test_dem=test_batch["dem"]*0.5+0.5
                        test_water=test_batch["mask"]*0.5+0.5
                        with torch.no_grad():
                            fake = 0.5* generator(test_input_img).detach().cpu()+0.5
                        for k in range(test_label.shape[0]):
                            vutils.save_image(test_label[k],"./images/%s/%d/%s_%s.png"%(run['sys/id'].fetch(),epoch,test_batch['name'][k],'builtarea'))
                            vutils.save_image(fake[k],"./images/%s/%d/%s_%s.png"%(run['sys/id'].fetch(), epoch,test_batch['name'][k],'output'))
                            vutils.save_image(test_dem[k],"./images/%s/%d/%s_%s.png"%(run['sys/id'].fetch(),epoch,test_batch['name'][k],'dem'))
                            vutils.save_image(test_water[k],"./images/%s/%d/%s_%s.png"%(run['sys/id'].fetch(),epoch,test_batch['name'][k],'water'))
                    # Save model checkpoints

                    torch.save(generator.state_dict(), "saved_models/%s/generator_%d.pth" % (run['sys/id'].fetch(), epoch))
                    torch.save(discriminator.state_dict(), "saved_models/%s/discriminator_%d.pth" % (run['sys/id'].fetch(), epoch))
            else:
                batches_done = epoch * len(dataloader) + i
                batches_left = prog_epochs[step] * len(dataloader) - batches_done
                # Print log
                sys.stdout.write(
                    "\r[Step: %i][Epoch %d/%d][Batch %d/%d][min: %f][Loss:%f]"
                    % (
                        step,
                        epoch,
                        prog_epochs[step],
                        i,
                        len(dataloader),
                        fake_img.detach().min().item(),
                        loss_G.item(),
                    )
                )
                if neptune_log:
                    run["train/pixelloss"].log(loss_G.item())
                if ((epoch % 10 == 0) or (epoch == parameters["n_epochs"] - 1)) and (i == len(dataloader) - 1):
                    ShowResults(generator, dataloader_test, step, "train")
            alpha += 1 / (
                prog_epochs[step] * len(dataloader)
            )
            alpha = min(alpha, 1)
    step += 1  # progress to the next img size

In [None]:
quit