In [8]:
import torch
# !pip install --upgrade albumentations
from utils import save_checkpoint, load_checkpoint, save_some_examples
import torch.nn as nn
import torch.optim as optim
import config
from DataSet.dataset import Satellite2Map_Data
from pix2pix.Generator import Generator
from pix2pix.Discriminator import Discriminator
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm


torch.backends.cudnn.benchmark = True
Gen_loss = []
Dis_loss = []

In [2]:
def train(netG: Generator, netD: Discriminator, train_dl, OptimizerG: optim.Adam, OptimizerD: optim.Adam, gen_loss, dis_loss, step_ahead = 0):
    loop = tqdm(train_dl, dynamic_ncols=True)
    for idx, (x, y) in enumerate(loop):
        x = x.to(config.DEVICE).float()
        y = y.to(config.DEVICE).float()
        
        # Train Discriminator
        y_fake = netG(x).float()
        d_real = netD(x, y)
        d_real_loss = dis_loss(d_real, torch.ones_like(d_real))
        d_fake = netD(x, y_fake.detach())
        d_fake_loss = dis_loss(d_fake, torch.zeros_like(d_fake))
        d_loss = (d_real_loss + d_fake_loss) / 2

        netD.zero_grad()
        Dis_loss.append(d_loss.item())
        d_loss.backward()
        OptimizerD.step()

        # Train Generator
        d_fake = netD(x, y_fake)
        g_fake_loss = gen_loss(d_fake, torch.ones_like(d_fake))
        loss = gen_loss(y_fake, y)  # * config.L1_LAMBDA
        g_loss = (g_fake_loss + loss) / 2

        Gen_loss.append(g_loss.item())
        g_loss.backward()
        OptimizerG.step()

        for _ in range(step_ahead):
            OptimizerG.zero_grad()
            y_fake = netG(x).float()
            d_fake = netD(x, y_fake)
            g_fake_loss = gen_loss(d_fake, torch.ones_like(d_fake))
            loss = gen_loss(y_fake, y)  # * config.L1_LAMBDA

            g_loss = (g_fake_loss + loss) / 2
            Gen_loss.append(g_loss.item())
            g_loss.backward()
            OptimizerG.step()

        if idx % 10 == 0:
            loop.set_postfix(
                d_real=torch.sigmoid(d_real).mean().item(),
                d_fake=torch.sigmoid(d_fake).mean().item(),
            )
    print("d_real: " + str(torch.sigmoid(d_real).mean().item()))
    print("d_fake: " + str(torch.sigmoid(d_fake).mean().item()))

def validate(netG, netD, val_dl, gen_loss, dis_loss, 
            scheduler_G, scheduler_D):
    # Calculate validation loss
        val_gen_loss = 0.0
        val_dis_loss = 0.0
        with torch.no_grad():
            for val_x, val_y in val_dl:
                val_x = val_x.to(config.DEVICE).float()
                val_y = val_y.to(config.DEVICE).float()

                # Calculate validation losses (similar to training loop)
                val_y_fake = netG(val_x).float()
                val_d_real = netD(val_x, val_y)
                val_d_fake = netD(val_x, val_y_fake)

                val_d_real_loss = dis_loss(val_d_real, torch.ones_like(val_d_real))
                val_d_fake_loss = dis_loss(val_d_fake, torch.zeros_like(val_d_fake))
                val_d_loss = (val_d_real_loss + val_d_fake_loss) / 2

                val_g_fake_loss = gen_loss(val_d_fake, torch.ones_like(val_d_fake))
                val_loss = gen_loss(val_y_fake, val_y)
                val_g_loss = (val_g_fake_loss + val_loss) / 2

                val_gen_loss += val_g_loss.item()
                val_dis_loss += val_d_loss.item()

            val_gen_loss /= len(val_dl)
            val_dis_loss /= len(val_dl)

        # Update schedulers based on validation loss
        scheduler_G.step(val_gen_loss)
        print("Current LR for Generator:", scheduler_G.get_last_lr()[0]) 
        
        scheduler_D.step(val_dis_loss)
        print("Current LR for Discriminator:", scheduler_D.get_last_lr()[0]) 

