In [1]:
import torch.nn as nn
import torch
import torchvision

In [2]:
class CNNBlock(nn.Module):
  def __init__(self, in_chan, out_chan, stride = 2):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_chan, out_chan, stride=stride, kernel_size = 4, bias = False, padding_mode = "reflect"),
        nn.BatchNorm2d(out_chan),
        nn.LeakyReLU(0.2)
    )
  
  def forward(self, x):
    return self.conv(x)

In [3]:
class Discriminator(nn.Module):
  def __init__(self, in_chan, features = [64,128,256,512]):
    super().__init__()
    self.init = nn.Sequential(
        nn.Conv2d(in_chan*2, features[0], kernel_size = 4, padding = 1, padding_mode="reflect", stride=2),
        nn.LeakyReLU(0.2)
    )
    layers = []
    in_chan = features[0]
    for feature in features[1:]:
      layers.append(
          CNNBlock(in_chan, feature, stride = 1 if feature == features[-1] else 2),

      )
      in_chan = feature
    layers.append(nn.Conv2d(in_chan, 1, kernel_size = 4, stride = 1, padding = 1, padding_mode = "reflect"))
    self.model = nn.Sequential(*layers)
    
  def forward(self, x,y):
    x = torch.cat([x,y], dim = 1)
    x = self.init(x)
    return self.model(x)

In [4]:
def test():
  x = torch.rand((1,3,256,256))
  y =  torch.rand((1,3,256,256))
  disc = Discriminator(3)
  return disc(x,y)

In [5]:
class Block(nn.Module):
  def __init__(self, in_chan, out, down = True, act = "relu", drop = False):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_chan ,out , 4,2,1,bias = False, padding_mode = "reflect")
        if down
        else nn.ConvTranspose2d(in_chan ,out, 4,2,1,bias= False),
        nn.BatchNorm2d(out),
        nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2)
    )
    self.drop = nn.Dropout(0.5)
    self.use_drop = drop
  
  def forward(self, x):
    x = self.conv(x)
    return self.drop(x) if self.use_drop else x

In [6]:
class Generator(nn.Module):
  def __init__(self, in_chan, features = 64):
    super().__init__()
    self.initial_down = nn.Sequential(
        nn.Conv2d(in_chan , features, 4,2,1,padding_mode = "reflect"),
        nn.LeakyReLU(0.2)
    )

    self.down1 = Block(features, features*2, down =True, act = "leaky", drop = False)
    self.down2 = Block(features*2, features*4, down =True, act = "leaky", drop = False)
    self.down3 = Block(features*4, features*8, down =True, act = "leaky", drop = False)
    self.down4 = Block(features*8, features*8, down =True, act = "leaky", drop = False)
    self.down5 = Block(features*8, features*8, down =True, act = "leaky", drop = False)
    self.down6 = Block(features*8, features*8, down =True, act = "leaky", drop = False)
    self.bottleneck = nn.Sequential(
        nn.Conv2d(features*8, features*8, 4,2,1, padding_mode = "reflect"),
        nn.ReLU()
    )
    self.up1 = Block(features*8, features*8, down = False, act = "relu", drop = True)
    self.up2 = Block(features*8*2, features*8, down = False, act = "relu", drop = True)
    self.up3 = Block(features*8*2, features*8, down = False, act = "relu", drop = False)
    self.up4 = Block(features*8*2, features*8, down = False, act = "relu", drop = False)
    self.up5 = Block(features*8*2, features*4, down = False, act = "relu", drop = False)
    self.up6 = Block(features*4*2, features*2, down = False, act = "relu", drop = False)
    self.up7 = Block(features*2*2, features, down = False, act = "relu", drop = False)
    
    
    self.final_up = nn.Sequential(
        nn.ConvTranspose2d(features*2, in_chan, 4,2,1),
        nn.Tanh(),
    )
  
  def forward(self, x):
    d1 = self.initial_down(x)
    d2 = self.down1(d1)
    d3 = self.down2(d2)
    d4 = self.down3(d3)
    d5 = self.down4(d4)
    d6 = self.down5(d5)
    d7 = self.down6(d6)

    bottleneck = self.bottleneck(d7)

    up1 = self.up1(bottleneck)
    up2 = self.up2(torch.cat([up1, d7],dim = 1))
    up3 = self.up3(torch.cat([up2, d6],dim = 1))
    up4 = self.up4(torch.cat([up3, d5],dim = 1))
    up5 = self.up5(torch.cat([up4, d4],dim = 1))
    up6 = self.up6(torch.cat([up5, d3],dim = 1))
    up7 = self.up7(torch.cat([up6, d2],dim = 1))

    return self.final_up(torch.cat([up7, d1], dim = 1))

