In [None]:
import torch
# !pip install --upgrade albumentations
from utils import save_checkpoint, load_checkpoint, save_some_examples
import torch.nn as nn
import torch.optim as optim
import config
from DataSet.dataset import Satellite2Map_Data
from pix2pix.Generator import Generator
from pix2pix.Discriminator import Discriminator
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm


torch.backends.cudnn.benchmark = True
Gen_loss = []
Dis_loss = []

In [None]:
def train(netG: Generator, netD: Discriminator, train_dl, OptimizerG: optim.Adam, OptimizerD: optim.Adam, gen_loss, dis_loss,scheduler_G, scheduler_D, step_ahead = 0):
    loop = tqdm(train_dl, dynamic_ncols=True)
    for idx, (x, y) in enumerate(loop):
        x = x.to(config.DEVICE).float()
        y = y.to(config.DEVICE).float()
        
        # Train Discriminator
        y_fake = netG(x).float()
        d_real = netD(x, y)
        d_real_loss = dis_loss(d_real, torch.ones_like(d_real))
        d_fake = netD(x, y_fake.detach())
        d_fake_loss = dis_loss(d_fake, torch.zeros_like(d_fake))
        d_loss = (d_real_loss + d_fake_loss) / 2

        netD.zero_grad()
        Dis_loss.append(d_loss.item())
        d_loss.backward()
        OptimizerD.step()

        # Train Generator
        d_fake = netD(x, y_fake)
        g_fake_loss = gen_loss(d_fake, torch.ones_like(d_fake))
        loss = gen_loss(y_fake, y)  # * config.L1_LAMBDA
        g_loss = (g_fake_loss + loss) / 2

        Gen_loss.append(g_loss.item())
        g_loss.backward()
        OptimizerG.step()

        for _ in range(step_ahead):
            OptimizerG.zero_grad()
            y_fake = netG(x).float()
            d_fake = netD(x, y_fake)
            g_fake_loss = gen_loss(d_fake, torch.ones_like(d_fake))
            loss = gen_loss(y_fake, y)  # * config.L1_LAMBDA

            g_loss = (g_fake_loss + loss) / 2
            Gen_loss.append(g_loss.item())
            g_loss.backward()
            OptimizerG.step()
        
        scheduler_G.step(g_loss.item())  
        scheduler_D.step(d_loss.item())  

        if idx % 10 == 0:
            loop.set_postfix(
                d_real=torch.sigmoid(d_real).mean().item(),
                d_fake=torch.sigmoid(d_fake).mean().item(),
            )
    print("d_real: " + str(torch.sigmoid(d_real).mean().item()))
    print("d_fake: " + str(torch.sigmoid(d_fake).mean().item()))

In [None]:
def main():
    start = 231
    rgb_on = False
    channels = 3 if rgb_on else 1

    netD = Discriminator(in_channels=channels).to(config.DEVICE)
    netG = Generator(in_channels=channels).to(config.DEVICE)
    optimizerD = torch.optim.Adam(
        netD.parameters(), lr=config.LEARNING_RATE_DISC, betas=(config.BETA1, 0.999)
    )
    optimizerG = torch.optim.Adam(
        netG.parameters(), lr=config.LEARNING_RATE_GEN, betas=(config.BETA1, 0.999)
    )
    gen_loss = nn.MSELoss()
    dis_loss = nn.MSELoss()

    
    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN, start-1, netG, optimizerG, config.LEARNING_RATE_GEN
        )
        load_checkpoint(
            config.CHECKPOINT_DISC, start-1, netD, optimizerD, config.LEARNING_RATE_DISC
        )
    
    train_dataset = Satellite2Map_Data(root=config.TRAIN_DIR, rgb_on=rgb_on)
    train_dl = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True, drop_last=True)
    val_dataset = Satellite2Map_Data(root=config.VAL_DIR, rgb_on=rgb_on)
    val_dl = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True, drop_last=True)
    
    for epoch in range(start, config.NUM_EPOCHS):
        train(
            netG, netD,train_dl,optimizerG,optimizerD,gen_loss,dis_loss, step_ahead=2
        )
        if config.SAVE_MODEL and epoch % 10 == 0 and epoch > 0:
            save_checkpoint(netG, optimizerG, epoch, Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_GEN}")
            save_checkpoint(netD, optimizerD, epoch, Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_DISC}")
        if epoch % 10 == 0 :
            print("save example")
            try:
                save_some_examples(netG,val_dl,epoch,folder="evaluation")
            except Exception as e:
                print(f"Something went wrong saving with epoch {epoch}: {e}")

        print("Epoch :",epoch, " Gen Loss :",Gen_loss[-1], "Disc Loss :",Dis_loss[-1])
    save_checkpoint(netG, optimizerG, "final", Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_GEN}")
    save_checkpoint(netD, optimizerD, "final", Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_DISC}")
    print("save example")
    try:
        save_some_examples(netG,val_dl,config.NUM_EPOCHS,folder="evaluation")
    except Exception as e:
        print(f"Something went wrong with the last epoch")


if __name__ == '__main__':
    main()

In [None]:
import os
from DataSet.dataset import Satellite2Map_Data
from torch.utils.data import DataLoader
import config

def get_folder_names(folder_path):
  folder_names = []
  for item in os.listdir(folder_path):
    item_path = os.path.join(folder_path, item)
    if os.path.isdir(item_path):
      folder_names.append(item)
  return folder_names

