# This notebook has the code used to generate of the all the data that was used for the plots in the paper "Training Data Size Induced Double Descent for Denosing Neural Networks and the Role of Training Noise Level. 


In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import seaborn as sb
from tqdm import tqdm

import torchvision
import torch.nn.functional as F
from torch.nn import init
import torchvision.transforms as Tranforms
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torch.nn as nn
import matplotlib.gridspec as gridspec
from IPython.display import clear_output
%matplotlib inline

# Set Up

In [2]:
def gen_data(m,n,r, S1 = None, S2 = None, device = 'cpu'):
  U = torch.svd(torch.randn(m,r)).U[:,:r].to(device)
  V1 = torch.svd(torch.randn(r,n)).V[:,:r].to(device)
  V2 = torch.svd(torch.randn(r,n)).V[:,:r].to(device)

  if S1 == None:
    S1 = torch.diag(torch.randn(r).square()).to(device)
  if S2 == None:
    S2 = torch.diag(torch.randn(r).square()).to(device)

  X = U.mm(S1.mm(V1.t()))
  Xtst = U.mm(S2.mm(V2.t()))

  return X, Xtst, S1, S2

def gen_data_test(m, ntrn, ntst, r, S1 = None, S2 = None, device = 'cpu'):
  U = torch.svd(torch.randn(m,r)).U[:,:r].to(device)
  V1 = torch.svd(torch.randn(r,ntrn)).V[:,:r].to(device)
  V2 = torch.svd(torch.randn(r,ntst)).V[:,:r].to(device)

  if S1 == None:
    S1 = torch.diag(torch.randn(r).square()).to(device)
  if S2 == None:
    S2 = torch.diag(torch.randn(r).square()).to(device)

  S1 = S1.to(device)
  S2 = S2.to(device)

  X = U.mm(S1.mm(V1.t()))
  Xtst = U.mm(S2.mm(V2.t()))

  return X, Xtst, S1, S2

def cal_term_recon(c,thetatrn,thetatst):
  if c <= 1:
    return (thetatst**2)/(1+(thetatrn**2)*c)**2
  else:
    return (thetatst**2)/(1+(thetatrn**2))**2

def cal_term(c, theta):
  if np.abs(c-1) < 1e-6:
    return np.inf
  if c < 1:
    return (c**2)*((theta**2+theta**4)/((1+(theta**2)*c)**2))/(1-c)
  else:
    return (c/(c-1))*(theta**2)/(1+(theta**2))

def gen_error_pair(ntrn, ntst, m, c, thetatrn, thetatst):
  recon = cal_term_recon(c, thetatrn, thetatst)/ntst
  norm = cal_term(c,thetatrn)/m

  return recon+norm

def gen_error(ntrn, ntst, m, c, theta, psi, Strn, Stst):
  r = Strn.shape[0]
  error = 0
  for i in range(r):
    error+= gen_error_pair(ntrn, ntst, m, c, theta*Strn[i,i], psi*Stst[i,i])
  return error

def adjust_s(c,m,s):
  if c < 1:
    return (s.square()*(1-c/(2-c))-c/(M*(2-c))).relu().sqrt()

def calc_opt(c,M,N,thetatst):
  if c < 1:
    thetatrn = (thetatst.square()*(1-c/(2-c)) - c/(M*(2-c))).relu().sqrt()
  else:
    thetatrn = (2*thetatst.square()*(c-1) - 1/N).relu().sqrt()
  return thetatrn

def calc_opt_unnormalized(c, M, Ntrn, Ntst, thetatst):
  thetatrn = calc_opt(c, M, Ntrn, thetatst/np.sqrt(Ntst))*np.sqrt(Ntrn)
  return thetatrn

def gen_noise_spherical(m,n, bi=False):
  A = torch.nn.functional.normalize(torch.randn(m,n), dim = 0)
  if bi:
    V = torch.linalg.svd(torch.randn(1,n)).Vh.t()
    A = A.mm(V)
  return A

