In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [None]:
transform = transforms.Compose([transforms.ToTensor()])

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=64, shuffle=True)

mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=64, shuffle=True)

In [None]:
#original
def myError(x,y,l,W):
  return ((1/2)*(x-y)**2).mean()  +(l/2)*W

def kl_divergence(r, r_hat):
    r_hat = torch.mean(torch.sigmoid(r_hat), 1) 
    r = torch.tensor([r] * (r_hat).shape[0]).to(device)
    loss= torch.sum(r * torch.log(r/r_hat) + (1 - r) * torch.log((1 - r)/(1 - r_hat)))
    return loss
# define the sparse loss function

def sparse_loss(r, input):
    leaky=nn.LeakyReLU(0.2)
    values = input
    loss = 0
    model_children = list(model.children())
    for i in range(np.array(model_children).shape[0]):
      values = leaky((model_children[i](values)))
      if i in np.arange(len(model_children))[:-2]:#we need the hidden layers 
        loss += kl_divergence(r, values)
    return loss

def calcW(model):
  model_children = list(model.children())
  loss = 0
  for i in range(np.array(model_children).shape[0]):
    if i in np.arange(len(model_children))[1:-1]:#proposed
        loss += ((abs(model_children[i].weight.data)**2).mean())
  return loss

def noise(array,noise_factor):
    noisy_array = array + noise_factor * np.random.normal(
        loc=0.0, scale=1.0, size=array.shape)
    return np.clip(noisy_array, 0.0, 1.0)


def evaluateOutputs(test_loader,classifier,autoencoder,noise_factor): 
  correct_count, all_count = 0, 0
  for images,labels in test_loader:

    X_test=images.reshape(-1,784)
    input=noise(X_test,noise_factor).to(device)
    output=autoencoder(input).detach().cpu()
    for i in range(len(labels)):
      img = output[i].view(1, 784).double().to(device)
      with torch.no_grad():
          logps = classifier(img)

      ps = torch.exp(logps).detach().cpu()
      probab = list(ps.numpy()[0])
      pred_label = probab.index(max(probab))
      true_label = labels.numpy()[i]
      if(true_label == pred_label):
        correct_count += 1
      all_count += 1
  return (correct_count/all_count)

In [None]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
input_size=784
learning_rate=1e-4/2

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super( AutoEncoder, self).__init__()
        # encoder
        self.enc1 = nn.Linear(input_size,512,bias=True)
        self.enc2 = nn.Linear(in_features=512, out_features=256,bias=True)
        self.enc3 = nn.Linear(in_features=256, out_features=128,bias=True)
        self.enc4 = nn.Linear(in_features=128, out_features=64,bias=True)
        # decoder 
        self.dec1 = nn.Linear(in_features=64, out_features=128,bias=True)
        self.dec2 = nn.Linear(in_features=128, out_features=256,bias=True)
        self.dec3 = nn.Linear(in_features=256, out_features=512,bias=True)
        self.dec4 = nn.Linear(in_features=512, out_features=input_size,bias=True)
 
    def forward(self, x):
        self.leaky=nn.LeakyReLU(0.2)
        # encoding
        x = self.leaky(self.enc1(x))
        x = self.leaky(self.enc2(x))
        x = self.leaky(self.enc3(x))
        x = self.leaky(self.enc4(x))
        # decoding
        x = self.leaky(self.dec1(x))
        x = self.leaky(self.dec2(x))
        x = self.leaky(self.dec3(x))
        x = torch.sigmoid(self.dec4(x))
        return x

model=AutoEncoder().double().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
Error=[]
Validation=[]
Outputs=[]

In [None]:
Epochs=90
lamda=1
beta=.1
rho=0.05

