<a href="https://colab.research.google.com/github/v-y-l/Machine-Learning-Notebooks/blob/main/Victor's_unbiased_estimation_using_2_neural_networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Unbiased estimation using two neural networks
## Section author: Victor Lin (vl2580)

## Implementation

In [None]:
import torch
import torch.nn as nn
import math
from torch.utils.data import Dataset, DataLoader

class GeluDataset(Dataset):
    def __init__(self, N, d):
        self.W = torch.randn(N, d, d)
        self.x = torch.randn(N, d)
        Z = torch.einsum('bij,bj->bi', self.W, self.x)
        self.y = 0.5 * Z * (1 + torch.erf(Z / math.sqrt(2)))

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.W[idx], self.x[idx], self.y[idx]

class PsiNet(nn.Module):
    def __init__(self, d, m, hidden=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, hidden),
            nn.ReLU(),
            nn.Linear(hidden, m)
        )

    def forward(self, x):
        return self.net(x)

class PhiNet(nn.Module):
    def __init__(self, d, m, hidden=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, hidden),
            nn.ReLU(),
            nn.Linear(hidden, m)
        )

    def forward(self, W):
        B, D, _ = W.shape
        W_flat = W.view(B * D, D)
        out_flat = self.net(W_flat)
        return out_flat.view(B, D, -1)

# --- Hyperparameters ---
# d: input/output dimension; m: hidden feature dimension
# N: dataset size; batch_size: training mini-batch size
# epochs: number of full passes over the data
d, m, N, batch_size, epochs = 2, 16, 1024, 64, 5

# --- Setup data + models ---
dataset = GeluDataset(N, d)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
psi = PsiNet(d, m)
phi = PhiNet(d, m)

# --- Optimizer ---
# Adam: fast convergence with minimal tuning, good for deep MLPs
optimizer = torch.optim.Adam(list(psi.parameters()) + list(phi.parameters()), lr=1e-3)
criterion = nn.MSELoss()

# --- Training loop ---
for epoch in range(1, epochs + 1):
    total_loss = 0.0
    for Wb, xb, yb in loader:
        optimizer.zero_grad()
        ψ = psi(xb)                      # (B, m)
        Φ = phi(Wb)                      # (B, d, m)
        y_pred = torch.bmm(Φ, ψ.unsqueeze(-1)).squeeze(-1)  # (B, d)
        loss = criterion(y_pred, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * Wb.size(0)

    print(f"Epoch {epoch}: MSE = {total_loss / N:.4f}")

## Validation

In [1]:
import numpy as np
import torch

class NeuralGELUComparator:
    def __init__(self, phi_model, psi_model):
        self.phi = phi_model
        self.psi = psi_model

    def gelu_tanh(self, x):
        return 0.5 * x * (1 + np.tanh(np.sqrt(2/np.pi) * (x + 0.044715 * x**3)))

    def compare(self, W_test, x_test):
        with torch.no_grad():
            ψ = self.psi(x_test)                         # (1, m)
            Φ = self.phi(W_test)                         # (1, d, m)
            y_approx = torch.bmm(Φ, ψ.unsqueeze(-1)).squeeze().numpy()

            x_proj = torch.bmm(W_test, x_test.unsqueeze(-1)).squeeze().numpy()
            y_true = self.gelu_tanh(x_proj)

        print("==== Neural φ/ψ Linearization of GELU ====\n")
        print(f"x' = Wx = {x_proj}")
        print(f"Learned φ(W) · ψ(x) = {y_approx}")
        print(f"GELU_tanh(x')      = {y_true}")

        rmse = np.sqrt(np.mean((y_approx - y_true) ** 2))
        print(f"\nRMSE: Linearized φ · ψ vs GELU_tanh = {rmse:.5f}")

class NeuralGELUDemo:
    def __init__(self, phi_model, psi_model, d):
        self.phi = phi_model
        self.psi = psi_model
        self.d = d
        self.comparator = NeuralGELUComparator(phi_model, psi_model)

    def run(self):
        W_test = torch.randn(1, self.d, self.d)
        x_test = torch.randn(1, self.d)
        self.comparator.compare(W_test, x_test)

NeuralGELUDemo(phi, psi, d=2).run()

NameError: name 'phi' is not defined