In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
import os
from tqdm import tqdm
import plotly.io as pio

import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader,random_split
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.manifold import TSNE
import plotly.express as px

In [None]:
data_dir = 'dataset'
### With these commands the train and test datasets, respectively, are downloaded
### automatically and stored in the local "data_dir" directory.
train_dataset = torchvision.datasets.MNIST(data_dir, train=True, download=True)
test_dataset  = torchvision.datasets.MNIST(data_dir, train=False, download=True)

In [None]:
train_transform = transforms.Compose([
    transforms.ToTensor(),
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
])

# Set the train transform
train_dataset.transform = train_transform
# Set the test transform
test_dataset.transform = test_transform

In [None]:
m=len(train_dataset)

#random_split randomly split a dataset into non-overlapping new datasets of given lengths
#train (55,000 images), val split (5,000 images)
train_data, val_data = random_split(train_dataset, [int(m - m * 0.2), int(m * 0.2)])

batch_size = 256

# The dataloaders handle shuffling, batching, etc...
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)

valid_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
result = []

## 1 Define Encoder and Decoder classes

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, encoded_space_dim, fc2_input_dim):
        super().__init__()
        
        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            # First convolutional layer
            nn.Conv2d(1, 8, 3, stride=2, padding=1),
            # nn.BatchNorm2d(8),
            nn.ReLU(True),
            # Second convolutional layer
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            # Third convolutional layer
            nn.Conv2d(16, 32, 3, stride=2, padding=0),
            # nn.BatchNorm2d(32),
            nn.ReLU(True)
        )
        
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)

        ### Linear section
        self.encoder_lin = nn.Sequential(
            # First linear layer
            nn.Linear(3 * 3 * 32, 128),
            nn.ReLU(True),
            # Second linear layer
            nn.Linear(128, encoded_space_dim)
        )
        
    def forward(self, x):
        # Apply convolutions
        x = self.encoder_cnn(x)
        # Flatten
        x = self.flatten(x)
        # Apply linear layers
        x = self.encoder_lin(x)
        return x