In [7]:
def test_gen():
  gen = Generator(3)
  x = torch.rand((1,3,256,256))
  return gen(x)

In [8]:
!cd /content/drive/MyDrive/Colab_Notebooks

In [None]:
!pip install albumentations --upgrade

In [11]:
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

device = torch.device("cuda")
lr = 2e-4
batch_size = 16
image_size = 256
chan_img = 3
l1_lambda = 100

num_epochs = 500
both_trans = A.Compose([A.Resize(width=256, height = 256)], additional_targets={"image0": "image"})
transform_only_input = A.Compose([
                                  A.ColorJitter(0.2),
                                  A.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5,0.5], max_pixel_value=255.0),
                                  ToTensorV2()
])

transform_only_mask = A.Compose([
                                 A.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5,0.5], max_pixel_value=255.0),
                                 ToTensorV2()
])

In [37]:
from PIL import Image
import numpy as np
import os
from torch.utils.data import Dataset

class Mapdataset(Dataset):
  def __init__(self, root_dir):
    self.root_dir = root_dir
    self.list_files = os.listdir(root_dir)
    print(self.list_files)

  def __len__(self):
    return len(self.list_files)
  
  def __getitem__(self, index):
    img_file = self.list_files[index]
    img_path = os.path.join(self.root_dir, img_file)
    image = np.array(Image.open(img_path))
    input_image = image[:, :600, :]
    target_image = image[:, 600:, :]
    
    aug = both_trans(image = input_image, image0 = target_image)
    input_image, target_image = aug['image'], aug['image0']
    input_image = transform_only_input(image = input_image)['image']
    target_image = transform_only_mask(image = target_image)['image']

    return input_image, target_image

In [38]:
import torch.optim as optim
from tqdm import tqdm
from torchvision.utils import save_image
from torch.utils.data import DataLoader

def save_some_examples(gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    x, y = x.to(device), y.to(device)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
        if epoch == 1:
            save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
    gen.train()


In [47]:
def train_fn(disc, gen, loader,opt_disc, opt_gen, l1,bce, g_scaler, d_scaler):
  loop = tqdm(loader, leave=True)

  for idx, (x,y) in enumerate(loop):
    x,y = x.to(device), y.to(device)

    with torch.cuda.amp.autocast():
      y_fake = gen(x)
      d_real = disc(x,y)
      d_fake = disc(x,y_fake.detach())
      d_real_loss = bce(d_real, torch.ones_like(d_real))
      d_fake_loss = bce(d_fake, torch.zeros_like(d_fake))
      d_loss = (d_real_loss + d_fake_loss)/2
  
    disc.zero_grad()
    d_scaler.scale(d_loss).backward()
    d_scaler.step(opt_disc)
    d_scaler.update()

    with torch.cuda.amp.autocast():
      d_fake = disc(x, y_fake)
      g_fake_loss = bce(d_fake, torch.ones_like(d_fake))
      L1 = l1(y_fake, y) * l1_lambda
      g_loss = g_fake_loss + L1

    opt_gen.zero_grad()
    g_scaler.scale(g_loss).backward()
    g_scaler.step(opt_gen)
    g_scaler.update()


def main():
  disc = Discriminator(3).to(device)
  gen = Generator(3).to(device)
  opt_disc = optim.Adam(disc.parameters(), lr = lr, betas = (0.5, 0.999))
  opt_gen = optim.Adam(gen.parameters(), lr = lr, betas = (0.5, 0.999))
  loss = nn.BCEWithLogitsLoss()
  l1_loss = nn.L1Loss()

  train_ds = Mapdataset(root_dir="/content/drive/MyDrive/Colab_Notebooks/maps/train")
  train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle =True, num_workers=0)
  g_scaler = torch.cuda.amp.GradScaler()
  d_scaler= torch.cuda.amp.GradScaler()
  val_ds = Mapdataset(root_dir="/content/drive/MyDrive/Colab_Notebooks/maps/val")
  val_loader = DataLoader(val_ds, shuffle=False, batch_size = 1)

  for epoch in range(num_epochs):
    train_fn(disc, gen, train_loader, opt_disc, opt_gen, l1_loss, loss, g_scaler, d_scaler)
    save_some_examples(gen, val_loader, epoch, folder = "evaluation")

