<a href="https://colab.research.google.com/github/yspennstate/RMT_pruning_2/blob/main/rmt_pruning_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install tracywidom

Collecting tracywidom
  Downloading TracyWidom-0.3.0.tar.gz (12 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: tracywidom
  Building wheel for tracywidom (setup.py) ... [?25l[?25hdone
  Created wheel for tracywidom: filename=TracyWidom-0.3.0-py3-none-any.whl size=11726 sha256=7b5c759477451adfdf05f9b807f3254206d5f20f6a546c52a557643845bf80e7
  Stored in directory: /root/.cache/pip/wheels/9b/d7/24/27c1c32a4f1030307028bfedb8bb7de2968c4ea64103705110
Successfully built tracywidom
Installing collected packages: tracywidom
Successfully installed tracywidom-0.3.0


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import random
import copy
import math
import scipy
from scipy.integrate import quad
from TracyWidom import TracyWidom

The next block contains all the code necessary to create the neural network

In [None]:
class networkModel(nn.Module):
  def __init__(self, without_rel, dims):
    super(networkModel, self).__init__()
    self.without_rel = without_rel
    self.fc = nn.ModuleList()
    self.dims = dims
    current_dim = dims[0]
    for i in range(1, len(dims)):
      X = nn.Linear(current_dim, dims[i], bias=True)
      self.fc.append(X)
      current_dim = dims[i]

  def forward(self, x):
    x = torch.flatten(x,start_dim=2)
    for j in range(len(self.fc)):
      x = self.fc[j](x)
      if not self.without_rel[j]:
        x = F.relu(x)
    x = torch.flatten(x, start_dim = 1)
    x = F.log_softmax(x, dim=1)
    return (x)

  def getLayers(self):
    return len(self.fc)

  def getLayerMatrix(self,index):
    layerMatrix = torch.as_tensor(self.fc[index].weight)
    layerMatrix = layerMatrix.cpu()
    layerMatrix = layerMatrix.detach().numpy()
    return layerMatrix

  def getWithout(self):
    return self.without_rel

  def getDims(self):
    return self.dims

  def getParameterCount(self):
    count = 0
    for param in self.parameters():
      temp = 1
      param = param.size()
      for i in param: temp *= i
      count += temp
    return count



class NetworkModelModified(nn.Module):
  def __init__(self, dims, relu_layers, alpha, beta, goodnessOfFitCutoff):
        super(NetworkModelModified, self).__init__()
        layers = []
        for i in range(len(dims) - 1):
            layers.append(SplittableLinear(dims[i], dims[i + 1], alpha, beta, goodnessOfFitCutoff, name=f'layer {i+1}'))
            if relu_layers[i]:
                layers.append(nn.ReLU())
        self.model = nn.Sequential(*layers)
        self._initialize_weights()

  def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        return self.model(x)

  def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=1/np.sqrt(m.weight.size(1)))
                nn.init.constant_(m.bias, 0)


  def getLayerMatrix(self,index):
    layerMatrix = torch.as_tensor(self.fc[index].weight)
    layerMatrix = layerMatrix.cpu()
    layerMatrix = layerMatrix.detach().numpy()
    return layerMatrix

  def getWithout(self):
    return self.without_rel

  def getDims(self):
    return self.dims

  def getParameterCount(self):
    count = 0
    for param in self.parameters():
      temp = 1
      param = param.size()
      for i in param: temp *= i
      count += temp
    return count


import torch
import torch.nn.functional as F
import torch.optim as optim


def train(args, model, device, train_loader, optimizer, epoch, showTrainingLoss=True):
    model.train()
    model.to(device)

    l1_lambda = 0.000005  # Adjust this value based on your requirement
    l2_lambda = 0.000005  # Adjust this value based on your requirement

    correct = 0
    total_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        output = model(data)
        loss = F.nll_loss(output, target)

        # L1 regularization
        l1_norm = sum(p.abs().sum() for p in model.parameters())
        loss += l1_lambda * l1_norm

        # L2 regularization
        l2_norm = sum((p ** 2).sum() for p in model.parameters())
        loss += l2_lambda * l2_norm

        total_loss += loss.item()

        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

        loss.backward()
        optimizer.step()

        if (batch_idx + 1) % args['log_interval'] == 0 and showTrainingLoss:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

    avg_loss = total_loss / len(train_loader)
    accuracy = 100. * correct / len(train_loader.dataset)

    print(f'Train set: Average loss: {avg_loss:.4f}, Accuracy: {correct}/{len(train_loader.dataset)} ({accuracy:.0f}%)')
    return accuracy


def test(args, m, device, test_loader):
  m.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      (output) = m(data)
      test_loss += F.nll_loss(output, target, reduction='sum').item()
      pred = output.argmax(dim=1, keepdim=True)
      correct += pred.eq(target.view_as(pred)).sum().item()
  test_loss /= len(test_loader)
  print(f'test set: Average loss: {test_loss:4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100.*correct/len(test_loader.dataset):.0f}%)')
  return (100 * correct / len(test_loader.dataset))



import torch
from torchvision import datasets, transforms
import torch.optim as optim

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

def neuralInit(seed, device, model, lr, momentum, batchSize):
    args = {"batch_size": batchSize, "test_batch_size": batchSize,  "lr": lr, "momentum": momentum,  "no_cuda": False,  "seed": seed,  "log_interval":50}
    use_cuda = not args['no_cuda'] and torch.cuda.is_available()
    torch.manual_seed(seed)
    device = torch.device("cuda" if use_cuda else "cpu")
    model.to(device)
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    # Load training data
    train_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST('../data', train=True, download=True, transform=transform),
        batch_size=args['batch_size'], shuffle=True, **kwargs)

    # Load test data
    test_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST('../data', train=False, transform=transform),
        batch_size=args['test_batch_size'], shuffle=True, **kwargs)

    optimizer = optim.SGD(model.parameters(), lr=args['lr'], momentum=args['momentum'])

    return args, model, optimizer, test_loader, train_loader


In [None]:
class CNNModel(nn.Module):
    def __init__(self,  conv_weights, fc_weights, without_rel):
      super(CNNModel, self).__init__()

      self.dims_conv = conv_weights
      self.dims_fc = fc_weights

      self.without_rel = without_rel

      self.fc = nn.ModuleList()
      self.conv = nn.ModuleList()
      self.bn_conv = nn.ModuleList()
      self.bn_fc = nn.ModuleList()


      for k in range(1,len(conv_weights)):
        self.conv.append(nn.Conv2d(conv_weights[k-1], conv_weights[k], 3))
        self.bn_conv.append(nn.BatchNorm2d(conv_weights[k]))

      for k in range(1,len(fc_weights)):
        self.fc.append(nn.Linear(fc_weights[k-1],fc_weights[k]))
        self.bn_fc.append(nn.BatchNorm1d(fc_weights[k]))

      self.bn = [self.bn_conv, self.bn_fc]

    def forward(self,x):
      for k in range(len(self.conv)):
        x = self.conv[k](x)
        if self.without_rel[0][k]:
          x = F.relu(x)
        x = self.bn[0][k](x)
        if ((k%2 == 0) and (k != 0)):
          x = F.max_pool2d(x, 2)
        x = F.dropout(x, 0.35)

      x = torch.flatten(x, start_dim = 1)
      x = x.reshape([x.shape[0],x.shape[1]])

      for k in range(len(self.fc)):
        x = self.fc[k](x)
        if k == 0:
          x = self.bn[1][k](x)
          if self.without_rel[1][k]:
            x = F.relu(x)
          x = F.dropout(x)
      #x = DropConnect(x, 0.5)
      x = F.log_softmax(x, dim=0)
      return(x)


    def getLayers(self):
      return len(self.conv)

    def getLayerMatrixOriginal(self,index):
      layerMatrix = torch.as_tensor(self.conv[index].weight)
      layerMatrix = layerMatrix.cpu()
      layerMatrix = layerMatrix.detach().numpy()
      return layerMatrix

    def getLayerMatrixConv(self,index):
      m = self.conv[index].weight.shape
      layerMatrix = torch.as_tensor(self.conv[index].weight.reshape(m[0],m[1]*m[2]*m[3]))
      layerMatrix = layerMatrix.cpu()
      layerMatrix = layerMatrix.detach().numpy()
      return layerMatrix

    def getLayerMatrixFC(self,index):
      layerMatrix = torch.as_tensor(self.fc[index].weight)
      layerMatrix = layerMatrix.cpu()
      layerMatrix = layerMatrix.detach().numpy()
      return layerMatrix

    def getWithout(self):
      return self.without_rel

    def getDims(self):
      return [self.dims_conv, self.dims_fc]

    def getDimsFC(self):
      return self.dims_fc

    def getDimsConv(self):
      return self.dims_conv

    def getParameterCount(self):
      count = 0
      for param in self.parameters():
        temp = 1
        param = param.size()
        for i in param: temp *= i
        count += temp
      return count

class CNNModelModified(nn.Module):
    def __init__(self,  conv_weights, fc_weights, without_rel,alpha, beta, goodnessOfFitCutoff):
      super(CNNModelModified, self).__init__()

      self.dims_conv = conv_weights
      self.dims_fc = fc_weights

      self.without_rel = without_rel

      self.fc = nn.ModuleList()
      self.conv = nn.ModuleList()
      self.bn_conv = nn.ModuleList()
      self.bn_fc = nn.ModuleList()


      for k in range(1,len(conv_weights)):
        X = SplittableConv(
            conv_weights[k-1],
            conv_weights[k],
            3,
            alpha=alpha,
            beta=beta,
            goodnessOfFitCutoff=goodnessOfFitCutoff[0],
            name=f'conv layer {k}'
            )

        self.conv.append(X)
        self.bn_conv.append(nn.BatchNorm2d(conv_weights[k]))

      for k in range(1,len(fc_weights)):
        X = SplittableLinear(
            fc_weights[k-1],
            fc_weights[k],
            alpha=alpha,
            beta=beta,
            goodnessOfFitCutoff=goodnessOfFitCutoff[1],
            name=f'fc layer {k}',
            )
        self.fc.append(X)
        self.bn_fc.append(nn.BatchNorm1d(fc_weights[k]))

      self.bn = [self.bn_conv, self.bn_fc]

    def forward(self,x):
      for k in range(len(self.conv)):
        x = self.conv[k](x)
        if self.without_rel[0][k]:
          x = F.relu(x)
        x = self.bn[0][k](x)
        if ((k%2 == 0) and (k != 0)):
          x = F.max_pool2d(x, 2)
        x = F.dropout(x, 0.35)

      x = torch.flatten(x, start_dim = 1)
      x = x.reshape([x.shape[0],x.shape[1]])

      for k in range(len(self.fc)):
        x = self.fc[k](x)
        if k == 0:
          x = self.bn[1][k](x)
          if self.without_rel[1][k]:
            x = F.relu(x)
          x = F.dropout(x)
      x = F.log_softmax(x, dim=0)
      return(x)

    def getLayers(self):
      return len(self.conv)

    def getLayerMatrixOriginal(self,index):
      layerMatrix = torch.as_tensor(self.conv[index].weight)
      layerMatrix = layerMatrix.cpu()
      layerMatrix = layerMatrix.detach().numpy()
      return layerMatrix

    def getLayerMatrixConv(self,index):
      m = self.conv[index].weight.shape
      layerMatrix = torch.as_tensor(self.conv[index].weight.reshape(m[0],m[1]*m[2]*m[3]))
      layerMatrix = layerMatrix.cpu()
      layerMatrix = layerMatrix.detach().numpy()
      return layerMatrix

    def getLayerMatrixFC(self,index):
      layerMatrix = torch.as_tensor(self.fc[index].weight)
      layerMatrix = layerMatrix.cpu()
      layerMatrix = layerMatrix.detach().numpy()
      return layerMatrix

    def getWithout(self):
      return self.without_rel

    def getDims(self):
      return [self.dims_conv, self.dims_fc]

    def getDimsFC(self):
      return self.dims_fc

    def getDimsConv(self):
      return self.dims_conv

    def getParameterCount(self):
      count = 0
      for param in self.parameters():
        temp = 1
        param = param.size()
        for i in param: temp *= i
        count += temp
      return count

