 # ResNet Bias Manipulation Experiment



 This notebook trains a ResNet on a binary classification task (concentric circles).

 It performs the following specific steps:

 1. Trains the model for 100 epochs.

 2. Manually interrupts to set a specific bias value in a hidden layer.

 3. Continues training for another 100 epochs to observe recovery.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
from sklearn.datasets import make_circles
from matplotlib.colors import LinearSegmentedColormap, to_rgb
import copy
import os

# Configuration for High-Res plots
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# Device config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Seed setup for reproducibility
# seed = 107
seed = np.random.randint(1,200)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
print(f"Global Seed: {seed}")


 ## 1. Model Definition (ResNet)

In [None]:
class DiagonalLinear(nn.Module):
    def __init__(self, in_features, out_features, fixed_w=None):
        super().__init__()
        self.dim = min(in_features, out_features)
        self.out_features = out_features
        if fixed_w is None:
            self.weight = nn.Parameter(torch.ones(self.dim))
            self.fixed = None
        else:
            fixed_w = torch.as_tensor(fixed_w, dtype=torch.float32)
            self.k_fixed = min(len(fixed_w), self.dim)
            self.register_buffer("fixed", fixed_w[:self.k_fixed])
            n_rest = max(self.dim - self.k_fixed, 0)
            self.weight_rest = nn.Parameter(torch.ones(n_rest))
            self.weight = None
        self.bias = nn.Parameter(torch.zeros(out_features))

    def _weight_vec(self, x):
        if self.fixed is None: return self.weight
        return torch.cat((self.fixed.to(x.device, x.dtype), self.weight_rest), dim=0)

    def forward(self, x):
        w = self._weight_vec(x)
        out = x[..., :self.dim] * w
        pad = self.out_features - out.shape[-1]
        if pad > 0: out = F.pad(out, (0, pad))
        return out + self.bias

class ResidualBlock(nn.Module):
    def __init__(self, features, skip_param=1, sara_param=1, activation='relu', batchnorm=True, gain=1.0, bias=True):
        super(ResidualBlock, self).__init__()
        self.fc = nn.Linear(features, features, bias=bias)
        nn.init.xavier_normal_(self.fc.weight, gain=gain)
        if self.fc.bias is not None: nn.init.zeros_(self.fc.bias)
        
        if batchnorm: self.bn = nn.BatchNorm1d(features)
        
        if activation == 'relu': self.activation = nn.ReLU()
        elif activation == 'tanh': self.activation = nn.Tanh()
        elif activation == 'id': self.activation = nn.Identity()
        
        self.skip_param = skip_param
        self.sara_param = sara_param

    def forward(self, x):
        identity = x
        out = self.fc(x)
        if hasattr(self, 'bn'): out = self.bn(out)
        out = self.activation(out)
        out = self.sara_param * out + self.skip_param * identity
        return out

class ResNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_hidden, skip_param=1, sara_param=1, 
                 activation='relu', final_sigmoid=True, batchnorm=True, input_layer=True, 
                 input_layer_diagonal=False, fixed_w=None, bias=True):
        super(ResNet, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_hidden = num_hidden
        self.input_layer_exists = input_layer
        self.activation_name = activation # Store for saving
        
        if activation == 'relu': self.activation = nn.ReLU()
        elif activation == 'tanh': self.activation = nn.Tanh()
        elif activation == 'id': self.activation = nn.Identity()
        
        if self.input_layer_exists:
            if input_layer_diagonal:
                self.input_fc = DiagonalLinear(self.input_dim, self.hidden_dim, fixed_w=fixed_w)
            else:
                self.input_fc = nn.Linear(input_dim, hidden_dim, bias=bias)
            self.input_layer = nn.Sequential(self.input_fc, self.activation)

        self.res_blocks = nn.Sequential(
            *[ResidualBlock(hidden_dim, skip_param=skip_param, sara_param=sara_param, 
                            activation=activation, batchnorm=batchnorm, bias=bias) 
              for _ in range(num_hidden)]
        )
        
        # Final Layer
        if final_sigmoid:
            self.output_fc = nn.Sequential(nn.Linear(hidden_dim, output_dim, bias=bias), nn.Sigmoid())
        else:
            self.output_fc = nn.Linear(hidden_dim, output_dim, bias=bias)

    def forward(self, x):
        if self.input_layer_exists: x = self.input_layer(x)
        x = self.res_blocks(x)
        x = self.output_fc(x)
        return x


 ## 2. Training Logic

 (Using the provided training script)

In [None]:
def compute_accuracy(y_pred, y_true, type='class'):
    if type == 'class':
        y_pred_binary = (y_pred >= 0.5).int()
        y_true_binary = y_true.int()
        correct = (y_pred_binary == y_true_binary).sum().item()
        total = y_true.shape[0]
        return correct / total
    if type == 'reg':
        mse = torch.mean((y_pred - y_true) ** 2)
        acc = 1.0 - mse
        acc = torch.clamp(acc, min=0.0, max=1.0)
        return acc.item()

def train_model(model, train_loader, test_loader,
                load_file=None, epochs=300, lr=0.01, early_stopping=True, patience=300, cross_entropy=True, seed=None):
    
    if load_file is None:
        model.to(device) # Ensure model is on device
        model.train()
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        
        # Note: If final_sigmoid=False in ResNet, use BCEWithLogitsLoss. 
        # If final_sigmoid=True, use BCELoss. 
        # The provided script uses BCELoss, so model output must be [0,1].
        if cross_entropy:
            criterion = nn.BCELoss() 
        else: 
            criterion = nn.MSELoss()

        best_acc = 0
        patience_counter = 0
        losses = []

        for epoch in range(epochs):
            epoch_loss = 0
            for batch_X, batch_y in train_loader:
                batch_X, batch_y = batch_X.to(device), batch_y.to(device) # Move data to device
                
                y_pred = model(batch_X)
                loss = criterion(y_pred, batch_y)
                epoch_loss += loss.item()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            losses.append(epoch_loss / len(train_loader))
            
            # Evaluation
            model.eval()
            with torch.no_grad():
                acc_summed = 0.
                counter = 0
                for X_test, y_test in test_loader:
                    X_test, y_test = X_test.to(device), y_test.to(device)
                    counter += 1
                    test_preds = model(X_test)
                    acc_summed += compute_accuracy(test_preds, y_test, type='reg' if not cross_entropy else 'class')
                acc = acc_summed / counter
            model.train()
            
            if early_stopping:
                if acc > best_acc:
                    best_acc = acc
                    best_model_state = copy.deepcopy(model.state_dict())
                    patience_counter = 0
                else:
                    patience_counter += 1
                    if patience_counter >= patience:
                        print(f"⏹️ Early stopping at epoch {epoch}, best acc: {best_acc:.3f}")
                        break
                if patience_counter > 0:
                    model.load_state_dict(best_model_state)
            else: 
                best_acc = acc 

        return model, best_acc, losses
    
    else:
        # (Load logic omitted for brevity as we are training from scratch)
        pass

def plot_loss_curve(losses, title="Training Loss", filename=None):
    plt.figure(figsize=(6, 4))
    plt.plot(losses, label="Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(title)
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    if filename is not None:
        plt.savefig(filename + '.png', dpi=300, bbox_inches='tight', facecolor='white')
    plt.show()


 ## 3. Helper Functions (Plotting & Bias Set)

In [None]:
def plot_decision_boundary(model, X, y, title="Decision Boundary", amount_levels=50, plotrange=[-2.5,2.5], file_name=None, footnote=None):
    """
    Visualizes the decision boundary.
    """
    colors = [to_rgb("C0"), [1, 1, 1], to_rgb("C1")]
    cm = LinearSegmentedColormap.from_list("Custom", colors, N=amount_levels)
    
    model.eval()
    x_min, x_max = plotrange[0], plotrange[1]
    y_min, y_max = plotrange[0], plotrange[1]
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200), np.linspace(y_min, y_max, 200))
    grid = np.c_[xx.ravel(), yy.ravel()]
    grid_tensor = torch.tensor(grid, dtype=torch.float32).to(device)

    with torch.no_grad():
        preds = model(grid_tensor).cpu().numpy().reshape(xx.shape)

    fig, ax = plt.subplots(figsize=(6, 5))
    levels = np.linspace(0., 1., amount_levels).tolist()
    contour = ax.contourf(xx, yy, preds, levels=levels, cmap=cm, alpha=0.8)
    
    # Move data to cpu for plotting
    X_cpu = X.cpu()
    y_cpu = y.cpu()
    
    ax.scatter(X_cpu[:, 0], X_cpu[:, 1], s=25, c=y_cpu.squeeze(), cmap=cm, edgecolors='black', linewidths=0.5, alpha=0.9)
    ax.set_title(title)
    ax.set_xlabel('$x_1$')
    ax.set_ylabel('$x_2$')
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    
    colorbar_ticks = np.linspace(0, 1, 9)
    cb = plt.colorbar(contour, ax=ax, label='Prediction Probability', ticks=colorbar_ticks)
    cb.set_ticklabels([f"{tick:.2f}" for tick in colorbar_ticks])
    
    if footnote:
        plt.figtext(0.5, 0, footnote, ha="center", fontsize=8)
        
    if file_name:
        plt.savefig(file_name + '.png', bbox_inches='tight', dpi=300, facecolor='white')
    plt.show()

def set_layer_bias(model, layer_name, b_specific):
    """
    Sets the bias of a specific layer to b_specific.
    """
    try:
        parts = layer_name.split('.')
        module = model
        for part in parts:
            module = getattr(module, part)
        
        if hasattr(module, 'bias') and module.bias is not None:
            with torch.no_grad():
                module.bias.fill_(b_specific)
            print(f"✅ Set bias of '{layer_name}' to {b_specific}")
        else:
            print(f"⚠️ Layer '{layer_name}' has no bias parameter.")
    except AttributeError:
        print(f"❌ Could not find layer '{layer_name}'.")


 ## 4. Experiment Setup: Data & Model

