In [None]:
##Import required libraries and packages

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import seaborn as sb
from tqdm import tqdm

import torchvision
import torch.nn.functional as F
from torch.nn import init
import torchvision.transforms as Tranforms
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torch.nn as nn
import matplotlib.gridspec as gridspec
from IPython.display import clear_output
%matplotlib inline

torch.set_default_device("cuda")

In [None]:
##Mount Google Drive to save tensors

from google.colab import drive
drive.mount('/content/drive')

In [None]:
## Check if cuda is available

if torch.cuda.is_available(): 
  torch.set_default_device("cuda")
  device = 'cuda'
else:
  torch.set_default_device("cpu")
  device = 'cpu'

print(device)

In [None]:
##Import required Datasets: CIFAR, STL10, SVHN

import torchvision
import torchvision.datasets as data
import torchvision.transforms as Transforms

Tflatten = Transforms.Lambda(lambda x: torch.flatten(x))
Tcuda = Transforms.Lambda(lambda x: x.to("cuda"))
Tfloat = Transforms.Lambda(lambda x: x.to(torch.float))

T = Transforms.Compose([Transforms.ToTensor(), Tfloat, Tflatten, Tcuda])
cifar_train = data.CIFAR10("./", train = True, download = True, transform=T)
cifar_dataloader = torch.utils.data.DataLoader(cifar_train)
print("CIFAR:", next(iter(cifar_dataloader))[0].shape)
print("CIFAR:", cifar_train.data.shape)

T = Transforms.Compose([Transforms.ToTensor(), Tfloat, Transforms.Normalize(mean = [0,0,0], std = [5,5,5]), Tflatten, Tcuda])
cifar_test = data.CIFAR10("./", train = False, download = True, transform=T)
cifar_dataloader = torch.utils.data.DataLoader(cifar_test)
print("CIFAR:", next(iter(cifar_dataloader))[0].shape)
print("CIFAR:", cifar_test.data.shape)

T = Transforms.Compose([Transforms.ToTensor(), Transforms.Resize((32,32)), Tfloat, Transforms.Normalize(mean = [0,0,0], std = [5,5,5]), Tflatten, Tcuda])
stl10_train = data.STL10("./", split = 'train', download = True, transform=T)
stl10_dataloader = torch.utils.data.DataLoader(stl10_train)
print("STL10:", next(iter(stl10_dataloader))[0].shape)
print("STL10:", stl10_train.data.shape)

T = Transforms.Compose([Transforms.ToTensor(), Tfloat, Transforms.Normalize(mean = [0,0,0], std = [5,5,5]), Tflatten, Tcuda])
svhn_train = data.SVHN("./", split = 'train', download = True, transform=T)
svhn_dataloader = torch.utils.data.DataLoader(svhn_train)
print("SVHN:", next(iter(svhn_dataloader))[0].shape)
print("SVHN:", svhn_train.data.shape)

In [None]:
##Define all functions required for the theoretical results (refer to Theorems)

def eigen_squared(cr, z):
  num = z * cr ** 2 + cr ** 2 + z * cr - 2*cr + 1
  den = 2*z**2*cr * np.sqrt(4*z*cr**2 + (1-cr + cr*z)**2)
  return num/den + (1-1/cr)/(2*z**2)

def scale(cr,z):
  return 0.5 + (1+z*cr-np.sqrt(4*z*cr**2 + (1-cr + cr*z)**2))/(2*cr)

def scale_squared(cr,z):
  return -0.5 + (1+cr+z*cr)/(2*np.sqrt(4*z*cr**2 + (1-cr + cr*z)**2))

def scale_both_squared(cr,z):
  return scale(cr,z) - z*scale_squared(cr,z)

def var(c,r,d,N):
  cr = r/N
  if c < 1:
    return r * (scale_both_squared(cr, 1/c) + scale_squared(cr,1/c))/(d*(1-c))
  else:
    return r * c * scale(cr,1) / (d*(c-1))

def calc_gen_error_new(M,ntrn,c,ntst,r,S1,L):
  gen_error = 0
  if c<1:
    gen_error += (torch.diag(1/(1+S1**2*c)) @ L).square().sum()/L.shape[1]
  if c>1:
    gen_error += (torch.diag(1/(1+S1**2)) @ L).square().sum()/L.shape[1]

  return gen_error + var(c,r,M,ntrn)