The next block contains all the code necessary to find lambda_+ according to the BEMA algorithm

In [None]:
def mpDensity(ndf, pdim, var = 1):
  gamma = ndf/pdim
  inv_gamma_sqrt = math.sqrt(1/gamma)
  a = var*(1-inv_gamma_sqrt)**2
  b = var*(1+inv_gamma_sqrt)**2
  return a,b

def dmp(x, ndf, pdim, var=1, log = False):
  gamma = ndf/pdim

  a,b = mpDensity(ndf, pdim, var)

  if not log :
    # we have to handle +/- zero carefully when gamma=1
    if gamma == 1 and x == 0 and 1/x > 0:
      d = math.inf
    elif x <= a and x >= b:
      d = 0
    else:
      d = gamma/( 2*math.pi*var*x ) * math.sqrt( ( x-a )*( b-x ) )
  else:
    if gamma == 1 and x == 0 and 1/x > 0:
      d = math.inf
    elif x <= a and x >= b:
      d = -math.inf
    else:
      d = log( gamma ) - ( log( 2 ) + log( math.pi ) + log( var ) + log( x ) ) + 0.5*log( x-a ) + 0.5*log( b-x )

  return d

def pmp(q, ndf, pdim, var=1, lower_tail = True, log_p = False):
  gamma = ndf/pdim
  a,b = mpDensity(ndf, pdim, var)
  f = lambda x : dmp(x, ndf, pdim, var)
  if lower_tail:
    if q<=a:
      p = 0
    elif q>=b:
      p = 1
    else:
      p = quad(f,a,q)[0]
    if gamma < 1 and q >= 0:
      p+= (1 - gamma)
  else:
    if q<=a:
      p = min(1, gamma)
    elif q>=b:
      p = 0
    else:
      p = quad(f,q,b)[0]
    if gamma < 1 and q <= 0:
      p+= (1 - gamma)
  if log_p:
    res = math.log(p)
  else:
    res = p
  return res

def qmp(p, ndf, pdim, var=1, lower_tail = True, log_p = False ):
  svr = ndf/pdim
  if lower_tail:
    p = p
  else:
    p = 1- p
  if log_p:
    p = math.exp(p)
  a,b = mpDensity(ndf, pdim, var)
  q = None
  if p<=0:
    if svr <=1:
      q = -0
    else:
      q = a
  else:
    if p>=1:
      q = b
  if svr<1:
    if p<1-svr:
      q = -0
    else:
      if p == 1-svr:
        q = 0
  if q is None:
    F = lambda x: pmp(x,ndf, pdim, var) - p
    q = scipy.optimize.brentq(F, a, b)
  return q



#bema_inside is where the BEMA algorithm is calculated
#use bema_mat_wrapper instead
def bema_inside(pdim, ndf, eigs, alpha, beta):
  pTilde = min(pdim,ndf)
  gamma = pdim/ndf
  ev = np.sort(eigs)
  ind = list(range(int(alpha*pTilde), int((1-alpha)*pTilde)))
  num =0
  q = [qmp(i/pTilde, ndf, pdim, 1) for i in ind]
  lamda = [ev[i] for i in ind]
  num = np.dot(q, lamda)
  denum = np.dot(q,q)
  sigma_sq = num/denum
  tw1 = TracyWidom(beta=1)
  t_b =tw1.cdfinv(1-beta)
  lamda_plus = sigma_sq*(((1+np.sqrt(gamma))**2+t_b*ndf**(-2/3)*(gamma)**(-1/6)*(1+np.sqrt(gamma))**4/3))
  l2 = sigma_sq* (1+np.sqrt(gamma))**2
  return sigma_sq, lamda_plus, l2


#use this function to compute bema
def bema_mat_wrapper(matrix,pReal,nReal, alpha, beta, goodnessOfFitCutoff, show = False):
    #this block uses the fact that eigenvalues are invariant under transposition
    #and hence without loss of generality our input matrix is p x n where
    #p <= n. This is used to ensure that our matrix has all positive singular values
    if pReal <= nReal:
      p = pReal
      n = nReal
      matrix_norm = np.matmul(matrix, matrix.transpose())/nReal
    else:
      p = nReal
      n = pReal
      matrix_norm = np.matmul(matrix.transpose(), matrix)/nReal

    v = np.linalg.eigvalsh(matrix_norm)
    sigma_sq, lamda_plus, l2 = bema_inside(p,n,v, alpha, beta)
    pTilde = min(p,n)
    LinfError = error(v,alpha, pTilde, p/n, sigma_sq)
    gamma = p/n
    goodFit = True if LinfError < goodnessOfFitCutoff else False
    if show:
      print("error", LinfError)
      plt.hist(v[-min(p,n):], bins = 100, color="black", label = "Empirical Density", density = True)
      #plt.axvline(x=lamda_plus, label = "Predicted Lambda Plus", color = "blue")
      Z = v[-min(p,n):]
      for t in range(len(Z)):
        if Z[t] > lamda_plus:
          Z = Z[:t]
          break
      Y = MP_Density_Wrapper(gamma, sigma_sq, Z)
      #plt.plot(Z,Y, color = "orange", label = "Predicted Density")
      plt.axvline(x = lamda_plus, label = "Lambda Plus", color = "red")
      plt.legend()
      plt.title("Empirical Distribution Density")
      plt.show()

      eigsTruncated = [i for i in v[-min(p,n):] if i < lamda_plus]
      plt.hist(eigsTruncated, bins = 100, color = "black", label = "Truncated Empirical Density", density = True)
      plt.plot(Z,Y, color = "orange", label = "Predicted Density")
      plt.legend()
      plt.title("Density Comparison Zoomed")
      plt.show()

    return v,p/n,sigma_sq, lamda_plus, goodFit

The next block contains all the code necessary for computing how well the MP distribution predicted by BEMA approximates the empirical data, through computing the L-infinity norm between the theoretical CDF and empirical CDF on the range sampled by BEMA

In [None]:
#helper MP density function evaluated at x
def MP_Density_Inner(gamma, sigma_sq,x):
  lp = sigma_sq*pow(1+math.sqrt(gamma),2)
  lm = sigma_sq*pow(1-math.sqrt(gamma),2)
  dv = math.sqrt((lp-x)*(x-lm))/(gamma*x*2*math.pi*sigma_sq)
  return dv

#at the sampled points, compute the MP distribution density
def MP_Density_Wrapper(gamma,sigma_sq,samplePoints):
  lp = sigma_sq*pow(1+math.sqrt(gamma),2)
  lm = sigma_sq*pow(1-math.sqrt(gamma),2)

  y = []
  for i in samplePoints:
    if lm <= i and i <= lp:
      y.append(MP_Density_Inner(gamma, sigma_sq,i))
    else: y.append(0)
  return np.array(y)

#helper function to compute MP CDF
def MP_CDF_inner(gamma, sigma_sq, x):
  lp = sigma_sq*pow(1+math.sqrt(gamma),2)
  lm = sigma_sq*pow(1-math.sqrt(gamma),2)
  r = math.sqrt((lp - x)/(x - lm))

  F = math.pi * gamma + (1/sigma_sq)*math.sqrt((lp - x)* (x - lm))
  F += -(1+gamma)*math.atan((r*r-1)/(2*r))
  if gamma !=  1:
    F += (1-gamma) *math.atan((lm *r*r - lp)/(2 *sigma_sq *(1-gamma)*r))
  F /= 2 * math.pi * gamma
  return F

#at the sample points compute the theoretical MP CDF
def MP_CDF(gamma, sigma_sq, samplePoints):
  lp = sigma_sq*pow(1+math.sqrt(gamma),2)
  lm = sigma_sq*pow(1-math.sqrt(gamma),2)

  output = []
  for x in samplePoints:
    if gamma <= 1:
      if x < lm:
        output.append(0)
      elif x >= lp:
        output.append(0)
      else:
        output.append(MP_CDF_inner(gamma, sigma_sq,x))
    else:
      if x < lm:
        output.append( (gamma-1)/gamma)
      elif x >= lp:
        output.append(1)
      else:
        output.append((gamma-1)/(2*gamma)+ MP_CDF_inner(gamma, sigma_sq, x))
  return np.array(output)


def empiricalCDF(S):
  return np.array( [( i)/len(S) for i in range(len(S))])


def error(singular_values,alpha,pTilde, gamma, sigma_sq, show = False):
  pTilde = len(singular_values)
  ind = np.arange(int(alpha*pTilde), int((1-alpha)*pTilde))
  prunedSingularValues = singular_values[ind]
  theoretical = MP_CDF(gamma, sigma_sq, prunedSingularValues)
  empirical = alpha + (1-2*alpha)*empiricalCDF(prunedSingularValues)
  difference = theoretical - empirical
  if show:
    plt.hist(difference, label = "Difference histogram")
    plt.legend()
    plt.show()
    x = np.arange(len(empirical))
    plt.plot(x, empirical, label = "empirical")
    plt.plot(x, theoretical, label = "theoretical")
    plt.legend()
    plt.show()
  return np.linalg.norm(difference, np.inf)

In [None]:
def bema_scheduler(epoch):
   return max(0, -1/300*epoch + 1)