In [None]:
# Data Generation
n_points = 4000
batch_size = 100
X, y = make_circles(n_samples=n_points, noise=0.05, factor=0.5, random_state=seed)
X = X * 2.0 # Scale

X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32).reshape(-1, 1)

train_ds = TensorDataset(X_tensor, y_tensor)
# For simplicity in this example, using same dataset for train/test to visualize overfitting/recovery clearly
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=False)

# Model Params
num_hidden = 5 # number of hidden layers. The total network has additionl 2 layers: input to hidden and hidden to output
input_dim = 2
hidden_dim = 2
output_dim = 1
skip_param = 1
sara_param = .2
activation = 'tanh' #'relu' and 'tanh' are supported
final_sigmoid = True
batchnorm = False
input_layer = False
bias_setting = True #!!! THIS IS A CRAZY TRY TO SIMPLIFY

model = ResNet(input_dim=2, hidden_dim=hidden_dim, output_dim=output_dim, 
               num_hidden=num_hidden, activation=activation, sara_param=sara_param,
               final_sigmoid=final_sigmoid, bias=bias_setting, 
               batchnorm=False, input_layer=False).to(device)

subfolder = 'resnet_bias_experiment'
if not os.path.exists(subfolder): os.makedirs(subfolder)

print("Model initialized.")


 ## 5. Stage 1: Train for 100 Epochs

In [None]:
print("--- Stage 1: Training (0 - 100 Epochs) ---")
model, acc1, losses1 = train_model(model, train_loader, test_loader, 
                                   epochs=100, lr=0.01, early_stopping=False, cross_entropy=True)

print(f"Stage 1 Accuracy: {acc1:.4f}")

# Visualize Stage 1
X_plot, y_plot = next(iter(test_loader))
plot_decision_boundary(model, X_plot, y_plot, 
                       title=f"Boundary after 100 Epochs (Acc: {acc1:.2f})", 
                       file_name=f"{subfolder}/stage1")


 ## 6. Bias Manipulation

 We will target the **3rd residual block** (`res_blocks.2.fc`) and set its bias to **5.0**.

In [None]:
target_layer = 'res_blocks.2.fc'
bias_value = 5.0

print(f"--- Interruption: Setting bias of {target_layer} to {bias_value} ---")
set_layer_bias(model, target_layer, bias_value)


In [None]:
def plot_weightmatrix(model, title='', ax=None):
    """
    For each Linear layer in the model, compute the eigenvalues of the weight matrix,
    take their modulus, and plot them by layer index (x-axis) vs. modulus (y-axis).
    """
    from matplotlib.colors import TwoSlopeNorm
    linear_layers = [module for module in model.modules() if isinstance(module, nn.Linear)]


    for i, layer in enumerate(linear_layers):
        W = layer.weight.detach().cpu().numpy()
        # Normalize so that 0 is at the center (white)
        
        val_min, val_max = W.min(), W.max()
    
        if val_min < 0 and 0 < val_max:
            # trivial case, no need for fancy normalization
            norm = TwoSlopeNorm(vmin=val_min, vcenter=0, vmax=val_max)
            plt.imshow(W, cmap="seismic", norm=norm, origin="upper")
            

        else:
            print(val_max, val_min)
            plt.imshow(W, cmap="seismic", origin="upper")
            

        plt.colorbar(label="Value")

        # Matrix indices
        plt.xticks([0, 1], ["1", "2"])  # columns j
        plt.yticks([0, 1], ["1", "2"])  # rows i
        plt.xlabel("j")
        plt.ylabel("i")
        plt.title( str(i) + "th layer")

        plt.show()
        

linear_layers = [module for module in model_res.modules() if isinstance(module, nn.Linear)]


 ## 7. Stage 2: Continue Training (100 Epochs)

In [None]:
print("--- Stage 2: Training (100 - 200 Epochs) ---")
# Note: train_model resets optimizer, so this simulates a 'restart' with the modified weights
model, acc2, losses2 = train_model(model, train_loader, test_loader, 
                                   epochs=100, lr=0.01, early_stopping=False, cross_entropy=True)

print(f"Stage 2 Accuracy: {acc2:.4f}")


 ## 8. Results & Visualization

In [None]:
# Combine losses
all_losses = losses1 + losses2

# 1. Plot Loss Curve
plt.figure(figsize=(8, 5))
plt.plot(all_losses, label="Training Loss", color='blue')
plt.axvline(x=100, color='red', linestyle='--', label='Bias Injection (Epoch 100)')
plt.title("ResNet Training Loss with Bias Perturbation")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(f"{subfolder}/full_loss_curve.png", dpi=300)
plt.show()

# 2. Plot Final Decision Boundary
plot_decision_boundary(model, X_plot, y_plot, 
                       title=f"Boundary after Bias Recovery (Acc: {acc2:.2f})", 
                       footnote=f"Epoch 200, Bias set to {bias_value} at Ep 100",
                       file_name=f"{subfolder}/stage2")