noise_factor=0.35
for epoch in range(Epochs):
  e=[]
  for i,_ in train_loader:
    i=i.reshape(-1,input_size)
    corrupted_data=noise(i,noise_factor)
    corrupted_data=corrupted_data.to(device)
    clean_data=i.to(device)

        # ===================forward=====================
       
    output = model(corrupted_data)
    W=calcW(model).to(device)
    loss = myError(output,clean_data,lamda,W) + beta*sparse_loss(rho, corrupted_data).to(device) 
        # ===================backward====================
    optimizer.zero_grad()
    loss.backward()
    e.append(loss.item())
    with torch.no_grad():
        optimizer.step()
    # ===================log========================
  with torch.no_grad():
      v=[] #v contains the validation of all testset batches
      model_children=list(model.children())
      for j,_ in test_loader:
        j=j.reshape(-1,input_size)
        corrupted_data_validation=noise(j,noise_factor)
        corrupted_data_validation=corrupted_data_validation.to(device)
        clean_data_validation=j.to(device)

        
        output_validation=model(corrupted_data_validation)
        W=calcW(model).to(device)
        
        validation = myError( output_validation,clean_data_validation,lamda,W)+beta*sparse_loss(rho, corrupted_data_validation).to(device)

        v.append(validation.item())
  
  Error.append(np.array(e).mean())
  Validation.append(np.array(v).mean())
  print('epoch [{}/{}], loss:{:.4f} , validation loss:{:.4f}'
          .format(epoch + 1, Epochs, np.array(e).mean(),np.array(v).mean()) )

In [None]:
plt.title('Error ')
plt.plot(Error)
plt.show()
plt.title('Validation')
plt.plot(Validation)
plt.show()
#torch.save(model,'model')

# Denoising visualization

In [None]:
noise_factor=0.35

model_proposed=torch.load('ae_mnist0.25_3layers-proposed').double().to(device)
model_bibliography=torch.load('ae_mnist0.25_3layers-full_reg').double().to(device)

for i in test_loader:
    
    X_test,_=i
    X_test=X_test.reshape(-1,784)
    input1=noise(X_test,noise_factor).to(device)
    output_proposed=model_proposed(input1)
    output_bibliography=model_bibliography(input1)  

    input1=input1.cpu().detach().numpy().reshape(-1,28,28)
    output_proposed=output_proposed.cpu().detach().numpy().reshape(-1,28,28)
    output_bibliography=output_bibliography.cpu().detach().numpy().reshape(-1,28,28)
    break



    
    
gridsize = (3, 3)
fig = plt.figure(figsize=(25, 15))
ax1 = plt.subplot2grid(gridsize, (0, 0))
ax2 = plt.subplot2grid(gridsize, (0, 1))
ax3 = plt.subplot2grid(gridsize, (0, 2))

ax4 = plt.subplot2grid(gridsize, (1, 0))
ax5 = plt.subplot2grid(gridsize, (1, 1))
ax6 = plt.subplot2grid(gridsize, (1, 2))

ax7 = plt.subplot2grid(gridsize, (2, 0))
ax8 = plt.subplot2grid(gridsize, (2, 1))
ax9 = plt.subplot2grid(gridsize, (2, 2))

fig.suptitle('SNR=5.89dB, 3 Layers ', fontsize=25)
fig.set_tight_layout(True)

ax1.set_title('Noisy Picture:')
ax1.imshow(input1[0],cmap='gray')
ax2.set_title('Noisy Picture:')
ax2.imshow(input1[1],cmap='gray')
ax3.set_title('Noisy Picture:')
ax3.imshow(input1[2],cmap='gray')

ax4.set_title('Clean Picture bibliography:')
ax4.imshow(output_bibliography[0],cmap='gray')
ax5.set_title('Clean Picture bibliography:')
ax5.imshow(output_bibliography[1],cmap='gray')
ax6.set_title('Clean Picture bibliography:')
ax6.imshow(output_bibliography[2],cmap='gray')

ax7.set_title('Clean Picture proposed:')
ax7.imshow(output_proposed[0],cmap='gray')
ax8.set_title('Clean Picture proposed:')
ax8.imshow(output_proposed[1],cmap='gray')
ax9.set_title('Clean Picture proposed:')
ax9.imshow(output_proposed[2],cmap='gray')
plt.show()

In [None]:
for i in test_loader:
    
    X_test,_=i
    X_test=X_test.reshape(-1,784)
    input1=noise(X_test,noise_factor).to(device)
    output=model(input1)
    
    input1=input1.cpu().detach().numpy().reshape(-1,28,28)
    output=output.cpu().detach().numpy().reshape(-1,28,28)

    break
plt.imshow(input1[0],cmap='gray')
plt.show()
plt.imshow(output[0],cmap='gray')

# Classifier validation