In [None]:
class Splittable(nn.Module):
    """Custom Linear layer but mimics a standard linear layer"""

    def forward(self, x):
        return self.layer2(self.layer1(x))

    @property
    def goodnessOfFitCutoff(self):
        return self.goodnessOfFitCutoff_

    @property
    def param_numbers(self):
        if self.splitted:
            return (self.in_features + self.out_features) * self.layer1.out_features
        return self.in_features * self.out_features


    def fit_MP(self, U, singular_values, V, save_name,show = False):
        eigenvals = singular_values**2 / V.shape[0]
        eigenvals = np.sort(eigenvals)

        p = min(U.shape[0], V.shape[0])
        n = max(U.shape[0], V.shape[0])
        gamma = p / n
        sigma_sq, lamda_plus, l2 = bema_inside(p, n, eigenvals, self.alpha, self.beta)
        Splus = math.sqrt(V.shape[1] * lamda_plus)
        LinfError = error(eigenvals, self.alpha, p, gamma, sigma_sq)
        goodFit = LinfError < self.goodnessOfFitCutoff
        if show:
          v = eigenvals
          print("error", LinfError)
          plt.hist(v[-min(p,n):], bins = 100, color="black", label = "Empirical Density", density = True)
          #plt.axvline(x=lamda_plus, label = "Predicted Lambda Plus", color = "blue")
          Z = v[-min(p,n):]
          for t in range(len(Z)):
            if Z[t] > lamda_plus:
              Z = Z[:t]
              break
          Y = MP_Density_Wrapper(gamma, sigma_sq, Z)
          #plt.plot(Z,Y, color = "orange", label = "Predicted Density")
          plt.axvline(x = lamda_plus, label = "Lambda Plus", color = "red")
          plt.legend()
          plt.title("Empirical Distribution Density")
          plt.show()

          eigsTruncated = [i for i in v[-min(p,n):] if i < lamda_plus]
          plt.hist(eigsTruncated, bins = 100, color = "black", label = "Truncated Empirical Density", density = True)
          plt.plot(Z,Y, color = "orange", label = "Predicted Density")
          plt.legend()
          plt.title("Density Comparison Zoomed")
          plt.show()


        return Splus, goodFit

    def split(self, ratio, save_name, show = False, layer_type = 'fc'):
        matrix = self.get_matrix()
        U, S, V = np.linalg.svd(matrix)
        Splus, goodFit = self.fit_MP(U, S, V, save_name)
        if not goodFit:
            return f" {self.name} no good fit"

        significant_singulars = np.sum(S > Splus)
        inner_dim = (
            int((S.shape[0] - significant_singulars) * ratio) + significant_singulars
        )
        if self.param_numbers <= (matrix.shape[0] + matrix.shape[1]) * inner_dim:
            if not self.splitted:
                new_weights = (U[:, :inner_dim] * S[None, :inner_dim]) @ V[:inner_dim, :]
                self.set_params(
                    "layer1",
                    torch.from_numpy(new_weights).float(),
                    bias=None,
                    change_bias=False,
                )
            return f" {self.name} not enough param reduc"

        new_weights1 = np.sqrt(S)[:inner_dim, None] * V[:inner_dim, :]
        new_weights2 = U[:, :inner_dim] * np.sqrt(S)[None, :inner_dim]
        try:
            bias = nn.Parameter(self.layer1.bias.clone())
        except AttributeError:
            bias = None
        self.layer1, self.layer2 = self.make_splitted_layers(inner_dim)
        self.set_params("layer1", torch.from_numpy(new_weights1).float(), bias=None)
        self.set_params("layer2", torch.from_numpy(new_weights2).float(), bias)
        self.splitted = True
        return f" {self.name} splitted, new dims {(self.in_features,inner_dim,self.out_features)}"




class SplittableLinear(Splittable):
    """Custom Linear layer but mimics a standard linear layer"""

    def __init__(
        self,
        in_features,
        out_features,
        alpha,
        beta,
        goodnessOfFitCutoff,
        name="splittable_linear",
        bias=True,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.layer1 = nn.Linear(in_features, out_features, bias=bias)
        self.layer2 = nn.Identity()
        self.splitted = False
        self.alpha = alpha
        self.beta = beta
        self.goodnessOfFitCutoff_ = goodnessOfFitCutoff
        self.name = name

        # Initialize layer1 weights with N(0, 1/N)
        nn.init.normal_(self.layer1.weight, mean=0, std=np.sqrt(1.0/in_features))
        if bias:
            nn.init.zeros_(self.layer1.bias)

        # Initialize the bias for the second submatrix to zero
        if hasattr(self.layer2, 'bias'):
            nn.init.zeros_(self.layer2.bias)

    def forward(self, x):
        if self.splitted:
            x = self.layer2(self.layer1(x))
        else:
            x = self.layer1(x)
        return x

    @property
    def param_numbers(self):
        if self.splitted:
            return (self.in_features + self.out_features) * self.layer1.out_features
        return self.in_features * self.out_features

    def from_layer(linear, alpha, beta, goodnessOfFitCutoff):
        bias = linear.bias != None
        splittable_lin = SplittableLinear(
            linear.in_features,
            linear.out_features,
            alpha,
            beta,
            goodnessOfFitCutoff,
            bias=bias,
        )
        splittable_lin.set_params("layer1", linear.weight, linear.bias)
        return splittable_lin

    def __str__(self):
        return (
            f"Linear(in_features={self.in_features}, out_features={self.out_features})"
        )

    def get_matrix(self):
        layerMatrix = torch.as_tensor(self.layer1.weight)
        layerMatrix = layerMatrix.cpu()
        layerMatrix = layerMatrix.detach().numpy()
        if not self.splitted:
            return layerMatrix
        layerMatrix2 = torch.as_tensor(self.layer2.weight)
        layerMatrix2 = layerMatrix2.cpu()
        layerMatrix2 = layerMatrix2.detach().numpy()
        return layerMatrix2 @ layerMatrix

    def set_params(self, which_layer, weight, bias, change_bias=True):
        assert which_layer in ["layer1", "layer2"]
        getattr(self, which_layer).weight = nn.Parameter(weight)
        if change_bias:
            if bias is None:
                getattr(self, which_layer).bias = None
            else:
                getattr(self, which_layer).bias = nn.Parameter(bias)

    def make_splitted_layers(self, inner_dim):
        layer1 = nn.Linear(self.in_features, inner_dim, bias=False)
        layer2 = nn.Linear(inner_dim, self.out_features, bias=False)
        return layer1, layer2


class SplittableConv(Splittable):
    """Custom Linear layer but mimics a standard linear layer"""

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        alpha,
        beta,
        goodnessOfFitCutoff,
        name="splittable_conv",
        stride=1,
        padding=0,
        groups=1,
        bias=True,
        padding_mode="zeros",
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.in_features = kernel_size * kernel_size * in_channels
        self.out_features = out_channels
        self.stride = stride
        self.kernel_size = kernel_size
        self.padding = padding
        self.groups = groups
        self.padding_mode = padding_mode

        self.layer1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation=1,
            groups=groups,
            bias=bias,
            padding_mode=padding_mode,
        )
        self.layer2 = nn.Identity()
        self.splitted = False
        self.alpha = alpha
        self.beta = beta
        self.goodnessOfFitCutoff_ = goodnessOfFitCutoff
        self.name = name

    @property
    def param_numbers(self):
        if self.splitted:
            return np.prod(self.layer1.weight.shape)+np.prod(self.layer2.weight.shape)
        return self.in_features * self.out_features

    def __str__(self):
        return f"SplittableConv2d({self.in_channels},{self.out_channels},kernel_size=({self.kernel_size},{self.kernel_size}),stride=({self.stride},{self.stride}))"

    def from_layer(conv, alpha, beta, goodnessOfFitCutoff):
        in_channels = conv.in_channels
        out_channels = conv.out_channels
        assert conv.kernel_size[0] == conv.kernel_size[1]
        kernel_size = conv.kernel_size[0]
        assert conv.stride[0] == conv.stride[1]
        stride = conv.stride[0]
        assert conv.padding[0] == conv.padding[1]
        padding = conv.padding[0]
        groups = conv.groups
        padding_mode = conv.padding_mode
        bias = conv.bias != None
        splittable_conv = SplittableConv(
            in_channels,
            out_channels,
            kernel_size,
            alpha,
            beta,
            goodnessOfFitCutoff,
            stride=stride,
            padding=padding,
            groups=groups,
            bias=bias,
            padding_mode=padding_mode,
        )
        splittable_conv.set_params(conv.weight, conv.bias)
        return splittable_conv

    def get_matrix(self):
        layerMatrix = torch.as_tensor(self.layer1.weight)
        layerMatrix = layerMatrix.cpu()
        layerMatrix = layerMatrix.detach().numpy()
        layerMatrix = np.reshape(layerMatrix, (layerMatrix.shape[0], -1))
        if not self.splitted:
           return layerMatrix
        layerMatrix2 = torch.as_tensor(self.layer2.weight)
        layerMatrix2 = layerMatrix2.cpu()
        layerMatrix2 = layerMatrix2.detach().numpy()
        layerMatrix2 = np.reshape(layerMatrix2, (layerMatrix2.shape[0], -1))
        return layerMatrix2 @ layerMatrix

    def set_params(self,which_layer, weight, bias, change_bias=True):
        assert which_layer in ["layer1", "layer2"]
        if len(weight.shape) == 2:
            weight = torch.reshape(
                weight,
                getattr(self, which_layer).weight.shape
            )
        getattr(self, which_layer).weight = nn.Parameter(weight)
        if change_bias:
            if bias is None:
                getattr(self, which_layer).bias = None
            else:
                getattr(self, which_layer).bias = nn.Parameter(bias)

    def make_splitted_layers(self, inner_dim):
        layer1 = nn.Conv2d(
            self.in_channels,
            inner_dim,
            self.kernel_size,
            stride=self.stride,
            padding=self.padding,
            groups=self.groups,
            bias=False,
            padding_mode=self.padding_mode,
        )
        layer2 = nn.Conv2d(
            inner_dim,
            self.out_channels,
            1,
            stride=1,
            padding=0,
            groups=1,
            bias=True,
            padding_mode=self.padding_mode,
        )
        return layer1, layer2