In [None]:
def main():
    start = 231
    rgb_on = False
    channels = 3 if rgb_on else 1

    netD = Discriminator(in_channels=channels).to(config.DEVICE)
    netG = Generator(in_channels=channels).to(config.DEVICE)
    optimizerD = torch.optim.Adam(
        netD.parameters(), lr=config.LEARNING_RATE_DISC, betas=(config.BETA1, 0.999)
    )
    optimizerG = torch.optim.Adam(
        netG.parameters(), lr=config.LEARNING_RATE_GEN, betas=(config.BETA1, 0.999)
    )
    gen_loss = nn.MSELoss()
    dis_loss = nn.MSELoss()

    
    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN, start-1, netG, optimizerG, config.LEARNING_RATE_GEN
        )
        load_checkpoint(
            config.CHECKPOINT_DISC, start-1, netD, optimizerD, config.LEARNING_RATE_DISC
        )
    
    train_dataset = Satellite2Map_Data(root=config.TRAIN_DIR, rgb_on=rgb_on)
    train_dl = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True, drop_last=True)
    val_dataset = Satellite2Map_Data(root=config.VAL_DIR, rgb_on=rgb_on)
    val_dl = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True, drop_last=True)
    
    for epoch in range(start, config.NUM_EPOCHS):
        train(
            netG, netD,train_dl,optimizerG,optimizerD,gen_loss,dis_loss, step_ahead=2
        )
        if config.SAVE_MODEL and epoch % 10 == 0 and epoch > 0:
            save_checkpoint(netG, optimizerG, epoch, Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_GEN}")
            save_checkpoint(netD, optimizerD, epoch, Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_DISC}")
        if epoch % 10 == 0 :
            print("save example")
            try:
                save_some_examples(netG,val_dl,epoch,folder="evaluation")
            except Exception as e:
                print(f"Something went wrong saving with epoch {epoch}: {e}")

        print("Epoch :",epoch, " Gen Loss :",Gen_loss[-1], "Disc Loss :",Dis_loss[-1])
    save_checkpoint(netG, optimizerG, "final", Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_GEN}")
    save_checkpoint(netD, optimizerD, "final", Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_DISC}")
    print("save example")
    try:
        save_some_examples(netG,val_dl,config.NUM_EPOCHS,folder="evaluation")
    except Exception as e:
        print(f"Something went wrong with the last epoch")


if __name__ == '__main__':
    main()

In [4]:
import os
from DataSet.dataset import Satellite2Map_Data
from torch.utils.data import DataLoader
import config

def get_folder_names(folder_path):
  folder_names = []
  for item in os.listdir(folder_path):
    item_path = os.path.join(folder_path, item)
    if os.path.isdir(item_path):
      folder_names.append(item)
  return folder_names

folder_path = "./versions (Gen-Disc)"
folder_names = get_folder_names(folder_path)

val_dataset = Satellite2Map_Data(root=config.VAL_DIR, rgb_on=False)
val_dl = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True, drop_last=True)

tests = []

for i, batch in enumerate(val_dl):
    if i >= 1:  # Stop after collecting 1 batch
        break
    tests = batch

In [7]:

from utils import save_matrix
import config
import numpy as np
import torch
from pix2pix.Generator import Generator


def load_model(model,folder_path, lr):
  checkpoint = torch.load(
    f"{folder_path}/399_{config.CHECKPOINT_GEN}", map_location=config.DEVICE
  )
  model.load_state_dict(checkpoint["state_dict"])


def calculate_rmse(matrix1, matrix2):

  # Ensure matrices have the same shape
  if matrix1.shape != matrix2.shape:
    raise ValueError("Matrices must have the same shape.")

  squared_differences = (matrix1 - matrix2) ** 2
  mean_squared_error = np.mean(squared_differences)
  rmse = np.sqrt(mean_squared_error)

  return rmse