def gen_noise_rademacher(m,n, bi=False):
  U = torch.linalg.svd(torch.randn(m,1)).U
  A = U.mm(torch.randn(m,n).sign())/np.sqrt(m)
  if bi:
    V = torch.linalg.svd(torch.randn(1,n)).Vh.t()
    A = A.mm(V)
  return A

def gen_noise_poisson(m,n, bi=False):
  U = torch.linalg.svd(torch.randn(m,1)).U
  A = U.mm(torch.poisson(torch.ones(m,n)) - 1)/np.sqrt(m)
  if bi:
    V = torch.linalg.svd(torch.randn(1,n)).Vh.t()
    A = A.mm(V)
  return A

def gen_noise_bernoulli(m,n, bi=False):
  U = torch.linalg.svd(torch.randn(m,1)).U
  A = U.mm(torch.bernoulli(torch.ones(m,n)/2) - 0.5)/np.sqrt(m/4)
  if bi:
    V = torch.linalg.svd(torch.randn(1,n)).Vh.t()
    A = A.mm(V)
  return A


# Theoretical Curves for rank 1 data

## Changing $c$ by changing $N$

In [None]:
M = 1000
N = torch.arange(2000,100,-100)
Ntst = 1000
Err = torch.zeros(N.shape[0])
bias = torch.zeros(N.shape[0])
var = torch.zeros(N.shape[0])
Err_emp = torch.zeros(N.shape[0])
Err_emp_bias = torch.zeros(N.shape[0])
Err_emp_var = torch.zeros(N.shape[0])

thetatst = torch.diag(torch.tensor([1.0]))*np.sqrt(Ntst)
T = 50
for i in range(N.shape[0]):
  c = M/N[i]
  print(c, N[i])
  thetatrn = torch.diag(torch.tensor([1.0]))*N[i].sqrt() #torch.diag(calc_opt(c,M,N[i],torch.tensor([1.0])))
  print(thetatrn)
  for k in range(T):
    X, Xtst, S1, S2 = gen_data_test(M,N[i], Ntst,1,S1 = thetatrn, S2 = torch.diag(torch.tensor([1.0]))*N[i].sqrt())
    A = torch.randn_like(X)/np.sqrt(M)
    Atst = torch.randn_like(Xtst)/np.sqrt(M)
    W = X.mm(torch.pinverse(X+A))
    Yp = W.mm(Xtst + Atst)
    Err_emp[i] += (Xtst - Yp).square().sum()/(T*Ntst)
    Err_emp_bias[i] += (Xtst - W.mm(Xtst)).square().sum()/(T*Ntst)
    Err_emp_var[i] += (W).square().sum()/(T*Ntst)
  
  var[i] = cal_term(c, thetatrn[0,0])
  bias[i] = cal_term_recon(c, thetatrn[0,0], thetatst[0,0])
  Err[i] = gen_error(N[i],Ntst,M,c,1,1,thetatrn, thetatst)

In [None]:
v = Err_emp_var*Ntst
v[10] = np.inf

In [None]:
plt.plot(N/M, var)
plt.xlabel("1/C = Ntrn/M")
plt.ylabel("Variance")
plt.legend()
plt.savefig("N-var")

In [None]:
b = Err_emp_bias*Ntst
b[10] = np.inf
plt.plot(N/M, bias)
plt.xlabel("1/C = Ntrn/M")
plt.ylabel("bias")
plt.legend()
plt.savefig("N-bias")

## Changing $c$ by changing $M$

In [None]:
N = 1000
M = torch.arange(10,2000,10)
Err = torch.zeros(M.shape[0])
var = torch.zeros(M.shape[0])
bias = torch.zeros(M.shape[0])
Err_emp = torch.zeros(M.shape[0])
Err_emp_bias = torch.zeros(M.shape[0])
Err_emp_var = torch.zeros(M.shape[0])
T = 5

