In [10]:
import math
import json
from pprint import pprint
from models.model_bank import DynamicOrigami


def calculate_width(P, F, C, A, B):
    """
    Calculate the width (W) of layers in a neural network given the desired number of parameters.
    Parameters:
        - P: int, total number of parameters
        - F: int, number of features (input size)
        - C: int, number of classes (output size)
        - A: int, number of fold layers
        - B: int, total number of linear layers (including the first and last)
    Returns:
        - W: int, calculated width of the layers
    """
    # Coefficients for the quadratic equation
    a = (B - 2)
    b = (F + A + C)
    c = -P

    # Solve the quadratic equation: aW^2 + bW + c = 0
    discriminant = b**2 - 4 * a * c
    if discriminant < 0:
        return 1
    
    try:
        W = (-b + math.sqrt(discriminant)) / (2 * a)
    except ZeroDivisionError:
        W = P / (F + A + C)
    W = max(1, round(W))
    return W



def count_params(n_features, n_classes, min_size=4, scale=2, n_sizes=6, min_dim=1000):
    """
    This function calculates the number of parameters for each desired size of the model.
    Parameters:
        - n_features: int, number of input features
        - n_classes: int, number of output classes
        - min_size: int, minimum size of the model
        - scale: int, scaling factor for the model size
        - n_sizes: int, number of model sizes to calculate
        - min_dim: int, minimum number of parameters for the model
    Returns:
        - n_params: list, number of parameters for each size of the model
    """
    n_params = []
    for i in range(n_sizes):
        n_params.append(int(max(min_dim, n_features * n_classes) * min_size * scale**i))
    return n_params



def layer_widths(layers:list, n_features:int, n_classes:int, 
                 min_size:int=8, scale:int=4, repeat:int=6, verbose=0) -> list:
    
    count_linear = sum([1 for layer in layers if "linear" in layer.lower()])
    count_fold = sum([1 for layer in layers if "fold" in layer.lower()])
    
    widths = []
    for i in range(repeat):
        n_params = int(n_features * n_classes * min_size * scale**i)
        widths.append(calculate_width(P=n_params, F=n_features, C=n_classes, A=count_fold, B=count_linear))
    if verbose > 0:
        print("Biggest model:", int(n_params))
    return widths
        
widths = layer_widths(["Linear", "Fold", "Linear", ], 784, 10)
width = widths[0]
print(widths)

[158, 631, 2525, 10098, 40393, 161573]


In [15]:
ablation_archs = {}
learning_rate = 0.001
soft_fold = True
has_stretch = True
crease = None
fold_in = False
leak = 0
repeat4stds = 30
architectures = [["Linear0", "Linear", "Linear", "Linear1"],
                 ["Linear0", "Linear", "Linear", "Fold", "Linear1"],
                 ["Linear0", "Fold", "Linear", "Linear", "Fold", "Linear1"],
                 ["Linear0", "Fold", "Linear", "Fold", "Linear", "Fold", "Linear1"], 
                 ["Linear0", "Linear", "Linear", "Linear", "Linear1"], 
                 ["Fold", "Fold", "Fold", "Fold", "Fold", "Fold", "Fold", "Fold", "Linear1"]]
arch_names = ["4Linear", "41Alt", "42Alt", "43Alt", "5Linear", "8Fold"]
dataset_dims =  [54,        784,        28,     1024,       784]
dataset_ccs =   [7,         10,         2,      10,         10]
dataset_names = ["Cover",   "Digits",   "Higgs", "Cifar10", "Fashion"]

for aname, architecture in zip(arch_names, architectures):
    for dname, input_dim, class_count in zip(dataset_names, dataset_dims, dataset_ccs):
        
        n_params = count_params(input_dim, class_count, min_size=4, scale=3, n_sizes=5)
        widths = layer_widths(architecture, input_dim, class_count)
        
        for n_param, width in zip(n_params, widths):
 
            name = f"{aname}_{dname}_{n_param}"
            if soft_fold:
                architecture = [layer if layer != "Fold" else "Soft" + layer for layer in architecture]
            arch = {"learning_rate": learning_rate, 
                    "repeat": repeat4stds,
                    "structure": []}
            
            for layer in architecture:
                if "Linear" in layer:
                    if layer[-1] == "0":
                        inf = input_dim
                        out = width
                    elif layer[-1] == "1":
                        inf = width
                        out = class_count
                    else:
                        inf = width
                        out = width
                    arch["structure"].append({"params": {"in_features": inf, 
                                                         "out_features": out}, 
                                              "type": "Linear"})
                else:
                    arch["structure"].append({"params": {"has_stretch": has_stretch,
                                                         "width": width,
                                                         "crease": crease,
                                                         }, 
                                              "type": layer})
            arch["string"] = ["".join(filter(str.isalpha, layer)) for layer in architecture]
            ablation_archs[name] = {"architecture": arch,
                                    "n_classes": class_count,
            }
