In [27]:

### ----- ERROR HANDLING
function do_activations_match_depth(depth, activation)
    """Error handling, if using multiple activations"""
    if length(activation) != (depth - 2)
        error("$(length(activation)) activations provided, but $(depth - 2) layers activate.\nBeware: identity is used for last layer - DO NOT PROVIDE FOR LAST LAYER")
    end
end

function depth_validation(depth)
    if depth < 3
        error("Depth must be at least 3 to form an ANN.\nBeware that 'depth' refers to ALL layers.")
    end
end

depth_validation (generic function with 1 method)

In [28]:
### Different Architecture types for varrying data
### Author: Axel Bjarkar

# ?model_architype(architype, dimIN, dimOUT, depth, activation, critical_width=Nothing)
# TO SEE ASCII ART OF LAYOUTS ↑↑↑

### ----- IMPORTS
using Flux
include("DenseNTK.jl")

function model_architype(architype, dimIN, dimOUT, depth, activation, approx_num_params,critical_width=Nothing)
    depth_validation(depth)

    # Get appropriate width
    if architype == "LH1"
        widths = [dimIN, critical_width, dimOUT]
        depth = 3
    elseif architype == "block"
        widths = widths_block(dimIN, dimOUT, depth, approx_num_params)
    elseif architype == "funnel"
        widths = widths_funnel(dimIN, dimIN, depth, approx_num_params)
    elseif architype == "reverse_funnel"
        widths = widths_reverse_funnel(dimIN, dimIN, depth, approx_num_params)
    elseif architype == "hourglass"
        widths = widths_hourglass(dimIN, dimOUT, depth, critical_width, approx_num_params)
    elseif architype == "diamond"
        widths = widths_diamond(dimIN, dimOUT, depth, critical_width, approx_num_params)
    else
        current_types = "Current supported types:\n"
        supported_types = ["LH1", "block", "funnel", "reverse_funnel", "hourglass", "diamond"]
        error("'$architype' is not a valid architecture type\n\n$current_types$(join(supported_types, '\n'))\n")
    end

    # Model construction
    layers = []
    if isa(activation, Function)
        # All layers use the same activation function
        for i in 1:depth-1
            act = i < depth-1 ? activation : identity   # if i < depth-1 use activation else use identity function
            push!(layers, DenseNTK(widths[i], widths[i+1], act))
        end
    elseif isa(activation, Vector)
        # Different activation for each layer
        do_activations_match_depth(depth, activation)
        for i in 1:depth-1
            act = i < depth-1 ? activation[i] : identity
            push!(layers, DenseNTK(widths[i], widths[i+1], act))
        end
    else
        error("Invalid activation type: must be a Function or Vector of Functions.")
    end

    model = Chain(layers...)

    return model
end

function construct_model(widths, activation)
    # Model construction
    layers = []
    depth = length(widths)

    if isa(activation, Function)
        # All layers use the same activation function
        for i in 1:depth-1
            act = i < depth-1 ? activation : identity   # if i < depth-1 use activation else use identity function
            push!(layers, DenseNTK(widths[i], widths[i+1], act))
        end
    elseif isa(activation, Vector)
        # Different activation for each layer
        do_activations_match_depth(depth, activation)
        for i in 1:depth-1
            act = i < depth-1 ? activation[i] : identity
            push!(layers, DenseNTK(widths[i], widths[i+1], act))
        end
    else
        error("Invalid activation type: must be a Function or Vector of Functions.")
    end

    model = Chain(layers...)
    return model
end


construct_model (generic function with 1 method)

In [32]:

function block(dimIN, dimOUT, depth, approx_num_params, activations)
    widths = zeros(Int, depth)
    widths[1] = dimIN
    widths[end] = dimOUT
    
    quad_solve(a,b,c) = (-b+sqrt(b^2 - (4*a*c)))/(2*a)

    # A B C fundin með algebru
    A = depth-3
    B = dimIN+1+dimOUT+depth-3
    C = dimOUT-approx_num_params

    nodes = Int(round(quad_solve(A,B,C)))
    
    for i = 2:(depth-1)
        widths[i] = nodes
    end

    return construct_model(widths, activations)
end

dimIN = 1
dimOUT = 10
depth = 4
params = 30_081


block(dimIN, dimOUT, depth, params, [σ,σ])

