In [1]:
import torch
from utils.utils import save_checkpoint, load_checkpoint, save_some_examples
import torch.nn as nn
import torch.optim as optim
import config
from DataSet.dataset import Satellite2Map_Data
from pix2pix.Generator import Generator
from pix2pix.Discriminator import Discriminator
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm


torch.backends.cudnn.benchmark = True
Gen_loss = []
Dis_loss = []
step_ahead = 4

ERROR:albumentations.check_version:Error fetching version info
Traceback (most recent call last):
  File "c:\Users\Alfredo\AppData\Local\Programs\Python\Python311\Lib\urllib\request.py", line 1348, in do_open
    h.request(req.get_method(), req.selector, req.data, headers,
  File "c:\Users\Alfredo\AppData\Local\Programs\Python\Python311\Lib\http\client.py", line 1286, in request
    self._send_request(method, url, body, headers, encode_chunked)
  File "c:\Users\Alfredo\AppData\Local\Programs\Python\Python311\Lib\http\client.py", line 1332, in _send_request
    self.endheaders(body, encode_chunked=encode_chunked)
  File "c:\Users\Alfredo\AppData\Local\Programs\Python\Python311\Lib\http\client.py", line 1281, in endheaders
    self._send_output(message_body, encode_chunked=encode_chunked)
  File "c:\Users\Alfredo\AppData\Local\Programs\Python\Python311\Lib\http\client.py", line 1041, in _send_output
    self.send(msg)
  File "c:\Users\Alfredo\AppData\Local\Programs\Python\Python311\Lib\h

In [2]:
def train(netG: Generator, netD: Discriminator, train_dl, OptimizerG: optim.Adam, OptimizerD: optim.Adam, gen_loss, dis_loss):
    loop = tqdm(train_dl, dynamic_ncols=True)
    for idx, (x, y) in enumerate(loop):
        x = x.to(config.DEVICE)
        y = y.to(config.DEVICE)
        y = y.permute(0,3,1,2).float()

        # Train Discriminator
        y_fake = netG(x).float()
        d_real = netD(x, y)
        d_real_loss = dis_loss(d_real, torch.ones_like(d_real))
        d_fake = netD(x, y_fake.detach())
        d_fake_loss = dis_loss(d_fake, torch.zeros_like(d_fake))
        d_loss = (d_real_loss + d_fake_loss) / 2

        netD.zero_grad()
        Dis_loss.append(d_loss.item())
        d_loss.backward()
        OptimizerD.step()

        # Train Generator
        d_fake = netD(x, y_fake)
        g_fake_loss = gen_loss(d_fake, torch.ones_like(d_fake)) 
        loss = gen_loss(y_fake, y) # * config.L1_LAMBDA
        g_loss = (g_fake_loss + loss)/2


        Gen_loss.append(g_loss.item())
        g_loss.backward()
        OptimizerG.step()

        for _ in range(step_ahead):
            OptimizerG.zero_grad()
            y_fake = netG(x).float()
            d_fake = netD(x, y_fake)
            g_fake_loss = gen_loss(d_fake, torch.ones_like(d_fake))
            loss = gen_loss(y_fake, y) # * config.L1_LAMBDA

            g_loss = (g_fake_loss + loss)/2
            Gen_loss.append(g_loss.item())
            g_loss.backward()
            OptimizerG.step()

        if idx % 10 == 0:
            loop.set_postfix(
                d_real=torch.sigmoid(d_real).mean().item(),
                d_fake=torch.sigmoid(d_fake).mean().item(),
            )
    print("d_real: " + str(torch.sigmoid(d_real).mean().item()))
    print("d_fake: " + str(torch.sigmoid(d_fake).mean().item()))

In [3]:
def main():
    start = 0
    netD = Discriminator(in_channels=3).to(config.DEVICE)
    netG = Generator(in_channels=3).to(config.DEVICE)
    optimizerD = torch.optim.Adam(netD.parameters(), lr = config.LEARNING_RATE, betas=(config.BETA1, 0.999))
    optimizerG = torch.optim.Adam(netG.parameters(), lr = config.LEARNING_RATE, betas=(config.BETA1, 0.999))
    dis_loss = nn.MSELoss()
    gen_loss = nn.MSELoss()
    
    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN, start-1, netG, optimizerG, config.LEARNING_RATE
        )
        load_checkpoint(
            config.CHECKPOINT_DISC, start-1, netD, optimizerD, config.LEARNING_RATE
        )
    
    train_dataset = Satellite2Map_Data(root=config.TRAIN_DIR)
    train_dl = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True, drop_last=True)
    val_dataset = Satellite2Map_Data(root=config.VAL_DIR)
    val_dl = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True, drop_last=True)
    
    for epoch in range(start, config.NUM_EPOCHS):
        train(
            netG, netD,train_dl,optimizerG,optimizerD,gen_loss,dis_loss
        )
        if config.SAVE_MODEL and epoch % 10 == 0 and epoch > 0:
            save_checkpoint(netG, optimizerG, epoch, Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_GEN}")
            save_checkpoint(netD, optimizerD, epoch, Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_DISC}")
        if epoch % 2 == 0 :
            print("save example")
            try:
                save_some_examples(netG,val_dl,epoch,folder="evaluation")
            except Exception as e:
                print(f"Something went wrong with epoch {epoch}: {e}")

        print("Epoch :",epoch, " Gen Loss :",Gen_loss[-1], "Disc Loss :",Dis_loss[-1])
    save_checkpoint(netG, optimizerG, "final", Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_GEN}")
    save_checkpoint(netD, optimizerD, "final", Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_DISC}")