thetatst = torch.diag(torch.tensor([1.0]))*np.sqrt(N)
for i in range(M.shape[0]):
  c = M[i]/N
  print(c, N)
  thetatrn = torch.diag(torch.tensor([1.0]))*np.sqrt(N) #torch.diag(calc_opt(c,M[i],N,torch.tensor([1.0])))
  print(thetatrn)
  for j in range(T):
    X, Xtst, S1, S2 = gen_data(M[i],N,1,S1 = thetatrn, S2 = torch.diag(torch.tensor([1.0])))
    A = torch.randn_like(X)/np.sqrt(M[i])
    Atst = torch.randn_like(Xtst)/np.sqrt(M[i])
    W = X.mm(torch.pinverse(X+A))
    Yp = W.mm(Xtst + Atst)
    Err_emp[i] += (Xtst - Yp).square().sum()/(T*N)
    Err_emp_bias[i] += (Xtst - W.mm(Xtst)).square().sum()/(T*N)
    Err_emp_var[i] += (W).square().sum()/(T*N)
  
  var[i] = cal_term(c, thetatrn[0,0])
  bias[i] = cal_term_recon(c, thetatrn[0,0], thetatst[0,0])
  Err[i] = gen_error(N,N,M[i],c,1,1,S1,S2)

In [None]:
plt.plot(M/N, var)
plt.xlabel("C = M/Ntrn")
plt.ylabel("Variance")
plt.legend()
plt.savefig("M-var")

In [None]:
plt.plot(M/N, bias)
plt.xlabel("C = M/Ntrn")
plt.ylabel("bias")
plt.legend()
plt.savefig("M-bias")

# Approximation Error for the Formula

## Low SNR

In [None]:
#Generate Strn and Stst
C = 10

theory_error = torch.zeros(11,2*C).to('cuda')
emperical_error = torch.zeros(11,2*C).to('cuda')

theory_norm = torch.zeros(11,2*C).to('cuda')
emperical_norm = torch.zeros(11,2*C).to('cuda')

theory_recon = torch.zeros(11,2*C).to('cuda')
emperical_recon = torch.zeros(11,2*C).to('cuda')

T = 10

R = [1,2,3,5,10,20,50,100,150,200,250]
r_idx = 0

for r in R:

    for i in range(2*C):
        Strn = torch.diag(torch.randn(r).square()).to('cuda')
        Stst = torch.diag(torch.randn(r).square()).to('cuda')
        m = 2500
        n = np.maximum(r,int(((i+1)/C)*m))
        c = m/n

        print("n = ",n,", m = ",m,", c = ",c,", r = ",r)

        theory_error[r_idx, i] = gen_error(n,n,m,c,1,1,Strn,Stst)

        for k in range(r):
            theory_recon[r_idx, i] += cal_term_recon(c, Strn[k,k], Stst[k,k])/n
            theory_norm[r_idx, i] += cal_term(c, Strn[k,k])

        #generate data
        for t in range(T):
            X, Xtst, S1, S2 = gen_data(m,n,r,S1 = Strn, S2 = Stst, device = 'cuda')

            X = torch.ones(T,1,1, device = X.device)*X
            Xtst = torch.ones(T,1,1, device = Xtst.device)*Xtst
            A = torch.randn_like(X)/np.sqrt(m)
            Atst = torch.randn_like(Xtst)/np.sqrt(m)

            Y = X+A
            Ytst = Xtst + Atst

            W = torch.matmul(X,Y.pinverse())
            emperical_error[r_idx, i] += (Xtst - W.bmm(Ytst)).square().sum()/(T*T*n)

        print((theory_error[r_idx, i]-emperical_error[r_idx, i])/emperical_error[r_idx, i])

    r_idx += 1

## High SNR

In [None]:
#Generate Strn and Stst
C = 10

theory_error2 = torch.zeros(11,2*C).to('cuda')
emperical_error2 = torch.zeros(11,2*C).to('cuda')

theory_norm2 = torch.zeros(11,2*C).to('cuda')
emperical_norm2 = torch.zeros(11,2*C).to('cuda')

theory_recon2 = torch.zeros(11,2*C).to('cuda')
emperical_recon2 = torch.zeros(11,2*C).to('cuda')

T = 10

R = [1,2,3,5,10,20,50,100,150,200,250]
r_idx = 0

