In [138]:
import torch, math, time, os, random
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.metrics import pairwise_distances
from sklearn.preprocessing import StandardScaler
from scipy import linalg

In [139]:
import warnings
warnings.filterwarnings("ignore")

In [140]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cpu


In [141]:
_eps = 1e-8

def cov_matrix(X, eps=_eps):
    # X: (N, D) numpy
    if X.shape[0] <= 1:
        return np.zeros((X.shape[1], X.shape[1]))
    C = np.cov(X, rowvar=False)
    C += np.eye(C.shape[0]) * eps
    return C

def cov_volume(X, eps=_eps):
    # return product of eigenvalues (i.e., det), or sqrt(det) as volume
    C = cov_matrix(X, eps=eps)
    sign, logdet = np.linalg.slogdet(C)
    if sign <= 0:
        return 0.0
    # "volume" = sqrt(det(C)) -> 0.5 * logdet
    return float(np.exp(0.5 * logdet))

def log_cov_volume(X, eps=_eps):
    C = cov_matrix(X, eps=eps)
    sign, logdet = np.linalg.slogdet(C)
    if sign <= 0:
        return np.nan
    return 0.5 * logdet

def fisher_ratio(acts, labels):
    # acts: (N,D) np, labels: (N,) np
    labels = np.array(labels)
    classes = np.unique(labels)
    overall_mean = acts.mean(axis=0)
    D = acts.shape[1]
    Sb = np.zeros((D,D))
    Sw = np.zeros((D,D))
    for c in classes:
        Xc = acts[labels==c]
        if Xc.shape[0] == 0: continue
        mean_c = Xc.mean(axis=0)
        Nc = Xc.shape[0]
        diff = (mean_c - overall_mean).reshape(-1,1)
        Sb += Nc * (diff @ diff.T)
        if Xc.shape[0] > 1:
            Sw += np.cov(Xc, rowvar=False) * (Nc-1)
    # scalar fisher ratio (trace)
    trSb = np.trace(Sb)
    trSw = np.trace(Sw) + _eps
    return trSb / trSw

def avg_movement_between_layers(projected_acts):
    # projected_acts: list of (N, k) np arrays, same N each
    movements = []
    for i in range(len(projected_acts)-1):
        d = projected_acts[i+1] - projected_acts[i]
        dists = np.linalg.norm(d, axis=1)
        movements.append(dists.mean())
    return movements

# CKA (linear version) for representational similarity
def linear_CKA(X, Y):
    # X, Y: (N, D1) and (N, D2)
    # center
    X = X - X.mean(0, keepdims=True)
    Y = Y - Y.mean(0, keepdims=True)
    HSIC = np.linalg.norm(X.T @ Y, 'fro')**2
    denom = (np.linalg.norm(X.T @ X, 'fro') * np.linalg.norm(Y.T @ Y, 'fro')) + _eps
    return HSIC / denom

# Cosine similarity between flattened weight matrices
def weight_cosine_similarity(w1, w2):
    a = w1.flatten()
    b = w2.flatten()
    min_len = min(len(a), len(b))
    a = a[:min_len]
    b = b[:min_len]
    denom = (np.linalg.norm(a) * np.linalg.norm(b)) + _eps
    return float(np.dot(a, b) / denom)



In [142]:
from sklearn.datasets import make_moons, make_circles

def get_toy(dataset='moons', n=1000, noise=0.1):
    if dataset == 'moons':
        X, y = make_moons(n_samples=n, noise=noise, random_state=seed)
    else:
        X, y = make_circles(n_samples=n, noise=noise, factor=0.5, random_state=seed)
    X = torch.tensor(X, dtype=torch.float32).to(device)
    y = torch.tensor(y, dtype=torch.long).to(device)
    return X, y

def get_mnist(subset=5000, flatten=True):
    transform = transforms.Compose([transforms.ToTensor()])
    mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    # small subset for speed
    indices = list(range(min(subset, len(mnist))))
    loader = DataLoader(torch.utils.data.Subset(mnist, indices), batch_size=subset, shuffle=False)
    imgs, labels = next(iter(loader))
    if flatten:
        imgs = imgs.view(imgs.size(0), -1)
    return imgs.to(device), labels.to(device)


