In [0]:
from PIL import Image 
import glob, random 
import matplotlib.pyplot as plt
import numpy as np
import math

import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms

from google.colab import files
files.upload()

In [0]:
! unzip SPAI_lens_dataset.zip

In [0]:
smpl_img = Image.open("SPAI_lens_dataset/train/A/33.png")
plt.imshow(smpl_img)

In [0]:
RR = transforms.RandomRotation((0, 360))

TT = transforms.ToTensor()
TPIL = transforms.ToPILImage()
myTransforms = transforms.Compose([RR, TT])

img_R = myTransforms(smpl_img)
plt.imshow(TPIL(img_R))

In [0]:
class ImageDataset(torch.utils.data.Dataset):
  def __init__(self, root, transforms=None):
    self.transforms = transforms
    path_A = root +'A'
    self.files_A = glob.glob(path_A + '/*.*')
    path_B = root +'B'
    self.files_B = glob.glob(path_B + '/*.*')
    
  def __getitem__(self, index):
    rand_A_file = self.files_A[random.randint(0, len(self.files_A)-1)] #Select random lens image file
    rand_A_im = Image.open(rand_A_file) #Open random lens image file
    tr_A = self.transforms(rand_A_im) #Transform random lens image
    
    rand_B_file = self.files_B[random.randint(0, len(self.files_B)-1)]
    rand_B_im = Image.open(rand_B_file)
    tr_B = self.transforms(rand_B_im)                           

    return {'A':tr_A, 'B':tr_B, 'A_label':1.0, 'B_label':0.0}
  
  def __len__(self):
    return 150

In [0]:
# Prepare the data

data_path = 'SPAI_lens_dataset/'
train_dataloader = torch.utils.data.DataLoader(ImageDataset(data_path+'train/',
                                                           transforms=myTransforms),
                                              batch_size=32,
                                              shuffle=True)

test_dataloader = torch.utils.data.DataLoader(ImageDataset(data_path+'test/',
                                                           transforms=myTransforms),
                                              batch_size=32,
                                              shuffle=True)

In [0]:
print('Number of Batches:', len(train_dataloader))

smpl_batch = next(iter(train_dataloader))

print('A Data shape: ', smpl_batch['A'].shape)
print(smpl_batch['A_label'])
print('B Data shape: ', smpl_batch['B'].shape)
print(smpl_batch['B_label'])

In [0]:
class Discriminator(torch.nn.Module):
  def __init__(self, nc, nfm):
    super(Discriminator, self).__init__()
    self.disc = nn.Sequential(# input: nc x 64 x 64,
                            nn.Conv2d(nc, nfm, 4, 2, 1),
                            # After 1st layer: nfm x 32 x 32)
                            nn.BatchNorm2d(nfm),
                            nn.ReLU(),
                            nn.Conv2d(nfm, nfm*2, 4, 2, 1),
                            # After 2nd layer: nfm*2 x 16 x 16)
                            nn.BatchNorm2d(nfm*2),
                            nn.ReLU(),
                            nn.Conv2d(nfm*2, nfm*4, 4, 2, 1),
                            # After 3rd layer: nfm*4 x 8 x 8)
                            nn.BatchNorm2d(nfm*4),
                            nn.ReLU(),    
                            nn.Conv2d(nfm*4, nfm*8, 4, 2, 1),
                            # After 4th layer: nfm*8 x 4 x 4)
                            nn.BatchNorm2d(nfm*8),
                            nn.ReLU(), 
                            nn.Conv2d(nfm*8, 1, 4, 1, 0),
                            # After 5th layer: 1 x 1 x 1)
                            nn.Sigmoid()
                            )
  def forward(self, inputs):
    return self.disc(inputs)
  
class Generator(torch.nn.Module):
  def __init__(self, nc, nfm, nz):
    super(Generator, self).__init__()
    self.gen = nn.Sequential(# input: nc x 64 x 64,
                            nn.ConvTranspose2d(nz, nfm*8, 4, 1, 0),
                            # After 1st layer: nfm x 32 x 32)
                            nn.BatchNorm2d(nfm*8),
                            nn.ReLU(),
                            nn.ConvTranspose2d(nfm*8, nfm*4, 4, 2, 1),
                            # After 2nd layer: nfm*2 x 16 x 16)
                            nn.BatchNorm2d(nfm*4),
                            nn.ReLU(),
                            nn.ConvTranspose2d(nfm*4, nfm*2, 4, 2, 1),
                            # After 3rd layer: nfm*4 x 8 x 8)
                            nn.BatchNorm2d(nfm*2),
                            nn.ReLU(),    
                            nn.ConvTranspose2d(nfm*2, nfm, 4, 2, 1),
                            # After 4th layer: nfm*8 x 4 x 4)
                            nn.BatchNorm2d(nfm),
                            nn.ReLU(), 
                            nn.ConvTranspose2d(nfm, nc, 4, 2, 1),
                            # After 5th layer: 1 x 1 x 1)
                            nn.Tanh()
                            )
  def forward(self, inputs):
    return self.gen(inputs)