# 1d example of embedding restriction

We show that a non-augmented model of 1d input to 1d output cannot approximate the function x^2

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
from matplotlib.colors import to_rgb
from matplotlib.colors import LinearSegmentedColormap
from sklearn.model_selection import train_test_split



import torch
import matplotlib.pyplot as plt

# Parameters
n_samples = 300
batch_size = 10

subfolder = '1dexample'
import os
if not os.path.exists(subfolder):
    os.makedirs(subfolder)

# # Sample x from uniform distribution in [-1, 1]
# x = 2 * torch.rand(n_samples, 1) - 1  # shape (n_samples, 1), range [-1, 1]

# # Labels: y = x^2
# y = x ** 2  # shape (n_samples, 1)

# # Plot for verification
# plt.scatter(x.numpy(), y.numpy(), alpha=0.6)
# plt.xlabel("x")
# plt.ylabel("y = x^2")
# plt.title("Sampled Data from [-1, 1] with Labels x^2")
# plt.grid(True)
# plt.show()

def make_x2in1d(output_dim, n_samples = 100, plot = True, batch_size = batch_size, filename = None):
    """Generates xor
    """
    # Generate training data
    # set random seed for reproducibility
    seed = np.random.randint(1000)
    print(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    
    data = 2 * torch.rand(n_samples, 1) - 1  # shape (n_samples, 1), range [-1, 1]
    labels = data ** 2  # shape (n_samples, 1)
    # Generate outer ring points
    
    
    if plot:
                # Plot for verification
        plt.scatter(data.numpy(), labels.numpy(), alpha=0.6)
        plt.xlabel("x")
        plt.ylabel("y = x^2")
        plt.title("Sampled Data from [-1, 1] with Labels $x^2$")
        plt.grid(True)
        plt.show()
        
                # Save plot if filename provided
        if filename is not None:
            plt.savefig(f'{filename}.png', bbox_inches='tight', dpi=300)
            print(f'Plot saved as {filename}.png')
        
        plt.show()
    
    # Convert to tensors
    # data_tensor = torch.tensor(data, dtype=torch.float32)

    labels_tensor = torch.tensor(labels, dtype=torch.float32)


    data = torch.tensor(data, dtype=torch.float32)
    labels = torch.tensor(labels, dtype=torch.float32) 
    labels = torch.tensor(labels.reshape(-1, 1), dtype=torch.float32)
   
    # Create DataLoader for training data
    train_dataset = TensorDataset(data, labels)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        
    return train_dataloader

train_loader = make_x2in1d(1, n_samples=n_samples)

## Plot function

In [None]:
import plots.plots
from plots.plots import psi_manual

def plot_params(model, ax=None):
    """Plots the weights and biases of the linear layers in a model."""
    
    import matplotlib.pyplot as plt

    linear_layers = [m for m in model.modules() if isinstance(m, nn.Linear)]
    n_layers = len(linear_layers)

    weights = []
    biases = []

    for layer in linear_layers:
        weights.append(layer.weight.detach().cpu().numpy().squeeze())
        biases.append(layer.bias.detach().cpu().numpy().squeeze())

    if ax is None:
        fig, ax = plt.subplots(figsize=(1, 3))

    if n_layers == 0:
        ax.set_title("No Linear layers found")
        return

    ax.grid(True, zorder=0)
    ax.scatter(range(n_layers), biases, label='Biases',
               zorder=2, marker='^', color = 'grey', alpha=0.7)
    ax.scatter(range(n_layers), weights, color = 'green', label='Weights',
               zorder=2, marker='s', alpha=0.7)
    ax.set_xlabel('Layer Index')
    ax.set_ylabel('Value')
    ax.set_title('Weights and Biases')
    ax.set_ylim(-5, 5)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8, bbox_transform=ax.transAxes)
    ax.set_xticks(list(range(n_layers)))

    if ax is None:
        plt.tight_layout()
        plt.show()
        