In [143]:
class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, depth, activation='tanh', bias=True):
        super().__init__()
        layers = []
        if depth == 1:
            layers.append(nn.Linear(in_dim, out_dim, bias=bias))
        else:
            layers.append(nn.Linear(in_dim, hidden_dim, bias=bias))
            for _ in range(depth-2):
                layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
            layers.append(nn.Linear(hidden_dim, out_dim, bias=bias))
        self.layers = nn.ModuleList(layers)
        if activation == 'tanh':
            self.act = torch.tanh
        elif activation == 'relu':
            self.act = F.relu
        else:
            self.act = torch.tanh

    def forward(self, x, return_activations=False):
        acts = []
        for i, layer in enumerate(self.layers):
            x = layer(x)
            # apply nonlinearity on all but final
            if i != len(self.layers)-1:
                x = self.act(x)
            if return_activations:
                acts.append(x)   # store post-nonlinearity for intermediate, store logits for final
        if return_activations:
            return x, acts
        return x


In [144]:
def quick_train(model, X, y, epochs=20, lr=1e-3, batch_size=None, verbose=False):
    model = model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    if batch_size is None:
        batch_size = X.shape[0]
    ds = TensorDataset(X, y)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True)
    for ep in range(epochs):
        model.train()
        total = 0.0
        correct = 0
        loss_acc = 0.0
        for xb, yb in loader:
            opt.zero_grad()
            out = model(xb)
            loss = F.cross_entropy(out, yb)
            loss.backward()
            opt.step()
            loss_acc += float(loss.item()) * xb.size(0)
            preds = out.argmax(1)
            correct += (preds == yb).sum().item()
            total += xb.size(0)
        if verbose and (ep%5==0 or ep==epochs-1):
            print(f"Ep {ep+1}/{epochs} loss={loss_acc/total:.4f} acc={correct/total:.4f}")
    return model


In [145]:
def get_activations(model, X):
    model = model.to(device)
    model.eval()
    with torch.no_grad():
        _, acts = model(X, return_activations=True)
    # acts is list of tensors (N,D_l)
    acts_cpu = [a.detach().cpu().float() for a in acts]
    return acts_cpu

def project_activations_consistent(acts_list, n_components=2):
    """
    Acts_list: list of (N, D_l) tensors.
    We return projected list in same PCA basis by fitting PCA on concatenated rows
    AFTER padding features via PCA per-layer? Approach:
      - For stability across layers, we compute per-layer PCA to reduce to min(n_components, D_l),
        then further align via PCA fitted on concatenation of the reduced coordinates (if dims match).
    Simpler robust approach implemented here: per-layer PCA to k dims (k=min(n_components, D_l)),
    and if k==n_components for all layers, we return list of arrays with same dims.
    Otherwise we pad smaller dims with zeros (conservative) so we can compute movements.
    """
    acts_np = [a.numpy() for a in acts_list]
    ks = [min(n_components, a.shape[1]) for a in acts_np]
    reduced = []
    for a, k in zip(acts_np, ks):
        if k <= 0:
            reduced.append(np.zeros((a.shape[0], n_components)))
            continue
        pca = PCA(n_components=k)
        red = pca.fit_transform(a)
        if k < n_components:
            # pad with zeros on right to reach n_components
            red = np.hstack([red, np.zeros((a.shape[0], n_components-k))])
        reduced.append(red)
    return reduced  # list of (N, n_components) arrays