# Cov might need to be diagonal....
def calc_gen_error_new_new(M,ntrn,c,ntst,r,S1,Cov):
  gen_error = 0
  if c<1:
    gen_error += eigen_squared(r/ntrn, 1/c) * (Cov).square().sum()/(c**2)
  if c>1:
    gen_error += eigen_squared(r/ntrn, 1) * (Cov).square().sum()

  return gen_error + var(c,r,M,ntrn)

def calc_gen_error(M,ntrn,c,ntst,r,S1,L):
  gen_error = 0
  if c<1:
    gen_error += (torch.diag(1/(1+S1**2*c)) @ L).square().sum()/L.shape[1]
  if c>1:
    gen_error += (torch.diag(1/(1+S1**2)) @ L).square().sum()/L.shape[1]
  
  gen_error += calc_wnorm(c,r,S1)/M

  return gen_error 

def calc_gen_error_regression(M,ntrn,c,ntst,r,S1,L,betahat):
  gen_error = 0
  if c<1:
    gen_error += (betahat.T @ torch.diag(1/(1+S1**2*c)) @ L).square().sum()/L.shape[1]
  if c>1:
    gen_error += (betahat.T @ torch.diag(1/(1+S1**2)) @ L).square().sum()/L.shape[1]
  
  gen_error += calc_wnorm_regression(c,r,S1,betahat)/M

  return gen_error 

def calc_lower_bound(M,ntrn,c,ntst,r,S1,L,alpha):
  wnorm_root = calc_wnorm(c,r,S1).sqrt()
  if c<1:
    bias = (torch.diag(1/(1+S1**2*c)) @ L).square().sum().sqrt()
  if c>1:
    bias = (torch.diag(1/(1+S1**2)) @ L).square().sum().sqrt()
  
  bias = (bias-alpha*(wnorm_root+1))**2/L.shape[0]

  return bias + wnorm_root.square()/M

def calc_upper_bound(M,ntrn,c,ntst,r,S1,L,alpha):
  wnorm_root = calc_wnorm(c,r,S1).sqrt()
  if c<1:
    bias = (torch.diag(1/(1+S1**2*c)) @ L).square().sum().sqrt()
  if c>1:
    bias = (torch.diag(1/(1+S1**2)) @  L).square().sum().sqrt()
  
  bias = (bias+alpha*(wnorm_root+1))**2/L.shape[0]

  return bias + wnorm_root.square()/M

def calc_W_minus_I_norm(c,r,S1,M):
  wnorm = 0
  if c<1:
    for i in range(1):
      wnorm = wnorm +  (((c**2*(S1[i]**2 + S1[i]**4))/((1+S1[i]**2*c)**2*(1-c))).sqrt()+1)**2
  if c>1:
    for i in range(1):
      wnorm = wnorm +  (((c*S1[i]**2)/((1+S1[i]**2)*(c-1))).sqrt()+1).square()
  return wnorm

def calc_wnorm(c,r,S1):
  wnorm = 0
  if c<1:
    for i in range(r):
      wnorm = wnorm +  ((S1[i]**2 + S1[i]**4))/((1/c+S1[i]**2)**2*(1-c))
  if c>1:
    for i in range(r):
      wnorm = wnorm +  (c*S1[i]**2)/((1+S1[i]**2)*(c-1))

  return wnorm 

def calc_wnorm_regression(c,r,S1,betahat):
  wnorm = 0
  if c<1:
    for i in range(r):
      wnorm = wnorm +  betahat[i,0]**2 * (c**2*(S1[i]**2 + S1[i]**4))/((1+S1[i]**2*c)**2*(1-c))
  if c>1:
    for i in range(r):
      wnorm = wnorm +  betahat[i,0]**2 * (c*S1[i]**2)/((1+S1[i]**2)*(c-1))

  return wnorm

In [None]:
#Define a folder named Denoising and subfolders named dataRanks and figures to store your data and figures

path1_rank = F"/content/drive/MyDrive/Denoising/dataRanks/"
path2_rank = F"/content/drive/MyDrive/Denoising/dataRanks/"

path_figures = F"/content/drive/My Drive/Denoising/figures/"

##Generating Figure 9

In [None]:
from numpy.lib.arraysetops import setxor1d
## Low SNR error 

M = 3072
N = torch.arange(1050,10500,550).to(torch.int).cpu()

r_values = [50]

Ntst = 2500

Err_stl10 = torch.zeros(len(r_values),N.shape[0]).to(device) #theoretical error
Err_emp_stl10 = torch.zeros(len(r_values),N.shape[0]).to(device) #emperical error 