def plot_eigvals(model, ax = None):
    model.eval()
    for layer in range(model.num_hidden + 1):  # +2 for input and output layers
        func = lambda x: model.sub_model_new(x, from_layer=layer, to_layer=layer)  # Use the model directly as the function
        # print(f"Calculating eigenvalues for layer {layer}...")
        interval = torch.linspace(-1, 1, 100)  # shape (100, 1)
        psi_values = np.zeros(100)
        for i, value in enumerate(interval):
            value = value.unsqueeze(0).unsqueeze(1) 
            psi_values[i] = psi_manual(value, func, output_type='eigvals_1d')

        if ax is None:
            plt.plot(interval.numpy(), psi_values, label = f'Layer {layer}')
            plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
            plt.grid(True)
            plt.title('Eigenvalues')
            plt.xlabel('Input')
        else:
            ax.plot(interval.numpy(), psi_values, label = f'Layer {layer}')
            ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8, bbox_transform=ax.transAxes)
            ax.grid(True)
            ax.set_title('Eigenvalues')
            ax.set_xlabel('Input')
            ax.set_box_aspect(1)

    
def subplots_model(model, epoch= None):
    # Generate model outputs for a smooth range of inputs
    x_plot = torch.linspace(-1, 1, 100).unsqueeze(1)
    y_plot = model(x_plot).detach().numpy()

    # Plot model prediction and training data
    fig, axes = plt.subplots(1, 3, figsize=(12, 4), gridspec_kw={'width_ratios': [6, 1, 6]})
    plt.subplots_adjust(right=0.8)
    ax = axes[0]
    ax.plot(x_plot.numpy(), y_plot, label=f'Epoch {epoch}', color='C1')
    ax.scatter(train_loader.dataset.tensors[0].numpy(),
                train_loader.dataset.tensors[1].numpy(),
                label='Training Data', color='C0', alpha=0.5, s=10)
    ax.set_xlabel('Input')
    ax.set_ylabel('Output')
    ax.set_title(f'eps={model.skip_param}, delta={model.sara_param}')
    ax.legend()
    ax.grid(True)
    ax.set_box_aspect(1)
   
    
    plot_params(model, ax=axes[1])
    plot_eigvals(model, ax=axes[2])
    
    plt.subplots_adjust(
    left=0.05,    # reduce space on the left
    right=0.90,   # reduce space on the right
    wspace=0.3   # increase horizontal spacing between subplots
)
    
from models.training import train_model
import io
import imageio
    
def generate_gif(model, gif_name = 'last.gif', epochs_per_frame = 10, num_frames = 30, fps = 5, loop = 'yes', lr = 0.01):
    
                # 5 frames ≈ 200 ms between frames
    frames = []                # collects RGB arrays for GIF


    for i in range(num_frames):
        model, acc, lossess = train_model(
            model,
            train_loader,           # ← using the loader twice is fine
            train_loader,
            load_file=None,
            epochs=epochs_per_frame,
            lr=lr,
            early_stopping=False,
            cross_entropy=False,
            seed=None
        )

        subplots_model(model, epoch = i * epochs_per_frame)
        
        
        
        fig = plt.gcf()                         # grab the current active figure
        buf = io.BytesIO()                      # in-memory bytes buffer
        fig.savefig(buf, format="png", dpi=200) # render figure as PNG into buffer
        buf.seek(0)                             # rewind to the beginning

        frames.append(imageio.imread(buf))        # read directly from the buffer
        plt.close(fig)                          # optional: frees memory
    # ------------------------------------------------------------------------

    # Save out the animation
    loop_num = 0 if loop == 'no' else 1
    imageio.mimsave(gif_name, frames, fps=fps, loop = loop_num, subrectangles=False)
    print(f"Saved → {gif_name}")
    

# One layer experiments

In [None]:
# to reload models.resnet module after changes without restarting the kernel
import importlib
import models.resnets
import models.training
importlib.reload(models.resnets) # Reload the module
importlib.reload(models.training) # Reload the module
from models.resnets import ResNet
from models.training import compute_accuracy, train_model, train_until_threshold, plot_loss_curve

# Model Params
input_dim = 1
hidden_dim = 1
num_hidden = 1 # number of hidden layers. The total network has additionl 2 layers: input to hidden and hidden to output
output_dim = 1
activation = 'tanh' #'relu' and 'tanh' are supported
input_layer = False #as simple as possible, no input layer
final_sigmoid = False # True supported with binary classification only

# Training Params
load_file = None
cross_entropy = False #True
num_epochs = 100

In [None]:
model_base = ResNet(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim, num_hidden=num_hidden, activation=activation, skip_param=0, final_sigmoid=final_sigmoid, input_layer=input_layer)

