In [482]:
import matplotlib.pyplot as plt
import numpy as np
import random
import os
from tqdm import tqdm

import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader,random_split
from torch import nn
import time
import cv2

## 0 Load and Pre-process Dataset

In [483]:
eybpath = './CroppedYale'
subjects = os.listdir(eybpath)
IMG = np.zeros((38, 64, 32, 32))
for ii, s in enumerate(subjects):  # 38 subjects
    subpath = eybpath + '/' + s
    imgs = os.listdir(subpath)

    for kk, im in enumerate(imgs):  # 64 images in each subject file
        impath = subpath + '/' + im
        img = cv2.imread(impath, cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, dsize=(32,32))
        # Normalization
        img = img / 255
        IMG[ii][kk] = img
# IMG to tensor
IMG = torch.from_numpy(IMG) # torch.Size([38, 64, 32, 32])

In [484]:
# Obtain training set, validation set and testing set
# Split dataset randomly
list_sub = list(range(0, 64))
num_train = 52
num_val = 6
num_test = 6    # Each subject has 52 training images, 6 validation images, and 6 testing images

# train data: torch.Size([38, 52, 32, 32])
m = random.sample(range(0, 64), num_train)
train_data = torch.zeros((IMG.shape[0], num_train, IMG.shape[2], IMG.shape[3]))
for tt, mm in enumerate(m):
    list_sub.remove(mm)
    train_data[:,tt,:,:] = IMG[:,mm,:,:]
# val data: torch.Size([38, 6, 32, 32])
val_data = torch.zeros((IMG.shape[0], num_val, IMG.shape[2], IMG.shape[3]))
k = random.sample(list_sub, num_val)
for rr, kk in enumerate(k):
    list_sub.remove(kk)
    val_data[:,rr,:,:] = IMG[:,kk,:,:]
# test data: torch.Size([38, 6, 32, 32])
test_data = torch.zeros((IMG.shape[0], num_test, IMG.shape[2], IMG.shape[3]))
for uu, ll in enumerate(list_sub):
    test_data[:,uu,:,:] = IMG[:,ll,:,:]

In [485]:
def split_to_batch(data, batch_size):
    data_all = data.reshape((data.shape[0]*data.shape[1], data.shape[2], data.shape[3]))
    num_batch = int(np.ceil(data_all.shape[0] / batch_size))
    data_batches = []

    for i in range(num_batch):
        if i != num_batch-1:
            data_batch = data_all[i*batch_size: (i+1)*batch_size,:,:] # torch.Size([#Batch, 32, 32])
            #data_batch = data_batch.reshape(data_batch.shape[0], 1, data_batch.shape[1], data_batch.shape[2])
            data_batch = data_batch.unsqueeze(1)
            data_batches.append(data_batch)
        else:
            data_batch = data_all[i*batch_size: ,:,:]
            #data_batch = data_batch.reshape(data_batch.shape[0], 1, data_batch.shape[1], data_batch.shape[2])
            data_batch = data_batch.unsqueeze(1)
            data_batches.append(data_batch)

    return data_batches

In [486]:
# Split dataset into batches
batch_size = 32
train_batches = split_to_batch(train_data, batch_size)
val_batches = split_to_batch(val_data, batch_size)
test_batches = split_to_batch(test_data, batch_size)

## 1 Define Encoder and Decoder classes

In [487]:
class Encoder(nn.Module):

    def __init__(self, encoded_space_dim, fc2_input_dim):
        super().__init__()

        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(8, 16, 3, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=0),
            nn.ReLU(),
        )

        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)

        ### Linear section
        self.encoder_lin = nn.Sequential(
            nn.Linear(3 * 3 * 32, 128),
            nn.ReLU(),
            nn.Linear(128, encoded_space_dim)
        )

    def forward(self, x):
        # Apply convolutions
        x = self.encoder_cnn(x)
        # Flatten
        x = self.flatten(x)
        # Apply linear layers
        x = self.encoder_lin(x)
        return x

