In [None]:
import torch
import torchvision
import torchvision.transforms.functional as f
import torch.nn as nn
import numpy as np
import glob
from tqdm import tqdm
from PIL import Image

from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
IM_DIM = 112
N_FACES = 1024
N_EPOCHS = 100
BATCH_SIZE = 2048
N_COLORS = 1
DEVICE = torch.device("cuda:3")

In [None]:
class SymmetrifyView(nn.Module):
    def __init__(self, shape):
        super(SymmetrifyView,self).__init__()
        self.shape = shape
    def forward(self,x):
        out = x.view((-1,*self.shape))
        out += torch.flip(out, [3])
        out = out / 2
        return out
    
class View(nn.Module):
    def __init__(self, shape):
        super(View,self).__init__()
        self.shape = shape
    def forward(self,x):
        out = x.view((-1,*self.shape))
        return out

class Normalize(nn.Module):
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return (x - mean) / (std + 1e-6)
    



encoder = nn.Sequential(
    nn.Upsample(size=(IM_DIM,IM_DIM)),
    nn.Flatten(),
    nn.Linear(3*IM_DIM*IM_DIM, N_FACES,bias=False),
    Normalize()
)
    
decoder = nn.Sequential(
    nn.Linear(N_FACES, 3*IM_DIM*IM_DIM,bias=False), 
    nn.Flatten(),
    Normalize()
)

In [None]:
im_list = glob.glob("../celeba_aligned_with_mtcnn/*.jpg")

class Dataset(torch.utils.data.Dataset):
    def __init__(self,im_list):
        super(Dataset,self).__init__()
        self.im_list = im_list
    def __len__(self):
        return len(self.im_list)
    def __getitem__(self,i):
        x_gray = torch.FloatTensor(np.array(f.to_grayscale(Image.open(self.im_list[i]),3))) / 255 - 0.5 
        x_rgb = torch.FloatTensor(np.array(Image.open(self.im_list[i]))) / 255 - 0.5 
        
        return x_gray.permute(2,0,1), x_rgb.permute(2,0,1)
    
dataset = Dataset(im_list)
print("# images:", len(dataset))
print(dataset[0][0].min())
print(dataset[0][0].max())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=10)

In [None]:
plt.imshow(0.5 + dataset[10][0].permute(1,2,0).detach().cpu().numpy())

In [None]:
encoder.to(DEVICE)
decoder.to(DEVICE)
opt = torch.optim.AdamW(list(encoder.parameters()) + list(decoder.parameters()),lr=1e-3,weight_decay=0.1)

In [None]:
train_loss = []
test_loss = []
iters = 0
epoch = 0

In [None]:
for _ in range(N_EPOCHS):
    for X_gray, X_rgb in dataloader:
        
        iters += 1
        X_gray = X_gray.to(DEVICE)
        X_rgb = X_rgb.to(DEVICE)
        X_rgb_sym = SymmetrifyView((3,IM_DIM,IM_DIM))(X_rgb.clone())
        X_soft_sym = 0.5 * X_rgb + 0.5 * X_rgb_sym
        
        y = View((3,IM_DIM,IM_DIM))(decoder(encoder(X_gray)))
        
        noise = torch.normal(0,1,size=(X_rgb.shape[0],N_FACES)).to(DEVICE)
        y_noise = View((3,IM_DIM,IM_DIM))(decoder(noise))
        
        loss = nn.MSELoss("mean")(X_soft_sym,y) + 10 *  nn.MSELoss("mean")(y_noise,X_rgb)
        train_loss.append(loss.item())
        loss.backward()
        opt.step()
        opt.zero_grad()
        
        if iters % 50 == 0:
            
            clear_output(wait=True)
            plt.figure(dpi=130)
            plt.subplot(1,2,1)
            plt.title("Input")
            plt.imshow(0.5 + X_soft_sym[0].permute(1,2,0).detach().cpu().numpy())
            plt.axis('off')
            
            plt.subplot(1,2,2)
            plt.title("Output")
            plt.imshow(0.5 + y[0].permute(1,2,0).detach().cpu().numpy())
            plt.axis('off')
            plt.show()
            
            plt.semilogy(train_loss)
            plt.title("train loss")
            plt.grid()
            plt.show()
            print("epoch:",epoch)
            torch.save(decoder,f"eigenfaces_new.pt")
    epoch += 1