model_base, acc_base, losses_base = train_model(model_base,
    train_loader, train_loader, load_file = None, epochs=num_epochs, lr=0.01, early_stopping = False, cross_entropy=False, seed = None)

plot_loss_curve(losses_base, title=f"Base Model Loss Curve", filename = 'ff1d')

input = torch.linspace(-1, 1, 100).unsqueeze(1)  # shape (100, 1), range [-1, 1]
output = model_base(input).detach().numpy() 

plt.plot(input, output, label='Model Output', color='C1')
plt.scatter(train_loader.dataset.tensors[0].numpy(), train_loader.dataset.tensors[1].numpy(), label='Training Data', color='C0', alpha=0.5, s=10)
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('Model Output vs Training Data')
plt.legend()

In [None]:
import io
import imageio      # v2 API keeps .mimsave unchanged
import matplotlib.pyplot as plt
import torch

# -----------------------------  CONFIG  ---------------------------------
gif_name = subfolder + "/skip0_training.gif"
fps = 5                    # 5 frames ≈ 200 ms between frames
frames = []                # collects RGB arrays for GIF
# ------------------------------------------------------------------------

hidden_dim   = 1
num_hidden   = 1
skip_param   = 0
epo   = 10
num_frames = 20

model_skip0 = ResNet(input_dim=input_dim,
                     hidden_dim=hidden_dim,
                     output_dim=output_dim,
                     num_hidden=num_hidden,
                     activation=activation,
                     skip_param=skip_param, final_sigmoid=final_sigmoid, input_layer=input_layer)

generate_gif(model_skip0, gif_name = gif_name, num_frames = num_frames)

In [None]:
from IPython.display import Image, display

display(Image(filename="skip0_training.gif", width=800))

In [None]:
import io
import imageio      # v2 API keeps .mimsave unchanged
import matplotlib.pyplot as plt
import torch

# -----------------------------  CONFIG  ---------------------------------
gif_name = subfolder + "/skip1_training.gif"
fps = 5                    # 5 frames ≈ 200 ms between frames
frames = []                # collects RGB arrays for GIF
# ------------------------------------------------------------------------

hidden_dim   = 1
num_hidden   = 1
skip_param   = 1

model_skip1 = ResNet(input_dim=input_dim,
                     hidden_dim=hidden_dim,
                     output_dim=output_dim,
                     num_hidden=num_hidden,
                     activation=activation,
                     skip_param=skip_param, final_sigmoid=final_sigmoid, input_layer=input_layer)

generate_gif(model_skip1, gif_name=gif_name)

In [None]:
display(Image(filename=subfolder + "/skip1_training.gif", width=600))

In [None]:
skip_param = 1
hidden_dim   = 1
num_hidden   = 1

model_test = ResNet(input_dim=input_dim,
                     hidden_dim=hidden_dim,
                     output_dim=output_dim,
                     num_hidden=num_hidden,
                     activation=activation,
                     skip_param=skip_param, final_sigmoid=final_sigmoid, batchnorm=False, input_layer=input_layer)

generate_gif(model_test, gif_name = subfolder + 'test.gif', num_frames = 50)

In [None]:
display(Image(filename=subfolder + "test.gif", width=900))

In [None]:
def plot_weight_heatmaps(model, title=''):
    linear_layers = [module for module in model.modules() if isinstance(module, nn.Linear)]

    n_layers = len(linear_layers)
    if n_layers == 0:
        print("No Linear layers with parameters found.")
        return

    fig, axes = plt.subplots(1, n_layers, figsize=(3 * n_layers, 3), squeeze=False)
    axes = axes[0]

    for i, layer in enumerate(linear_layers):
        weight = layer.weight.detach().cpu().numpy()
        # weight = abs(weight)
        ax = axes[i]
        im = ax.imshow(weight, cmap='viridis', vmin = 0, vmax = 5, aspect='equal')

        ax.set_title(f"Layer {i}")
        ax.set_xlabel("Out")
        ax.set_ylabel("In")
       

        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

    plt.suptitle("Weight Matrices Heatmaps - " + title)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()


# Two Layers

In [None]:
import models.resnets; importlib.reload(models.resnets)
from models.resnets import ResNet

skip_param = 0
hidden_dim = 1
num_hidden = 2

