In [None]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
from matplotlib.colors import to_rgb
from matplotlib.colors import LinearSegmentedColormap

# Juptyer magic: For export. Makes the plots size right for the screen 
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# %config InlineBackend.figure_formats = ['svg'] 


torch.backends.cudnn.deterministic = True
seed = np.random.randint(1,200)
# seed = 107 #59
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
print(seed)
g = torch.Generator()
g.manual_seed(seed)

subfolder = 'resnet_2d_critgen'
import os
if not os.path.exists(subfolder):
    os.makedirs(subfolder)

# ResNet Parameters and Training Data

In [None]:
# Model Params
num_hidden = 5 # number of hidden layers. The total network has additionl 2 layers: input to hidden and hidden to output
input_dim = 2
hidden_dim = 2
output_dim = 1
skip_param = 1
sara_param = 0.2
activation = 'tanh' #'relu' and 'tanh' are supported
batchnorm = False
input_layer = False
bias = True #!!! THIS IS A CRAZY TRY TO SIMPLIFY

# Training Params
load_file = None
cross_entropy = True #True supported with binary classification only
final_sigmoid = True #needed for cross entropy
num_epochs = 200
max_retries = 1


In [None]:
import models.training
from models.training import make_circles_uniform



# Generate training data
n_points = 4000 #number of points in the dataset
batch_size = 100

inner_radius = 0.5
outer_radius = 1
buffer = 0.2

plotrange = [-2.5,2.5]

import importlib
importlib.reload(models.training) # Reload the module

train_loader, test_loader = make_circles_uniform(output_dim = output_dim, n_samples = n_points, inner_radius = inner_radius, outer_radius = outer_radius, buffer = buffer, cross_entropy = cross_entropy, batch_size = batch_size, seed = seed, filename = subfolder + '/circles_dataset')

# ResNet case

trying to establish a ResNet case that is stable under initialization

In [None]:
# to reload models.resnet module after changes without restarting the kernel
import importlib
import models.resnets
import models.training
importlib.reload(models.resnets) # Reload the module
importlib.reload(models.training) # Reload the module
from models.resnets import ResNet
from models.training import compute_accuracy, train_model, train_until_threshold, plot_loss_curve

In [None]:
import plots.plots 
from plots.plots import plot_decision_boundary, plot_level_sets
importlib.reload(plots.plots) # Reload the module



model_res, acc_res, losses_res = train_until_threshold(ResNet,
    train_loader, test_loader,
    load_file = load_file, max_retries=max_retries, threshold=0.8, early_stopping = False,
    input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim, num_hidden=num_hidden, skip_param=skip_param, sara_param = sara_param, activation=activation, input_layer = input_layer, final_sigmoid = final_sigmoid, batchnorm = batchnorm, bias = bias, epochs = num_epochs, cross_entropy=cross_entropy,  seed = seed
)

plot_loss_curve(losses_res, title=f"ResNet Model Loss Curve", filename = subfolder + '/ff6_res')

footnote = f'num_hidden={num_hidden}, hidden_dim={hidden_dim}, output_dim={output_dim}, act={activation}, seed={seed}, ce={cross_entropy}'



In [None]:

X_test, y_test = next(iter(test_loader))
plot_decision_boundary(model_res, X_test, y_test, show=True, file_name= subfolder + '/ff6hiddencirc' + str(num_hidden), footnote = footnote, amount_levels= 100, show_points= False)

In [None]:
import torch