import matplotlib.pyplot as plt
model_error = []
for folder in folder_names:
    netG = Generator(in_channels=1).to(config.DEVICE)
    
    load_model(netG,f"{folder_path}/{folder}",config.LEARNING_RATE_GEN)

    errors = []

    for i in range(len(tests[0])):
      x, y = tests[0][i], tests[1][i]
      x, y = x.to(config.DEVICE).float().unsqueeze(1), y.to(config.DEVICE).float().unsqueeze(1)
      netG.eval()
      with torch.no_grad():
          y_fake = netG(x)
      y_fake = y_fake.cpu().detach().numpy().squeeze()
      y = y.cpu().detach().numpy().squeeze()
      errors.append(calculate_rmse(y,y_fake))
    model_error.append(np.mean(errors))

indexed_list = list(enumerate(model_error))
sorted_list = sorted(indexed_list, key= lambda x: x[1])

for tuple in sorted_list:
  print(f"{folder_names[tuple[0]]} with a rmse of {tuple[1]}")


L1 with MSE with a rmse of 0.07165253162384033
L1 with L1 with a rmse of 0.10108126699924469
MSE with L1 with a rmse of 0.10357613116502762
MSE with MSE with a rmse of 0.11847350746393204
L1 with BCEwithLogits with a rmse of 0.14302858710289001
MSE with BCEwithLogits with a rmse of 0.2506081461906433


In [9]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
def main():
    start = 51
    rgb_on = False
    channels = 3 if rgb_on else 1

    # batch_size = [16,32]

    netD = Discriminator(in_channels=channels).to(config.DEVICE)
    netG = Generator(in_channels=channels).to(config.DEVICE)
    
    optimizerD = torch.optim.Adam(
        netD.parameters(), lr=config.LEARNING_RATE_DISC, betas=(config.BETA1, 0.999)
    )
    optimizerG = torch.optim.Adam(
        netG.parameters(), lr=config.LEARNING_RATE_GEN, betas=(config.BETA1, 0.999)
    )


    gen_loss = nn.L1Loss()
    dis_loss = nn.MSELoss()

    
    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN, start-1, netG, optimizerG, config.LEARNING_RATE_GEN
        )
        load_checkpoint(
            config.CHECKPOINT_DISC, start-1, netD, optimizerD, config.LEARNING_RATE_DISC
        )
    
    train_dataset = Satellite2Map_Data(root=config.TRAIN_DIR, rgb_on=rgb_on)
    train_dl = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True, drop_last=True)
    val_dataset = Satellite2Map_Data(root=config.VAL_DIR, rgb_on=rgb_on)
    val_dl = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True, drop_last=True)
    
    for epoch in range(start, config.NUM_EPOCHS):
        
        train(netG, netD, train_dl, optimizerG, optimizerD, gen_loss, dis_loss, step_ahead=2)
        
        # validate(netG, netD, val_dl, gen_loss, dis_loss, scheduler_G, scheduler_D)
        

        if config.SAVE_MODEL and epoch % 10 == 0 and epoch > 0:
            save_checkpoint(netG, optimizerG, epoch, Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_GEN}")
            save_checkpoint(netD, optimizerD, epoch, Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_DISC}")
        if epoch % 10 == 0 :
            print("save example")
            try:
                save_some_examples(netG,val_dl,epoch,folder="evaluation")
            except Exception as e:
                print(f"Something went wrong saving with epoch {epoch}: {e}")

        print("Epoch :",epoch, " Gen Loss :",Gen_loss[-1], "Disc Loss :",Dis_loss[-1])
    save_checkpoint(netG, optimizerG, "final", Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_GEN}")
    save_checkpoint(netD, optimizerD, "final", Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_DISC}")
    print("save example")
    try:
        save_some_examples(netG,val_dl,config.NUM_EPOCHS,folder="evaluation")
    except Exception as e:
        print(f"Something went wrong with the last epoch")