class SplittableLayer(nn.Module):
    """Custom Linear layer but mimics a standard linear layer"""

    def __init__(
        self,
        in_dim,
        out_dim,
        alpha,
        beta,
        goodnessOfFitCutoff,
        name = "splittable_linear",
        bias = True,
        layer_type = 'fc',
    ):
        super().__init__()
        if layer_type == "fc":
          self.in_features = in_dim
          self.out_features = out_dim
          self.layer1 = nn.Linear(in_dim, out_dim, bias=bias)

        if layer_type == "conv":
          self.in_channels = in_dim
          self.out_channels = out_dim
          self.layer1 = nn.Conv2d(in_dim, out_dim, 3, bias=bias)

        self.layer2 = nn.Identity()
        self.splitted = False
        self.alpha = alpha
        self.beta = beta
        self.goodnessOfFitCutoff_ = goodnessOfFitCutoff
        self.name = name
        self.layer_type = layer_type

    def forward(self, x):
        return self.layer2(self.layer1(x))

    @property
    def goodnessOfFitCutoff(self):
        return self.goodnessOfFitCutoff_

    @property
    def param_numbers(self):
        if self.splitted:
            return (self.in_features + self.out_features) * self.layer1.out_features
        return self.in_features * self.out_features

    @property
    def param_numbers(self):
        if self.splitted:
            return (self.in_features + self.out_features) * self.layer1.out_features
        return self.in_features * self.out_features


    def fit_MP(self, U, singular_values, V, save_name,show = False):
        eigenvals = singular_values**2 / V.shape[0]
        eigenvals = np.sort(eigenvals)

        p = min(U.shape[0], V.shape[0])
        n = max(U.shape[0], V.shape[0])
        gamma = p / n
        sigma_sq, lamda_plus, l2 = bema_inside(p, n, eigenvals, self.alpha, self.beta)
        Splus = math.sqrt(V.shape[1] * lamda_plus)
        LinfError = error(eigenvals, self.alpha, p, gamma, sigma_sq)
        goodFit = LinfError < self.goodnessOfFitCutoff
        if show:
          v = eigenvals
          print("error", LinfError)
          plt.hist(v[-min(p,n):], bins = 100, color="black", label = "Empirical Density", density = True)
          #plt.axvline(x=lamda_plus, label = "Predicted Lambda Plus", color = "blue")
          Z = v[-min(p,n):]
          for t in range(len(Z)):
            if Z[t] > lamda_plus:
              Z = Z[:t]
              break
          Y = MP_Density_Wrapper(gamma, sigma_sq, Z)
          #plt.plot(Z,Y, color = "orange", label = "Predicted Density")
          plt.axvline(x = lamda_plus, label = "Lambda Plus", color = "red")
          plt.legend()
          plt.title("Empirical Distribution Density")
          plt.show()

          eigsTruncated = [i for i in v[-min(p,n):] if i < lamda_plus]
          plt.hist(eigsTruncated, bins = 100, color = "black", label = "Truncated Empirical Density", density = True)
          plt.plot(Z,Y, color = "orange", label = "Predicted Density")
          plt.legend()
          plt.title("Density Comparison Zoomed")
          plt.show()


        return Splus, goodFit

    def split(self, ratio, save_name, show = False, layer_type = 'fc'):
        matrix = self.get_matrix()
        U, S, V = np.linalg.svd(matrix)

        if layer_type == "fc":
          Splus, goodFit = self.fit_MP(U, S, V, save_name, show = show)
        if layer_type == "conv":
          Splus, goodFit = self.fit_MP(U, np.array(S).flatten(), V, save_name, show = show)

        if not goodFit:
            return f" {self.name} not good fit"
        #print(S.shape,"qrwetesydryj")
        significant_singulars = np.sum(np.array(S) > Splus)
        #print()
        if layer_type == "fc":
          inner_dim = (
              int((S.shape[0] - significant_singulars) * ratio) + significant_singulars
          )
        if layer_type == "conv":
          inner_dim = (
              int((S.shape[1] - significant_singulars) * ratio) + significant_singulars
          )
        old_num_param = np.prod(matrix.shape)
        print(old_num_param , matrix.shape, inner_dim, significant_singulars, "erqwt")
        print()
        if self.param_numbers <= (matrix.shape[0] + matrix.shape[1]) * inner_dim:
            if layer_type == "fc":
              new_weights = (U[:, :inner_dim] * S[None, :inner_dim]) @ V[:inner_dim, :]
            if layer_type =="conv":
              new_weights = (U[:,:inner_dim, :,:] * S[None,:inner_dim,:,:]) @ V [:,:inner_dim,:, :]
              #(U[..., :3] * S[..., None,:]) @ V [:inner_dim, :]

            self.set_params(
                'layer1',torch.from_numpy(new_weights).float(), bias=None, change_bias=False
            )
            return f" {self.name} not enough param reduc"
        print(U.shape, S.shape, V.shape,"10293847")
        if layer_type == "fc":
          print(U[:, :inner_dim].shape, np.sqrt(S)[None, :inner_dim].shape, np.sqrt(S)[:inner_dim, None].shape , V[:inner_dim, :].shape,"pppp")
          new_weights1 = np.sqrt(S)[:inner_dim, None] * V[:inner_dim, :]
          new_weights2 = U[:, :inner_dim] * np.sqrt(S)[None, :inner_dim]
          print(new_weights1.shape, new_weights2.shape," fc new")

        if layer_type == "conv":
          print(U.shape, S.shape, V.shape, "gjhgjhg")
          print(U[:,:inner_dim,:,:].shape, np.sqrt(S)[None,:inner_dim,:3,:3].shape, np.sqrt(S)[:inner_dim,None,:3,:3].shape, V[:inner_dim,:, :,:].shape,"asfd")
          new_weights1 =  np.sqrt(S)[:inner_dim,None,:3,:3] * V[:inner_dim,:, :,:]
          #np.sqrt(S)[..., None,:] * V#[:3, :]
          new_weights2 = U[:,:inner_dim,:,:] * np.sqrt(S)[None,:inner_dim,:3,:3]
          #U[..., ,:3] * np.sqrt(S)[None,:inner_dim,:3,:3]
          print(new_weights1.shape, new_weights2.shape, " conv new")
        try:
            bias = nn.Parameter(self.layer1.bias.clone())
        except AttributeError:
            bias = None
        self.layer1, self.layer2 = self.make_splitted_layers(inner_dim)
        self.set_params("layer1", torch.from_numpy(new_weights1).float(), bias=None)
        self.set_params("layer2", torch.from_numpy(new_weights2).float(), bias)
        self.splitted = True
        if layer_type == "fc":
          return f" {self.name} splitted, new dims {(self.in_features, inner_dim, self.out_features)}"
        if layer_type == "conv":
          return f" {self.name} splitted, new dims {(self.in_channels, inner_dim, self.out_channels)}"

    def from_layer(self, layer, alpha, beta, goodnessOfFitCutoff):
        bias = layer.bias != None

        if self.layer_type == 'fc':
          in_dim = layer.in_features
          out_dim = layer.out_features

        if self.layer_type == 'conv':
          in_dim = layer.in_channels
          out_dim = layer.in_channels

        splittable_layer = SplittableLayer(
            in_dim,
            out_dim,
            alpha,
            beta,
            goodnessOfFitCutoff,
            bias=bias,
        )
        splittable_layer.set_params("layer1", layer.weight, layer.bias)
        return splittable_layer

    def __str__(self):
      if self.layer_type == 'fc':
        return (f"Linear(in_features={self.in_features}, out_features={self.out_features})")
      if self.layer_type == 'conv':
        return (f"Convolutional(in_channels={self.in_channels}, out_channels={self.out_channels})")

    def get_matrix(self):
        layerMatrix = torch.as_tensor(self.layer1.weight)
        layerMatrix = layerMatrix.cpu()
        layerMatrix = layerMatrix.detach().numpy()
        if not self.splitted:
            return layerMatrix
        layerMatrix2 = torch.as_tensor(self.layer2.weight)
        layerMatrix2 = layerMatrix2.cpu()
        layerMatrix2 = layerMatrix2.detach().numpy()
        return layerMatrix2 @ layerMatrix

    def set_params(self, which_layer, weight, bias, change_bias=True):
        assert which_layer in ["layer1", "layer2"]
        getattr(self, which_layer).weight = nn.Parameter(weight)
        if change_bias:
            if bias is None:
                getattr(self, which_layer).bias = None
            else:
                getattr(self, which_layer).bias = nn.Parameter(bias)

    def make_splitted_layers(self, inner_dim):
        if self.layer_type == 'fc':
          layer1 = nn.Linear(self.in_features, inner_dim, bias=False)
          layer2 = nn.Linear(inner_dim, self.out_features, bias=False)
        if self.layer_type == 'conv':
          layer1 = nn.Conv2d(self.in_channels, inner_dim, 3, bias=False)
          layer2 = nn.Conv2d(inner_dim, self.out_channels, 3, bias=False)
        return layer1, layer2

the following block contains the splitting algorithms that process one layer

In [None]:
#interface to the rest of the code that processes
def innerAlgWrapper(model, index, lambda_plus, eigsToKeep, goodFit, alg = 0):
  if alg == -1:
    return model, 0
  if alg == 0:
    return algZero(model, index, lambda_plus, eigsToKeep, goodFit)
  elif alg == 1:
    return algOne(model,index,lambda_plus,eigsToKeep,goodFit)
  elif alg == 2:
    return algTwo(model,index,lambda_plus,eigsToKeep,goodFit)
  elif alg == 3:
    return algThree(model, index, lambda_plus,eigsToKeep,goodFit)
  elif alg == 4:
    return algFour(model, index, lambda_plus, eigsToKeep, goodFit)



#split + truncate, only split if reduces size
#no modifications if not good fit
def algZero(model, index, lambda_plus, eigsToKeep, goodFit):
  if not goodFit: return model, 0

  dims = model.getDimsFC()
  layerMatrix = model.getLayerMatrix(index)
  #layerMatrix has size outputDim x inputDim
  #so layer 1 will have eigsToKeep x inputDim
  #layer 2 will have outputDim x eigsToKeep
  inputDim = dims[index]
  outputDim = dims[index+1]

  (U,S,V) = np.linalg.svd(layerMatrix)
  s = np.zeros([outputDim, inputDim])

  for i in range(0, eigsToKeep):
    s[i][i] = S[i]**(1/2)

  #the flip here comes from pytorch not computing Wx + b
  #instead it computes x W^T + b
  #thus if W = W2 W1 then we want W1 then W2 so that
  #x W^T + b = x W1^T W2^T + b
  w2 = np.matmul(U,s)[:,:eigsToKeep]
  w1 = np.matmul(s,V)[:eigsToKeep,:]
  withoutrel = model.getWithout()

  #only split if this lowers total param numbers
  capChange = 0
  if inputDim*outputDim < eigsToKeep * (inputDim + outputDim):
    model.fc[index].weight = torch.nn.Parameter(torch.from_numpy(w2 @ w1).float())
  else:
    capChange = 1

    layerOne = nn.Linear(eigsToKeep,outputDim, bias=False)
    layerOne.weight = torch.nn.Parameter(torch.from_numpy(w1).float())

    if model.fc[index].bias is not None:
      layerTwo = nn.Linear(inputDim, eigsToKeep, bias=True)
      layerTwo.bias = model.fc[index].bias.float()
    else:
      layerTwo = nn.Linear(inputDim, eigsToKeep, bias=False)

    layerTwo.weight = torch.nn.Parameter(torch.from_numpy(w2).float())
    model.fc[index] = layerOne
    model.fc.insert(index+1, layerTwo)
    model.without_rel.insert(index, True)
    model.dims.insert(index+1, eigsToKeep)
  return model, capChange