Err_svhn = torch.zeros(len(r_values),N.shape[0]).to(device) #theoretical error
Err_emp_svhn = torch.zeros(len(r_values),N.shape[0]).to(device) #emperical error 

Err_cifar = torch.zeros(len(r_values),N.shape[0]).to(device) #theoretical error
Err_emp_cifar = torch.zeros(len(r_values),N.shape[0]).to(device) #emperical error 
print(Err_stl10.shape)

beta = torch.randn(M,1)
beta /= torch.norm(beta)


# bias = torch.zeros(len(r_values),N.shape[0])
# var = torch.zeros(len(r_values),N.shape[0])

# bias_emp = torch.zeros(len(r_values),N.shape[0])
# var_emp = torch.zeros(len(r_values),N.shape[0])

T = 200 #Number of runs

for i,r in list(enumerate(r_values)):
  print(r)
  for j in range(N.shape[0]):
    c = M/N[j]

    Q = torch.linalg.svd(torch.randn(M,r)).U[:,:r]
    Xtrn = Q @ torch.randn(r,N[j])/np.sqrt(r)

    cifar_test_data = torch.utils.data.DataLoader(cifar_test, batch_size = Ntst, shuffle = False)
    Xtst_cifar = next(iter(cifar_test_data))[0].T

    stl10_data = torch.utils.data.DataLoader(stl10_train, batch_size = Ntst, shuffle = False)
    Xtst_stl10 = next(iter(stl10_data))[0].T

    svhn_data = torch.utils.data.DataLoader(svhn_train, batch_size = Ntst, shuffle = False)
    Xtst_svhn = next(iter(svhn_data))[0].T

    print(Xtrn.shape, Xtst_cifar.shape)

    print(c)
    U,S,Vh = torch.linalg.svd(Xtrn)
    Xtrn = U[:,:r] @ torch.diag(S[:r]) @ Vh[:r,:]

    U = U[:,:r]
    Vh = Vh[:r,:]
    
    L_cifar = U[:,:r].T @ Xtst_cifar
    L_stl10 = U[:,:r].T @ Xtst_stl10
    L_svhn = U[:,:r].T @ Xtst_svhn

    Xtst_cifar_proj = U[:,:r] @ L_cifar
    Xtst_stl10_proj = U[:,:r] @ L_stl10
    Xtst_svhn_proj = U[:,:r] @ L_svhn

    Err_cifar[i,j] = calc_gen_error_new(M,N[j],c,Ntst,r,S[:r],L_cifar)
    Err_stl10[i,j] = calc_gen_error_new(M,N[j],c,Ntst,r,S[:r],L_stl10)
    Err_svhn[i,j] = calc_gen_error_new(M,N[j],c,Ntst,r,S[:r],L_svhn)
    
    for k in tqdm(range(T)):
        Atrn = torch.randn_like(Xtrn)/np.sqrt(M)
        W = Xtrn.mm(torch.pinverse(Xtrn+Atrn))

        Atst_cifar = torch.randn_like(Xtst_cifar_proj)/np.sqrt(M)
        Yp = W.mm(Xtst_cifar_proj + Atst_cifar)
        Err_emp_cifar[i,j] += (Xtst_cifar_proj - Yp).square().sum()/(T*Ntst)

        Atst_stl10 = torch.randn_like(Xtst_stl10_proj)/np.sqrt(M)
        Yp = W.mm(Xtst_stl10_proj + Atst_stl10)
        Err_emp_stl10[i,j] += (Xtst_stl10_proj - Yp).square().sum()/(T*Ntst)

        Atst_svhn = torch.randn_like(Xtst_svhn_proj)/np.sqrt(M)
        Yp = W.mm(Xtst_svhn_proj + Atst_svhn)
        Err_emp_svhn[i,j] += (Xtst_svhn_proj - Yp).square().sum()/(T*Ntst)
    
    print((Err_emp_cifar[i,j]-Err_cifar[i,j]).abs()/Err_emp_cifar[i,j])
    print((Err_emp_stl10[i,j]-Err_stl10[i,j]).abs()/Err_emp_stl10[i,j])
    print((Err_emp_svhn[i,j]-Err_svhn[i,j]).abs()/Err_emp_svhn[i,j])

    torch.save(Err_emp_cifar,path2_rank+"cifar-emp-train-gaussian.pt")
    torch.save(Err_cifar,path1_rank+"cifar-train-gaussian.pt")  
    torch.save(Err_emp_stl10,path2_rank+"stl10-emp-train-gaussian.pt")
    torch.save(Err_stl10,path1_rank+"stl10-train-gaussian.pt")  
    torch.save(Err_emp_svhn,path2_rank+"svhn-emp-train-gaussian.pt")
    torch.save(Err_svhn,path1_rank+"svhn-train-gaussian.pt")    
    