def set_layer_params(model, layer_name, b_specific=None, w_specific=None, freeze=False):
    """
    Sets the bias and/or weights of a specific layer to specific values (Scalars OR Tensors).
    Optionally freezes the parameters so they are not updated during training.
    
    Args:
        model: The PyTorch model.
        layer_name: String path to the layer (e.g., 'res_blocks.2.fc').
        b_specific: Scalar (float) OR Tensor to set the bias to.
        w_specific: Scalar (float) OR Tensor to set the weights to.
        freeze: (bool) If True, sets requires_grad=False for the modified parameters.
    """
    try:
        # 1. Navigate to the specific layer module
        parts = layer_name.split('.')
        module = model
        for part in parts:
            module = getattr(module, part)
        
        # 2. Set Weights (if w_specific is provided)
        if w_specific is not None:
            if hasattr(module, 'weight') and module.weight is not None:
                with torch.no_grad():
                    # Case A: Input is a Tensor (e.g., torch.eye(2))
                    if isinstance(w_specific, torch.Tensor):
                        if w_specific.shape != module.weight.shape:
                            print(f"❌ Shape mismatch for Weights! Layer expects {module.weight.shape}, got {w_specific.shape}")
                            return
                        module.weight.copy_(w_specific.to(module.weight.device))
                    # Case B: Input is a Scalar
                    else:
                        module.weight.fill_(w_specific)
                
                # Apply Freezing
                if freeze:
                    module.weight.requires_grad = False
                    print(f"✅ Set weights of '{layer_name}' (Frozen ❄️)")
                else:
                    print(f"✅ Set weights of '{layer_name}' (Trainable)")
            else:
                print(f"⚠️ Layer '{layer_name}' has no trainable weight parameter.")

        # 3. Set Bias (if b_specific is provided)
        if b_specific is not None:
            if hasattr(module, 'bias') and module.bias is not None:
                with torch.no_grad():
                    # Case A: Input is a Tensor
                    if isinstance(b_specific, torch.Tensor):
                        if b_specific.shape != module.bias.shape:
                            print(f"❌ Shape mismatch for Bias! Layer expects {module.bias.shape}, got {b_specific.shape}")
                            return
                        module.bias.copy_(b_specific.to(module.bias.device))
                    # Case B: Input is a Scalar
                    else:
                        module.bias.fill_(b_specific)
                
                # Apply Freezing
                if freeze:
                    module.bias.requires_grad = False
                    print(f"✅ Set bias of '{layer_name}' (Frozen ❄️)")
                else:
                    print(f"✅ Set bias of '{layer_name}' (Trainable)")
            else:
                print(f"⚠️ Layer '{layer_name}' has no bias parameter.")

    except AttributeError:
        print(f"❌ Could not find layer '{layer_name}' in the model.")

In [None]:
from plots.plots import model_to_func_incl_output_layer


x_0 = torch.tensor([0.,0.])

def inject_crit_point(model, x_0, layer = 4, verbose = False, freeze = False):
    """Injects a critical point at x_0 by modifying the weights and bias of the specified layer.
    layer is the number of resblock that are realized, counted from 0."""
    
    
    from plots.plots import psi_manual
    
    my_custom_weights = -1/sara_param * torch.eye(2)             # Identity matrix
    prev_layer_func = model_to_func_incl_output_layer(model, from_layer = 0, to_layer = layer - 1, output_layer = False) 
    
    my_custom_bias = - my_custom_weights @ prev_layer_func(x_0)   # bias chosen so that inside sigma'(.) the value is 0 at x_0

    set_layer_params(model, 'res_blocks.' + str(layer) + '.fc', # e.g., 'res_blocks.4.fc'
                    w_specific=my_custom_weights, 
                    b_specific=my_custom_bias, freeze = freeze)
    
    if verbose:
        print(f'{my_custom_bias = }')
        test_inside = my_custom_weights @ prev_layer_func(x_0) + my_custom_bias
        print(f'{test_inside = }, Result is correct?: ', test_inside == 0)
        output = model(x_0)
        print(f'output at x_0: {output}')
        grad_at_x_0 = psi_manual(x_0, model_res)
        if grad_at_x_0 == 0:
            print("✅ Critical point successfully injected at x_0!")
        else:
            print("❌ Failed to inject critical point at x_0.")
        print(f'{ grad_at_x_0 = }')
        
    
inject_crit_point(model_res, x_0, layer = 4, verbose = True, freeze= True)    