In [146]:
def analyze_model(model, X, y, n_components=3, do_project=True, eps=_eps):
    """
    returns a dict with:
      - per-layer means, variances
      - per-layer covariance volumes (per class)
      - per-layer log-cov volumes (per class)
      - per-layer fisher ratios
      - projected activations (for movement)
      - per-layer movement magnitudes
      - weight similarities (pairwise cosine between sequential layers)
      - CKA matrix (pairwise CKA between layers)
      - eigen spectra top-K per layer per class (optional)
    """
    acts = get_activations(model, X)
    N = X.shape[0]
    labels_np = y.cpu().numpy()
    results = {}
    # per-layer moment stats
    means = [a.numpy().mean(axis=0) for a in acts]
    variances = [a.numpy().var(axis=0) for a in acts]
    results['means'] = means
    results['variances'] = variances

    # per-class volumes & log-volumes
    classes = np.unique(labels_np)
    vols = {int(c): [] for c in classes}
    logvols = {int(c): [] for c in classes}
    fisher = []
    for a in acts:
        a_np = a.numpy()
        for c in classes:
            Xc = a_np[labels_np==c]
            if Xc.shape[0] < 2:
                vols[int(c)].append(np.nan)
                logvols[int(c)].append(np.nan)
            else:
                vols[int(c)].append(cov_volume(Xc, eps=eps))
                logvols[int(c)].append(log_cov_volume(Xc, eps=eps))
        # fisher ratio
        fisher.append(fisher_ratio(a_np, labels_np))
    results['volumes'] = vols
    results['logvolumes'] = logvols
    results['fisher'] = fisher

    # projected activations & movements
    if do_project:
        projected = project_activations_consistent(acts, n_components=n_components)
        movements = avg_movement_between_layers(projected)
        results['projected'] = projected
        results['movements'] = movements

    # weight-based redundancy: cosine similarity between sequential layer weights
    weights = [p.detach().cpu().numpy() for n,p in model.named_parameters() if 'weight' in n]
    # flatten and compute pairwise sequential cosine
    wcos = []
    for i in range(len(weights)-1):
        wcos.append(weight_cosine_similarity(weights[i], weights[i+1]))
    results['weight_cosine_seq'] = wcos

    # CKA matrix between layer activations
    num_layers = len(acts)
    CKA = np.zeros((num_layers, num_layers))
    for i in range(num_layers):
        for j in range(num_layers):
            CKA[i,j] = linear_CKA(acts[i].numpy(), acts[j].numpy())
    results['CKA'] = CKA

    # between/within traces
    tb, tw = [], []
    for a in acts:
        a_np = a.numpy()
        # re-use fisher helper via Sb,Sw traces
        # compute Sb and Sw traces manually
        overall = a_np.mean(axis=0)
        Sb = np.zeros((a_np.shape[1], a_np.shape[1]))
        Sw = np.zeros_like(Sb)
        for c in classes:
            Xc = a_np[labels_np==c]
            if Xc.shape[0] == 0: continue
            mean_c = Xc.mean(axis=0)
            Nc = Xc.shape[0]
            diff = (mean_c - overall).reshape(-1,1)
            Sb += Nc * (diff @ diff.T)
            if Xc.shape[0] > 1:
                Sw += np.cov(Xc, rowvar=False) * (Nc-1)
        tb.append(np.trace(Sb))
        tw.append(np.trace(Sw))
    results['trace_between'] = tb
    results['trace_within'] = tw

    print("RESULTS\n")
    print(results)

    return results

In [147]:
def run_experiment_toy(dataset='moons', depth_list=[3,6], hidden=12, epochs=40, quick=True):
    X,y = get_toy(dataset=dataset, n=1000, noise=0.12)
    X = X.to(device); y = y.to(device)
    summary = {}
    for depth in depth_list:
        print(f"Running depth={depth} ...")
        model = MLP(in_dim=2, hidden_dim=hidden, out_dim=2, depth=depth, activation='tanh').to(device)
        model = quick_train(model, X, y, epochs=epochs if not quick else max(5, epochs//8), lr=1e-3, verbose=True)
        res = analyze_model(model, X, y, n_components=3, do_project=True)
        summary[depth] = res
    return X, y, summary



In [148]:
# Example quick run (toy)
# Xtoy, ytoy, toy_summary = run_experiment_toy(dataset='moons', depth_list=[3,6], hidden=6, epochs=40, quick=True)
