In [None]:
##Import required libraries and packages

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import seaborn as sb
from tqdm import tqdm

import torchvision
import torch.nn.functional as F
from torch.nn import init
import torchvision.transforms as Tranforms
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torch.nn as nn
import matplotlib.gridspec as gridspec
from IPython.display import clear_output
%matplotlib inline

torch.set_default_device("cuda")

In [None]:
##Mount Google Drive to save tensors

from google.colab import drive
drive.mount('/content/drive')

In [None]:
## Check if cuda is available

if torch.cuda.is_available(): 
  torch.set_default_device("cuda")
  device = 'cuda'
else:
  torch.set_default_device("cpu")
  device = 'cpu'

print(device)

In [None]:
##Import required Datasets: CIFAR, STL10, SVHN

import torchvision
import torchvision.datasets as data
import torchvision.transforms as Transforms

Tflatten = Transforms.Lambda(lambda x: torch.flatten(x))
Tcuda = Transforms.Lambda(lambda x: x.to("cuda"))
Tfloat = Transforms.Lambda(lambda x: x.to(torch.float))

T = Transforms.Compose([Transforms.ToTensor(), Tfloat, Tflatten, Tcuda])
cifar_train = data.CIFAR10("./", train = True, download = True, transform=T)
cifar_dataloader = torch.utils.data.DataLoader(cifar_train)
print("CIFAR:", next(iter(cifar_dataloader))[0].shape)
print("CIFAR:", cifar_train.data.shape)

T = Transforms.Compose([Transforms.ToTensor(), Tfloat, Transforms.Normalize(mean = [0,0,0], std = [5,5,5]), Tflatten, Tcuda])
cifar_test = data.CIFAR10("./", train = False, download = True, transform=T)
cifar_dataloader = torch.utils.data.DataLoader(cifar_test)
print("CIFAR:", next(iter(cifar_dataloader))[0].shape)
print("CIFAR:", cifar_test.data.shape)

T = Transforms.Compose([Transforms.ToTensor(), Transforms.Resize((32,32)), Tfloat, Transforms.Normalize(mean = [0,0,0], std = [5,5,5]), Tflatten, Tcuda])
stl10_train = data.STL10("./", split = 'train', download = True, transform=T)
stl10_dataloader = torch.utils.data.DataLoader(stl10_train)
print("STL10:", next(iter(stl10_dataloader))[0].shape)
print("STL10:", stl10_train.data.shape)

T = Transforms.Compose([Transforms.ToTensor(), Tfloat, Transforms.Normalize(mean = [0,0,0], std = [5,5,5]), Tflatten, Tcuda])
svhn_train = data.SVHN("./", split = 'train', download = True, transform=T)
svhn_dataloader = torch.utils.data.DataLoader(svhn_train)
print("SVHN:", next(iter(svhn_dataloader))[0].shape)
print("SVHN:", svhn_train.data.shape)

In [None]:
##Define all functions required for the theoretical results (refer to Theorems)

def eigen_squared(cr, z):
  num = z * cr ** 2 + cr ** 2 + z * cr - 2*cr + 1
  den = 2*z**2*cr * np.sqrt(4*z*cr**2 + (1-cr + cr*z)**2)
  return num/den + (1-1/cr)/(2*z**2)

def scale(cr,z):
  return 0.5 + (1+z*cr-np.sqrt(4*z*cr**2 + (1-cr + cr*z)**2))/(2*cr)

def scale_squared(cr,z):
  return -0.5 + (1+cr+z*cr)/(2*np.sqrt(4*z*cr**2 + (1-cr + cr*z)**2))

def scale_both_squared(cr,z):
  return scale(cr,z) - z*scale_squared(cr,z)

def var(c,r,d,N):
  cr = r/N
  if c < 1:
    return r * (scale_both_squared(cr, 1/c) + scale_squared(cr,1/c))/(d*(1-c))
  else:
    return r * c * scale(cr,1) / (d*(c-1))

def calc_gen_error_new(M,ntrn,c,ntst,r,S1,L):
  gen_error = 0
  if c<1:
    gen_error += (torch.diag(1/(1+S1**2*c)) @ L).square().sum()/L.shape[1]
  if c>1:
    gen_error += (torch.diag(1/(1+S1**2)) @ L).square().sum()/L.shape[1]

  return gen_error + var(c,r,M,ntrn)

# Cov might need to be diagonal....
def calc_gen_error_new_new(M,ntrn,c,ntst,r,S1,Cov):
  gen_error = 0
  if c<1:
    gen_error += eigen_squared(r/ntrn, 1/c) * (Cov).square().sum()/(c**2)
  if c>1:
    gen_error += eigen_squared(r/ntrn, 1) * (Cov).square().sum()

  return gen_error + var(c,r,M,ntrn)