In [488]:
class Decoder(nn.Module):

    def __init__(self, encoded_space_dim, fc2_input_dim):
        super().__init__()

        ### Linear section
        self.decoder_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 3 * 3 * 32),
            nn.ReLU()
        )

        ### Unflatten
        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(32, 3, 3))

        ### Convolutional section
        self.decoder_conv = nn.Sequential(
            # First transposed convolution
            nn.ConvTranspose2d(32, 16, 3, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 8, 3, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(8, 1, 4, stride=2)
        )

    def forward(self, x):
        # Apply linear layers
        x = self.decoder_lin(x)
        # Unflatten
        x = self.unflatten(x)
        # Apply transposed convolutions
        x = self.decoder_conv(x)
        # Apply a sigmoid to force the output to be between 0 and 1 (valid pixel values)
        x = torch.sigmoid(x)
        return x

## 2 initialize models, loss and optimizer

In [489]:
### Set the random seed for reproducible results
torch.manual_seed(0)

### Initialize the two networks
d = 32

encoder = Encoder(encoded_space_dim=d, fc2_input_dim=128)
decoder = Decoder(encoded_space_dim=d, fc2_input_dim=128)

In [490]:
### Define the loss function
loss_fn = torch.nn.MSELoss()

### Define an optimizer (both for the encoder and the decoder!)
lr = 0.001 # Learning rate

params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
#print(f'Selected device: {device}')

optim = torch.optim.Adam(params_to_optimize, lr=lr)

# Move both the encoder and the decoder to the selected device
encoder.to(device)
decoder.to(device)


Decoder(
  (decoder_lin): Sequential(
    (0): Linear(in_features=32, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=288, bias=True)
    (3): ReLU()
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(32, 3, 3))
  (decoder_conv): Sequential(
    (0): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2))
    (1): ReLU()
    (2): ConvTranspose2d(16, 8, kernel_size=(3, 3), stride=(2, 2))
    (3): ReLU()
    (4): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(2, 2))
  )
)

## 3 Train model

In [491]:
### Training function
def train_epoch_den(encoder, decoder, device, dataloader, optimizer, val_lambda0=1, val_lambda=1, val_lambdaprime=1):
    # Set train mode for both the encoder and the decoder
    encoder.train()
    decoder.train()
    train_loss1 = []
    train_loss2 = []
    train_loss3 = []
    train_loss = []

    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
    for image_batch in dataloader: # with "_" we just ignore the labels (the second element of the dataloader tuple)
        # Move tensor to the proper device
        image_batch = image_batch.to(device)
        # Encode data (z)
        encoded_data = encoder(image_batch)
        z_zeros = torch.zeros_like(encoded_data)
        # Decode data (x_hat)
        decoded_data = decoder(encoded_data)
        # Re-encode data (z_hat)
        z_hat = encoder(decoded_data)

        # Evaluate loss
        loss_fn = torch.nn.MSELoss()
        loss1 = loss_fn(image_batch, decoded_data)
        loss2 = loss_fn(encoded_data, z_hat)
        loss3 = loss_fn(encoded_data, z_zeros)
        loss = val_lambda0 * loss1 + val_lambda * loss2 + val_lambdaprime * loss3

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Batch loss
        train_loss1.append(loss1.detach().cpu().numpy())
        train_loss2.append(loss2.detach().cpu().numpy())
        train_loss3.append(loss3.detach().cpu().numpy())
        train_loss.append(loss.detach().cpu().numpy())

    return np.mean(train_loss1), np.mean(train_loss2), np.mean(train_loss3), np.mean(train_loss)

In [492]:
### Testing function
def test_epoch_den(encoder, decoder, device, dataloader, loss_fn):
    # Set evaluation mode for encoder and decoder
    encoder.eval()
    decoder.eval()
    with torch.no_grad(): # No need to track the gradients
        # Define the lists to store the outputs for each batch
        conc_out = []
        conc_label = []
        # img_rec = []
        for image_batch in dataloader:
            # Move tensor to the proper device
            image_batch = image_batch.to(device)
            # Encode data
            encoded_data = encoder(image_batch)
            # Decode data
            decoded_data = decoder(encoded_data)
            # Append the network output and the original image to the lists
            conc_out.append(decoded_data.cpu())
            conc_label.append(image_batch.cpu())
            # img_rec.append(decoded_data.squeeze())  # (Batch_size, 32, 32)
        # Create a single tensor with all the values in the lists
        conc_out = torch.cat(conc_out)
        conc_label = torch.cat(conc_label)
        # img_rec = torch.cat(img_rec)
        # Evaluate global loss
        val_loss = loss_fn(conc_out, conc_label)
        val_loss = val_loss.numpy()
        # img_rec = img_rec.cpu().numpy()

    return val_loss

In [None]:
### Training cycle
num_epochs = 30
val_lambda0 = 1
val_lambda = 1
val_lambdaprime = 1
history_da={'train_loss':[],'val_loss':[]}
loss1 = []
loss2 = []
loss3 = []
loss = []
v_loss = []

for epoch in range(num_epochs):
    # print('EPOCH %d/%d' % (epoch + 1, num_epochs))
    ### Training (use the training function)
    train_loss1, train_loss2, train_loss3, train_loss=train_epoch_den(
        encoder=encoder,
        decoder=decoder,
        device=device,
        dataloader=train_batches,
        optimizer=optim,
        val_lambda0=val_lambda0,
        val_lambda=val_lambda,
        val_lambdaprime=val_lambdaprime)

    loss.append(train_loss)
    loss1.append(train_loss1)
    loss2.append(train_loss2)
    loss3.append(train_loss3)

    ### Validation  (use the testing function)
    val_loss = test_epoch_den(
        encoder=encoder,
        decoder=decoder,
        device=device,
        dataloader=val_batches,
        loss_fn=loss_fn)
    v_loss.append(val_loss)
    # Print loss
    print('EPOCH {}/{} \t train loss {:.3f} \t val loss {:.3f}'.format(epoch + 1, num_epochs, train_loss, val_loss))

