In [1]:
import torch
from utils.utils import save_checkpoint, load_checkpoint, save_some_examples
import torch.nn as nn
import torch.optim as optim
import config
from DataSet.dataset import Satellite2Map_Data
from pix2pix.Generator import Generator
from pix2pix.Discriminator import Discriminator
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm


torch.backends.cudnn.benchmark = True
Gen_loss = []
Dis_loss = []
step_ahead = 4

ERROR:albumentations.check_version:Error fetching version info
Traceback (most recent call last):
  File "c:\Users\Alfredo\AppData\Local\Programs\Python\Python311\Lib\urllib\request.py", line 1348, in do_open
    h.request(req.get_method(), req.selector, req.data, headers,
  File "c:\Users\Alfredo\AppData\Local\Programs\Python\Python311\Lib\http\client.py", line 1286, in request
    self._send_request(method, url, body, headers, encode_chunked)
  File "c:\Users\Alfredo\AppData\Local\Programs\Python\Python311\Lib\http\client.py", line 1332, in _send_request
    self.endheaders(body, encode_chunked=encode_chunked)
  File "c:\Users\Alfredo\AppData\Local\Programs\Python\Python311\Lib\http\client.py", line 1281, in endheaders
    self._send_output(message_body, encode_chunked=encode_chunked)
  File "c:\Users\Alfredo\AppData\Local\Programs\Python\Python311\Lib\http\client.py", line 1041, in _send_output
    self.send(msg)
  File "c:\Users\Alfredo\AppData\Local\Programs\Python\Python311\Lib\h

In [2]:
def train(netG: Generator, netD: Discriminator, train_dl, OptimizerG: optim.Adam, OptimizerD: optim.Adam, gen_loss, dis_loss):
    loop = tqdm(train_dl, dynamic_ncols=True)
    for idx, (x, y) in enumerate(loop):
        x = x.to(config.DEVICE)
        y = y.to(config.DEVICE)
        y = y.permute(0,3,1,2).float()

        # Train Discriminator
        y_fake = netG(x).float()
        d_real = netD(x, y)
        d_real_loss = dis_loss(d_real, torch.ones_like(d_real))
        d_fake = netD(x, y_fake.detach())
        d_fake_loss = dis_loss(d_fake, torch.zeros_like(d_fake))
        d_loss = (d_real_loss + d_fake_loss) / 2

        netD.zero_grad()
        Dis_loss.append(d_loss.item())
        d_loss.backward()
        OptimizerD.step()

        # Train Generator
        d_fake = netD(x, y_fake)
        g_fake_loss = gen_loss(d_fake, torch.ones_like(d_fake)) 
        loss = gen_loss(y_fake, y) # * config.L1_LAMBDA
        g_loss = (g_fake_loss + loss)/2


        Gen_loss.append(g_loss.item())
        g_loss.backward()
        OptimizerG.step()

        for _ in range(step_ahead):
            OptimizerG.zero_grad()
            y_fake = netG(x).float()
            d_fake = netD(x, y_fake)
            g_fake_loss = gen_loss(d_fake, torch.ones_like(d_fake))
            loss = gen_loss(y_fake, y) # * config.L1_LAMBDA

            g_loss = (g_fake_loss + loss)/2
            Gen_loss.append(g_loss.item())
            g_loss.backward()
            OptimizerG.step()

        if idx % 10 == 0:
            loop.set_postfix(
                d_real=torch.sigmoid(d_real).mean().item(),
                d_fake=torch.sigmoid(d_fake).mean().item(),
            )
    print("d_real: " + str(torch.sigmoid(d_real).mean().item()))
    print("d_fake: " + str(torch.sigmoid(d_fake).mean().item()))

In [3]:
def main():
    start = 21
    netD = Discriminator(in_channels=3).to(config.DEVICE)
    netG = Generator(in_channels=3).to(config.DEVICE)
    optimizerD = torch.optim.Adam(netD.parameters(), lr = config.LEARNING_RATE, betas=(config.BETA1, 0.999))
    optimizerG = torch.optim.Adam(netG.parameters(), lr = config.LEARNING_RATE, betas=(config.BETA1, 0.999))
    dis_loss = nn.MSELoss()
    gen_loss = nn.MSELoss()
    
    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN, start-1, netG, optimizerG, config.LEARNING_RATE
        )
        load_checkpoint(
            config.CHECKPOINT_DISC, start-1, netD, optimizerD, config.LEARNING_RATE
        )
    
    train_dataset = Satellite2Map_Data(root=config.TRAIN_DIR)
    train_dl = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True, drop_last=True)
    val_dataset = Satellite2Map_Data(root=config.VAL_DIR)
    val_dl = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True, drop_last=True)
    
    for epoch in range(start, config.NUM_EPOCHS):
        train(
            netG, netD,train_dl,optimizerG,optimizerD,gen_loss,dis_loss
        )
        if config.SAVE_MODEL and epoch % 10 == 0 and epoch > 0:
            save_checkpoint(netG, optimizerG, epoch, Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_GEN}")
            save_checkpoint(netD, optimizerD, epoch, Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_DISC}")
        if epoch % 2 == 0 :
            print("save example")
            try:
                save_some_examples(netG,val_dl,epoch,folder="evaluation")
            except Exception as e:
                print(f"Something went wrong with epoch {epoch}: {e}")

        print("Epoch :",epoch, " Gen Loss :",Gen_loss[-1], "Disc Loss :",Dis_loss[-1])
    save_checkpoint(netG, optimizerG, "final", Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_GEN}")
    save_checkpoint(netD, optimizerD, "final", Gen_loss[-1], Dis_loss[-1], filename= f"./checkpoints/{epoch}_{config.CHECKPOINT_DISC}")