for r in R:

    for i in range(2*C):
        m = 2500
        n = np.maximum(r,int(((i+1)/C)*m))
        c = m/n
        Strn = (torch.diag(torch.randn(r)).square()*np.sqrt(n)).to('cuda')
        Stst = (torch.diag(torch.randn(r)).square()*np.sqrt(n)).to('cuda')
        

        print("n = ",n,", m = ",m,", c = ",c,", r = ",r)

        theory_error2[r_idx, i] = gen_error(n,n,m,c,1,1,Strn,Stst)

        for k in range(r):
            theory_recon2[r_idx, i] += cal_term_recon(c, Strn[k,k], Stst[k,k])/n
            theory_norm2[r_idx, i] += cal_term(c, Strn[k,k])

        #generate data
        for t in range(T):
            X, Xtst, S1, S2 = gen_data(m,n,r,S1 = Strn, S2 = Stst, device = 'cuda')

            X = torch.ones(T,1,1, device = X.device)*X
            Xtst = torch.ones(T,1,1, device = Xtst.device)*Xtst
            A = torch.randn_like(X)/np.sqrt(m)
            Atst = torch.randn_like(Xtst)/np.sqrt(m)

            Y = X+A
            Ytst = Xtst + Atst

            W = torch.matmul(X,Y.pinverse())
            emperical_error2[r_idx, i] += (Xtst - W.bmm(Ytst)).square().sum()/(T*T*n)

        print((theory_error2[r_idx, i]-emperical_error2[r_idx, i])/emperical_error2[r_idx, i])

    r_idx += 1

# Introduction Double Descent Figure

## Deep Network Denoising  

### Linear Rank 1 and Non Linear Synthetic

In [None]:
W0 = torch.randn(100,100)
W1 = torch.randn(100,100)

r = 1 # change rank for non - linear
U = torch.svd(torch.randn(100,100)).U[:,:r]
Vtrn = torch.svd(torch.randn(250000,r)).U[:,:r]
Vtst = torch.svd(torch.randn(1000,r)).U[:,:r]

X = U.mm(Vtrn.t()).t() # for linear
Xtst = U.mm(Vtst.t()).to('cuda').t() # for linear

X = W1.mm(W0.mm(X.t()).relu()).relu() # For non linear synthetic
Xtst = W1.mm(W0.mm(Xtst.t().cpu()).relu()).relu() # For non linear synthetic

device = 'cuda'

In [None]:
def do_sim(theta_train, theta_test, Xtrn, Xtst, bias = False):
    M = Xtrn.shape[1]
    Ntrn = Xtrn.shape[0]

    X = theta_train*Xtrn + torch.randn(Ntrn,M, device = device)/np.sqrt(M)
    Y = theta_train*Xtrn

    X_tst = theta_test*Xtst + torch.randn(Ntst,M, device = device)/np.sqrt(M)
    
    model = torch.nn.Sequential(nn.Linear(M,M, bias = bias),
                            nn.ReLU(),
                            nn.Linear(M,M, bias = bias),
                            nn.ReLU(),
                            nn.Linear(M,M, bias = bias))

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
    
    for i in range(1500):
        optimizer.zero_grad()
        Y_pred = model(X)
        loss = nn.functional.mse_loss(Y_pred,Y)
        loss.backward()
        optimizer.step()

    with torch.no_grad():
        Y_tst = model(X_tst)
        error = nn.functional.mse_loss(Y_tst,theta_test*Xtst)
        w_norms = 0
        for param in model.parameters():
            w_norms += param.square().sum()

    del model
    del X,Y,Y_tst,X_tst
    return error, w_norms 

In [None]:
Ndata = torch.tensor([100,125,150,175,200,250,300,500,700,900,
         1000,1500,2000,2500,3000,4000,5000, 6000, 7000, 8000, 9000, 10000,15000
         ,20000,25000,30000,40000,50000,60000,70000,80000,90000,100000])

Ntst = 1000
psi_norm = 0.1

T = 10    
avg_error = torch.zeros((len(Ndata),3,T), device = device)
theta = torch.zeros((len(Ndata),3), device = device)
w_norm = torch.zeros((len(Ndata),3,T), device = device)
    