if __name__ == '__main__':
    main()

  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6235193014144897
d_fake: 0.6157655119895935
Epoch : 51  Gen Loss : 0.47037339210510254 Disc Loss : 0.25937414169311523


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6289966106414795
d_fake: 0.6244072318077087
Epoch : 52  Gen Loss : 0.4208671450614929 Disc Loss : 0.2508722245693207


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6263059377670288
d_fake: 0.6230446696281433
Epoch : 53  Gen Loss : 0.48991918563842773 Disc Loss : 0.23657044768333435


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6226921081542969
d_fake: 0.6250010132789612
Epoch : 54  Gen Loss : 0.44930577278137207 Disc Loss : 0.2575656771659851


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6168069839477539
d_fake: 0.6128154397010803
Epoch : 55  Gen Loss : 0.4556320011615753 Disc Loss : 0.24820345640182495


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6233754754066467
d_fake: 0.6158583760261536
Epoch : 56  Gen Loss : 0.47338977456092834 Disc Loss : 0.2506832480430603


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6157385110855103
d_fake: 0.6082164645195007
Epoch : 57  Gen Loss : 0.5085278749465942 Disc Loss : 0.24061518907546997


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.637104868888855
d_fake: 0.6235648989677429
Epoch : 58  Gen Loss : 0.46954989433288574 Disc Loss : 0.24667340517044067


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6227491497993469
d_fake: 0.6086909174919128
Epoch : 59  Gen Loss : 0.4928242862224579 Disc Loss : 0.24077382683753967


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6272919178009033
d_fake: 0.5974645018577576
=> Saving checkpoint
=> Saving checkpoint
save example
Epoch : 60  Gen Loss : 0.5482182502746582 Disc Loss : 0.22387485206127167


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6452912092208862
d_fake: 0.6154126524925232
Epoch : 61  Gen Loss : 0.49925050139427185 Disc Loss : 0.23716309666633606


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6130025386810303
d_fake: 0.6062058806419373
Epoch : 62  Gen Loss : 0.5319962501525879 Disc Loss : 0.2579275071620941


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6184369921684265
d_fake: 0.6079829335212708
Epoch : 63  Gen Loss : 0.5026132464408875 Disc Loss : 0.23722630739212036


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6414462924003601
d_fake: 0.6283531785011292
Epoch : 64  Gen Loss : 0.44599205255508423 Disc Loss : 0.22865648567676544


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6378320455551147
d_fake: 0.6331707835197449
Epoch : 65  Gen Loss : 0.4804353415966034 Disc Loss : 0.2106930911540985


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6393111348152161
d_fake: 0.6200671195983887
Epoch : 66  Gen Loss : 0.4398031234741211 Disc Loss : 0.23644328117370605


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6429591178894043
d_fake: 0.609325110912323
Epoch : 67  Gen Loss : 0.4490903317928314 Disc Loss : 0.249503493309021


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6566177606582642
d_fake: 0.6123566031455994
Epoch : 68  Gen Loss : 0.49331632256507874 Disc Loss : 0.19680723547935486


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6400961875915527
d_fake: 0.5671025514602661
Epoch : 69  Gen Loss : 0.5980520844459534 Disc Loss : 0.19901831448078156


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6303287744522095
d_fake: 0.596813976764679
=> Saving checkpoint
=> Saving checkpoint
save example
Epoch : 70  Gen Loss : 0.5263544321060181 Disc Loss : 0.24657008051872253


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6452862620353699
d_fake: 0.5595701932907104
Epoch : 71  Gen Loss : 0.5754839777946472 Disc Loss : 0.21508651971817017


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6176212430000305
d_fake: 0.6079372763633728
Epoch : 72  Gen Loss : 0.4803730845451355 Disc Loss : 0.23595482110977173


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6390725374221802
d_fake: 0.5875298976898193
Epoch : 73  Gen Loss : 0.5264860987663269 Disc Loss : 0.13207542896270752


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6111326217651367
d_fake: 0.5400695204734802
Epoch : 74  Gen Loss : 0.6034974455833435 Disc Loss : 0.23688337206840515


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6313586235046387
d_fake: 0.6393184065818787
Epoch : 75  Gen Loss : 0.46379154920578003 Disc Loss : 0.14888378977775574


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6588175296783447
d_fake: 0.6293607354164124
Epoch : 76  Gen Loss : 0.4401484727859497 Disc Loss : 0.1173267811536789


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.649379551410675
d_fake: 0.5854743719100952
Epoch : 77  Gen Loss : 0.5180613994598389 Disc Loss : 0.12502089142799377


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6137650609016418
d_fake: 0.6475090980529785
Epoch : 78  Gen Loss : 0.37969714403152466 Disc Loss : 0.18967154622077942


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6649156212806702
d_fake: 0.5671074390411377
Epoch : 79  Gen Loss : 0.561975359916687 Disc Loss : 0.06700417399406433


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6487431526184082
d_fake: 0.61783367395401
=> Saving checkpoint
=> Saving checkpoint
save example
Epoch : 80  Gen Loss : 0.43846842646598816 Disc Loss : 0.20701012015342712


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6247807741165161
d_fake: 0.5593615174293518
Epoch : 81  Gen Loss : 0.5498307943344116 Disc Loss : 0.24747173488140106


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6036369800567627
d_fake: 0.5855877995491028
Epoch : 82  Gen Loss : 0.515152096748352 Disc Loss : 0.37225961685180664


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6222595572471619
d_fake: 0.5701043009757996
Epoch : 83  Gen Loss : 0.5349119901657104 Disc Loss : 0.24051982164382935


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.663594663143158
d_fake: 0.5518708825111389
Epoch : 84  Gen Loss : 0.5675978660583496 Disc Loss : 0.07537961006164551


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6126070022583008
d_fake: 0.5804324746131897
Epoch : 85  Gen Loss : 0.5203972458839417 Disc Loss : 0.24057844281196594


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6663045883178711
d_fake: 0.5884319543838501
Epoch : 86  Gen Loss : 0.47993576526641846 Disc Loss : 0.2278970181941986


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.5867527723312378
d_fake: 0.6238021850585938
Epoch : 87  Gen Loss : 0.42462021112442017 Disc Loss : 0.26057472825050354


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.7101292610168457
d_fake: 0.6010503768920898
Epoch : 88  Gen Loss : 0.4354051649570465 Disc Loss : 0.3425627648830414


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.624064564704895
d_fake: 0.5754176378250122
Epoch : 89  Gen Loss : 0.4716932773590088 Disc Loss : 0.253073126077652


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.632189929485321
d_fake: 0.5880082845687866
=> Saving checkpoint
=> Saving checkpoint
save example
Epoch : 90  Gen Loss : 0.4648759961128235 Disc Loss : 0.27512142062187195


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6224684119224548
d_fake: 0.5196035504341125
Epoch : 91  Gen Loss : 0.5920268297195435 Disc Loss : 0.2032679319381714


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.7012205719947815
d_fake: 0.5593804717063904
Epoch : 92  Gen Loss : 0.5315811038017273 Disc Loss : 0.20180678367614746


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6308937668800354
d_fake: 0.5833318829536438
Epoch : 93  Gen Loss : 0.45169344544410706 Disc Loss : 0.24326376616954803


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6684141755104065
d_fake: 0.5953670740127563
Epoch : 94  Gen Loss : 0.4105716049671173 Disc Loss : 0.23050625622272491


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.7010435461997986
d_fake: 0.5728591084480286
Epoch : 95  Gen Loss : 0.4857684373855591 Disc Loss : 0.2015601098537445


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6469628810882568
d_fake: 0.5599107146263123
Epoch : 96  Gen Loss : 0.5169044733047485 Disc Loss : 0.08755996078252792


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6555792689323425
d_fake: 0.5575547814369202
Epoch : 97  Gen Loss : 0.5038688778877258 Disc Loss : 0.18372982740402222


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.631230354309082
d_fake: 0.5256814956665039
Epoch : 98  Gen Loss : 0.5760529041290283 Disc Loss : 0.18224337697029114


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6033867597579956
d_fake: 0.6454627513885498
Epoch : 99  Gen Loss : 0.3008301854133606 Disc Loss : 0.2343977838754654


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6193440556526184
d_fake: 0.6269538402557373
=> Saving checkpoint
=> Saving checkpoint
save example
Epoch : 100  Gen Loss : 0.3560456335544586 Disc Loss : 0.25950542092323303


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.639859139919281
d_fake: 0.6424740552902222
Epoch : 101  Gen Loss : 0.3144470453262329 Disc Loss : 0.13275745511054993


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6659231185913086
d_fake: 0.6284561157226562
Epoch : 102  Gen Loss : 0.3398282825946808 Disc Loss : 0.1887303739786148


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6439167261123657
d_fake: 0.554067075252533
Epoch : 103  Gen Loss : 0.5064145922660828 Disc Loss : 0.10070352256298065


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.7009296417236328
d_fake: 0.6378844976425171
Epoch : 104  Gen Loss : 0.3202197253704071 Disc Loss : 0.13161110877990723


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.61845862865448
d_fake: 0.623954176902771
Epoch : 105  Gen Loss : 0.34605127573013306 Disc Loss : 0.2557285726070404


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6273531913757324
d_fake: 0.6226269006729126
Epoch : 106  Gen Loss : 0.3480178713798523 Disc Loss : 0.25098228454589844


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6230254173278809
d_fake: 0.6046208143234253
Epoch : 107  Gen Loss : 0.38614434003829956 Disc Loss : 0.23954111337661743


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6362920999526978
d_fake: 0.5744204521179199
Epoch : 108  Gen Loss : 0.44897156953811646 Disc Loss : 0.23789429664611816


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.615638792514801
d_fake: 0.6229678988456726
Epoch : 109  Gen Loss : 0.33677491545677185 Disc Loss : 0.2527112066745758


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6725109815597534
d_fake: 0.6324481964111328
=> Saving checkpoint
=> Saving checkpoint
save example
Epoch : 110  Gen Loss : 0.33412280678749084 Disc Loss : 0.20800499618053436


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6590949296951294
d_fake: 0.6179581880569458
Epoch : 111  Gen Loss : 0.34493908286094666 Disc Loss : 0.1321454644203186


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.661935031414032
d_fake: 0.6332874298095703
Epoch : 112  Gen Loss : 0.3165108561515808 Disc Loss : 0.2512596845626831


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6387253403663635
d_fake: 0.540273129940033
Epoch : 113  Gen Loss : 0.5083649754524231 Disc Loss : 0.11046530306339264


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6385794878005981
d_fake: 0.6111198663711548
Epoch : 114  Gen Loss : 0.36752402782440186 Disc Loss : 0.28056448698043823


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.63204425573349
d_fake: 0.4581799805164337
Epoch : 115  Gen Loss : 0.6744469404220581 Disc Loss : 0.24915434420108795


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6129060983657837
d_fake: 0.602178156375885
Epoch : 116  Gen Loss : 0.3829113841056824 Disc Loss : 0.21276192367076874


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.643004834651947
d_fake: 0.5813228487968445
Epoch : 117  Gen Loss : 0.4252481460571289 Disc Loss : 0.23561668395996094


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.7052788734436035
d_fake: 0.5482174754142761
Epoch : 118  Gen Loss : 0.49259257316589355 Disc Loss : 0.031774312257766724


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.62336665391922
d_fake: 0.557591438293457
Epoch : 119  Gen Loss : 0.4657444953918457 Disc Loss : 0.27694299817085266


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6436311602592468
d_fake: 0.575439453125
=> Saving checkpoint
=> Saving checkpoint
save example
Epoch : 120  Gen Loss : 0.4368552565574646 Disc Loss : 0.24147655069828033


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.5804324746131897
d_fake: 0.4909334182739258
Epoch : 121  Gen Loss : 0.5989611148834229 Disc Loss : 0.3229075074195862


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6398410797119141
d_fake: 0.5466489195823669
Epoch : 122  Gen Loss : 0.4934775233268738 Disc Loss : 0.25247612595558167


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6397034525871277
d_fake: 0.484725683927536
Epoch : 123  Gen Loss : 0.623105525970459 Disc Loss : 0.12744764983654022


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.66544109582901
d_fake: 0.5404675006866455
Epoch : 124  Gen Loss : 0.505303680896759 Disc Loss : 0.18235577642917633


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6359036564826965
d_fake: 0.588605523109436
Epoch : 125  Gen Loss : 0.39600670337677 Disc Loss : 0.12499891966581345


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6989916563034058
d_fake: 0.621171772480011
Epoch : 126  Gen Loss : 0.33693063259124756 Disc Loss : 0.11801183223724365


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6234903335571289
d_fake: 0.634121835231781
Epoch : 127  Gen Loss : 0.3098459839820862 Disc Loss : 0.17196601629257202


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6315714120864868
d_fake: 0.6244916915893555
Epoch : 128  Gen Loss : 0.32678842544555664 Disc Loss : 0.17588654160499573


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.5887383222579956
d_fake: 0.47242915630340576
Epoch : 129  Gen Loss : 0.6445836424827576 Disc Loss : 0.33710601925849915


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6043174266815186
d_fake: 0.6287502646446228
=> Saving checkpoint
=> Saving checkpoint
save example
Epoch : 130  Gen Loss : 0.31569215655326843 Disc Loss : 0.21784235537052155


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6762218475341797
d_fake: 0.5721760392189026
Epoch : 131  Gen Loss : 0.4377667307853699 Disc Loss : 0.1548035740852356


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6664565801620483
d_fake: 0.5844076871871948
Epoch : 132  Gen Loss : 0.4170745611190796 Disc Loss : 0.1319301575422287


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6745532155036926
d_fake: 0.5240064263343811
Epoch : 133  Gen Loss : 0.5258451700210571 Disc Loss : 0.2519257366657257


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6514376997947693
d_fake: 0.5396255850791931
Epoch : 134  Gen Loss : 0.4924446642398834 Disc Loss : 0.21345248818397522


  0%|          | 0/51 [00:00<?, ?it/s]