# # 1. Define your specific matrix and bias that generates a critical point at x_0
# my_custom_weights = -1/sara_param * torch.eye(2)             # Identity matrix
# prev_layer_func = model_to_func_incl_output_layer(model_res, from_layer = 0, to_layer = 3, output_layer = False) 


# print(prev_layer_func(x_0))
# my_custom_bias = - my_custom_weights @ prev_layer_func(x_0)   # bias chosen so that inside sigma'(.) the value is 0 at x_0
# print(f'{my_custom_bias = }')

# test_inside = my_custom_weights @ prev_layer_func(x_0) + my_custom_bias
# print(f'{test_inside = }, Result is correct?: ', test_inside == 0)


# set_layer_params(model_res, 'res_blocks.4.fc', 
#                  w_specific=my_custom_weights, 
#                  b_specific=my_custom_bias)


from plots.plots import psi_manual
# #tests to see if the choice is generating a critical point in the right layer


In [None]:
linear_layers = [module for module in model_res.modules() if isinstance(module, nn.Linear)]
print(len(linear_layers))
layer = linear_layers[4]
print(layer)
W = layer.weight.detach().cpu()
b = layer.bias.detach().cpu()
print(W)
print(b)





In [None]:
X_test, y_test = next(iter(test_loader))
plot_decision_boundary(model_res, X_test, y_test, show=True, file_name= subfolder + '/ff6hiddencirc' + str(num_hidden), footnote = footnote, amount_levels= 100, show_points = False)

In [None]:
num_late_epochs = 500

model_res, acc2, losses2 = train_model(model_res, train_loader, test_loader,
                                load_file = None, epochs=num_late_epochs, early_stopping = False, cross_entropy=True, seed = seed)           

In [None]:
all_losses = losses_res + losses2

# 1. Plot Loss Curve
plt.figure(figsize=(8, 5))
plt.plot(all_losses, label="Training Loss", color='blue')
plt.axvline(x=num_epochs, color='red', linestyle='--', label='Bias Injection (Epoch 100)')
plt.title("ResNet Training Loss with Bias Perturbation")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(f"{subfolder}/full_loss_curve.png", dpi=300)
plt.show()

X_test, y_test = next(iter(test_loader))
plot_decision_boundary(model_res, X_test, y_test, show=True, file_name= subfolder + '/ff6hiddencirc' + str(num_hidden), footnote = footnote, amount_levels= 100, show_points = False)

In [None]:
inject_crit_point(model_res, x_0, layer = 4, verbose = True)    
X_test, y_test = next(iter(test_loader))
plot_decision_boundary(model_res, X_test, y_test, show=True, file_name= subfolder + '/ff6hiddencirc' + str(num_hidden), footnote = footnote, amount_levels= 100, show_points = False)

In [None]:
num_late_epochs = 500

model_res, acc2, losses2 = train_model(model_res, train_loader, test_loader,
                                load_file = None, epochs=num_late_epochs, early_stopping = False, cross_entropy=True, seed = seed) 

X_test, y_test = next(iter(test_loader))
plot_decision_boundary(model_res, X_test, y_test, show=True, file_name= subfolder + '/ff6hiddencirc' + str(num_hidden), footnote = footnote, amount_levels= 100, show_points = False)    



In [None]:
psi_manual(x_0, model_res)

In [None]:
importlib.reload(plots.plots) # Reload the module
from plots.plots import plot_level_sets

plot_level_sets(model_res, plotrange, amount_levels=1000, grid_size=500, file_name=subfolder + '/levelsets_after_critgen', footnote= footnote, plotrange= [-1.,1.])

In [None]:
inject_crit_point(model_res, x_0, layer = 4, verbose = True)    
X_test, y_test = next(iter(test_loader))
plot_decision_boundary(model_res, X_test, y_test, show=True, file_name= subfolder + '/ff6hiddencirc' + str(num_hidden), footnote = footnote, amount_levels= 100, show_points = False)

psi_manual(x_0, model_res)