#truncate spectrum at eigstokeep if goodfit
#no split
def algOne(model, index, lambda_plus, eigsToKeep, goodFit):
  if not goodFit: return model, 0

  dims = model.getDims()
  layerMatrix = model.getLayerMatrix(index)

  #layerMatrix has size outputDim x inputDim
  inputDim = dims[index]
  outputDim = dims[index+1]

  (U,S,V) = np.linalg.svd(layerMatrix)
  s = np.zeros([outputDim, inputDim])

  for i in range(eigsToKeep):
    s[i][i] = S[i]
  model.fc[index].weight = torch.nn.Parameter(torch.from_numpy(U @ s @ V).float())
  return model, 0

#shrink if good fit, no split ever
def algTwo(model, index, lambda_plus, eigsToKeep, goodFit):
  if not goodFit: return model, 0

  dims = model.getDims()
  layerMatrix = model.getLayerMatrix(index)

  #layerMatrix has size outputDim x inputDim
  inputDim = dims[index]
  outputDim = dims[index+1]

  (U,S,V) = np.linalg.svd(layerMatrix)
  s = np.zeros([outputDim, inputDim])
  for i in range(len(S)):
    s[i][i] = S[i]
    if S[i] <= lambda_plus:
      s[i][i]*= 0.5

  model.fc[index].weight = torch.nn.Parameter(torch.from_numpy(U@s@V).float())
  return model, 0

#alg3 = prune uniformly if good fit
#split if parameters reduced
def algThree(model, index, lambda_plus, eigsToKeep, goodFit):
  if not goodFit: return model, 0

  dims = model.getDims()
  layerMatrix = model.getLayerMatrix(index)

  #layerMatrix has size outputDim x inputDim
  #so layer 1 will have eigsToKeep x inputDim
  #layer 2 will have outputDim x eigsToKeep
  inputDim = dims[index]
  outputDim = dims[index+1]

  (U,S,V) = np.linalg.svd(layerMatrix)
  s = np.zeros([outputDim, inputDim])
  effectiveRank = computePredictedRank(model, index)
  for i in range(0, len(S)):
    s[i][i] = S[i]**(1/2)

  eigMask = list(range(effectiveRank))
  sampleDomain = list(range(effectiveRank, len(S)))
  eigsToKeepBelow = min(len(sampleDomain), eigsToKeep)
  eigMask.extend(random.sample(sampleDomain, eigsToKeepBelow))
  eigsToKeep = len(eigMask)

  #the flip here comes from pytorch not computing Wx + b
  #instead it computes x W^T + b
  #thus if W = W2 W1 then we want W1 then W2 so that
  #x W^T + b = x W1^T W2^T + b
  w2 = U[np.ix_(range(np.shape(U)[0]),eigMask)] @ s[np.ix_(eigMask,eigMask)]
  w1 = s[np.ix_(eigMask,eigMask)] @ V[np.ix_(eigMask,range(np.shape(V)[1]))]

  #only split if this lowers total param numbers
  capChange = 0
  if inputDim*outputDim < len(eigMask) * (inputDim + outputDim):
    model.fc[index].weight = torch.nn.Parameter(torch.from_numpy(w2 @ w1).float())
  else:
    capChange = 1

    layerOne = nn.Linear(eigsToKeep,outputDim, bias=False)
    layerOne.weight = torch.nn.Parameter(torch.from_numpy(w1).float())

    if model.fc[index].bias is not None:
      layerTwo = nn.Linear(inputDim, eigsToKeep, bias=True)
      layerTwo.bias = model.fc[index].bias.float()
    else:
      layerTwo = nn.Linear(inputDim, eigsToKeep, bias=False)

    layerTwo.weight = torch.nn.Parameter(torch.from_numpy(w2).float())
    model.fc[index] = layerOne
    model.fc.insert(index+1, layerTwo)
    model.without_rel.insert(index, True)
    model.dims.insert(index+1, eigsToKeep)
  return model, capChange

def algFour(model, index, lambda_plus, eigsToKeep, goodFit):
  if not goodFit: return model, 0

  dims = model.getDimsConv()
  layerMatrix = model.getLayerMatrixOriginal(index)
  #layerMatrix has size outputDim x inputDim
  #so layer 1 will have eigsToKeep x inputDim
  #layer 2 will have outputDim x eigsToKeep

  outputDim, inputDim  = layerMatrix.shape[:2]

  (U, S, V) = np.linalg.svd(layerMatrix)
  s = np.zeros(layerMatrix.shape)

  for k in range(layerMatrix.shape[0]):
    for j in range(layerMatrix.shape[1]):
      for i in range(eigsToKeep): #0 ,
        s[k][j][i%3][i%3] = S[k][j][i%3]**(1/2)

  #the flip here comes from pytorch not computing Wx + b
  #instead it computes x W^T + b
  #thus if W = W2 W1 then we want W1 then W2 so that
  #x W^T + b = x W1^T W2^T + b
  w2 = np.matmul(U,s)[:,:eigsToKeep]
  w1 = np.matmul(s,V)[:eigsToKeep,:]
  withoutrel = model.getWithout()[0]
  #only split if this lowers total param numbers
  capChange = 0
  if inputDim*outputDim < eigsToKeep * (inputDim + outputDim):
    model.conv[index].weight = torch.nn.Parameter(torch.from_numpy(w2 @ w1).float())
  else:
    capChange = 1

    layerOne = nn.Conv2d(eigsToKeep, outputDim, 3, bias = False)
    layerOne.weight = torch.nn.Parameter(torch.from_numpy(w1).float())

    if model.conv[index].bias is not None:
      layerTwo = nn.Conv2d(inputDim, eigsToKeep, 3, bias = True)
      layerTwo.bias = model.conv[index].bias.float()

    else:
      layerTwo = nn.Conv2d(inputDim, eigsToKeep, 3, bias = False)

    layerTwo.weight = torch.nn.Parameter(torch.from_numpy(w2).float())
    model.conv[index] = layerOne
    model.conv.insert(index+1, layerTwo)
    model.without_rel.insert(index, True)
    model.dims.insert(index+1, eigsToKeep)
  return model, capChange

The following block contains all the splitting code. It takes as input the model and which layer to be split, and returns the model with the selected layer split (according to the BEMA algorithm). splitWrapper should be called whenever you want to split the network, and it will apply the SVD algorithm to every layer.

The algorithm consists of the following. We find lambda+ and then keep X% of the singular values below lambda+ where the specific percentage is governed by our bema_scheduler. We split according to the square root trick when doing so reduces the total number of parameters of the network, with the caveat that we never split to a layer with size smaller than the end layer (to prevent information bottlenecks).

In [None]:
#computes lamdaPlus, eigsToKeep, and whether fit is good
def computeEigsToKeep(model, layerMatrix, dims, epoch, goodnessOfFitCutoff, show = False):
  (p,n) = layerMatrix.shape

  eigs, gamma, sigma_sq,lambda_plus, goodFit = bema_mat_wrapper(layerMatrix, p, n, 0.2, 0.1, goodnessOfFitCutoff, show = show)

  lt = len(eigs)-1
  for i in range(len(eigs)):
    if eigs[i] > lambda_plus:
      lt = i - 1
      break

  #lt = number of eigs less than lambda_plus
  #gt = number greater than lambda_plus
  gt = len(eigs) - lt
  p = gt + int(bema_scheduler(epoch)*lt)

  #this line is a necessary sanity check to stop the network from
  #pruning itself smaller than the output dimension
  eigsToKeep = max(p, dims[-1])
  lpTransformed = np.sqrt(lambda_plus)*np.sqrt(n)

  return lpTransformed, eigsToKeep, goodFit

#computes rank predicted by BEMA
def computePredictedRank(model, index, goodnessOfFitCutoff = 0.01, show = False):
  dims = model.getDims()
  layerMatrix = model.getLayerMatrix(index)

  (p,n) = layerMatrix.shape

  eigs, gamma, sigma_sq,lambda_plus, goodFit = bema_mat_wrapper(layerMatrix, p, n, 0.25, 0.1, goodnessOfFitCutoff)
  gt = 0


  for i in range(len(eigs)):
    if eigs[i] > lambda_plus:
      gt += 1
  if show:
    plt.hist(eigs, bins = 200)
    plt.axvline(x = lambda_plus)
    plt.show()
  return gt


#when desired to apply split algorithm, call splitWrapper
def splitWrapper(model, device, epoch, goodnessOfFitCutoff, show = False, alg = [0]):
  print("cutting. at this epoch we keep", bema_scheduler(epoch))
  #caps = [dims for dims in model.getDims()]
  #layersAdded ensures that when a layer is split, we skip over it
  for k in range(len(alg)):
    layersAdded = 0
    if alg[k] == 0:
      cap = model.getDimsFC()
    if alg[k] == 4:
      cap = model.getDimsConv()
    for j in range(len(cap) - 1):
      if alg[k] == 0:
        layer = model.getLayerMatrixFC(j + layersAdded)
      if alg[k] == 4:
        layer = model.getLayerMatrixConv(j + layersAdded)
      layer_size = np.prod(layer.shape)  # numpy.prod() gives total number of elements
      if layer_size > 300000:
          lpTransformed, eigsToKeep, goodFit = computeEigsToKeep(model, j+layersAdded, cap, epoch, goodnessOfFitCutoff[k], show = show)
          model, capChange = innerAlgWrapper(model, j+layersAdded, lpTransformed, eigsToKeep, goodFit, alg = alg[k])
          model.to(device)
          layersAdded += capChange
  #print("New Dimensions:", model.getDims())
  return model

In [None]:
def getMeanAndVar(yData):
  length = len(yData[0])
  mean = np.zeros(length)
  squaredMean = np.zeros(length)

  for i in yData:
    mean += i
    squaredMean += np.square(i)
  mean = mean / len(yData)
  squaredMean = squaredMean / len(yData)

  variance = squaredMean - np.square(mean)
  return mean, variance

def graphData(xData, yData, title = None):
  mean, variance = getMeanAndVar(yData)
  std = np.sqrt(variance)
  plt.xlabel("Epoch")
  plt.ylabel("Accuracy")
  if title != None:
    plt.title(title)
  plt.plot(xData, mean, label = "Mean", color = "black")
  plt.plot(xData, mean - std, "--", label = "Mean - 1std", color = "red")
  plt.plot(xData, mean + std, "--", label = "Mean + 1std", color = "blue")
  plt.fill_between(xData, mean, mean + std, color = "grey", alpha = 0.5)
  plt.fill_between(xData, mean, mean - std, color = "grey", alpha = 0.5)
  plt.legend()
  plt.show()


def graphDataBoth(xData, yData0, yData1):
  L = [yData0, yData1]
  Labels = ["Normal", "Pruned"]
  colors = ["black", "red"]
  plt.title("Normal Vs Pruned Comparison")
  plt.xlabel("Epoch")
  plt.ylabel("Accuracy")
  for i in range(2):
    mean, variance = getMeanAndVar(L[i])
    std = np.sqrt(variance)
    plt.plot(xData, mean, label = "Mean " + Labels[i], color = colors[i])
    plt.plot(xData, mean - std, "--", label = "Mean - 1std " +  Labels[i], color = colors[i])
    plt.plot(xData, mean + std, "--", label = "Mean + 1std " + Labels[i], color = colors[i])
    plt.fill_between(xData, mean, mean + std, color = "grey", alpha = 0.5)
    plt.fill_between(xData, mean, mean - std, color = "grey", alpha = 0.5)
  plt.legend()
  plt.show()