In [49]:
main()






  0%|          | 0/69 [00:00<?, ?it/s][A[A[A[A[A

['1016.jpg', '1009.jpg', '1013.jpg', '101.jpg', '1012.jpg', '1011.jpg', '1014.jpg', '1010.jpg', '1015.jpg', '1008.jpg', '1007.jpg', '1004.jpg', '1003.jpg', '1002.jpg', '1006.jpg', '1005.jpg', '1001.jpg', '100.jpg', '1000.jpg', '10.jpg', '1.jpg', '1043.jpg', '1041.jpg', '1040.jpg', '1042.jpg', '1037.jpg', '1039.jpg', '1038.jpg', '104.jpg', '1036.jpg', '1035.jpg', '1034.jpg', '1032.jpg', '1033.jpg', '1031.jpg', '1030.jpg', '1028.jpg', '1029.jpg', '1027.jpg', '103.jpg', '1025.jpg', '1026.jpg', '102.jpg', '1022.jpg', '1023.jpg', '1020.jpg', '1021.jpg', '1024.jpg', '1019.jpg', '1018.jpg', '1017.jpg', '108.jpg', '1079.jpg', '1072.jpg', '1074.jpg', '1077.jpg', '1075.jpg', '1073.jpg', '1078.jpg', '1076.jpg', '1071.jpg', '1070.jpg', '107.jpg', '1069.jpg', '1067.jpg', '1068.jpg', '1066.jpg', '1063.jpg', '1065.jpg', '1064.jpg', '1062.jpg', '1061.jpg', '1060.jpg', '1055.jpg', '1057.jpg', '1054.jpg', '106.jpg', '1059.jpg', '1058.jpg', '1056.jpg', '1053.jpg', '1052.jpg', '1051.jpg', '1050.jpg', '104

[1;30;43mВыходные данные были обрезаны до нескольких последних строк (5000).[0m




 90%|████████▉ | 62/69 [00:29<00:03,  2.14it/s][A[A[A[A[A




 91%|█████████▏| 63/69 [00:29<00:02,  2.14it/s][A[A[A[A[A




 93%|█████████▎| 64/69 [00:30<00:02,  2.15it/s][A[A[A[A[A




 94%|█████████▍| 65/69 [00:30<00:01,  2.17it/s][A[A[A[A[A




 96%|█████████▌| 66/69 [00:30<00:01,  2.18it/s][A[A[A[A[A




 97%|█████████▋| 67/69 [00:31<00:00,  2.16it/s][A[A[A[A[A




 99%|█████████▊| 68/69 [00:31<00:00,  2.16it/s][A[A[A[A[A




100%|██████████| 69/69 [00:32<00:00,  2.14it/s]





  0%|          | 0/69 [00:00<?, ?it/s][A[A[A[A[A




  1%|▏         | 1/69 [00:00<00:31,  2.15it/s][A[A[A[A[A




  3%|▎         | 2/69 [00:00<00:31,  2.15it/s][A[A[A[A[A




  4%|▍         | 3/69 [00:01<00:30,  2.15it/s][A[A[A[A[A




  6%|▌         | 4/69 [00:01<00:30,  2.14it/s][A[A[A[A[A




  7%|▋         | 5/69 [00:02<00:30,  2.13it/s][A[A[A[A[A




  9%

KeyboardInterrupt: ignored