In [18]:
# Task 3: Implement the XOR network and the Gradient Descent Algorithm

In [None]:
import torch
from typing import Tuple, Dict


# Load XOR data
def load_xor() -> Tuple[torch.Tensor, torch.Tensor]:
    X = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float32)
    y = torch.tensor([0, 1, 1, 0], dtype=torch.float32)
    return X, y

In [None]:
def dsigmoid(h):
    # h is already the output of sigmoid(z)
    return h * (1 - h)


def dtanh(h):
    # h is already the output of tanh(z)
    return 1 - h**2


def drelu(z):
    return (z > 0).float()


class XORNet:
    """
    Minimal 2-2-1 MLP for learning XOR.
    Total parameters = 9: W1(2x2)=4, b1(2)=2, W2(2)=2, b2(1)=1
    """

    def __init__(
        self,
        activation="sigmoid",
        weight_scale: float = 0.5,
        seed: int = 42,
        dtype=torch.float32,
    ):
        self.dtype = dtype
        self.seed = seed

        if self.seed is not None:
            torch.manual_seed(self.seed)

        self.W1 = torch.normal(
            mean=0.0, std=weight_scale, size=(2, 2), dtype=self.dtype
        )
        self.b1 = torch.normal(mean=0.0, std=weight_scale, size=(2,), dtype=self.dtype)
        self.W2 = torch.normal(mean=0.0, std=weight_scale, size=(2,), dtype=self.dtype)
        self.b2 = torch.normal(mean=0.0, std=weight_scale, size=(), dtype=self.dtype)

        if activation == "sigmoid":
            self.act = torch.sigmoid
        elif activation == "tanh":
            self.act = torch.tanh
        elif activation == "relu":
            self.act = torch.relu
        else:
            raise ValueError("Unknown activation function")

        self.activation_name = activation

    def forward(self, X: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        X: (N,2)
        Returns y_hat and cache (for backpropagation)
        """
        z1 = X @ self.W1 + self.b1  # (N,2)
        h = self.act(z1)  # (N,2)
        z2 = h @ self.W2 + self.b2  # (N,)
        y_hat = torch.sigmoid(z2)  # Output layer uses sigmoid
        cache = {"X": X, "z1": z1, "h": h, "z2": z2, "y_hat": y_hat}
        return y_hat, cache

    @staticmethod
    def mse(y_hat: torch.Tensor, y: torch.Tensor) -> float:
        return float(torch.mean((y_hat - y) ** 2).item())

    @staticmethod
    def miscls(y_hat: torch.Tensor, y: torch.Tensor, thr: float = 0.5) -> int:
        y_bin = torch.zeros_like(y_hat, dtype=torch.int)
        y_bin[y_hat >= thr] = 1
        return int(torch.sum(y_bin != y).item())

    def gradients(
        self, cache: Dict[str, torch.Tensor], y: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        """
        Compute gradients for the whole batch (4 samples). Loss L = mean((y_hat-y)^2)
        Let N=4. Chain rule:
          dL/dy_hat = 2/N * (y_hat - y)
          y_hat = sigmoid(z2) → dy_hat/dz2 = y_hat*(1-y_hat)
          z2 = h·W2 + b2
          h = act(z1), z1 = X·W1 + b1
        """
        X = cache["X"]  # (N,2)
        z1 = cache["z1"]  # (N,2)
        h = cache["h"]  # (N,2)
        z2 = cache["z2"]  # (N,)
        y_hat = cache["y_hat"]  # (N,)

        N = X.shape[0]
        dL_dyhat = (2.0 / N) * (y_hat - y)  # (N,)

        # Output layer: sigmoid
        dyhat_dz2 = y_hat * (1.0 - y_hat)  # (N,)
        dL_dz2 = dL_dyhat * dyhat_dz2  # (N,)

        # For W2, b2
        # z2_i = sum_j h_ij * W2_j + b2
        dL_dW2 = h.t() @ dL_dz2  # (2,)
        dL_db2 = torch.sum(dL_dz2)  # ()

        # Backprop to hidden layer output h
        dL_dh = dL_dz2.unsqueeze(1) * self.W2.unsqueeze(0)  # (N,2)

        # Hidden layer activation derivative
        if self.activation_name == "sigmoid":
            dh_dz1 = dsigmoid(h)  # (N,2)
        elif self.activation_name == "tanh":
            dh_dz1 = dtanh(h)
        elif self.activation_name == "relu":
            dh_dz1 = drelu(z1)
        else:
            raise ValueError("Unknown activation function")

        dL_dz1 = dL_dh * dh_dz1  # (N,2)

        # z1 = X·W1 + b1
        dL_dW1 = X.t() @ dL_dz1  # (2,2)
        dL_db1 = torch.sum(dL_dz1, dim=0)  # (2,)

        return {
            "W1": dL_dW1,
            "b1": dL_db1,
            "W2": dL_dW2,
            "b2": dL_db2,
        }

    def step(self, grads: Dict[str, torch.Tensor], lr: float) -> None:
        self.W1 -= lr * grads["W1"]
        self.b1 -= lr * grads["b1"]
        self.W2 -= lr * grads["W2"]
        self.b2 -= lr * grads["b2"]

    def train(
        self,
        X: torch.Tensor,
        y: torch.Tensor,
        lr: float = 0.5,
        epochs: int = 5000,
        verbose_every: int = 500,
    ) -> Dict[str, list]:
        history = {"mse": [], "mis": []}
        for ep in range(1, epochs + 1):
            y_hat, cache = self.forward(X)
            loss = self.mse(y_hat, y)
            mis = self.miscls(y_hat, y)
            history["mse"].append(loss)
            history["mis"].append(mis)

            grads = self.gradients(cache, y)
            self.step(grads, lr)

            if verbose_every and (ep % verbose_every == 0):
                print(f"[{ep:5d}] mse={loss:.6f} mis={mis}")
        return history

    def predict(self, X: torch.Tensor) -> torch.Tensor:
        y_hat, _ = self.forward(X)
        return (y_hat >= 0.5).int()