d_real: 0.6421248912811279
d_fake: 0.5118874907493591
Epoch : 135  Gen Loss : 0.5479847192764282 Disc Loss : 0.13464485108852386


  0%|          | 0/51 [00:00<?, ?it/s]

In [None]:
# Sharpening the image
import cv2
import numpy as np
from skimage import filters

# Assuming your GAN-generated elevation image is loaded as a NumPy array
elevation_image = # ... your loading code here ...

# --- Method 1: Classic Unsharp Masking ---
# Highly effective for general sharpening
gaussian_blurred = cv2.GaussianBlur(elevation_image, (5, 5), 0) 
sharpened_image = cv2.addWeighted(elevation_image, 1.5, gaussian_blurred, -0.5, 0)

# --- Method 2: Laplacian Sharpening ---
# Emphasizes edges, might be too strong for subtle elevation changes
laplacian = cv2.Laplacian(elevation_image, cv2.CV_64F)
sharpened_image = elevation_image - 0.2 * laplacian 

# --- Method 3: Using Scikit-Image for Unsharp Masking ---
# Offers more control over the sharpening parameters
sharpened_image = filters.unsharp_mask(elevation_image, radius=3, amount=1.0)

# --- Display or Save the Result ---
cv2.imshow("Sharpened Elevation", sharpened_image)
cv2.waitKey(0)
cv2.destroyAllWindows()

# Or save the image:
cv2.imwrite("sharpened_elevation.png", sharpened_image) 