def calc_gen_error(M,ntrn,c,ntst,r,S1,L):
  gen_error = 0
  if c<1:
    gen_error += (torch.diag(1/(1+S1**2*c)) @ L).square().sum()/L.shape[1]
  if c>1:
    gen_error += (torch.diag(1/(1+S1**2)) @ L).square().sum()/L.shape[1]
  
  gen_error += calc_wnorm(c,r,S1)/M

  return gen_error 

def calc_gen_error_regression(M,ntrn,c,ntst,r,S1,L,betahat):
  gen_error = 0
  if c<1:
    gen_error += (betahat.T @ torch.diag(1/(1+S1**2*c)) @ L).square().sum()/L.shape[1]
  if c>1:
    gen_error += (betahat.T @ torch.diag(1/(1+S1**2)) @ L).square().sum()/L.shape[1]
  
  gen_error += calc_wnorm_regression(c,r,S1,betahat)/M

  return gen_error 

def calc_lower_bound(M,ntrn,c,ntst,r,S1,L,alpha):
  wnorm_root = calc_wnorm(c,r,S1).sqrt()
  if c<1:
    bias = (torch.diag(1/(1+S1**2*c)) @ L).square().sum().sqrt()
  if c>1:
    bias = (torch.diag(1/(1+S1**2)) @ L).square().sum().sqrt()
  
  bias = (bias-alpha*(wnorm_root+1))**2/L.shape[0]

  return bias + wnorm_root.square()/M

def calc_upper_bound(M,ntrn,c,ntst,r,S1,L,alpha):
  wnorm_root = calc_wnorm(c,r,S1).sqrt()
  if c<1:
    bias = (torch.diag(1/(1+S1**2*c)) @ L).square().sum().sqrt()
  if c>1:
    bias = (torch.diag(1/(1+S1**2)) @  L).square().sum().sqrt()
  
  bias = (bias+alpha*(wnorm_root+1))**2/L.shape[0]

  return bias + wnorm_root.square()/M

def calc_W_minus_I_norm(c,r,S1,M):
  wnorm = 0
  if c<1:
    for i in range(1):
      wnorm = wnorm +  (((c**2*(S1[i]**2 + S1[i]**4))/((1+S1[i]**2*c)**2*(1-c))).sqrt()+1)**2
  if c>1:
    for i in range(1):
      wnorm = wnorm +  (((c*S1[i]**2)/((1+S1[i]**2)*(c-1))).sqrt()+1).square()
  return wnorm

def calc_wnorm(c,r,S1):
  wnorm = 0
  if c<1:
    for i in range(r):
      wnorm = wnorm +  ((S1[i]**2 + S1[i]**4))/((1/c+S1[i]**2)**2*(1-c))
  if c>1:
    for i in range(r):
      wnorm = wnorm +  (c*S1[i]**2)/((1+S1[i]**2)*(c-1))

  return wnorm 

def calc_wnorm_regression(c,r,S1,betahat):
  wnorm = 0
  if c<1:
    for i in range(r):
      wnorm = wnorm +  betahat[i,0]**2 * (c**2*(S1[i]**2 + S1[i]**4))/((1+S1[i]**2*c)**2*(1-c))
  if c>1:
    for i in range(r):
      wnorm = wnorm +  betahat[i,0]**2 * (c*S1[i]**2)/((1+S1[i]**2)*(c-1))

  return wnorm

In [None]:
#Define a folder named Denoising and subfolders named dataRanks and figures to store your data and figures

path1_rank = F"/content/drive/MyDrive/Denoising/dataRanks/"
path2_rank = F"/content/drive/MyDrive/Denoising/dataRanks/"

path_figures = F"/content/drive/My Drive/Denoising/figures/"

##Generating Figures 5 and 6

In [None]:
##Optimal Noise level and Generalization Error for optimal Noise

r = 50
M = 3072
Ns = torch.tensor([500,750,1000,1250,1500,1750,2000,2250,2500,2600,2700,2800,2900,3000,3020,3130,3200,3300,3400,3500,3750,4000,4250,4500,4750,5000,5250,5500]).cpu()

T = Ns.shape[0]

opt_thetas = torch.zeros(T)
opt_thetas_cifar = torch.zeros(T)
opt_thetas_stl10 = torch.zeros(T)
opt_thetas_svhn = torch.zeros(T)

opt_gen_error = torch.zeros(T)
opt_gen_error_cifar = torch.zeros(T)
opt_gen_error_stl10 = torch.zeros(T)
opt_gen_error_svhn = torch.zeros(T)