the following block allows for testing out the various algorithms and adjusting the hyper parameters, comparing accuracy between our modified algorithm and a stock SGD training algorithm

In [None]:
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
import copy
import os
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Directory to save models
save_dir = "/content/drive/MyDrive/DNN_models_5/"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# Training hyperparameters
trainCycles = 1000
seed = 4435912
lr = 0.02
m = 50
momentum = 0.9
batchSize = 128
splitFreq = 13
showTrainingLoss = False
showSpectrumDuringSplit = False
goodnessOfFitCutoff = [1]

# Network topology setup
dims = [28*28, 3000, 3000, 3000, 10]
no_rel = [False for i in range(len(dims)-2)]+[True,True]  # Layers with relu

def count_total_params(model):
    return sum(p.numel() for p in model.parameters())

def naive_prune(layer, threshold):
    with torch.no_grad():
        weight_mask = torch.abs(layer.weight) > threshold
        layer.weight *= weight_mask.float()

def count_nonzero_params(model):
    count = 0
    for param in model.parameters():
        count += torch.count_nonzero(param).item()
    return count

# New parameters
trialRuns = 1

if __name__ == '__main__':
    use_cuda = not False and torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu:1")

    yTestDataStock = []
    yTestDataModified = []
    yTrainDataStock = []
    yTrainDataModified = []

    for trainCycle in range(trialRuns):
        seed = seed + trainCycle  # Adjust the seed calculation
        torch.manual_seed(seed)
        print("Seed:", seed)  # Print the seed value

        x, y_test, y_test2, y_train, y_train2 = [], [], [], [], []

        torch.manual_seed(seed)  # Ensure `model2` gets the same initial weights as `model`
        print("Seed:", seed)  # Print the seed value

        model = networkModelModified(no_rel, dims, alpha=0.25, beta=0.9, goodnessOfFitCutoff=[0])
        model.to(device)
        (args, model, optimizer, test_loader, train_loader) = neuralInit(seed, device, model, lr, momentum, batchSize)
        optimizer = optim.SGD(model.parameters(), lr=args['lr'], momentum=args['momentum'])

        model2 = networkModelModified(no_rel, dims, alpha=0.25, beta=0.1, goodnessOfFitCutoff=goodnessOfFitCutoff)
        model2.to(device)
        (args, model2, optimizer2, test_loader, train_loader) = neuralInit(seed, device, model2, lr, momentum, batchSize)
        optimizer2 = optim.SGD(model2.parameters(), lr=args['lr'], momentum=args['momentum'])

        num_params_unpruned = sum([p.numel() for p in model2.parameters() if p.requires_grad])
        accuracies = []
        num_params_pruned = []
        threshold_values = [1/(1+k) for k in range(m)]  # Example thresholds

        for epoch in range(1, trainCycles + 1):
            # Update learning rate and split at specified frequency
            lr *= 0.96
            if epoch % splitFreq == 0 and epoch != 0:
                for i, layer in enumerate(model2.fc):
                    print(layer.split(bema_scheduler(epoch), save_name=f'layer_{i}_epoch_{epoch}', show=showSpectrumDuringSplit))
                    print("Cutting. At this epoch we keep", bema_scheduler(epoch))
                optimizer2 = optim.SGD(model2.parameters(), lr=lr, momentum=args['momentum'])

            # Training and testing
            y_train2.append(train(args, model2, device, train_loader, optimizer2, epoch, showTrainingLoss))
            y_test2.append(test(args, model2, device, test_loader))
            print("Number of parameters of DNN", sum([p.numel() for p in model2.parameters() if p.requires_grad]))

            if epoch % 10 == 0:
                torch.save(model.state_dict(), os.path.join(save_dir, f'model_epoch_{epoch}.pth'))

        for threshold in threshold_values:
            # Clone the model for each threshold
            model_pruned = copy.deepcopy(model2)

            # Apply naive pruning with current threshold
            for module in model_pruned.modules():
                if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
                    naive_prune(module, threshold)

            # Calculate the number of nonzero parameters after pruning
            num_nonzero = count_nonzero_params(model_pruned)
            num_params_pruned.append(num_nonzero)

            # Evaluate the model's accuracy after pruning
            accuracy = test(args, model_pruned, device, test_loader)
            accuracies.append(accuracy)

            # Calculate percentages of parameters kept
            params_kept_percentages = [100 * num / num_params_unpruned for num in num_params_pruned]

        # Calculate and print the number of nonzero parameters and the model's accuracy
        num_nonzero = count_nonzero_params(model2)
        print("Number of nonzero parameters after pruning:", num_nonzero)

        # Assuming 'test' function returns accuracy, adjust if it returns different metrics
        accuracy = test(args, model2, device, test_loader)
        print("Model accuracy after pruning:", accuracy)

        # Calculate percentages of parameters kept
        params_kept_percentages = [100 * num / num_params_unpruned for num in num_params_pruned]

        # Plotting accuracy vs number of parameters kept, only showing accuracies >= 90%
        plt.figure(figsize=(10, 6))
        for num, acc, pct in zip(num_params_pruned, accuracies, params_kept_percentages):
            if acc >= 90:  # Only plot if accuracy is >= 90%
                plt.scatter(num, acc, marker='o')
                plt.annotate(f"{pct:.1f}%", (num, acc), textcoords="offset points", xytext=(0, 10), ha='center')

        plt.xlabel("Number of Parameters Kept")
        plt.ylabel("Test Set Accuracy")
        plt.title("Test Set Accuracy vs Number of Parameters Kept (Accuracies >= 90%)")
        plt.grid(True)
        plt.show()

        # We now train normally, setting GoF to 0
        for epoch in range(1, trainCycles + 1):
            # Split every splitFreq many cycles
            lr = .96 * lr
            if epoch % splitFreq == 0 and epoch != 0:
                for i, layer in enumerate(model.fc):
                    print(layer.split(bema_scheduler(epoch), save_name=f'layer_{i}_epoch_{epoch}', show=showSpectrumDuringSplit))
                optimizer2 = optim.SGD(model.parameters(), lr=args['lr'], momentum=args['momentum'])

            y_train.append(train(args, model, device, train_loader, optimizer2, epoch, showTrainingLoss))
            y_test.append(test(args, model, device, test_loader))
            print("Number of parameters of DNN", sum([p.numel() for p in model.parameters() if p.requires_grad]))

        yTestDataStock.append(y_test)
        yTestDataModified.append(y_test2)
        yTrainDataStock.append(y_train)
        yTrainDataModified.append(y_train2)

    xData = [i for i in range(len(yTestDataStock[0]))]

    plt.figure()
    plt.plot(xData, yTrainDataStock[0], label="Training Accuracy")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.title("Training Accuracy vs Epochs")
    plt.legend()
    plt.grid(True)
    plt.show()

    plt.figure()
    plt.plot(xData, [acc for acc in yTestDataStock[0] if acc >= 90], label="Test Accuracy (>= 90%)")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.title("Test Accuracy (>= 90%) vs Epochs")
    plt.legend()
    plt.grid(True)
    plt.show()

    graphData(xData, yTestDataStock, title="Normal Test")
    graphData(xData, yTestDataModified, title="Pruned Test")
    graphData(xData, yTrainDataStock, title="Normal Train")
    graphData(xData, yTrainDataModified, title="Pruned Train")
    graphDataBoth(xData, yTestDataStock, yTestDataModified)
    graphDataBoth(xData, yTrainDataStock, yTrainDataModified)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Seed: 4435912
Seed: 4435912
Train set: Average loss: 2.1360, Accuracy: 49913/60000 (83%)
test set: Average loss: 50.631303, Accuracy: 8548/10000 (85%)
Number of parameters of DNN 20391010
Train set: Average loss: 1.9563, Accuracy: 52685/60000 (88%)
test set: Average loss: 44.784083, Accuracy: 8712/10000 (87%)
Number of parameters of DNN 20391010
Train set: Average loss: 1.8707, Accuracy: 53535/60000 (89%)
test set: Average loss: 42.377631, Accuracy: 8785/10000 (88%)
Number of parameters of DNN 20391010
Train set: Average loss: 1.7950, Accuracy: 54189/60000 (90%)
test set: Average loss: 46.568660, Accuracy: 8694/10000 (87%)
Number of parameters of DNN 20391010
Train set: Average loss: 1.7278, Accuracy: 54802/60000 (91%)
test set: Average loss: 43.390534, Accuracy: 8749/10000 (87%)
Number of parameters of DNN 20391010
Train set: Average loss: 1.6667, Accuracy: 

In [None]:
#training hyperparameters#
trainCycles = 20
seed = 4434543598112
lr = 0.01
m=20
momentum = 0.9
batchSize = 128
splitFreq = 4
showTrainingLoss = False
showSpectrumDuringSplit = False
goodnessOfFitCutoff = [1]

#network topology setup
dims = [28*28, 1000, 10]
no_rel = [False for i in range(len(dims))] #Layers without relu



def count_total_params(model):
    return sum(p.numel() for p in model.parameters())


def naive_prune(layer, threshold):
    with torch.no_grad():
        weight_mask = torch.abs(layer.weight) > threshold
        layer.weight *= weight_mask.float()

def count_nonzero_params(model):
    count = 0
    for param in model.parameters():
        count += torch.count_nonzero(param).item()
    return count


alg = [0]

#new parameters
trialRuns = 1

if __name__ == '__main__':
  use_cuda = not False and torch.cuda.is_available()
  device = torch.device("cuda:0" if use_cuda else "cpu:1")

yTestDataStock = []
yTestDataModified = []
yTrainDataStock = []
yTrainDataModified = []