In [None]:
class Decoder(nn.Module):
    
    def __init__(self, encoded_space_dim, fc2_input_dim):
        super().__init__()

        ### Linear section
        self.decoder_lin = nn.Sequential(
            # First linear layer
            nn.Linear(encoded_space_dim, 128),
            nn.ReLU(True),
            # Second linear layer
            nn.Linear(128, 3 * 3 * 32),
            nn.ReLU(True)
        )

        ### Unflatten
        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(32, 3, 3))

        ### Convolutional section
        self.decoder_conv = nn.Sequential(
            # First transposed convolution
            nn.ConvTranspose2d(32, 16, 3, stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            # Second transposed convolution
            nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            # Third transposed convolution
            nn.ConvTranspose2d(8, 1, 3, stride=2, padding=1, output_padding=1)
        )
        
    def forward(self, x):
        # Apply linear layers
        x = self.decoder_lin(x)
        # Unflatten
        x = self.unflatten(x)
        # Apply transposed convolutions
        x = self.decoder_conv(x)
        # Apply a sigmoid to force the output to be between 0 and 1 (valid pixel values)
        x = torch.sigmoid(x)
        return x

## 2 Define function

In [None]:
def add_noise(inputs, noise_factor=0.3):
     noise = inputs + torch.randn_like(inputs) * noise_factor
     noise = torch.clip(noise, 0., 1.)
     return noise

In [None]:
## Compute covariance
def cov_cpt(img):
    # img = img.detach().cpu().numpy()
    cov_out = torch.zeros((len(img), img.shape[2], img.shape[3]))
    for ii in range(len(img)):
        pic = img[ii][0]
        cov_img = torch.cov(pic) # (28, 28)
        cov_out[ii] = cov_img

    return cov_out # (256,28, 28) 

In [None]:
def plot_ae_outputs_den(encoder,decoder,n=5,noise_factor=0.3):
    plt.figure(figsize=(10,4.5))
    for i in range(n):

      ax = plt.subplot(3,n,i+1)
      img = test_dataset[i][0].unsqueeze(0)
      image_noisy = add_noise(img,noise_factor)     
      image_noisy = image_noisy.to(device)

      encoder.eval()
      decoder.eval()

      with torch.no_grad():
         rec_img = decoder(encoder(image_noisy))

      plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
        ax.set_title('Original images')
      ax = plt.subplot(3, n, i + 1 + n)
      plt.imshow(image_noisy.cpu().squeeze().numpy(), cmap='gist_gray')
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
        ax.set_title('Corrupted images')

      ax = plt.subplot(3, n, i + 1 + n + n)
      plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')  
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
         ax.set_title('Reconstructed images')
    plt.subplots_adjust(left=0.1,
                    bottom=0.1, 
                    right=0.7, 
                    top=0.9, 
                    wspace=0.3, 
                    hspace=0.3)     
    plt.show()   

In [None]:
def plot_train_loss(train_loss, train_loss_x, train_loss_z):
    fig, ax = plt.subplots() 
    ax.plot(range(len(train_loss)), train_loss, label='train_loss') 
    ax.plot(range(len(train_loss_x)), train_loss_x, label='train_loss_x') 
    ax.plot(range(len(train_loss_z)), train_loss_z, label='train_loss_z')
    ax.set_xlabel('Training epoch') 
    ax.set_ylabel('Loss') 
    ax.legend() 

    plt.show() 
    

In [None]:
def plot_train_val_loss(train_loss, val_loss):
    fig, ax = plt.subplots() 
    ax.plot(range(len(train_loss)), train_loss, label='train_loss') 
    ax.plot(range(len(val_loss)), val_loss, label='val_loss') 

    ax.set_xlabel('Epochs') 
    ax.set_ylabel('Loss') 
    ax.legend() 
    plt.show() 
    return fig

## 3 Train model

### 3.1 Non-feedback model

In [None]:
torch.manual_seed(0)
### Initialize the two networks
d = 4
encoder = Encoder(encoded_space_dim=d,fc2_input_dim=128)
decoder = Decoder(encoded_space_dim=d,fc2_input_dim=128)
### Define the loss function
loss_fn = torch.nn.MSELoss()
lr= 0.001 # Learning rate
params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]
optim = torch.optim.Adam(params_to_optimize, lr=lr)
encoder.to(device)
decoder.to(device)

