In [1]:
import os, torch, torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from PIL import Image
from tqdm import tqdm
from torch import nn, optim

In [None]:
class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True):
        super().__init__()
        layers = [nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0):
        super().__init__()
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()
        self.down1 = DownBlock(in_channels, 256, normalize=False)
        self.down2 = DownBlock(256, 512)
        self.down3 = DownBlock(512, 1024)
        self.down4 = DownBlock(1024, 1024)
        self.down5 = DownBlock(1024, 1024)
        self.down6 = DownBlock(1024, 1024)
        self.down7 = DownBlock(1024, 1024)
        self.down8 = DownBlock(1024, 1024, normalize=False)

        self.up1 = UpBlock(1024, 1024, dropout=0.5)
        self.up2 = UpBlock(2048, 1024, dropout=0.5)
        self.up3 = UpBlock(2048, 1024, dropout=0.5)
        self.up4 = UpBlock(2048, 1024)
        self.up5 = UpBlock(2048, 1024)
        self.up6 = UpBlock(2048, 512)
        self.up7 = UpBlock(1024, 256)

        self.final = nn.Sequential(
            nn.ConvTranspose2d(512, out_channels, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        u1 = self.up1(d8)
        u2 = self.up2(torch.cat([u1, d7], 1))
        u3 = self.up3(torch.cat([u2, d6], 1))
        u4 = self.up4(torch.cat([u3, d5], 1))
        u5 = self.up5(torch.cat([u4, d4], 1))
        u6 = self.up6(torch.cat([u5, d3], 1))
        u7 = self.up7(torch.cat([u6, d2], 1))
        return self.final(torch.cat([u7, d1], 1))

class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=6):
        super().__init__()
        layers = [
            nn.Conv2d(in_channels, 128, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1024, 4, 1, 1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1024, 1, 4, 1, 1)
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        return self.model(torch.cat([x, y], dim=1))

class Pix2PixDataset(Dataset):
    def __init__(self, root_dir, image_size=256):
        self.files = [os.path.join(root_dir, f) for f in os.listdir(root_dir)
                      if f.endswith(('.png','.jpg','.jpeg'))]
        tf = [transforms.Resize((image_size, image_size)),
              transforms.ToTensor(),
              transforms.Normalize([0.5]*3, [0.5]*3)]
        self.transform = transforms.Compose(tf)

    def __len__(self): return len(self.files)
    def __getitem__(self, i):
        img = Image.open(self.files[i]).convert("RGB")
        w = img.width // 2
        sat = self.transform(img.crop((0, 0, w, img.height)))
        mapp = self.transform(img.crop((w, 0, img.width, img.height)))
        return sat, mapp

os.makedirs("/kaggle/working/samples", exist_ok=True)
os.makedirs("/kaggle/working/checkpoints", exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
gen = GeneratorUNet().to(device)
disc = PatchDiscriminator().to(device)
opt_g = optim.Adam(gen.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_d = optim.Adam(disc.parameters(), lr=2e-4, betas=(0.5, 0.999))
bce = nn.BCEWithLogitsLoss()
l1 = nn.L1Loss()
lambda_l1 = 100

train = Pix2PixDataset("/kaggle/input/pix2pix-maps/train", image_size=256)
val = Pix2PixDataset("/kaggle/input/pix2pix-maps/val", image_size=256)
loader = DataLoader(train, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)
valloader = DataLoader(val, batch_size=1, shuffle=True)

In [None]:
for epoch in range(300):
    gen.train(); disc.train()
    loop = tqdm(loader, f"Epoch {epoch+1}/50")
    for sat, mapp in loop:
        sat, mapp = sat.to(device), mapp.to(device)

        fake = gen(sat)
        real_pred = disc(sat, mapp)
        fake_pred = disc(sat, fake.detach())
        d_loss = (bce(real_pred, torch.ones_like(real_pred)) +
                  bce(fake_pred, torch.zeros_like(fake_pred))) * 0.5
        opt_d.zero_grad(); d_loss.backward(); opt_d.step()

        fake_pred = disc(sat, fake)
        g_adv = bce(fake_pred, torch.ones_like(fake_pred))
        g_l1 = l1(fake, mapp) * lambda_l1
        g_loss = g_adv + g_l1
        opt_g.zero_grad(); g_loss.backward(); opt_g.step()

        loop.set_postfix(d_loss=d_loss.item(), g_loss=g_loss.item())

    gen.eval()
    with torch.no_grad():
        sat, real = next(iter(valloader))
        sat, real = sat.to(device), real.to(device)
        fake = gen(sat)
        grid = make_grid(torch.cat([sat, fake, real], dim=0).cpu(), nrow=1, normalize=True)
        save_image(grid, f"/kaggle/working//samples/epoch_{epoch+1}.png")

    if (epoch + 1) % 5 == 0:
        torch.save(gen.state_dict(), f"/kaggle/working//checkpoints/gen_epoch_{epoch+1}.pth")

print("Training has ben ended!")

Epoch 1/50: 100%|██████████| 137/137 [02:01<00:00,  1.13it/s, d_loss=0.768, g_loss=12]  
Epoch 2/50: 100%|██████████| 137/137 [02:01<00:00,  1.13it/s, d_loss=0.0413, g_loss=15.8] 
Epoch 3/50: 100%|██████████| 137/137 [02:01<00:00,  1.13it/s, d_loss=0.0314, g_loss=16.2]
Epoch 4/50: 100%|██████████| 137/137 [02:01<00:00,  1.13it/s, d_loss=0.0236, g_loss=14]   
Epoch 5/50: 100%|██████████| 137/137 [02:01<00:00,  1.13it/s, d_loss=0.00555, g_loss=16.9]
Epoch 6/50: 100%|██████████| 137/137 [02:01<00:00,  1.13it/s, d_loss=0.00142, g_loss=17.7]
Epoch 7/50: 100%|██████████| 137/137 [02:01<00:00,  1.13it/s, d_loss=0.0851, g_loss=13.7] 
Epoch 8/50: 100%|██████████| 137/137 [02:01<00:00,  1.13it/s, d_loss=0.325, g_loss=17.2]  
Epoch 9/50: 100%|██████████| 137/137 [02:01<00:00,  1.13it/s, d_loss=0.0146, g_loss=16.9] 
Epoch 10/50: 100%|██████████| 137/137 [02:01<00:00,  1.13it/s, d_loss=0.278, g_loss=10.9] 
Epoch 11/50: 100%|██████████| 137/137 [02:01<00:00,  1.13it/s, d_loss=0.00704, g_loss=15.1]
E