if __name__ == '__main__':
    main()

  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.656052827835083
d_fake: 0.5084251761436462
save example
Epoch : 0  Gen Loss : 4794.04931640625 Disc Loss : 0.3655484616756439


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6322439312934875
d_fake: 0.5149700045585632
Epoch : 1  Gen Loss : 4160.90576171875 Disc Loss : 0.3844643235206604


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6381873488426208
d_fake: 0.520123302936554
save example
Epoch : 2  Gen Loss : 4164.61865234375 Disc Loss : 0.3307587802410126


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6408076286315918
d_fake: 0.5281713008880615
Epoch : 3  Gen Loss : 3938.68212890625 Disc Loss : 0.32576969265937805


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6501224637031555
d_fake: 0.5314621329307556
save example
Epoch : 4  Gen Loss : 4318.56640625 Disc Loss : 0.3174758851528168


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6442902088165283
d_fake: 0.5305838584899902
Epoch : 5  Gen Loss : 4637.86279296875 Disc Loss : 0.30401840806007385


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6590099334716797
d_fake: 0.5345675945281982
save example
Epoch : 6  Gen Loss : 4891.27587890625 Disc Loss : 0.2626647651195526


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6568162441253662
d_fake: 0.5309591293334961
Epoch : 7  Gen Loss : 5142.92529296875 Disc Loss : 0.2798402011394501


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6449084877967834
d_fake: 0.5335103273391724
save example
Epoch : 8  Gen Loss : 4076.576904296875 Disc Loss : 0.252987802028656


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6317894458770752
d_fake: 0.5312263369560242
Epoch : 9  Gen Loss : 4226.91796875 Disc Loss : 0.2562687397003174


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6566479206085205
d_fake: 0.5332763195037842
=> Saving checkpoint
=> Saving checkpoint
save example
Epoch : 10  Gen Loss : 4127.47705078125 Disc Loss : 0.23229260742664337


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6426325440406799
d_fake: 0.5313881635665894
Epoch : 11  Gen Loss : 4487.0908203125 Disc Loss : 0.21673735976219177


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6483606100082397
d_fake: 0.5337108373641968
save example
Epoch : 12  Gen Loss : 3929.832763671875 Disc Loss : 0.19222375750541687


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.663670003414154
d_fake: 0.530544638633728
Epoch : 13  Gen Loss : 5765.04296875 Disc Loss : 0.16284428536891937


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6621705293655396
d_fake: 0.5351824760437012
save example
Epoch : 14  Gen Loss : 3427.88427734375 Disc Loss : 0.16095639765262604


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.687138557434082
d_fake: 0.5292856693267822
Epoch : 15  Gen Loss : 5174.642578125 Disc Loss : 0.18264588713645935


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6607246994972229
d_fake: 0.5291721224784851
save example
Epoch : 16  Gen Loss : 3837.575927734375 Disc Loss : 0.16428112983703613


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6692271828651428
d_fake: 0.5304210782051086
Epoch : 17  Gen Loss : 3845.30517578125 Disc Loss : 0.1579919159412384


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6745588779449463
d_fake: 0.5305412411689758
save example
Epoch : 18  Gen Loss : 4174.01611328125 Disc Loss : 0.13598528504371643


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.684827446937561
d_fake: 0.5286145806312561
Epoch : 19  Gen Loss : 4108.0068359375 Disc Loss : 0.12462253868579865


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6762755513191223
d_fake: 0.5287330746650696
=> Saving checkpoint
=> Saving checkpoint
save example
Epoch : 20  Gen Loss : 4715.240234375 Disc Loss : 0.13954868912696838


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6608948707580566
d_fake: 0.5295476317405701
Epoch : 21  Gen Loss : 4687.28955078125 Disc Loss : 0.14492960274219513


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6714622974395752
d_fake: 0.5244181156158447
save example
Epoch : 22  Gen Loss : 5077.20458984375 Disc Loss : 0.16911625862121582


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.664324164390564
d_fake: 0.5216342806816101
Epoch : 23  Gen Loss : 5403.0439453125 Disc Loss : 0.13874229788780212


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6646665930747986
d_fake: 0.5222618579864502
save example
Epoch : 24  Gen Loss : 2851.070068359375 Disc Loss : 0.12413232028484344


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6984195113182068
d_fake: 0.5219180583953857
Epoch : 25  Gen Loss : 5072.98095703125 Disc Loss : 0.10863343626260757


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6495094895362854
d_fake: 0.5235030055046082
save example
Epoch : 26  Gen Loss : 4619.36962890625 Disc Loss : 0.14204344153404236


  0%|          | 0/34 [00:00<?, ?it/s]