for i,N in tqdm(list(enumerate(Ns))):

  Ntst = N.item()

  cifar_data = torch.utils.data.DataLoader(cifar_train, batch_size = N.item(), shuffle = False)
  Xtrn = next(iter(cifar_data))[0].T
  U,S,Vh = torch.linalg.svd(Xtrn)
  Xtrn = U[:,:r] @ torch.diag(S[:r]) @ Vh[:r,:]

  cifar_test_data = torch.utils.data.DataLoader(cifar_test, batch_size = Ntst, shuffle = False)
  Xtst_cifar = next(iter(cifar_test_data))[0].T

  stl10_data = torch.utils.data.DataLoader(stl10_train, batch_size = Ntst, shuffle = False)
  Xtst_stl10 = next(iter(stl10_data))[0].T

  svhn_data = torch.utils.data.DataLoader(svhn_train, batch_size = Ntst, shuffle = False)
  Xtst_svhn = next(iter(svhn_data))[0].T

  print(Xtrn.shape, Xtst_cifar.shape)

  U = U[:,:r]
  Vh = Vh[:r,:]
  
  L_cifar = U[:,:r].T @ Xtst_cifar
  L_stl10 = U[:,:r].T @ Xtst_stl10
  L_svhn = U[:,:r].T @ Xtst_svhn

  S = S[:r]

  L = torch.diag(S[:r]) @ Vh[:r,:]

  thetas = torch.linspace(0.01,3.5,2000) #eta is 1/theta
  gen_error = torch.zeros(2000)
  gen_error_cifar = torch.zeros(2000)
  gen_error_stl10 = torch.zeros(2000)
  gen_error_svhn = torch.zeros(2000)

  for j,theta in enumerate(thetas):
    Sigma = theta*S
    gen_error[j] = calc_gen_error(M,N,M/N,N,r,Sigma,L)
    gen_error_cifar[j] = calc_gen_error(M,N,M/N,N,r,Sigma,L_cifar)
    gen_error_stl10[j] = calc_gen_error(M,N,M/N,N,r,Sigma,L_stl10)
    gen_error_svhn[j] = calc_gen_error(M,N,M/N,N,r,Sigma,L_svhn)

  opt_thetas[i] = thetas[gen_error.argmin()]
  opt_thetas_cifar[i] = thetas[gen_error_cifar.argmin()]
  opt_thetas_stl10[i] = thetas[gen_error_stl10.argmin()]
  opt_thetas_svhn[i] = thetas[gen_error_svhn.argmin()]

  opt_gen_error[i] = gen_error[gen_error.argmin()]
  opt_gen_error_cifar[i] = gen_error_cifar[gen_error_cifar.argmin()]
  opt_gen_error_stl10[i] = gen_error_stl10[gen_error_stl10.argmin()]
  opt_gen_error_svhn[i] = gen_error_svhn[gen_error_svhn.argmin()]


In [None]:
torch.save((opt_thetas,opt_gen_error), "drive/MyDrive/Denoising/opt-theta-opt-error-cifar-train-finer.pt")
torch.save((opt_thetas_cifar,opt_gen_error_cifar), "drive/MyDrive/Denoising/opt-theta-opt-error-cifar-test-finer.pt")
torch.save((opt_thetas_stl10,opt_gen_error_stl10), "drive/MyDrive/Denoising/opt-theta-opt-error-stl10-test-finer.pt")
torch.save((opt_thetas_svhn,opt_gen_error_svhn), "drive/MyDrive/Denoising/opt-theta-opt-error-svhn-test-finer.pt")

In [None]:
(opt_thetas,opt_gen_error) = torch.load("drive/MyDrive/Denoising/opt-theta-opt-error-cifar-train-finer.pt")
(opt_thetas_cifar,opt_gen_error_cifar) = torch.load("drive/MyDrive/Denoising/opt-theta-opt-error-cifar-test-finer.pt")
(opt_thetas_stl10,opt_gen_error_stl10) = torch.load("drive/MyDrive/Denoising/opt-theta-opt-error-stl10-test-finer.pt")
(opt_thetas_svhn,opt_gen_error_svhn) = torch.load("drive/MyDrive/Denoising/opt-theta-opt-error-svhn-test-finer.pt")

### Plot optimal Noise


In [None]:
plt.rc('font',size=20)
plt.rc('xtick', labelsize=14) 
plt.rc('ytick', labelsize=14)
plt.rc('legend',fontsize=13)

In [None]:
Ns = torch.tensor([500,750,1000,1250,1500,1750,2000,2250,2500,2600,2700,2800,2900,3000,3020,3130,3200,3300,3400,3500,3750,4000,4250,4500,4750,5000,5250,5500]).cpu() 

plt.plot(Ns.cpu()/M,1/opt_thetas.cpu(), c = "orange")
plt.xlabel(r"$c = \frac{d}{N}$")
plt.ylabel(r"Optimal $\eta$")

