<a href="https://colab.research.google.com/github/vasudhavenkatesan/pix2pix/blob/main/pix2pix.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Importing necessary libraries

In [14]:
import numpy as np
import os
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
from torchvision.utils import make_grid
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from torchvision.transforms import InterpolationMode

# Loading data

In [15]:
class MapDataset(Dataset):
  def __init__(self, root_dir):
    self.root_dir = root_dir
    self.list_files = os.listdir(self.root_dir)
    print(self.list_files)
    self.data_transforms = transforms.Compose([transforms.Resize(size=(256,512), interpolation=InterpolationMode.NEAREST),
                                      transforms.CenterCrop(size=(256,512)),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 

  def __len__(self):
    return len(self.list_files)

  def __getitem__(self, index):
    img_file = self.list_files[index]
    img_path = os.path.join(self.root_dir, img_file)
    image = np.array(Image.open(img_path))
    input_image = image[:, :600, :]
    target_image = image[:, 600:, :]
    input_image = self.data_transforms(input_image)
    target_image = self.data_transforms(target_image)
    return input_image, target_image
    


In [None]:
from __main__ import MapDataset
data_dir = "/content/drive/MyDrive/maps_dataset/maps/"

dataset_train = MapDataset(root_dir=os.path.join(data_dir, "train"))
dataset_val = MapDataset(root_dir=os.path.join(data_dir, "val"))

dataloader_train = DataLoader(dataset=dataset_train, batch_size=24, shuffle=True, num_workers=2)
dataloader_val = DataLoader(dataset=dataset_val, batch_size=24, shuffle=True, num_workers=2)

print(len(dataset_val))