for trainCycle in range(trialRuns):

  seed = seed + trainCycle  # Adjust the seed calculation
  torch.manual_seed(seed)
  print("Seed:", seed)  # Print the seed value




  x, y_test, y_test2, y_train, y_train2 = [], [], [], [], []

  torch.manual_seed(seed)  # Ensure `model2` gets the same initial weights as `model`
  print("Seed:", seed)  # Print the seed value

  model = networkModelModified(no_rel, dims, alpha=0.25, beta=0.9, goodnessOfFitCutoff=[0])

  model.to(device)

  (args, model,optimizer, test_loader, train_loader) = neuralInit(seed, device, model, lr, momentum, batchSize)

  optimizer = optim.SGD(model.parameters(), lr=args['lr'], momentum=args['momentum'])

  model2 = networkModelModified(no_rel, dims, alpha=0.25, beta=0.1, goodnessOfFitCutoff=goodnessOfFitCutoff)

  model2.to(device)

  (args, model2,optimizer, test_loader, train_loader) = neuralInit(seed, device, model2, lr, momentum, batchSize)

  optimizer2 = optim.SGD(model2.parameters(), lr=args['lr'], momentum=args['momentum'])

  num_params_unpruned = sum([p.numel() for p in model2.parameters() if p.requires_grad])
  accuracies = []
  num_params_pruned = []
  threshold_values = [1/(1+k) for k in range(m)]  # Example thresholds

  l = 5  # Define the frequency of pruning, every 'l' epochs
  initial_threshold = 0.001
  max_threshold = 0.02  # Maximum threshold value
  threshold_increment = (max_threshold - initial_threshold) / trainCycles  # Increment per epoch

  # Function to calculate the pruning threshold based on the current epoch
  def calculate_pruning_threshold(epoch, initial_threshold=0.01, increment=0.01):
      return min(initial_threshold + increment * epoch, max_threshold)

  # Training loop
  for epoch in range(1, trainCycles + 1):
      # Update learning rate and split at specified frequency
      lr *= 0.96
      if epoch % splitFreq == 0 and epoch != 0:
          for i, layer in enumerate(model2.fc):
              print(layer.split(bema_scheduler(epoch), save_name=f'layer_{i}_epoch_{epoch}', show=showSpectrumDuringSplit))
              print("Cutting. At this epoch we keep", bema_scheduler(epoch))
          optimizer2 = optim.SGD(model2.parameters(), lr=args['lr'], momentum=args['momentum'])

      # Training and testing
      y_train2.append(train(args, model2, device, train_loader, optimizer2, epoch, showTrainingLoss))
      y_test2.append(test(args, model2, device, test_loader))

      # Prune every 'l' epochs
      if epoch % l == 0:
          threshold = calculate_pruning_threshold(epoch, initial_threshold, threshold_increment)
          for module in model2.modules():
              if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
                  naive_prune(module, threshold)

          # Calculate and print the number of nonzero parameters after pruning
          num_nonzero = count_nonzero_params(model2)
          print(f"Epoch {epoch}: Pruned with threshold {threshold}. Nonzero parameters: {num_nonzero}")

          # Evaluate the model's accuracy after pruning
          accuracy = test(args, model2, device, test_loader)
          print(f"Model accuracy after pruning at Epoch {epoch}: {accuracy}")

      print("Number of parameters of DNN", sum([p.numel() for p in model2.parameters() if p.requires_grad]))



In [None]:
#training hyperparameters#
trainCycles = 20
seed = 4434543598112
lr = 0.02
momentum = 0.9
batchSize = 128
m=50
splitFreq = 5
showTrainingLoss = False
showSpectrumDuringSplit = False
goodnessOfFitCutoff = [1]

#network topology setup
dims = [28*28, 1000,1000, 10]
no_rel = [False for i in range(len(dims))] #Layers without relu



alg = [0]

#new parameters
trialRuns = 1

if __name__ == '__main__':
  use_cuda = not False and torch.cuda.is_available()
  device = torch.device("cuda:0" if use_cuda else "cpu:1")

yTestDataStock = []
yTestDataModified = []
yTrainDataStock = []
yTrainDataModified = []

for trainCycle in range(trialRuns):

  seed = seed + trainCycle  # Adjust the seed calculation
  torch.manual_seed(seed)
  print("Seed:", seed)  # Print the seed value




  x, y_test, y_test2, y_train, y_train2 = [], [], [], [], []

  torch.manual_seed(seed)  # Ensure `model2` gets the same initial weights as `model`
  print("Seed:", seed)  # Print the seed value

  model = networkModelModified(no_rel, dims, alpha=0.25, beta=0.1, goodnessOfFitCutoff=[0])

  model.to(device)

  (args, model,optimizer, test_loader, train_loader) = neuralInit(seed, device, model, lr, momentum, batchSize)

  optimizer = optim.SGD(model.parameters(), lr=args['lr'], momentum=args['momentum'])

  model2 = networkModelModified(no_rel, dims, alpha=0.25, beta=0.9, goodnessOfFitCutoff=goodnessOfFitCutoff)

  model2.to(device)

  (args, model2,optimizer, test_loader, train_loader) = neuralInit(seed, device, model2, lr, momentum, batchSize)

  optimizer2 = optim.SGD(model2.parameters(), lr=args['lr'], momentum=args['momentum'])



  num_params_unpruned = sum([p.numel() for p in model2.parameters() if p.requires_grad])
  accuracies = []
  num_params_pruned = []
  threshold_values = [1/(1+k) for k in range(m)]  # Example thresholds





  for epoch in range(1,trainCycles):
    #split every splitFreq many cycles
    if epoch % splitFreq == 0 and epoch != 0:
      lr=.96*lr
      for i,layer in enumerate(model2.fc):
        print(layer.split(bema_scheduler(epoch),save_name=f'layer_{i}_epoch_{epoch}',show=showSpectrumDuringSplit))
        print("cutting. at this epoch we keep", bema_scheduler(epoch))
      optimizer2 = optim.SGD(model2.parameters(), lr, momentum=args['momentum'])

    y_train2.append(train(args, model2, device, train_loader, optimizer2, epoch, showTrainingLoss))
    y_test2.append(test(args, model2, device, test_loader))
    print("Numer of parameters of DNN", sum([p.numel() for p in model2.parameters() if p.requires_grad]))

  for threshold in threshold_values:
    # Clone the model for each threshold
    model_pruned = copy.deepcopy(model2)

    # Apply naive pruning with current threshold
    for module in model_pruned.modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            naive_prune(module, threshold)

    # Calculate the number of nonzero parameters after pruning
    num_nonzero = count_nonzero_params(model_pruned)
    num_params_pruned.append(num_nonzero)

    # Evaluate the model's accuracy after pruning
    accuracy = test(args, model_pruned, device, test_loader)
    accuracies.append(accuracy)

    # Calculate percentages of parameters kept
    params_kept_percentages = [100 * num / num_params_unpruned for num in num_params_pruned]

  # Calculate and print the number of nonzero parameters and the model's accuracy
  num_nonzero = count_nonzero_params(model2)
  print("Number of nonzero parameters after pruning:", num_nonzero)

  accuracy = test(args, model2, device, test_loader)
  print("Model accuracy after pruning:", accuracy)

  # Calculate percentages of parameters kept
  params_kept_percentages = [100 * num / num_params_unpruned for num in num_params_pruned]

  # Plotting accuracy vs number of parameters kept, only showing accuracies >= 88%
  plt.figure(figsize=(10, 6))
  for num, acc, pct in zip(num_params_pruned, accuracies, params_kept_percentages):
      if acc >= 85:  # Only plot if accuracy is >= 85%
          plt.scatter(num, acc, marker='o')
          plt.annotate(f"{pct:.1f}%", (num, acc), textcoords="offset points", xytext=(0,10), ha='center')

  plt.xlabel("Number of Parameters Kept")
  plt.ylabel("Test Set Accuracy")
  plt.title("Test Set Accuracy vs Number of Parameters Kept (Accuracies >= 85%)")
  plt.grid(True)
  plt.show()



In [None]:
#training hyperparameters#
trainCycles = 20
seed = 4434543598112
lr = 0.02
momentum = 0.9
batchSize = 128
m=50
splitFreq = 5
showTrainingLoss = False
showSpectrumDuringSplit = False
goodnessOfFitCutoff = [0]

#network topology setup
dims = [28*28, 1000,1000, 10]
no_rel = [False for i in range(len(dims))] #Layers without relu



alg = [0]

#new parameters
trialRuns = 1

if __name__ == '__main__':
  use_cuda = not False and torch.cuda.is_available()
  device = torch.device("cuda:0" if use_cuda else "cpu:1")

yTestDataStock = []
yTestDataModified = []
yTrainDataStock = []
yTrainDataModified = []

for trainCycle in range(trialRuns):

  seed = seed + trainCycle  # Adjust the seed calculation
  torch.manual_seed(seed)
  print("Seed:", seed)  # Print the seed value




  x, y_test, y_test2, y_train, y_train2 = [], [], [], [], []

  torch.manual_seed(seed)  # Ensure `model2` gets the same initial weights as `model`
  print("Seed:", seed)  # Print the seed value

  model = networkModelModified(no_rel, dims, alpha=0.25, beta=0.1, goodnessOfFitCutoff=[0])

  model.to(device)

  (args, model,optimizer, test_loader, train_loader) = neuralInit(seed, device, model, lr, momentum, batchSize)

  optimizer = optim.SGD(model.parameters(), lr=args['lr'], momentum=args['momentum'])

  model2 = networkModelModified(no_rel, dims, alpha=0.25, beta=0.9, goodnessOfFitCutoff=goodnessOfFitCutoff)

  model2.to(device)

  (args, model2,optimizer, test_loader, train_loader) = neuralInit(seed, device, model2, lr, momentum, batchSize)

  optimizer2 = optim.SGD(model2.parameters(), lr=args['lr'], momentum=args['momentum'])



  num_params_unpruned = sum([p.numel() for p in model2.parameters() if p.requires_grad])
  accuracies = []
  num_params_pruned = []
  threshold_values = [1/(1+k) for k in range(m)]  # Example thresholds





  for epoch in range(1,trainCycles):
    #split every splitFreq many cycles
    if epoch % splitFreq == 0 and epoch != 0:
      lr=.96*lr
      for i,layer in enumerate(model2.fc):
        print(layer.split(bema_scheduler(epoch),save_name=f'layer_{i}_epoch_{epoch}',show=showSpectrumDuringSplit))
        print("cutting. at this epoch we keep", bema_scheduler(epoch))
      optimizer2 = optim.SGD(model2.parameters(), lr, momentum=args['momentum'])

    y_train2.append(train(args, model2, device, train_loader, optimizer2, epoch, showTrainingLoss))
    y_test2.append(test(args, model2, device, test_loader))
    print("Numer of parameters of DNN", sum([p.numel() for p in model2.parameters() if p.requires_grad]))

  for threshold in threshold_values:
    # Clone the model for each threshold
    model_pruned = copy.deepcopy(model2)

    # Apply naive pruning with current threshold
    for module in model_pruned.modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            naive_prune(module, threshold)

    # Calculate the number of nonzero parameters after pruning
    num_nonzero = count_nonzero_params(model_pruned)
    num_params_pruned.append(num_nonzero)

    # Evaluate the model's accuracy after pruning
    accuracy = test(args, model_pruned, device, test_loader)
    accuracies.append(accuracy)

    # Calculate percentages of parameters kept
    params_kept_percentages = [100 * num / num_params_unpruned for num in num_params_pruned]

  # Calculate and print the number of nonzero parameters and the model's accuracy
  num_nonzero = count_nonzero_params(model2)
  print("Number of nonzero parameters after pruning:", num_nonzero)

  accuracy = test(args, model2, device, test_loader)
  print("Model accuracy after pruning:", accuracy)

  # Calculate percentages of parameters kept
  params_kept_percentages = [100 * num / num_params_unpruned for num in num_params_pruned]

  # Plotting accuracy vs number of parameters kept, only showing accuracies >= 70%
  plt.figure(figsize=(10, 6))
  for num, acc, pct in zip(num_params_pruned, accuracies, params_kept_percentages):
      if acc >= 70:  # Only plot if accuracy is >= 70%
          plt.scatter(num, acc, marker='o')
          plt.annotate(f"{pct:.1f}%", (num, acc), textcoords="offset points", xytext=(0,10), ha='center')

  plt.xlabel("Number of Parameters Kept")
  plt.ylabel("Test Set Accuracy")
  plt.title("Test Set Accuracy vs Number of Parameters Kept (Accuracies >= 85%)")
  plt.grid(True)
  plt.show()