In [None]:
importlib.reload(plots.plots) # Reload the module
from plots.plots import plot_weightmatrix
plot_weightmatrix(model_res)

# Singular value computations of Jacobian and plotting
We want to determine singular points in the compact space

In [None]:
# Define a grid over the input space.
grid_size = 50 # Adjust as needed.

importlib.reload(plots.plots) # Reload the module
from plots.plots import psi_manual, model_to_func, model_to_func_incl_output_layer, sv_plot
        
# Put the model in evaluation mode.
model = model_res
model.eval()
to_layer = 5
plot_levels = 20
# func = model_to_func(model_res, from_layer = 0, output_layer = False)
func_incl_output_layer = model_to_func_incl_output_layer(model, from_layer = 0, to_layer = to_layer)
func = model_to_func(model, from_layer = 0, to_layer = to_layer, output_layer = True)



sv_plot(func_incl_output_layer, v_index = 0, x_range = [-2.5,2.5], y_range = [-2.5,2.5], vmax = 1, title = f'incl output layer', grid_size = grid_size, plot_levels = plot_levels)
sv_plot(func, v_index = 0, x_range = [-2.5,2.5], y_range = [-2.5,2.5], vmax = 1, title = f'no output layer', grid_size = grid_size, plot_levels = plot_levels)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Put the model in evaluation mode.
model.eval()

num_hidden = model.num_hidden

plot_levels = 15
vmax = None
grid_size = 50

scaler = 0.9
amount_layers = num_hidden + 1

# Prepare figure and axes
fig, axes = plt.subplots(2, amount_layers, figsize=(scaler * 5 * (amount_layers), scaler * 10), dpi=100) # Adjust figsize if needed

for layer in range(amount_layers):
    func = model_to_func_incl_output_layer(model, from_layer = 0, to_layer = layer)

    
    ax = axes[0, layer] if num_hidden > 1 else axes[0]
    cs = sv_plot(func, v_index = 0, ax = ax, grid_size=grid_size, plot_levels = plot_levels, vmax = vmax, x_range=[-3,3], y_range=[-3,3])
    fig.colorbar(cs, ax=ax)
    ax.set_title(f'Min SV\n layer_in = 0, layer_out = {layer}')
    ax.set_xlabel('x1')
    ax.set_ylabel('x2')
    ax.set_aspect('equal')
    
    

    # Plot largest singular value (index 0) - second row
    ax = axes[1, layer] if num_hidden > 1 else axes[1]
    cs = sv_plot(func, v_index = 1, ax = ax, grid_size=grid_size, plot_levels = plot_levels, vmax = vmax)
    fig.colorbar(cs, ax=ax)
    ax.set_title(f'Max SV\n layer_in = 0, layer_out = {layer}')
    ax.set_xlabel('x1')
    ax.set_ylabel('x2')
    ax.set_aspect('equal')

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Put the model in evaluation mode.
model.eval()

num_hidden = model.num_hidden

plot_levels = 30
grid_size = 100
amount_layers = num_hidden + 1

scaler = 1

# Prepare figure and axes
fig, axes = plt.subplots(2, amount_layers, figsize=(scaler * 5 * (amount_layers), scaler * 10), dpi=100) # Adjust figsize if needed