folder_path = "./versions (Gen-Disc)"
folder_names = get_folder_names(folder_path)

val_dataset = Satellite2Map_Data(root=config.VAL_DIR, rgb_on=False)
val_dl = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True, drop_last=True)

tests = []

for i, batch in enumerate(val_dl):
    if i >= 1:  # Stop after collecting 1 batch
        break
    tests = batch

In [13]:

from utils import save_matrix
import config
import numpy as np
import torch
from pix2pix.Generator import Generator





def load_model(model,folder_path, lr):
  checkpoint = torch.load(
    f"{folder_path}/399_{config.CHECKPOINT_GEN}", map_location=config.DEVICE
  )
  model.load_state_dict(checkpoint["state_dict"])

def save_some_examples(gen, val_loader, epoch, folder):
    if not os.path.exists(folder):
        os.makedirs(folder)
    x, y = next(iter(val_loader))
    x, y = x.to(config.DEVICE).float(), y.to(config.DEVICE).float()
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        save_matrix(y_fake, folder + f"/y_gen_{epoch}.pkl")
        save_matrix(y, folder + f"/label_{epoch}.pkl")
    gen.train()

def calculate_rmse(matrix1, matrix2):

  # Ensure matrices have the same shape
  if matrix1.shape != matrix2.shape:
    raise ValueError("Matrices must have the same shape.")

  squared_differences = (matrix1 - matrix2) ** 2
  mean_squared_error = np.mean(squared_differences)
  rmse = np.sqrt(mean_squared_error)

  return rmse


model_error = []
for folder in folder_names:
    netG = Generator(in_channels=1).to(config.DEVICE)
    
    load_model(netG,f"{folder_path}/{folder}",config.LEARNING_RATE_GEN)

    errors = []

    for i in range(len(tests[0])):
      x, y = tests[0][i], tests[1][i]
      x, y = x.to(config.DEVICE).float().unsqueeze(1), y.to(config.DEVICE).float().unsqueeze(1)
      netG.eval()
      with torch.no_grad():
          y_fake = netG(x)
      y_fake = y_fake.cpu().detach().numpy().squeeze()
      y = y.cpu().detach().numpy().squeeze()

      errors.append(calculate_rmse(y,y_fake))
    
    model_error.append(np.mean(errors))

indexed_list = list(enumerate(model_error))
sorted_list = sorted(indexed_list, key= lambda x: x[1])

print(f"The best model is {folder_names[sorted_list[0][0]]} with a rmse of {sorted_list[0][1]}")


The best model is L1 with BCEwithLogits with a rmse of 0.5253028869628906


In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
def main():
    start = 0
    rgb_on = False
    channels = 3 if rgb_on else 1

    batch_size = [16,32]

    netD = Discriminator(in_channels=channels).to(config.DEVICE)
    netG = Generator(in_channels=channels).to(config.DEVICE)
    
    optimizerD = torch.optim.Adam(
        netD.parameters(), lr=0.001, betas=(config.BETA1, 0.999)
    )
    optimizerG = torch.optim.Adam(
        netG.parameters(), lr=0.001, betas=(config.BETA1, 0.999)
    )

    scheduler_G = ReduceLROnPlateau(optimizerG, 'min', factor=0.5, patience=5, verbose=True)
    scheduler_D = ReduceLROnPlateau(optimizerD, 'min', factor=0.5, patience=5, verbose=True)

    gen_loss = nn.L1Loss()
    dis_loss = nn.BCEWithLogitsLoss()

    
    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN, start-1, netG, optimizerG, config.LEARNING_RATE_GEN
        )
        load_checkpoint(
            config.CHECKPOINT_DISC, start-1, netD, optimizerD, config.LEARNING_RATE_DISC
        )
    
    train_dataset = Satellite2Map_Data(root=config.TRAIN_DIR, rgb_on=rgb_on)
    train_dl = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True, drop_last=True)
    val_dataset = Satellite2Map_Data(root=config.VAL_DIR, rgb_on=rgb_on)
    val_dl = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True, drop_last=True)
    
    for epoch in range(start, config.NUM_EPOCHS):
        train(
            netG, netD, train_dl, optimizerG, optimizerD, gen_loss, dis_loss, 
            scheduler_G, scheduler_D, step_ahead=2
        )
        if config.SAVE_MODEL and epoch % 10 == 0 and epoch > 0:
            save_checkpoint(netG, optimizerG, epoch, Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_GEN}")
            save_checkpoint(netD, optimizerD, epoch, Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_DISC}")
        if epoch % 10 == 0 :
            print("save example")
            try:
                save_some_examples(netG,val_dl,epoch,folder="evaluation")
            except Exception as e:
                print(f"Something went wrong saving with epoch {epoch}: {e}")

        print("Epoch :",epoch, " Gen Loss :",Gen_loss[-1], "Disc Loss :",Dis_loss[-1])
    save_checkpoint(netG, optimizerG, "final", Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_GEN}")
    save_checkpoint(netD, optimizerD, "final", Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_DISC}")
    print("save example")
    try:
        save_some_examples(netG,val_dl,config.NUM_EPOCHS,folder="evaluation")
    except Exception as e:
        print(f"Something went wrong with the last epoch")


if __name__ == '__main__':
    main()