# save ablation architectures
with open("BenchmarkTests/ablation_archs.json", "w") as f:
    json.dump(ablation_archs, f)
print("Ablation architectures saved.")

pprint(ablation_archs)

Ablation architectures saved.
{'41Alt_Cifar10_1105920': {'architecture': {'learning_rate': 0.001,
                                            'repeat': 30,
                                            'string': ['Linear',
                                                       'Linear',
                                                       'Linear',
                                                       'SoftFold',
                                                       'Linear'],
                                            'structure': [{'params': {'in_features': 1024,
                                                                      'out_features': 2046},
                                                           'type': 'Linear'},
                                                          {'params': {'in_features': 2046,
                                                                      'out_features': 2046},
                                                           'type': 'Li

In [16]:
ablation_archs.keys()

dict_keys(['4Linear_Cover_4000', '4Linear_Cover_12000', '4Linear_Cover_36000', '4Linear_Cover_108000', '4Linear_Cover_324000', '4Linear_Digits_31360', '4Linear_Digits_94080', '4Linear_Digits_282240', '4Linear_Digits_846720', '4Linear_Digits_2540160', '4Linear_Higgs_4000', '4Linear_Higgs_12000', '4Linear_Higgs_36000', '4Linear_Higgs_108000', '4Linear_Higgs_324000', '4Linear_Cifar10_40960', '4Linear_Cifar10_122880', '4Linear_Cifar10_368640', '4Linear_Cifar10_1105920', '4Linear_Cifar10_3317760', '4Linear_Fashion_31360', '4Linear_Fashion_94080', '4Linear_Fashion_282240', '4Linear_Fashion_846720', '4Linear_Fashion_2540160', '41Alt_Cover_4000', '41Alt_Cover_12000', '41Alt_Cover_36000', '41Alt_Cover_108000', '41Alt_Cover_324000', '41Alt_Digits_31360', '41Alt_Digits_94080', '41Alt_Digits_282240', '41Alt_Digits_846720', '41Alt_Digits_2540160', '41Alt_Higgs_4000', '41Alt_Higgs_12000', '41Alt_Higgs_36000', '41Alt_Higgs_108000', '41Alt_Higgs_324000', '41Alt_Cifar10_40960', '41Alt_Cifar10_122880', 

In [10]:
old_archs = json.load(open("BenchmarkTests/architectures.json", "r"))
read_idx = 15
for i, (name, info) in enumerate(old_archs.items()):
    if i == read_idx:
        pprint(info)
        break

{'learning_rate': 0.01,
 'repeat': 3,
 'string': ['SoftFold', 'SoftFold', 'SoftFold', 'Linear'],
 'structure': [{'params': {'has_stretch': False, 'width': 1.1},
                'type': 'SoftFold'},
               {'params': {'has_stretch': False, 'width': 1.1},
                'type': 'SoftFold'},
               {'params': {'has_stretch': False, 'width': 1.1},
                'type': 'SoftFold'},
               {'params': {'in_features': 1.1, 'out_features': 1.1},
                'type': 'Linear'}]}


In [11]:
# load ablation architectures
ablation_archs = json.load(open("BenchmarkTests/ablation_archs.json", "r"))
for name, info in ablation_archs.items():
    print(name, end=":\t")
    model = DynamicOrigami(info["architecture"]["structure"], info["n_classes"], iknowaboutthecutlayer=True)
    print("Succeeded.")

FullLinear_Cover_1512:	Succeeded.
FullLinear_Cover_3024:	Succeeded.
FullLinear_Cover_6048:	Succeeded.
FullLinear_Cover_12096:	Succeeded.
FullLinear_Cover_24192:	Succeeded.
FullLinear_Digits_31360:	Succeeded.
FullLinear_Digits_62720:	Succeeded.
FullLinear_Digits_125440:	Succeeded.
FullLinear_Digits_250880:	Succeeded.
FullLinear_Digits_501760:	Succeeded.
FullLinear_Higgs_224:	Succeeded.
FullLinear_Higgs_448:	Succeeded.
FullLinear_Higgs_896:	Succeeded.
FullLinear_Higgs_1792:	Succeeded.
FullLinear_Higgs_3584:	Succeeded.
FullLinear_Cifar10_40960:	Succeeded.
FullLinear_Cifar10_81920:	Succeeded.
FullLinear_Cifar10_163840:	Succeeded.
FullLinear_Cifar10_327680:	Succeeded.
FullLinear_Cifar10_655360:	Succeeded.
FullLinear_Fashion_31360:	Succeeded.
FullLinear_Fashion_62720:	Succeeded.
FullLinear_Fashion_125440:	Succeeded.
FullLinear_Fashion_250880:	Succeeded.
FullLinear_Fashion_501760:	Succeeded.
21Alt_Cover_1512:	Succeeded.
21Alt_Cover_3024:	Succeeded.
21Alt_Cover_6048:	Succeeded.
21Alt_Cover_120