model = ResNet(input_dim=input_dim,
                     hidden_dim=hidden_dim,
                     output_dim=output_dim,
                     num_hidden=num_hidden,
                     activation=activation,
                     skip_param=skip_param, final_sigmoid=final_sigmoid, batchnorm=False, input_layer=input_layer)

generate_gif(model, gif_name = subfolder + '/test2.gif', num_frames = 50)
display(Image(filename=subfolder + "/test2.gif", width=900))
plot_eigvals(model)

In [None]:
import models.resnets; importlib.reload(models.resnets)
from models.resnets import ResNet

skip_param = 1
sara_param = 1
hidden_dim = 1
num_hidden = 2

model = ResNet(input_dim=input_dim,
                     hidden_dim=hidden_dim,
                     output_dim=output_dim,
                     num_hidden=num_hidden,
                     activation=activation,
                     skip_param=skip_param, final_sigmoid=final_sigmoid, batchnorm=False, input_layer=input_layer)

generate_gif(model, gif_name = subfolder + '/1d2l_eps1delta1.gif', num_frames = 50)
display(Image(filename=subfolder + '/1d2l_eps1delta1.gif', width=900))
plot_eigvals(model)
    

In [None]:
import models.resnets; importlib.reload(models.resnets)
from models.resnets import ResNet

counter = 0

for skip in [1, 1, 1]:
    for sara in [1, 1, 1]:
        skip_param = skip
        sara_param = sara
        hidden_dim = 1
        num_hidden = 2

        model_running = ResNet(input_dim=input_dim,
                            hidden_dim=hidden_dim,
                            output_dim=output_dim,
                            num_hidden=num_hidden,
                            activation=activation,
                            skip_param=skip_param, sara_param=sara_param, final_sigmoid=final_sigmoid, batchnorm=False, input_layer=input_layer)
        gif_name = subfolder + '/running' + str(counter) + '.gif'
        generate_gif(model_running, gif_name = gif_name, fps = 2, num_frames = 50)
        counter += 1
        display(Image(filename=gif_name, width=900))
    

In [None]:
import models.resnets; importlib.reload(models.resnets)
from models.resnets import ResNet

counter = 0

for skip in [1, 1]:
    for sara in [1, 1]:
        skip_param = skip
        sara_param = sara
        hidden_dim = 1
        num_hidden = 3

        model_running = ResNet(input_dim=input_dim,
                            hidden_dim=hidden_dim,
                            output_dim=output_dim,
                            num_hidden=num_hidden,
                            activation=activation,
                            skip_param=skip_param, sara_param=sara_param, final_sigmoid=final_sigmoid, batchnorm=False, input_layer=input_layer)
        gif_name = subfolder + '/running' + str(counter) + '.gif'
        generate_gif(model_running, gif_name = gif_name, fps = 2, num_frames = 50)
        counter += 1
        display(Image(filename=gif_name, width=900))
    

In [None]:
import models.resnets; importlib.reload(models.resnets)
from models.resnets import ResNet, ResidualBlock

counter = 0

for skip in [1]:
    for sara in [0.1]:
        skip_param = skip
        sara_param = sara
        hidden_dim = 1
        num_hidden = 2

        model_running = ResNet(input_dim=input_dim,
                            hidden_dim=hidden_dim,
                            output_dim=output_dim,
                            num_hidden=num_hidden,
                            activation=activation,
                            skip_param=skip_param, sara_param=sara_param, final_sigmoid=final_sigmoid, batchnorm=False, input_layer=input_layer)
        
        # for reproducibility (optional)

# find all ResidualBlocks inside the model
        blocks = [m for m in model_running.modules() if isinstance(m, ResidualBlock)]


        with torch.no_grad():
            # for block in blocks:
            #     nn.init.xavier_normal_(block.fc.weight, gain=1/sara_param)
        

            blocks[0].fc.weight[0] = -2/sara_param
            blocks[1].fc.weight[0] = -0.5/sara_param
    
            
        gif_name = subfolder + '/running' + str(counter) + '.gif'
        
        
        
        generate_gif(model_running, gif_name = gif_name, fps = 2, epochs_per_frame = 1, num_frames = 40, lr = 0.01)
        counter += 1
        display(Image(filename=gif_name, width=900))
    