if __name__ == '__main__':
    main()

  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6371028423309326
d_fake: 0.5288757085800171
Epoch : 21  Gen Loss : 5623.4072265625 Disc Loss : 0.23264126479625702


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.7146334052085876
d_fake: 0.5294208526611328
save example
Epoch : 22  Gen Loss : 5097.41552734375 Disc Loss : 0.22778692841529846


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6681365966796875
d_fake: 0.5281392335891724
Epoch : 23  Gen Loss : 4963.517578125 Disc Loss : 0.1456485241651535


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.5721899271011353
d_fake: 0.5295363664627075
save example
Epoch : 24  Gen Loss : 4773.03857421875 Disc Loss : 0.3145977258682251


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.7413450479507446
d_fake: 0.5216503143310547
Epoch : 25  Gen Loss : 5542.5234375 Disc Loss : 0.18820640444755554


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6614007949829102
d_fake: 0.5236060619354248
save example
Epoch : 26  Gen Loss : 4200.27490234375 Disc Loss : 0.1378692388534546


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.7073062658309937
d_fake: 0.5230106711387634
Epoch : 27  Gen Loss : 3676.354248046875 Disc Loss : 0.13460862636566162


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.7148381471633911
d_fake: 0.5181271433830261
save example
Epoch : 28  Gen Loss : 4977.34521484375 Disc Loss : 0.10851652175188065


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6537049412727356
d_fake: 0.519221305847168
Epoch : 29  Gen Loss : 3509.880859375 Disc Loss : 0.1477772295475006


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6682382822036743
d_fake: 0.5161553025245667
=> Saving checkpoint
=> Saving checkpoint
save example
Epoch : 30  Gen Loss : 4486.74560546875 Disc Loss : 0.11118482798337936


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6746047735214233
d_fake: 0.5141677260398865
Epoch : 31  Gen Loss : 3550.29443359375 Disc Loss : 0.14807040989398956


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.711966335773468
d_fake: 0.5161095857620239
save example
Epoch : 32  Gen Loss : 4149.6787109375 Disc Loss : 0.08293257653713226


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.7316001653671265
d_fake: 0.5192095041275024
Epoch : 33  Gen Loss : 3082.14208984375 Disc Loss : 0.15809984505176544


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.7515535950660706
d_fake: 0.5201954245567322
save example
Epoch : 34  Gen Loss : 4344.85791015625 Disc Loss : 0.1810034066438675


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6821620464324951
d_fake: 0.515355110168457
Epoch : 35  Gen Loss : 3307.2607421875 Disc Loss : 0.10354313999414444


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.7161970734596252
d_fake: 0.5191569924354553
save example
Epoch : 36  Gen Loss : 5520.1689453125 Disc Loss : 0.09895314276218414


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6759148240089417
d_fake: 0.5097847580909729
Epoch : 37  Gen Loss : 4549.5263671875 Disc Loss : 0.10212407261133194


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6663081645965576
d_fake: 0.5141023397445679
save example
Epoch : 38  Gen Loss : 3998.505859375 Disc Loss : 0.1278851330280304


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6637830138206482
d_fake: 0.5200480818748474
Epoch : 39  Gen Loss : 3782.253662109375 Disc Loss : 0.11723489314317703


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6669871807098389
d_fake: 0.5186947584152222
=> Saving checkpoint
=> Saving checkpoint
save example
Epoch : 40  Gen Loss : 4024.422607421875 Disc Loss : 0.08549428731203079


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.5772103071212769
d_fake: 0.5192499160766602
Epoch : 41  Gen Loss : 5124.45556640625 Disc Loss : 0.3033181428909302


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6611862778663635
d_fake: 0.518269419670105
save example
Epoch : 42  Gen Loss : 4925.66845703125 Disc Loss : 0.10078191757202148


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.7675027847290039
d_fake: 0.5145741105079651
Epoch : 43  Gen Loss : 4093.198974609375 Disc Loss : 0.24894589185714722


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6537558436393738
d_fake: 0.515307605266571
save example
Epoch : 44  Gen Loss : 3703.07275390625 Disc Loss : 0.13419033586978912


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6931123733520508
d_fake: 0.5126268267631531
Epoch : 45  Gen Loss : 5076.2001953125 Disc Loss : 0.07192487269639969


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.8293105959892273
d_fake: 0.5146068334579468
save example
Epoch : 46  Gen Loss : 5714.49658203125 Disc Loss : 0.7093728184700012


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6622035503387451
d_fake: 0.5180482268333435
Epoch : 47  Gen Loss : 3969.59228515625 Disc Loss : 0.11428926140069962


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6775172352790833
d_fake: 0.5205495357513428
save example
Epoch : 48  Gen Loss : 4126.40283203125 Disc Loss : 0.08046229183673859


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.8113366365432739
d_fake: 0.5138424634933472
Epoch : 49  Gen Loss : 4374.87060546875 Disc Loss : 0.5625137686729431


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.7440994381904602
d_fake: 0.5134973526000977
=> Saving checkpoint
=> Saving checkpoint
save example
Epoch : 50  Gen Loss : 3703.696533203125 Disc Loss : 0.1059546247124672


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6990605592727661
d_fake: 0.5217775702476501
Epoch : 51  Gen Loss : 4945.9091796875 Disc Loss : 0.0818416178226471


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6908401846885681
d_fake: 0.5147282481193542
save example
Epoch : 52  Gen Loss : 4157.16943359375 Disc Loss : 0.09278328716754913


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.7693265676498413
d_fake: 0.5136529803276062
Epoch : 53  Gen Loss : 5854.47802734375 Disc Loss : 0.18497586250305176


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6567491888999939
d_fake: 0.5208284258842468
save example
Epoch : 54  Gen Loss : 3600.17529296875 Disc Loss : 0.10678358376026154


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.7856647968292236
d_fake: 0.5083008408546448
Epoch : 55  Gen Loss : 4493.13818359375 Disc Loss : 0.23533159494400024


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6967257857322693
d_fake: 0.5203234553337097
save example
Epoch : 56  Gen Loss : 3790.264404296875 Disc Loss : 0.07263428717851639


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.7500354647636414
d_fake: 0.5160638689994812
Epoch : 57  Gen Loss : 4254.08544921875 Disc Loss : 0.09572631120681763


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.7792720198631287
d_fake: 0.513077974319458
save example
Epoch : 58  Gen Loss : 4520.0322265625 Disc Loss : 0.1805599480867386


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6790910959243774
d_fake: 0.517051100730896
Epoch : 59  Gen Loss : 6111.12646484375 Disc Loss : 0.07830654829740524


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6748472452163696
d_fake: 0.5190443396568298
=> Saving checkpoint
=> Saving checkpoint
save example
Epoch : 60  Gen Loss : 4374.3935546875 Disc Loss : 0.07343828678131104


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6437650918960571
d_fake: 0.5260228514671326
Epoch : 61  Gen Loss : 3376.714599609375 Disc Loss : 0.12009070813655853


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.7365204095840454
d_fake: 0.5162258744239807
save example
Epoch : 62  Gen Loss : 3821.31494140625 Disc Loss : 0.07954257726669312


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.751140296459198
d_fake: 0.524662971496582
Epoch : 63  Gen Loss : 4908.76806640625 Disc Loss : 0.16638030111789703


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.669464111328125
d_fake: 0.5199269652366638
save example
Epoch : 64  Gen Loss : 4327.30712890625 Disc Loss : 0.07802525907754898


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6757120490074158
d_fake: 0.5196437239646912
Epoch : 65  Gen Loss : 3313.566162109375 Disc Loss : 0.0714927613735199


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6894662976264954
d_fake: 0.522530198097229
save example
Epoch : 66  Gen Loss : 4427.681640625 Disc Loss : 0.06923410296440125


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.7018680572509766
d_fake: 0.520572304725647
Epoch : 67  Gen Loss : 4353.56982421875 Disc Loss : 0.07020314037799835


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.758481502532959
d_fake: 0.5118030905723572
save example
Epoch : 68  Gen Loss : 4732.2861328125 Disc Loss : 0.1130058765411377


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6521049737930298
d_fake: 0.5172221064567566
Epoch : 69  Gen Loss : 4267.357421875 Disc Loss : 0.10010175406932831


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.7113751173019409
d_fake: 0.5211033821105957
=> Saving checkpoint
=> Saving checkpoint
save example
Epoch : 70  Gen Loss : 5203.40625 Disc Loss : 0.0558880940079689


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6947903037071228
d_fake: 0.5237586498260498
Epoch : 71  Gen Loss : 3212.933837890625 Disc Loss : 0.06599438190460205


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.7765645980834961
d_fake: 0.5139904618263245
save example
Epoch : 72  Gen Loss : 4905.04541015625 Disc Loss : 0.1510770320892334


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.7955854535102844
d_fake: 0.5161843299865723
Epoch : 73  Gen Loss : 4520.65380859375 Disc Loss : 0.2042582482099533


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6667552590370178
d_fake: 0.5218737721443176
save example
Epoch : 74  Gen Loss : 4185.0693359375 Disc Loss : 0.09111988544464111


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.7097680568695068
d_fake: 0.5136268138885498
Epoch : 75  Gen Loss : 4349.68212890625 Disc Loss : 0.051938943564891815


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6699801087379456
d_fake: 0.5264964699745178
save example
Epoch : 76  Gen Loss : 3450.843017578125 Disc Loss : 0.07685264199972153


  0%|          | 0/34 [00:00<?, ?it/s]

d_real: 0.6099199652671814
d_fake: 0.5182011723518372
Epoch : 77  Gen Loss : 4109.4365234375 Disc Loss : 0.19986648857593536


  0%|          | 0/34 [00:00<?, ?it/s]