In [None]:
plt.rc('font',size=20)
plt.rc('xtick', labelsize=14) 
plt.rc('ytick', labelsize=14)
plt.rc('legend',fontsize=13)

In [None]:
##Generate figures for Gaussian training data
device = 'cpu'
M = 3072
N = torch.arange(1050,10500,550).to(torch.int).cpu()

Cinverse = (N/M).cpu().numpy()
Cinverse_theory = (N_theory/M).cpu().numpy()

Err_emp_cifar_gaussian = torch.load(path2_rank+"cifar-emp-train-gaussian.pt",map_location = device).numpy()
Err_cifar_gaussian = torch.load(path1_rank+"cifar-train-gaussian.pt",map_location = device).numpy()

plt.plot(Cinverse,Err_cifar_gaussian[0,:],color="green")
plt.plot(Cinverse,Err_emp_cifar_gaussian[0,:],'.',markersize = 15,color="green")

plt.yscale("log")
plt.xlabel("1/c = N/d")
plt.ylabel("Generalization Error")
plt.legend(['Theoretical Result','Empirical Result'])
plt.savefig(path_figures+"CIFAR_gaussian_error.pdf", bbox_inches='tight', facecolor='white', dpi = 300, format = 'pdf')

In [None]:
Err_emp_stl10_gaussian = torch.load(path2_rank+"stl10-emp-train-gaussian.pt",map_location = device).numpy()
Err_stl10_gaussian = torch.load(path1_rank+"stl10-train-gaussian.pt",map_location = device).numpy()

plt.plot(Cinverse,Err_stl10_gaussian[0,:],color="blue")
plt.plot(Cinverse,Err_emp_stl10_gaussian[0,:],'.',markersize = 15,color="blue")

plt.yscale("log")
plt.xlabel("1/c = N/d")
plt.ylabel("Generalization Error")
plt.legend(['Theoretical Result','Empirical Result'])
plt.savefig(path_figures+"STL10_gaussian_error.pdf", bbox_inches='tight', facecolor='white', dpi = 300, format = 'pdf')

In [None]:
Err_emp_svhn_gaussian = torch.load(path2_rank+"svhn-emp-train-gaussian.pt",map_location = device).numpy()
Err_svhn_gaussian = torch.load(path1_rank+"svhn-train-gaussian.pt",map_location = device).numpy()

plt.plot(Cinverse,Err_svhn_gaussian[0,:],color="red")
plt.plot(Cinverse,Err_emp_svhn_gaussian[0,:],'.',markersize = 15,color="red")

plt.yscale("log")
plt.xlabel("1/c = N/d")
plt.ylabel("Generalization Error")
plt.legend(['Theoretical Result','Empirical Result'])
plt.savefig(path_figures+"SVHN_gaussian_error.pdf", bbox_inches='tight', facecolor='white', dpi = 300, format = 'pdf')

## Generate figure 10


In [None]:
from numpy.lib.arraysetops import setxor1d
 

M = 3072
N = torch.arange(500,6010,550).to(torch.int).cpu()

r_values = [50]

Ntst = 5000

Err = torch.zeros(len(r_values),N.shape[0]).to(device) #theoretical error
Err_emp = torch.zeros(len(r_values),N.shape[0]).to(device) #emperical error 

T = 50 #Number of runs