for k in range(len(Ndata)):
    Ntrn = Ndata[k]
    print(Ntrn)
    M = 100

    C = M/Ntrn

    Xtrn = X[:Ntrn,:].to(device)

    print(Xtrn.shape, Xtst.shape)

    Xtrn *= np.sqrt(Ntrn)/Xtrn.norm()
    Xtst *= np.sqrt(Ntst)/Xtst.norm()
    
    print(Xtrn.shape, Xtst.shape)

    for i in range(3):
        if i == 0:
            theta[k,i] = 0.5*psi_norm
        elif i ==1:
            theta[k,i] = psi_norm
        else:
            theta[k,i] = 2*psi_norm

        for j in tqdm(range(T)):
            a, w = do_sim(theta[k,i], theta[k,i], Xtrn, Xtst, bias = False)
            avg_error[k,i,j] = a
            w_norm[k,i,j] = w

    
    number = '%s.pt'%Ntrn
    torch.save(avg_error,"rank50nl-3l-error-without-bias-"+number)
    torch.save(theta,"rank50nl-3l-theta-without-bias-"+number)
    torch.save(w_norm,"rank50nl-3l-wnorm-without-bias-"+number)

### For MNIST

In [None]:
mnist_train = torchvision.datasets.MNIST('./MNIST_data', train=True, download=True,
                           transform=Tranforms.ToTensor())
mnist_test = torchvision.datasets.MNIST('./MNIST_data', train=False, download=True,
                           transform=Tranforms.ToTensor())

Ndata = [30,40,50,100,200,300,500,750,1000,1200,1400,1600,1800,2000,2500,3000,4000,5000,7500,10000,15000,20000,30000,40000,50000]

for Ntrn in Ndata:
    print(Ntrn)

    C = M/Ntrn

    X = mnist_train.data
    y = mnist_train.targets
    X = X.to(device).to(torch.float32)/256
    
    print(X.max())

    Xtst = mnist_test.data
    Xtst = Xtst.to(device).to(torch.float32)/256

    Xtrn = X[:Ntrn,:,:].to(device).to(torch.float32)

    Ntst = Xtst.shape[0]
    psi_norm = 0.1
    psi = psi_norm*np.sqrt(Ntst)

    Xtrn = Xtrn.reshape(Ntrn,784)
    Xtst = Xtst.reshape(Ntst,784)
    
    print(Xtrn.shape, Xtst.shape)

    T = 50
    N = 40

    avg_error = torch.zeros((2,T), device = device)
    theta = torch.zeros(2, device = device)
    w_norm = torch.zeros((2,T), device = device)

    Xtrn *= np.sqrt(Ntrn)/Xtrn.norm()
    Xtst *= np.sqrt(Ntst)/Xtst.norm()

    for i in range(2):
        if i == 0:
            theta[i] = 0.5*psi_norm
        else:
            theta[i] = 2*psi_norm

        for j in range(T):
            a, w = do_sim(theta[i], theta[i], Xtrn, Xtst, bias = False)
            avg_error[i,j] = a
            w_norm[i,j] = w

    
    number = '%s.pt'%Ntrn
    torch.save(avg_error,"moretheta-3l-error-without-bias-finer-"+number)
    torch.save(theta,"moretheta-3l-theta-without-bias-finer-"+number)
    torch.save(w_norm,"moretheta-3l-wnorm-without-bias-finer-"+number)

## Linear Network

In [None]:
import torch
import torchvision
import torchvision.transforms as Tranforms

mnist_train = torchvision.datasets.MNIST('./MNIST_data', train=True, download=True,
                           transform=Tranforms.ToTensor())
mnist_test = torchvision.datasets.MNIST('./MNIST_data', train=False, download=True,
                           transform=Tranforms.ToTensor())

In [None]:
X = mnist_train.data
X = X.to(torch.float32).reshape(-1,784)/256

In [None]:
Xtst = mnist_test.data
Xtst = Xtst.to(torch.float32).reshape(-1,784)/256
Xtst = Xtst.t().to('cuda')

In [None]:
W0 = torch.randn(784,784)
W1 = torch.randn(784,784)