Chain(
  DenseNTK(Float32[-0.37557358; 0.48491597; … ; 0.30102775; -1.2416043;;], Float32[1.5125196, -0.33850855, -0.083546065, 0.0999315, 2.0458007, -0.94083095, -2.5482404, -0.9288291, -0.09738325, 0.099693835  …  -1.1703217, -0.4228017, -0.81690973, -0.5821124, 0.6065722, 1.0762348, -1.0848103, 0.8173694, -2.1110027, -1.7434909], NNlib.σ),  [90m# 334 parameters[39m
  DenseNTK(Float32[1.200404 -1.9892446 … -1.3110242 1.316334; -0.15121916 -2.0997684 … 2.892227 0.685615; … ; 0.9470955 0.57674223 … -0.10624681 0.9021404; -0.53077334 -1.1780746 … -0.18228237 0.30758098], Float32[0.06752559, -0.6431602, -0.59302855, -0.19768971, 0.24681294, -0.50889575, 0.05497908, -1.2088491, -1.362163, 0.8213387  …  1.481869, 0.75512195, -0.7272934, -0.36818293, -0.5731397, -1.5194638, -0.2790977, -0.793775, 0.043785892, 0.10491545], NNlib.σ),  [90m# 28_056 parameters[39m
  DenseNTK(Float32[-1.2654766 0.21954335 … 0.4437208 0.9263125; 0.43022233 -1.0366275 … 0.5155695 -0.24308722; … ; -1.6863629 -1

In [None]:


function widths_reverse_funnel(dimIN::Int, dimOUT::Int, depth::Int)
    widths = zeros(Int, depth)
    width_increment = (dimOUT - dimIN) / (depth - 1)

    # Set the widths for each layer
    for i in 1:depth
        widths[i] = dimIN + round(Int, width_increment * (i - 1))
    end

    widths[1] = dimIN
    widths[depth] = dimOUT

    return widths
end


In [None]:

function widths_hourglass(dimIN, dimOUT, depth, min_width)
    widths = zeros(Int, depth) # n - Zero Vector
    if depth%2 == 0
        middle_layers = (depth ÷ 2, depth ÷ 2 + 1)
        widths[middle_layers[1]], widths[middle_layers[2]] = min_width, min_width

        # Calculate decrease and increase steps
        decrease_step = (dimIN - min_width) / (middle_layers[1] - 1)
        increase_step = (dimOUT - min_width) / (middle_layers[1] - 1)

        # Set widths for decreasing and increasing phases
        for i in 1:(middle_layers[1] - 1)
            widths[i] = dimIN - round(Int, decrease_step * (i - 1))
        end
        for i in (middle_layers[2] + 1):depth
            widths[i] = min_width + round(Int, increase_step * (i - middle_layers[2]))
        end
    else
        middle_layer = ceil(Int, depth/2)
        widths[middle_layer] = min_width

        # Calculate decreasing widths from the input to the middle layer
        decrease_step = (dimIN - min_width) / (middle_layer - 1)
        for i in 1:(middle_layer-1)
            widths[i] = dimIN - round(Int, decrease_step * (i - 1))
        end

        # Calculate increasing widths from the middle layer to the output
        increase_step = (dimOUT - min_width) / (middle_layer - 1)
        for i in (middle_layer+1):depth
            widths[i] = min_width + round(Int, increase_step * (i - middle_layer))
        end
    end
    widths[depth] = dimOUT

    return widths
end


In [None]:

function widths_diamond(dimIN::Int, dimOUT::Int, depth::Int, max_width::Int)
    widths = zeros(Int, depth)

    if depth % 2 == 1 #ODD
        middle_index = ceil(Int, depth / 2)
        widths[middle_index] = max_width

        # Calculate width increments/decrements
        expand_step = (max_width - dimIN) / (middle_index - 1)
        contract_step = (max_width - dimOUT) / (middle_index - 1)
        
        # Set widths for expansion phase
        for i in 1:(middle_index - 1)
            widths[i] = dimIN + round(Int, expand_step * (i - 1))
        end
        # Set widths for contraction phase
        for i in (middle_index + 1):depth
            widths[i] = max_width - round(Int, contract_step * (i - middle_index))
        end
    
    else #EVEN
        middle_first = depth ÷ 2
        middle_second = middle_first + 1
        widths[middle_first], widths[middle_second] = max_width, max_width

        # Calculate width increments/decrements
        expand_step = (max_width - dimIN) / (middle_first - 1)
        contract_step = (max_width - dimOUT) / (middle_first - 1)

        # Set widths for expansion phase
        for i in 1:(middle_first - 1)
            widths[i] = dimIN + round(Int, expand_step * (i - 1))
        end
        # Set widths for contraction phase
        for i in (middle_second + 1):depth
            widths[i] = max_width - round(Int, contract_step * (i - middle_second))
        end
    end

    widths[1] = dimIN
    widths[depth] = dimOUT

    return widths
end