for i, layer in enumerate(range(amount_layers-1, -1, -1)):
    func = model_to_func_incl_output_layer(model, from_layer = layer, to_layer = amount_layers)
    func_range = model_to_func_incl_output_layer(model, from_layer = 0, to_layer = layer, output_layer=False)
    print(i)
    
    output = func_range(torch.tensor([0.0,0.0]))
    print(output)
    x_range, y_range = output[0], output[1]
    x_range = [x_range.item() - 1, x_range.item() + 1]
    y_range = [y_range.item() - 1, y_range.item() + 1]
    print(f'{x_range = }')
    print(f'{y_range = }')

    
    ax = axes[0, i] if num_hidden > 1 else axes[0]
    cs = sv_plot(func, v_index = 0, ax = ax, grid_size=grid_size, plot_levels = plot_levels, x_range=x_range, y_range=y_range)  
    # fig.colorbar(cs, ax=ax)
    ax.set_title(f'Min SV\n layer_in = {layer}, layer_out = {amount_layers}')
    ax.set_xlabel('x1')
    ax.set_ylabel('x2')
    ax.set_aspect('equal')
    
    

    # Plot largest singular value (index 0) - second row
    ax = axes[1, i] if num_hidden > 1 else axes[1]
    cs = sv_plot(func, v_index = 1, ax = ax, grid_size=grid_size, plot_levels = plot_levels)
    # fig.colorbar(cs, ax=ax)
    ax.set_title(f'Max SV\n layer_in = {layer}, layer_out = {amount_layers}')
    ax.set_xlabel('x1')
    ax.set_ylabel('x2')
    ax.set_aspect('equal')

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Put the model in evaluation mode.
model.eval()

num_hidden = model.num_hidden


# Prepare figure and axes
fig, axes = plt.subplots(2, num_hidden + 2, figsize=(5 * (num_hidden + 2), 10))  # Adjust figsize if needed

for layer in range(num_hidden + 2):
    func = model_to_func(model, from_layer=0, to_layer = layer)

    
    ax = axes[0, layer] if num_hidden > 1 else axes[0]
    cs = sv_plot(func, v_index = 0, ax = ax, grid_size=100)
    # fig.colorbar(cs, ax=ax)
    ax.set_title(f'Min SV\n layer_in = 0, layer_out = {layer}')
    ax.set_xlabel('x1')
    ax.set_ylabel('x2')
    ax.set_aspect('equal')
    
    

    # Plot largest singular value (index 0) - second row
    ax = axes[1, layer] if num_hidden > 1 else axes[1]
    cs = sv_plot(func, v_index = 1, ax = ax, grid_size=100)
    # fig.colorbar(cs, ax=ax)
    ax.set_title(f'Max SV\n layer_in = 0, layer_out = {layer}')
    ax.set_xlabel('x1')
    ax.set_ylabel('x2')
    ax.set_aspect('equal')

plt.tight_layout()
plt.show()


In [None]:
# Put the model in evaluation mode.
model.eval()

# Prepare figure and axes
fig, axes = plt.subplots(2, num_hidden + 1, figsize=(5 * (num_hidden + 1), 10))  # Adjust figsize if needed

for layer in range(num_hidden + 1):
    func = model_to_func(model, from_layer=layer, to_layer = layer)
    
    ax = axes[0, layer] if num_hidden > 1 else axes[0]
    cs = sv_plot(func, v_index= 0 , ax = ax, grid_size=100)
    fig.colorbar(cs, ax=ax)
    ax.set_title(f'Max EV mod\n layer {layer}')
    ax.set_xlabel('x1')
    ax.set_ylabel('x2')
    ax.set_aspect('equal')
    

    # Plot largest singular value (index 0) - second row
    ax = axes[1, layer] if num_hidden > 1 else axes[1]
    cs = sv_plot(func, v_index=1, ax = ax, grid_size=100)
    fig.colorbar(cs, ax=ax)
    ax.set_title(f'Min EV mod\n layer {layer}')
    ax.set_xlabel('x1')
    ax.set_ylabel('x2')
    ax.set_aspect('equal')

plt.tight_layout()

plt.savefig('SV_each_layer.png', dpi=600, bbox_inches='tight', facecolor='white')
plt.show()


In [None]:
import plots.plots; importlib.reload(plots.plots)
from plots.plots import sv_plot

# Put the model in evaluation mode.
skip = 0.0
model = models_with_skipparams[skip]
model.eval()

# Prepare figure and axes
fig, axes = plt.subplots(2, num_hidden + 1, figsize=(5 * (num_hidden + 1), 10))  # Adjust figsize if needed