r = 5
U = torch.svd(torch.randn(784,784)).U[:,:r]
Vtrn = torch.svd(torch.randn(50000,r)).U[:,:r]
Vtst = torch.svd(torch.randn(10000,r)).U[:,:r]

X = W1.mm(W0.mm(U.mm(Vtrn.t())).relu()).relu().t()
Xtst = W1.mm(W0.mm(U.mm(Vtst.t())).relu()).relu().to('cuda')

In [None]:
Ntst = 10000

In [None]:
from tqdm import tqdm

m = 784

Ns = torch.tensor([100,200,300,400,500,600,700,725,750,775,800,850,900,1000,1500,2000,2500,3000,5000,7500,10000])
theta = [0.5,1,2]

emperical_norm  = torch.zeros(3,Ns.shape[0]).to('cuda')
emperical_error = torch.zeros(3,Ns.shape[0]).to('cuda')

print(emperical_error)

for i in range(Ns.shape[0]):
  Ntrn = Ns[i]
  c = m/Ntrn

  Xtrn = X[:Ntrn,:].t().to('cuda')
  
  for j in range(3):

    Xtrn *= theta[j]*np.sqrt(Ntrn)/Xtrn.norm()
    Xtst *= theta[j]*np.sqrt(Ntst)/Xtst.norm()

    print(Ntrn)
    print(Xtrn.shape)
    T = 100
    for k in tqdm(range(T)):
      A = torch.randn_like(Xtrn)/np.sqrt(m)
      Atst = torch.randn_like(Xtst)/np.sqrt(m)

      Y = Xtrn + A
      Ytst = Xtst + Atst

      W = torch.mm(Xtrn,Y.pinverse())

      emperical_norm[j,i] += (W.norm().square())/T
      emperical_error[j,i] += (Xtst - W.mm(Ytst)).norm().square()/(T*Ntst)

# More Indepth Figure For MNIST and CIFAR10

Show the example for cifar10. Just replace the dataset for MNIST

In [None]:
cifar_train = torchvision.datasets.CIFAR10('./CIFAR_data', train=True, download=True,
                           transform=Tranforms.ToTensor())
cifar_test = torchvision.datasets.CIFAR10('./CIFAR_data', train=False, download=True,
                           transform=Tranforms.ToTensor())

In [None]:
Ndata = [10,20,30,40,50,100,200,300,500,750,1000,1250,1500,1750,2000,2500,3000,4000,5000,7500,10000,12500,15000]

for Ntrn in Ndata:
    X = torch.tensor(cifar_train.data)
    y = torch.tensor(cifar_train.targets)
    X = X.to(device).to(torch.float32)
    
    print(X.max())

    Xtst = torch.tensor(cifar_test.data)
    Xtst = Xtst.to(device).to(torch.float32)

    Xtrn = X[:Ntrn,:,:].to(device).to(torch.float32)

    Ntst = Xtst.shape[0]
    psi_norm = 0.1
    psi = psi_norm*np.sqrt(Ntst)

    Xtrn = Xtrn.reshape(Ntrn,-1)
    Xtst = Xtst.reshape(Ntst,-1)
    
    print(Xtrn.shape, Xtst.shape)

    T = 5
    N = 20

    avg_error = torch.zeros((N,T), device = device)
    theta = torch.zeros(N, device = device)
    w_norm = torch.zeros((N,T), device = device)

    Xtrn *= np.sqrt(Ntrn)/Xtrn.norm()
    Xtst *= np.sqrt(Ntst)/Xtst.norm()

    for i in tqdm(range(N)):
        theta[i] = psi_norm*((i+1)/N)
  
        for j in range(T):
            a, w = do_sim(theta[i], psi_norm, Xtrn, Xtst, bias = False)
            avg_error[i,j] = a
            w_norm[i, j] = w
    
    number = '%s.pt'%Ntrn
    torch.save(avg_error,"Cifar100-2layer-relu/error-without-bias-"+number)
    torch.save(theta,"Cifar100-2layer-relu/theta-without-bias-"+number)
    torch.save(w_norm,"Cifar100-2layer-relu/wnorm-without-bias-"+number)