plt.yscale("log")
plt.savefig(path_figures+"train-test-opt-sigma.pdf", facecolor = "white", bbox_inches = "tight", dpi = 300)

In [None]:
Ns = torch.tensor([500,750,1000,1250,1500,1750,2000,2250,2500,2600,2700,2800,2900,3000,3020,3130,3200,3300,3400,3500,3750,4000,4250,4500,4750,5000,5250,5500]).cpu() 
M = 3072

plt.plot(Ns.cpu()/M,1/opt_thetas_cifar.cpu(), c = "g")
plt.xlabel(r"$c = \frac{d}{N}$")
plt.ylabel(r"Optimal $\eta$")

plt.yscale("log")
plt.savefig(path_figures+"cifar-test-opt-sigma.pdf", facecolor = "white", bbox_inches = "tight", dpi = 300)

In [None]:
Ns = torch.tensor([500,750,1000,1250,1500,1750,2000,2250,2500,2600,2700,2800,2900,3000,3020,3130,3200,3300,3400,3500,3750,4000,4250,4500,4750,5000,5250,5500]).cpu() 
M = 3072


plt.plot(Ns.cpu()/M,1/opt_thetas_stl10.cpu(), c = "b")
plt.xlabel(r"$c = \frac{d}{N}$")
plt.ylabel(r"Optimal $\eta$")

plt.yscale("log")
plt.savefig(path_figures+"stl10-test-opt-sigma.pdf", facecolor = "white", bbox_inches = "tight", dpi = 300)



In [None]:
Ns = torch.tensor([500,750,1000,1250,1500,1750,2000,2250,2500,2600,2700,2800,2900,3000,3020,3130,3200,3300,3400,3500,3750,4000,4250,4500,4750,5000,5250,5500]).cpu() 
M = 3072

plt.plot(Ns.cpu()/M,1/opt_thetas_svhn.cpu(), c = "r")
plt.xlabel(r"$c = \frac{d}{N}$")
plt.ylabel(r"Optimal $\eta$")

plt.yscale("log")
plt.savefig(path_figures+"svhn-test-opt-sigma.pdf", facecolor = "white", bbox_inches = "tight", dpi = 300)

### Plot optimal Generalization Error

In [None]:
Ns = torch.tensor([500,750,1000,1250,1500,1750,2000,2250,2500,2600,2700,2800,2900,3000,3020,3130,3200,3300,3400,3500,3750,4000,4250,4500,4750,5000,5250,5500]).cpu() 
M = 3072

plt.plot(Ns.cpu()/M,opt_gen_error.cpu(), c = "r")
plt.xlabel(r"$c = \frac{d}{N}$")
plt.ylabel("Generalization Error")
plt.yscale("log")
plt.savefig(path_figures+"train-test-opt-gen-error.pdf", facecolor = "white", bbox_inches = "tight", dpi = 300)

In [None]:
Ns = torch.tensor([500,750,1000,1250,1500,1750,2000,2250,2500,2600,2700,2800,2900,3000,3020,3130,3200,3300,3400,3500,3750,4000,4250,4500,4750,5000,5250,5500]).cpu() 
M = 3072

plt.plot(Ns.cpu()/M,opt_gen_error_cifar.cpu(), c = "r")
plt.xlabel(r"$c = \frac{d}{N}$")
plt.ylabel("Generalization Error")
plt.yscale("log")
plt.savefig(path_figures+"cifar-test-opt-gen-error.pdf", facecolor = "white", bbox_inches = "tight", dpi = 300)

In [None]:
Ns = torch.tensor([500,750,1000,1250,1500,1750,2000,2250,2500,2600,2700,2800,2900,3000,3020,3130,3200,3300,3400,3500,3750,4000,4250,4500,4750,5000,5250,5500]).cpu() 
M = 3072

plt.plot(Ns.cpu()/M,opt_gen_error_stl10.cpu(), c = "r")
plt.xlabel(r"$c = \frac{d}{N}$")
plt.ylabel("Generalization Error")
plt.yscale("log")
plt.savefig(path_figures+"stl10-test-opt-gen-error.pdf", facecolor = "white", bbox_inches = "tight", dpi = 300)

In [None]:
Ns = torch.tensor([500,750,1000,1250,1500,1750,2000,2250,2500,2600,2700,2800,2900,3000,3020,3130,3200,3300,3400,3500,3750,4000,4250,4500,4750,5000,5250,5500]).cpu() 
M = 3072

plt.plot(Ns.cpu()/M,opt_gen_error_svhn.cpu(), c = "r")
plt.xlabel(r"$c = \frac{d}{N}$")
plt.ylabel("Generalization Error")
plt.yscale("log")
plt.savefig(path_figures+"svhn-test-opt-gen-error.pdf", facecolor = "white", bbox_inches = "tight", dpi = 300)