for layer in range(num_hidden + 1):
    func = model_to_func(model, from_layer=layer, to_layer = layer)
    
    ax = axes[0, layer] if num_hidden > 1 else axes[0]
    cs = sv_plot(func, v_index= 0 , ax = ax, grid_size=100, output_type='eigmods')
    fig.colorbar(cs, ax=ax)
    ax.set_title(f'Max EV\n layer {layer}')
    ax.set_xlabel('x1')
    ax.set_ylabel('x2')
    ax.set_aspect('equal')
    

    # Plot largest singular value (index 0) - second row
    ax = axes[1, layer] if num_hidden > 1 else axes[1]
    cs = sv_plot(func, v_index=1, ax = ax, grid_size=100, output_type='eigmods')
    fig.colorbar(cs, ax=ax)
    ax.set_title(f'Min EV\n layer {layer}')
    ax.set_xlabel('x1')
    ax.set_ylabel('x2')
    ax.set_aspect('equal')

plt.tight_layout()

plt.savefig('EV_each_layer' + str(skip) + '.png', dpi=600, bbox_inches='tight', facecolor='white')
plt.show()

In [None]:
import plots.plots; importlib.reload(plots.plots)
from plots.plots import sv_plot

# Put the model in evaluation mode.
skip = 1.0
model = models_with_skipparams[skip]
model.eval()

# Prepare figure and axes
fig, axes = plt.subplots(2, num_hidden + 1, figsize=(5 * (num_hidden + 1), 10))  # Adjust figsize if needed

for layer in range(num_hidden + 1):
    func = model_to_func(model, from_layer=layer, to_layer = layer)
    
    ax = axes[0, layer] if num_hidden > 1 else axes[0]
    cs = sv_plot(func, v_index= 0 , ax = ax, grid_size=100, output_type='eigmods')
    fig.colorbar(cs, ax=ax)
    ax.set_title(f'Max EV\n layer {layer}')
    ax.set_xlabel('x1')
    ax.set_ylabel('x2')
    ax.set_aspect('equal')
    

    # Plot largest singular value (index 0) - second row
    ax = axes[1, layer] if num_hidden > 1 else axes[1]
    cs = sv_plot(func, v_index=1, ax = ax, grid_size=100, output_type='eigmods')
    fig.colorbar(cs, ax=ax)
    ax.set_title(f'Min EV\n layer {layer}')
    ax.set_xlabel('x1')
    ax.set_ylabel('x2')
    ax.set_aspect('equal')

plt.tight_layout()

plt.savefig('EV_each_layer' + str(skip) + '.png', dpi=600, bbox_inches='tight', facecolor='white')
plt.show()

In [None]:
import plots.plots; importlib.reload(plots.plots)
from plots.plots import sv_plot

# Put the model in evaluation mode.
skip = 2.0
model = models_with_skipparams[skip]
model.eval()

# Prepare figure and axes
fig, axes = plt.subplots(2, num_hidden + 1, figsize=(5 * (num_hidden + 1), 10))  # Adjust figsize if needed

for layer in range(num_hidden + 1):
    func = model_to_func(model, from_layer=layer, to_layer = layer)
    
    ax = axes[0, layer] if num_hidden > 1 else axes[0]
    cs = sv_plot(func, v_index= 0 , ax = ax, grid_size=100, output_type='eigmods')
    fig.colorbar(cs, ax=ax)
    ax.set_title(f'Max EV\n layer {layer}')
    ax.set_xlabel('x1')
    ax.set_ylabel('x2')
    ax.set_aspect('equal')
    

    # Plot largest singular value (index 0) - second row
    ax = axes[1, layer] if num_hidden > 1 else axes[1]
    cs = sv_plot(func, v_index=1, ax = ax, grid_size=100, output_type='eigmods')
    fig.colorbar(cs, ax=ax)
    ax.set_title(f'Min EV\n layer {layer}')
    ax.set_xlabel('x1')
    ax.set_ylabel('x2')
    ax.set_aspect('equal')

plt.tight_layout()

plt.savefig('EV_each_layer' + str(skip) + '.png', dpi=600, bbox_inches='tight', facecolor='white')
plt.show()