Decoder(
  (decoder_lin): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=128, out_features=288, bias=True)
    (3): ReLU(inplace=True)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(32, 3, 3))
  (decoder_conv): Sequential(
    (0): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(16, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (4): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(8, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  )
)

In [None]:
### Training function
from email.mime import image


def train_epoch_den(encoder, decoder, device, dataloader, loss_fn, optimizer,noise_factor=0.3):
    # Set train mode for both the encoder and the decoder
    encoder.train()
    decoder.train()
    train_loss = []
    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
    for image_batch, _ in dataloader: # with "_" we just ignore the labels (the second element of the dataloader tuple)
        # Move tensor to the proper device
        image_noisy = add_noise(image_batch,noise_factor)
        image_noisy = image_noisy.to(device)    
        # Encode data
        encoded_data = encoder(image_noisy)
        # Decode data
        decoded_data = decoder(encoded_data)
        # Compute covariance
        c_x = cov_cpt(image_noisy)
        c_x_bar = cov_cpt(decoded_data)
        # Evaluate loss
        loss = loss_fn(c_x, c_x_bar)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Print batch loss
        # print('\t partial train loss (single batch): %f' % (loss.data))
        train_loss.append(loss.detach().cpu().numpy())

    return np.mean(train_loss)

In [None]:
### Testing function
def test_epoch_den(encoder, decoder, device, dataloader, loss_fn,noise_factor=0.3):
    # Set evaluation mode for encoder and decoder
    encoder.eval()
    decoder.eval()
    with torch.no_grad(): # No need to track the gradients
        # Define the lists to store the outputs for each batch
        conc_out = []
        conc_label = []
        for image_batch, _ in dataloader:
            # Move tensor to the proper device
            image_noisy = add_noise(image_batch,noise_factor)
            image_noisy = image_noisy.to(device)
            # Encode data
            encoded_data = encoder(image_noisy)
            # Decode data
            decoded_data = decoder(encoded_data)
            # Append the network output and the original image to the lists
            conc_out.append(decoded_data.cpu())
            conc_label.append(image_batch.cpu())
        # Create a single tensor with all the values in the lists
        conc_out = torch.cat(conc_out)
        conc_label = torch.cat(conc_label) 
        # Evaluate global loss
        val_loss = loss_fn(conc_out, conc_label)
    return val_loss.data

In [None]:
### Training cycle
noise_factor = 0.3
num_epochs = 10
history_da={'train_loss':[],'val_loss':[]}

for epoch in range(num_epochs):
    ### Training (use the training function)
    train_loss=train_epoch_den(
        encoder=encoder, 
        decoder=decoder, 
        device=device, 
        dataloader=train_loader, 
        loss_fn=loss_fn, 
        optimizer=optim,noise_factor=noise_factor)
    ### Validation  (use the testing function)
    val_loss = test_epoch_den(
        encoder=encoder, 
        decoder=decoder, 
        device=device, 
        dataloader=valid_loader, 
        loss_fn=loss_fn,noise_factor=noise_factor)
    # Print Validationloss
    history_da['train_loss'].append(train_loss)
    history_da['val_loss'].append(val_loss)
    print('\n EPOCH {}/{} \t train loss {:.3f} \t val loss {:.3f}'.format(epoch + 1, num_epochs,train_loss,val_loss))


In [None]:
plot_ae_outputs_den(encoder,decoder,noise_factor=noise_factor)

In [None]:
res=test_epoch_den(encoder,decoder,device,test_loader,loss_fn)
print(res.item())
result.append({'lambda_x':1, 'lambda_z':0, 'test_loss':res.item()})

In [None]:
fig = plot_train_val_loss(history_da['train_loss'], history_da['val_loss'])
# fig.savefig('fig/covariance_1_0')

### 3.2 Feedback model

#### 3.2.1 $\mathcal{L}=\|C_x-C_{\hat{x}}\|_{F}^{2}+\lambda\|z-\hat{z}\|_{2}^{2}$

In [22]:
### Training function
def train_epoch_den(encoder, decoder, device, dataloader, loss_fn, optimizer, para, noise_factor=0.3):
    # Set train mode for both the encoder and the decoder
    encoder.train()
    decoder.train()
    train_loss = []
    train_loss_x = []
    train_loss_z = []
    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
    for image_batch, _ in dataloader: # with "_" we just ignore the labels (the second element of the dataloader tuple)
        # Move tensor to the proper device
        image_noisy = add_noise(image_batch,noise_factor)
        image_noisy = image_noisy.to(device)    
        # Encode data
        encoded_data = encoder(image_noisy)
        # Decode data
        decoded_data = decoder(encoded_data)
        # Feedback data
        feedback_data = encoder(decoded_data)
        # Evaluate loss
        c_x = cov_cpt(image_noisy)
        c_x_bar = cov_cpt(decoded_data)
        
        loss_x = loss_fn(c_x, c_x_bar)
        loss_z = loss_fn(feedback_data, encoded_data)
        
        loss = loss_x + para * loss_z
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Print batch loss
        # print('\t partial train loss (single batch): %f' % (loss.data))
        train_loss.append(loss.detach().cpu().numpy())
        train_loss_x.append(loss_x.detach().cpu().numpy())
        train_loss_z.append(loss_z.detach().cpu().numpy())

    return np.mean(train_loss), np.mean(train_loss_x), np.mean(train_loss_z)

In [23]:
### Testing function
def test_epoch_den(encoder, decoder, device, dataloader, loss_fn,noise_factor=0.3):
    # Set evaluation mode for encoder and decoder
    encoder.eval()
    decoder.eval()
    with torch.no_grad(): # No need to track the gradients
        # Define the lists to store the outputs for each batch
        conc_out = []
        conc_label = []
        for image_batch, _ in dataloader:
            # Move tensor to the proper device
            image_noisy = add_noise(image_batch,noise_factor)
            image_noisy = image_noisy.to(device)
            # Encode data
            encoded_data = encoder(image_noisy)
            # Decode data
            decoded_data = decoder(encoded_data)
            # Append the network output and the original image to the lists
            conc_out.append(decoded_data.cpu())
            conc_label.append(image_batch.cpu())
        # Create a single tensor with all the values in the lists
        conc_out = torch.cat(conc_out)
        conc_label = torch.cat(conc_label) 
        # Evaluate global loss
        val_loss = loss_fn(conc_out, conc_label)
    return val_loss.data

$\lambda=1$

In [24]:
torch.manual_seed(0)
### Initialize the two networks
d = 4
encoder = Encoder(encoded_space_dim=d,fc2_input_dim=128)
decoder = Decoder(encoded_space_dim=d,fc2_input_dim=128)
### Define the loss function
loss_fn = torch.nn.MSELoss()
lr= 0.001 # Learning rate
params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
optim = torch.optim.Adam(params_to_optimize, lr=lr)
encoder.to(device)
decoder.to(device)

Decoder(
  (decoder_lin): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=128, out_features=288, bias=True)
    (3): ReLU(inplace=True)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(32, 3, 3))
  (decoder_conv): Sequential(
    (0): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(16, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (4): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(8, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  )
)

In [None]:
### Training cycle
noise_factor = 0.3
num_epochs = 10
para = 1
history_da={'train_loss':[], 'val_loss':[], 'train_loss_x':[], 'train_loss_z':[]}

for epoch in range(num_epochs):
    ### Training (use the training function)
    train_loss, train_loss_x, train_loss_z=train_epoch_den(
        encoder=encoder, 
        decoder=decoder, 
        device=device, 
        dataloader=train_loader, 
        loss_fn=loss_fn, 
        para=para,
        optimizer=optim,noise_factor=noise_factor)
    ### Validation  (use the testing function)
    val_loss = test_epoch_den(
        encoder=encoder, 
        decoder=decoder, 
        device=device, 
        dataloader=valid_loader, 
        loss_fn=loss_fn,noise_factor=noise_factor)
    # Print Validationloss
    history_da['train_loss'].append(train_loss)
    history_da['train_loss_x'].append(train_loss_x)
    history_da['train_loss_z'].append(train_loss_z)
    history_da['val_loss'].append(val_loss)

    print('\n EPOCH {}/{} \t train loss {:.5f} \t val loss {:.5f} \t train loss x {:.5f} \t train loss z {:.5f}'.format(epoch + 1, num_epochs,train_loss,val_loss,train_loss_x,train_loss_z))

In [None]:
plot_ae_outputs_den(encoder,decoder,noise_factor=noise_factor)

In [None]:
res=test_epoch_den(encoder,decoder,device,test_loader,loss_fn)
print(res.item())
result.append({'lambda_x':1, 'lambda_z':para, 'test_loss':res.item()})

In [None]:
fig = plot_train_val_loss(history_da['train_loss'], history_da['val_loss'])
# fig.savefig('fig/covariance_1_1')

In [None]:
plot_train_loss(history_da['train_loss'], history_da['train_loss_x'], history_da['train_loss_z'])

$\lambda=0.5$

In [30]:
torch.manual_seed(0)
### Initialize the two networks
d = 4
encoder = Encoder(encoded_space_dim=d,fc2_input_dim=128)
decoder = Decoder(encoded_space_dim=d,fc2_input_dim=128)
### Define the loss function
loss_fn = torch.nn.MSELoss()
lr= 0.001 # Learning rate
params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
optim = torch.optim.Adam(params_to_optimize, lr=lr)
encoder.to(device)
decoder.to(device)

Decoder(
  (decoder_lin): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=128, out_features=288, bias=True)
    (3): ReLU(inplace=True)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(32, 3, 3))
  (decoder_conv): Sequential(
    (0): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(16, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (4): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(8, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  )
)

In [None]:
### Training cycle
noise_factor = 0.3
num_epochs = 10
para = 0.5
history_da={'train_loss':[], 'val_loss':[], 'train_loss_x':[], 'train_loss_z':[]}

for epoch in range(num_epochs):
    ### Training (use the training function)
    train_loss, train_loss_x, train_loss_z=train_epoch_den(
        encoder=encoder, 
        decoder=decoder, 
        device=device, 
        dataloader=train_loader, 
        loss_fn=loss_fn, 
        para=para,
        optimizer=optim,noise_factor=noise_factor)
    ### Validation  (use the testing function)
    val_loss = test_epoch_den(
        encoder=encoder, 
        decoder=decoder, 
        device=device, 
        dataloader=valid_loader, 
        loss_fn=loss_fn,noise_factor=noise_factor)
    # Print Validationloss
    history_da['train_loss'].append(train_loss)
    history_da['train_loss_x'].append(train_loss_x)
    history_da['train_loss_z'].append(train_loss_z)
    history_da['val_loss'].append(val_loss)

    print('\n EPOCH {}/{} \t train loss {:.5f} \t val loss {:.5f} \t train loss x {:.5f} \t train loss z {:.5f}'.format(epoch + 1, num_epochs,train_loss,val_loss,train_loss_x,train_loss_z))

In [None]:
plot_ae_outputs_den(encoder,decoder,noise_factor=noise_factor)

In [None]:
res=test_epoch_den(encoder,decoder,device,test_loader,loss_fn)
print(res.item())
result.append({'lambda_x':1, 'lambda_z':para, 'test_loss':res.item()})

In [None]:
fig = plot_train_val_loss(history_da['train_loss'], history_da['val_loss'])
# fig.savefig('fig/covariance_1_05')

In [None]:
plot_train_loss(history_da['train_loss'], history_da['train_loss_x'], history_da['train_loss_z'])

$\lambda=2$

In [36]:
torch.manual_seed(0)
### Initialize the two networks
d = 4
encoder = Encoder(encoded_space_dim=d,fc2_input_dim=128)
decoder = Decoder(encoded_space_dim=d,fc2_input_dim=128)
### Define the loss function
loss_fn = torch.nn.MSELoss()
lr= 0.001 # Learning rate
params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
optim = torch.optim.Adam(params_to_optimize, lr=lr)
encoder.to(device)
decoder.to(device)

Decoder(
  (decoder_lin): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=128, out_features=288, bias=True)
    (3): ReLU(inplace=True)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(32, 3, 3))
  (decoder_conv): Sequential(
    (0): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(16, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (4): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(8, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  )
)

In [None]:
### Training cycle
noise_factor = 0.3
num_epochs = 10
para = 2
history_da={'train_loss':[], 'val_loss':[], 'train_loss_x':[], 'train_loss_z':[]}

for epoch in range(num_epochs):
    ### Training (use the training function)
    train_loss, train_loss_x, train_loss_z=train_epoch_den(
        encoder=encoder, 
        decoder=decoder, 
        device=device, 
        dataloader=train_loader, 
        loss_fn=loss_fn, 
        para=para,
        optimizer=optim,noise_factor=noise_factor)
    ### Validation  (use the testing function)
    val_loss = test_epoch_den(
        encoder=encoder, 
        decoder=decoder, 
        device=device, 
        dataloader=valid_loader, 
        loss_fn=loss_fn,noise_factor=noise_factor)
    # Print Validationloss
    history_da['train_loss'].append(train_loss)
    history_da['train_loss_x'].append(train_loss_x)
    history_da['train_loss_z'].append(train_loss_z)
    history_da['val_loss'].append(val_loss)

    print('\n EPOCH {}/{} \t train loss {:.5f} \t val loss {:.5f} \t train loss x {:.5f} \t train loss z {:.5f}'.format(epoch + 1, num_epochs,train_loss,val_loss,train_loss_x,train_loss_z))

In [None]:
plot_ae_outputs_den(encoder,decoder,noise_factor=noise_factor)

In [None]:
res=test_epoch_den(encoder,decoder,device,test_loader,loss_fn)
print(res.item())
result.append({'lambda_x':1, 'lambda_z':para, 'test_loss':res.item()})

In [None]:
fig = plot_train_val_loss(history_da['train_loss'], history_da['val_loss'])
# fig.savefig('fig/covariance_1_2')

In [None]:
plot_train_loss(history_da['train_loss'], history_da['train_loss_x'], history_da['train_loss_z'])

#### 3.2.2 $\mathcal{L}=\lambda\|x-\hat{x}\|_{2}^{2}+\|z-\hat{z}\|_{2}^{2}$

In [42]:
### Training function
def train_epoch_den(encoder, decoder, device, dataloader, loss_fn, optimizer, para, noise_factor=0.3):
    # Set train mode for both the encoder and the decoder
    encoder.train()
    decoder.train()
    train_loss = []
    train_loss_x = []
    train_loss_z = []
    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
    for image_batch, _ in dataloader: # with "_" we just ignore the labels (the second element of the dataloader tuple)
        # Move tensor to the proper device
        image_noisy = add_noise(image_batch,noise_factor)
        image_noisy = image_noisy.to(device)    
        # Encode data
        encoded_data = encoder(image_noisy)
        # Decode data
        decoded_data = decoder(encoded_data)
        # Feedback data
        feedback_data = encoder(decoded_data)
        # Evaluate loss
        c_x = cov_cpt(image_noisy)
        c_x_bar = cov_cpt(decoded_data)
        
        loss_x = loss_fn(c_x, c_x_bar)
        loss_z = loss_fn(feedback_data, encoded_data)
        
        loss = para * loss_x + loss_z
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Print batch loss
        # print('\t partial train loss (single batch): %f' % (loss.data))
        train_loss.append(loss.detach().cpu().numpy())
        train_loss_x.append(loss_x.detach().cpu().numpy())
        train_loss_z.append(loss_z.detach().cpu().numpy())

    return np.mean(train_loss), np.mean(train_loss_x), np.mean(train_loss_z)

$\lambda=0.5$

In [43]:
torch.manual_seed(0)
### Initialize the two networks
d = 4
encoder = Encoder(encoded_space_dim=d,fc2_input_dim=128)
decoder = Decoder(encoded_space_dim=d,fc2_input_dim=128)
### Define the loss function
loss_fn = torch.nn.MSELoss()
lr= 0.001 # Learning rate
params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
optim = torch.optim.Adam(params_to_optimize, lr=lr)
encoder.to(device)
decoder.to(device)

Decoder(
  (decoder_lin): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=128, out_features=288, bias=True)
    (3): ReLU(inplace=True)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(32, 3, 3))
  (decoder_conv): Sequential(
    (0): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(16, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (4): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(8, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  )
)

In [None]:
### Training cycle
noise_factor = 0.3
num_epochs = 10
para = 0.5
history_da={'train_loss':[], 'val_loss':[], 'train_loss_x':[], 'train_loss_z':[]}

for epoch in range(num_epochs):
    ### Training (use the training function)
    train_loss, train_loss_x, train_loss_z=train_epoch_den(
        encoder=encoder, 
        decoder=decoder, 
        device=device, 
        dataloader=train_loader, 
        loss_fn=loss_fn, 
        para=para,
        optimizer=optim,noise_factor=noise_factor)
    ### Validation  (use the testing function)
    val_loss = test_epoch_den(
        encoder=encoder, 
        decoder=decoder, 
        device=device, 
        dataloader=valid_loader, 
        loss_fn=loss_fn,noise_factor=noise_factor)
    # Print Validationloss
    history_da['train_loss'].append(train_loss)
    history_da['train_loss_x'].append(train_loss_x)
    history_da['train_loss_z'].append(train_loss_z)
    history_da['val_loss'].append(val_loss)

    print('\n EPOCH {}/{} \t train loss {:.5f} \t val loss {:.5f} \t train loss x {:.5f} \t train loss z {:.5f}'.format(epoch + 1, num_epochs,train_loss,val_loss,train_loss_x,train_loss_z))

In [None]:
plot_ae_outputs_den(encoder,decoder,noise_factor=noise_factor)

In [None]:
res=test_epoch_den(encoder,decoder,device,test_loader,loss_fn)
print(res.item())
result.append({'lambda_x':para, 'lambda_z':1, 'test_loss':res.item()})

In [None]:
fig = plot_train_val_loss(history_da['train_loss'], history_da['val_loss'])
# fig.savefig('fig/covariance_05_1')

In [None]:
plot_train_loss(history_da['train_loss'], history_da['train_loss_x'], history_da['train_loss_z'])

## 4 Result

In [49]:
pd.DataFrame(result)

Unnamed: 0,lambda_x,lambda_z,test_loss
0,1.0,0.0,0.348588
1,1.0,1.0,0.312755
2,1.0,0.5,0.321387
3,1.0,2.0,0.309205
4,0.5,1.0,0.305212