prefix = './model' + '/CAE_dim{}'.format(d)+ '_lmd_{}_{}_{}'.format(val_lambda0, val_lambda, val_lambdaprime) + '_epc{}_'.format(epoch+1)
name = time.strftime(prefix + '%m%d_%H_%M_%S.pth')
torch.save([encoder, decoder], name)

## 4 Results

### 4.1 Training Loss and Reconstructed Images

In [None]:
plt.figure()
plt.plot(loss, '--', linewidth=2.5)
plt.plot(v_loss)
plt.plot(loss1)
plt.plot(loss2)
plt.plot(loss3)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(labels=['loss', 'val_loss', 'x-x_hat', 'z-z_hat', 'z'])
#plt.savefig('./figs/Loss_dim{}_lmd_{}_{}_{}.png'.format(d, val_lambda0, val_lambda, val_lambdaprime))
plt.show()

plt.figure()
plt.plot(loss2)
plt.plot(loss3)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(labels=['z-z_hat', 'z'])
#plt.savefig('./figs/Loss_z_dim{}_lmd_{}_{}_{}.png'.format(d, val_lambda0, val_lambda, val_lambdaprime))
plt.show()

In [None]:
encoder, decoder = torch.load('./model/CAE_dim32_lmd_1_1_1_epc30_0925_23_37_15.pth')
plt.figure(figsize=(10,4.5))
n = 5
for i in range(n):
    ax = plt.subplot(2,n,i+1)
    img = test_data[i+7][5].unsqueeze(0).unsqueeze(0).to(device)
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        rec_img = decoder(encoder(img))

    plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    if i == n//2:
        ax.set_title('Original images')
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    if i == n//2:
        ax.set_title('Reconstructed images')
#plt.savefig('./figs/Rec_imgs_dim{}_lmd_{}_{}_{}_a.png'.format(d, val_lambda0, val_lambda, val_lambdaprime))
plt.show()


### 4.2 SNR_x vs. SNR_z with noisy z (encoded data)

In [498]:
def add_noise(inputs, noise_factor=0.3):
     noise = inputs+torch.randn_like(inputs)*noise_factor
     noise = torch.clip(noise,0.,1.)
     noise = noise.float()
     return noise

In [499]:
### SNR_x vs. SNR_z
encoder, decoder = torch.load('./model/CAE_dim32_lmd_1_1_1_epc30_0925_23_37_15.pth')
snrs = np.linspace(1, 100, 100)
snrs_x = np.zeros(len(snrs))
i = 0
for snr in snrs:
    encoder.eval()
    decoder.eval()
    with torch.no_grad(): # No need to track the gradients
        conc_label = []
        conc_z = []
        conc_x_zeros = []
        conc_z_zeros = []
        for image_batch in test_batches:
            # Move tensor to the proper device (x)
            conc_label.append(image_batch)
            x_zeros = torch.zeros_like(image_batch)
            image_batch = image_batch.to(device)
            # Encode data (z)
            encoded_data = encoder(image_batch)
            encoded_data = encoded_data.detach().cpu()
            # Calculate sigma
            z_zeros = torch.zeros_like(encoded_data)
            conc_z.append(encoded_data)
            conc_x_zeros.append(x_zeros)
            conc_z_zeros.append(z_zeros)

        conc_label = torch.cat(conc_label)
        conc_z = torch.cat(conc_z)
        conc_x_zeros = torch.cat(conc_x_zeros)
        conc_z_zeros = torch.cat(conc_z_zeros)

        Ez = loss_fn(conc_z, conc_z_zeros)
        sigma = Ez / torch.pow(10, torch.log10(torch.from_numpy(np.array([snr/10]))))

        # Add noise
        z_tilde = add_noise(conc_z, sigma)
        z_tilde = z_tilde.to(device)
        # Decode noisy data(x_hat)
        decoded_data = decoder(z_tilde)
        decoded_data = decoded_data.detach().cpu()

        # Evaluate global loss
        lossx = loss_fn(decoded_data, conc_label)
        Ex = loss_fn(conc_label, conc_x_zeros)
        lossx = lossx.numpy()
        Ex = Ex.numpy()
        # Calculate SNR_x
        snr_x = 10 * np.log10(Ex / lossx)
        snrs_x[i] = snr_x

        i = i + 1

In [None]:
plt.figure()
plt.plot(snrs, snrs_x)
plt.xlabel('SNR_z (dB)')
plt.ylabel('SNR_x (dB)')
#plt.savefig('./figs/SNR_x_vs_SNR_z_dim{}_lmd_{}_{}_{}.png'.format(d, val_lambda0, val_lambda, val_lambdaprime))
plt.show()