for i,r in list(enumerate(r_values)):
  print(r)
  for j in range(N.shape[0]):
    c = M/N[j]
    Q = torch.linalg.svd(torch.randn(M,r)).U[:,:r]
    Xtrn = Q @ torch.randn(r,N[j])/np.sqrt(r)

    L = torch.randn(r,Ntst)/np.sqrt(r)
    Ul, Sl, _ = torch.linalg.svd(L)

    cov_root_emp = Ul @ torch.diag(Sl)

    # print((cov_root_emp @ cov_root_emp.T - L @ L.T ))
    

    print(Xtrn.shape, L.shape)

    print(c)
    U,S,Vh = torch.linalg.svd(Xtrn)
    Xtrn = U[:,:r] @ torch.diag(S[:r]) @ Vh[:r,:]
    Xtst = U[:,:r] @ L


    Err[i,j] = calc_gen_error_new_new(M,N[j],c,Ntst,r,S[:r],torch.eye(r)*np.sqrt(1/r))
    
    for k in tqdm(range(T)):
        Atrn = torch.randn_like(Xtrn)/np.sqrt(M)
        W = Xtrn.mm(torch.pinverse(Xtrn+Atrn))

        Atst = torch.randn_like(Xtst)/np.sqrt(M)
        Yp = W.mm(Xtst + Atst)
        Err_emp[i,j] += (Xtst - Yp).square().sum()/(T*Ntst)
    

    print((Err_emp[i,j]-Err[i,j]).abs()/Err_emp[i,j])
    torch.save(Err_emp,path2_rank+"iid-both-emp.pt")
    torch.save(Err,path1_rank+"iid-both.pt")    
 

In [None]:
##Run for more 'N' values to obtain smoother theory curve

from numpy.lib.arraysetops import setxor1d
## Both training and test data is IID

M = 3072
#N = torch.arange(500,6010,550).to(torch.int).cpu()

N = torch.arange(500,6010,55).to(torch.int).cpu() #more c values for theory curve

r_values = [50]

Ntst = 5000

Err = torch.zeros(len(r_values),N.shape[0]).to(device) #theoretical error
Err_emp = torch.zeros(len(r_values),N.shape[0]).to(device) #emperical error 
r=50


print(Err.shape)



T = 50 #Number of runs

for i,r in list(enumerate(r_values)):
  print(r)
  for j in range(N.shape[0]):
    c = M/N[j]
    Q = torch.linalg.svd(torch.randn(M,r)).U[:,:r]
    Xtrn = Q @ torch.randn(r,N[j])/np.sqrt(r)

    L = torch.randn(r,Ntst)/np.sqrt(r)
    Ul, Sl, _ = torch.linalg.svd(L)

    cov_root_emp = Ul @ torch.diag(Sl)

    # print((cov_root_emp @ cov_root_emp.T - L @ L.T ))
    

    print(Xtrn.shape, L.shape)

    print(c)
    U,S,Vh = torch.linalg.svd(Xtrn)
    Xtrn = U[:,:r] @ torch.diag(S[:r]) @ Vh[:r,:]
    Xtst = U[:,:r] @ L

    Err[i,j] = calc_gen_error_new_new(M,N[j],c,Ntst,r,S[:r],torch.eye(r)*np.sqrt(1/r))
    
    # for k in tqdm(range(T)):
    #     Atrn = torch.randn_like(Xtrn)/np.sqrt(M)
    #     W = Xtrn.mm(torch.pinverse(Xtrn+Atrn))

    #     Atst = torch.randn_like(Xtst)/np.sqrt(M)
    #     Yp = W.mm(Xtst + Atst)
    #     Err_emp[i,j] += (Xtst - Yp).square().sum()/(T*Ntst)
    

    # print((Err_emp[i,j]-Err[i,j]).abs()/Err_emp[i,j])
    # torch.save(Err_emp,path2_rank+"iid-both-emp.pt")
    torch.save(Err,path1_rank+"iid-both.pt")    ##Rewrite the original tensor
 

In [None]:
plt.rc('font',size=20)
plt.rc('xtick', labelsize=14) 
plt.rc('ytick', labelsize=14)
plt.rc('legend',fontsize=13)

In [None]:
##Generate figures for both test and train as IID

device = 'cpu'
M = 3072
N = torch.arange(500,6010,550).to(torch.int).cpu()
N_theory = torch.arange(500,6010,55).to(torch.int)

Cinverse = (N/M).cpu().numpy()
Cinverse_theory = (N_theory/M).cpu().numpy()
Err_emp_iid_both = torch.load(path2_rank+"iid-both-emp.pt",map_location = device).numpy()
Err_iid_both = torch.load(path1_rank+"iid-both.pt",map_location = device).numpy()

plt.plot(Cinverse_theory,Err_iid_both[0,:],color="orange")
plt.plot(Cinverse,Err_emp_iid_both[0,:],'.',markersize = 15,color="orange")

plt.yscale("log")
plt.xlabel("1/c = N/d")
plt.ylabel("Generalization Error")
plt.legend(['Theoretical Result','Empirical Result'],fontsize=13.5)
plt.savefig(path_figures+"iid_both_error.pdf", bbox_inches='tight', facecolor='white', dpi = 300, format = 'pdf')