In [None]:
def graphDataBothTest(xData, yData0, yData1):
  L = [yData0, yData1]
  Labels = ["Normal", "Pruned"]
  colors = ["black", "red"]
  plt.title("Accuracy on Test Set of Normal Vs Pruned DNNs")
  plt.xlabel("Epoch")
  plt.ylabel("Accuracy")
  for i in range(2):
    mean, variance = getMeanAndVar(L[i])
    std = np.sqrt(variance)
    plt.plot(xData, mean, label = "Mean " + Labels[i], color = colors[i])
    plt.plot(xData, mean - std, "--", label = "Mean - 1std " +  Labels[i], color = colors[i])
    plt.plot(xData, mean + std, "--", label = "Mean + 1std " + Labels[i], color = colors[i])
    plt.fill_between(xData, mean, mean + std, color = "grey", alpha = 0.5)
    plt.fill_between(xData, mean, mean - std, color = "grey", alpha = 0.5)
  plt.legend()
  plt.show()

def graphDataBothTrain(xData, yData0, yData1):
  L = [yData0, yData1]
  Labels = ["Normal", "Pruned"]
  colors = ["black", "red"]
  plt.title("Accuracy on Training Set of Normal Vs Pruned DNNs")
  plt.xlabel("Epoch")
  plt.ylabel("Accuracy")
  for i in range(2):
    mean, variance = getMeanAndVar(L[i])
    std = np.sqrt(variance)
    plt.plot(xData, mean, label = "Mean " + Labels[i], color = colors[i])
    plt.plot(xData, mean - std, "--", label = "Mean - 1std " +  Labels[i], color = colors[i])
    plt.plot(xData, mean + std, "--", label = "Mean + 1std " + Labels[i], color = colors[i])
    plt.fill_between(xData, mean, mean + std, color = "grey", alpha = 0.5)
    plt.fill_between(xData, mean, mean - std, color = "grey", alpha = 0.5)
  plt.legend()
  plt.show()

xData = [i for i in range(len(yTestDataStock[0]))]
graphData(xData, yTestDataStock, title = "Normal Test")
graphData(xData, yTestDataModified, title = "Pruned Test")
graphData(xData, yTrainDataStock, title = "Normal Train")
graphData(xData, yTrainDataModified, title = "Pruned Train")
graphDataBothTest(xData, yTestDataStock, yTestDataModified)
graphDataBothTrain(xData, yTrainDataStock, yTrainDataModified)

In [None]:
#training hyperparameters#
trainCycles = 30
seed = 44345434
lr = 0.02
lr1 = lr
momentum = .9
batchSize = 128
splitFreq = 11
m=50
showTrainingLoss = False
showSpectrumDuringSplit = False
goodnessOfFitCutoff = [.6, .05]

#network topology setup
conv_dims = [1,32, 64, 128, 256, 512] # article [1, 64, 128, 256, 512]
original_conv_dims =  conv_dims

fc_dims = [4608, 1000,   10]
original_fc_dims = fc_dims


#dims = [conv_dims, fc_dims]
#original_dims = dims

c=[False for i in range(len(fc_dims) - 1)]
c.insert(0,True)
no_rel = [[True for i in range(len(conv_dims))],
          c] #Layers with relu

alg = [4,0]

#new parameters
trialRuns = 1

if __name__ == '__main__':
  use_cuda = torch.cuda.is_available()
  device = torch.device("cuda:0" if use_cuda else "cpu:1")

yTestDataStock = []
yTestDataModified = []
yTrainDataStock = []
yTrainDataModified = []



for trainCycle in range(trialRuns):
  seed = seed + trainCycle  # Adjust the seed calculation
  torch.manual_seed(seed)
  print("Seed:", seed)

  model =  CNNModelModified(conv_dims, fc_dims, no_rel,alpha=0.25,beta=0.1,goodnessOfFitCutoff=[0,0])
  model.to(device)

  (args, model,optimizer, test_loader, train_loader) = neuralInit(seed, device, model, lr, momentum, batchSize)

  x, y_test, y_test2, y_train, y_train2 = [], [], [], [], []
  y_test_loss, y_test2_loss, y_train_loss, y_train2_loss = [], [], [], [],

  model2 = CNNModelModified(conv_dims, fc_dims, no_rel,alpha=0.2,beta=0.9,goodnessOfFitCutoff=goodnessOfFitCutoff)
  model2.to(device)
  optimizer2 = optim.SGD(model2.parameters(), lr=args['lr'], momentum=args['momentum'])




  num_params_unpruned = sum([p.numel() for p in model2.parameters() if p.requires_grad])
  accuracies = []
  num_params_pruned = []
  threshold_values = [1/(20*(1+k)) for k in range(m)]  # Example thresholds



  torch.manual_seed(trainCycle)
  for epoch in range(1,trainCycles):
    #split every splitFreq many cycles
    if epoch % splitFreq == 0 and epoch != 0:
      lr = 0.96*lr
      for i,layer in enumerate(model2.fc):
        print(layer.split(bema_scheduler(epoch),save_name=f'layer_{i}_epoch_{epoch}',show=showSpectrumDuringSplit, layer_type = 'fc'))
        print("cutting. at this epoch we keep", bema_scheduler(epoch))
      for i,layer in enumerate(model2.conv):
        print(layer.split(bema_scheduler(epoch),save_name=f'layer_{i}_epoch_{epoch}',show=showSpectrumDuringSplit, layer_type = 'conv'))
        print("cutting. at this epoch we keep", bema_scheduler(epoch))
      optimizer2 = optim.SGD(model2.parameters(), lr = args['lr'], momentum=args['momentum'])

    train_accuracy = train(args, model2, device, train_loader, optimizer2, epoch, showTrainingLoss)
    y_train2.append(train_accuracy)
    print("Numer of parameters of DNN", sum([p.numel() for p in model2.parameters() if p.requires_grad]))

  for threshold in threshold_values:
    # Clone the model for each threshold
    model_pruned = copy.deepcopy(model2)

    # Apply naive pruning with current threshold
    for module in model_pruned.modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            naive_prune(module, threshold)

    # Calculate the number of nonzero parameters after pruning
    num_nonzero = count_nonzero_params(model_pruned)
    num_params_pruned.append(num_nonzero)

    # Evaluate the model's accuracy after pruning
    accuracy = test(args, model_pruned, device, test_loader)
    accuracies.append(accuracy)

    # Calculate percentages of parameters kept
    params_kept_percentages = [100 * num / num_params_unpruned for num in num_params_pruned]

  # Calculate and print the number of nonzero parameters and the model's accuracy
  num_nonzero = count_nonzero_params(model2)
  #print("Number of nonzero parameters after pruning:", num_nonzero)

  # Assuming 'test' function returns accuracy, adjust if it returns different metrics
  accuracy = test(args, model2, device, test_loader)
  #print("Model accuracy after pruning:", accuracy)

  # Calculate percentages of parameters kept
  params_kept_percentages = [100 * num / num_params_unpruned for num in num_params_pruned]

  # Plotting accuracy vs number of parameters kept, only showing accuracies >= 80%
  plt.figure(figsize=(10, 6))
  for num, acc, pct in zip(num_params_pruned, accuracies, params_kept_percentages):
      if acc >= 80:  # Only plot if accuracy is >= 80%
          plt.scatter(num, acc, marker='o')
          plt.annotate(f"{pct:.1f}%", (num, acc), textcoords="offset points", xytext=(0,10), ha='center')

  plt.xlabel("Number of Parameters Kept")
  plt.ylabel("Test Set Accuracy")
  plt.title("Test Set Accuracy vs Number of Parameters Kept (Accuracies >= 80%)")
  plt.grid(True)
  plt.show()




In [None]:
def graphDataBothTest(xData, yData0, yData1):
  L = [yData0, yData1]
  Labels = ["Normal", "Modified"]
  colors = ["black", "red"]
  plt.title("Accuracy on Test Set of Normal Vs Pruned DNNs")
  plt.xlabel("Epoch")
  plt.ylabel("Accuracy")
  for i in range(2):
    mean, variance = getMeanAndVar(L[i])
    std = np.sqrt(variance)
    plt.plot(xData, mean, label = "Mean " + Labels[i], color = colors[i])
    plt.plot(xData, mean - std, "--", label = "Mean - 1std " +  Labels[i], color = colors[i])
    plt.plot(xData, mean + std, "--", label = "Mean + 1std " + Labels[i], color = colors[i])
    plt.fill_between(xData, mean, mean + std, color = "grey", alpha = 0.5)
    plt.fill_between(xData, mean, mean - std, color = "grey", alpha = 0.5)
  plt.legend()
  plt.show()

def graphDataBothTrain(xData, yData0, yData1):
  L = [yData0, yData1]
  Labels = ["Normal", "Pruned"]
  colors = ["black", "red"]
  plt.title("Accuracy on Training Set of Normal Vs Pruned DNNs")
  plt.xlabel("Epoch")
  plt.ylabel("Accuracy")
  for i in range(2):
    mean, variance = getMeanAndVar(L[i])
    std = np.sqrt(variance)
    plt.plot(xData, mean, label = "Mean " + Labels[i], color = colors[i])
    plt.plot(xData, mean - std, "--", label = "Mean - 1std " +  Labels[i], color = colors[i])
    plt.plot(xData, mean + std, "--", label = "Mean + 1std " + Labels[i], color = colors[i])
    plt.fill_between(xData, mean, mean + std, color = "grey", alpha = 0.5)
    plt.fill_between(xData, mean, mean - std, color = "grey", alpha = 0.5)
  plt.legend()
  plt.show()

xData = [i for i in range(len(yTestDataStock[0]))]
graphData(xData, yTestDataStock, title = "Normal Test")
graphData(xData, yTestDataModified, title = "Pruned Test")
graphData(xData, yTrainDataStock, title = "Normal Train")
graphData(xData, yTrainDataModified, title = "Pruned Train")
graphDataBothTest(xData, yTestDataStock, yTestDataModified)
graphDataBothTrain(xData, yTrainDataStock, yTrainDataModified)