<a href="https://colab.research.google.com/github/vasudhavenkatesan/pix2pix/blob/main/pix2pix.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Importing necessary libraries

In [14]:
import numpy as np
import os
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
from torchvision.utils import make_grid
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from torchvision.transforms import InterpolationMode

# Loading data

In [15]:
class MapDataset(Dataset):
  def __init__(self, root_dir):
    self.root_dir = root_dir
    self.list_files = os.listdir(self.root_dir)
    print(self.list_files)
    self.data_transforms = transforms.Compose([transforms.Resize(size=(256,512), interpolation=InterpolationMode.NEAREST),
                                      transforms.CenterCrop(size=(256,512)),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 

  def __len__(self):
    return len(self.list_files)

  def __getitem__(self, index):
    img_file = self.list_files[index]
    img_path = os.path.join(self.root_dir, img_file)
    image = np.array(Image.open(img_path))
    input_image = image[:, :600, :]
    target_image = image[:, 600:, :]
    input_image = self.data_transforms(input_image)
    target_image = self.data_transforms(target_image)
    return input_image, target_image
    


In [20]:
from __main__ import MapDataset
data_dir = "/content/drive/MyDrive/maps_dataset/maps/"

dataset_train = MapDataset(root_dir=os.path.join(data_dir, "train"))
dataset_val = MapDataset(root_dir=os.path.join(data_dir, "val"))

dataloader_train = DataLoader(dataset=dataset_train, batch_size=24, shuffle=True, num_workers=2)
dataloader_val = DataLoader(dataset=dataset_val, batch_size=24, shuffle=True, num_workers=2)

print(len(dataset_val))

['1009.jpg', '1003.jpg', '1007.jpg', '1002.jpg', '1008.jpg', '1005.jpg', '1006.jpg', '1004.jpg', '1001.jpg', '1000.jpg', '100.jpg', '1.jpg', '10.jpg', '1055.jpg', '1057.jpg', '1058.jpg', '1056.jpg', '1054.jpg', '105.jpg', '1051.jpg', '1052.jpg', '1046.jpg', '1050.jpg', '1049.jpg', '1048.jpg', '1047.jpg', '1053.jpg', '1045.jpg', '1043.jpg', '104.jpg', '1037.jpg', '1042.jpg', '1040.jpg', '1044.jpg', '1038.jpg', '1041.jpg', '1039.jpg', '1036.jpg', '1029.jpg', '1034.jpg', '1030.jpg', '1031.jpg', '1032.jpg', '103.jpg', '1035.jpg', '1028.jpg', '1033.jpg', '1027.jpg', '1021.jpg', '1020.jpg', '1019.jpg', '1024.jpg', '1022.jpg', '1026.jpg', '102.jpg', '1025.jpg', '1023.jpg', '1018.jpg', '101.jpg', '1016.jpg', '1010.jpg', '1017.jpg', '1015.jpg', '1014.jpg', '1013.jpg', '1011.jpg', '1012.jpg', '112.jpg', '1088.jpg', '1085.jpg', '109.jpg', '1086.jpg', '1089.jpg', '1082.jpg', '1095.jpg', '1092.jpg', '1094.jpg', '110.jpg', '1093.jpg', '111.jpg', '1096.jpg', '11.jpg', '1091.jpg', '1090.jpg', '1084.jp