In [None]:
model_proposed=torch.load('ae_mnist0.25_3layers-proposed').double().to(device)
model_bibliography=torch.load('ae_mnist0.25_3layers-full_reg').double().to(device)

In [None]:
model_proposed=torch.load('ae_mnist0.35_4layers-proposed').double().to(device)
model_bibliography=torch.load('ae_mnist0.35_4layers-full-reg').double().to(device)

In [None]:
noise_factor=0.35
mean_acc=[]
classifier=torch.load('MNISTclassifier').double().to(device)
for i in range(10):
    mean_acc.append(evaluateOutputs(test_loader,classifier,model,noise_factor))
print(np.array(mean_acc).mean())

In [None]:
noise_factor=0.35
mean_acc=[]
classifier=torch.load('MNISTclassifier').double().to(device)
for i in range(100):
    mean_acc.append(evaluateOutputs(test_loader,classifier,model,noise_factor))
print(np.array(mean_acc).mean())

# Rest of experiments

In [None]:
learning_rate=1e-4/4
for noise_factor in [0.1, 0.25,0.35,0.4]: 
    model=AutoEncoder().double().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    Error=[]
    Validation=[]
    Outputs=[]
    
    for epoch in range(Epochs):
        e=[]
        for i,_ in train_loader:
            
            i=i.reshape(-1,input_size)
            corrupted_data=noise(i,noise_factor)
            corrupted_data=corrupted_data.to(device)
            clean_data=i.to(device)

        # ===================forward=====================
       
            output = model(corrupted_data)
            W=calcW(model).to(device)
            loss = myError(output,clean_data,lamda,W) + beta*sparse_loss(rho, corrupted_data).to(device) 
        # ===================backward====================
            optimizer.zero_grad()
            loss.backward()
            e.append(loss.item())
            with torch.no_grad():
                optimizer.step()
    # ===================log========================
        with torch.no_grad():
            v=[] #v contains the validation of all testset batches
            model_children=list(model.children())
            for j,_ in test_loader:
                
                j=j.reshape(-1,input_size)
                corrupted_data_validation=noise(j,noise_factor)
                corrupted_data_validation=corrupted_data_validation.to(device)
                clean_data_validation=j.to(device)
                
                output_validation=model(corrupted_data_validation)
                W=calcW(model).to(device)
                validation = myError( output_validation,clean_data_validation,lamda,W)+beta*sparse_loss(rho, corrupted_data_validation).to(device)

                v.append(validation.item())
  
        Error.append(np.array(e).mean())
        Validation.append(np.array(v).mean())
        print('epoch [{}/{}], loss:{:.4f} , validation loss:{:.4f}'
          .format(epoch + 1, Epochs, np.array(e).mean(),np.array(v).mean()) )
        
        
    torch.save(model,'ae_mnist'+str(noise_factor)+'_4layers-proposed')

In [None]:
noise_factor=0.1
mean_acc=[]
classifier=torch.load('MNISTclassifier').double().to(device)
model=torch.load('ae_mnist0.1_3layers-proposed').double().to(device)
for i in range(100):
    mean_acc.append(evaluateOutputs(test_loader,classifier,model,noise_factor))
print(np.array(mean_acc).mean())

In [None]:
noise_factor=0.25
mean_acc=[]
classifier=torch.load('MNISTclassifier').double().to(device)
model=torch.load('ae_mnist0.25_3layers-proposed').double().to(device)
for i in range(100):
    mean_acc.append(evaluateOutputs(test_loader,classifier,model,noise_factor))
print(np.array(mean_acc).mean())

In [None]:
noise_factor=0.35
mean_acc=[]
classifier=torch.load('MNISTclassifier').double().to(device)
model=torch.load('ae_mnist0.35_3layers-proposed').double().to(device)
for i in range(100):
    mean_acc.append(evaluateOutputs(test_loader,classifier,model,noise_factor))
print(np.array(mean_acc).mean())

In [None]:
noise_factor=0.4
mean_acc=[]
classifier=torch.load('MNISTclassifier').double().to(device)
model=torch.load('ae_mnist0.4_3layers-proposed').double().to(device)
for i in range(100):
    mean_acc.append(evaluateOutputs(test_loader,classifier,model,noise_factor))
print(np.array(mean_acc).mean())