# Estimating Factor Analysis
One of the simplest model for generative auto-encoders is Factor Analysis. We'll use our method to learn it, with the loss function


$$ \mathcal L(\boldsymbol \theta, y^{(i)})  = {y^{(i)}}^\top \hat y_{\boldsymbol \theta}^{(i)} - \mathbb E_{p_{\boldsymbol\theta}}[y^\top \hat y_{\boldsymbol \theta}]$$

where $\hat y_{\boldsymbol \theta} = \text{decoder}_{\boldsymbol \theta}(\text{encoder}_{\boldsymbol \theta}(y))$ and $y\sim p_{\boldsymbol \theta}$.

In [10]:
import torch
from torch import nn
import numpy as np
import math
import matplotlib.pyplot as plt
import copy

In [11]:
# DIMENSIONS
n = 100
p = 50
q = 5
DIMENSION_Y = (n, p)
DIMENSION_Z = (n, q)

In [12]:
class FA(nn.Module):
    def __init__(self):
        super().__init__()
        # Initialize Module
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.var_y = torch.ones(p)

    def forward(self, y):
        z = self.encoder(y)
        y = self.decoder(z)
        return y
    
    def sample(self, z = None):
        """Sample from the fitted model."""
        with torch.no_grad():
            if z is None:
                z = torch.randn(DIMENSION_Z)

            eps = torch.randn(DIMENSION_Y) * torch.sqrt(self.var_y)
            y = self.decoder(z) + eps
        
        return (y, z)

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()

        # Define parameters
        self.w = nn.Parameter(torch.randn(q, p))
    
    # override forward
    def forward(self, z):
        linpar = z @ self.w
        return linpar


class Encoder(nn.Module):
    def __init__(self):
        super().__init__()

        # Define a sequential model
        self.encoder_model = nn.Sequential(
            nn.Linear(in_features = p, out_features = 100),
            nn.ReLU(),
            nn.Linear(in_features = 100, out_features = 100),
            nn.ReLU(),
            nn.Linear(in_features=100, out_features = q)
        )

    def forward(self, y):
        return self.encoder_model(y)

class GMLoss(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, input, target, input_sim, target_sim, z_hat, z_sim):
        loss = -torch.mean(input * target - input_sim * target_sim) + torch.mean(torch.pow(z_hat - z_sim, 2))
        return loss

In [14]:
model_true = FA()
y_true, z_true = model_true.sample()
assert y_true.shape == DIMENSION_Y and z_true.shape == DIMENSION_Z
model = FA()
loss_fn = GMLoss()
# Fit the model
epochs = 1000
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# Fit eval: MSE
eval_fn = nn.MSELoss()

for epoch in range(1, epochs+1):
    # Zero the gradients
    optimizer.zero_grad()

    # Forward pass
    target_true = model(y_true)

    # Simulate and forward pass on the simulated sample
    y_sim, z_sim = model.sample()
    z_hat = model.encoder(y_sim)
    target_sim = model.decoder(z_hat)

    # Compute the loss
    loss = loss_fn(y_true, target_true, y_sim, target_sim, z_hat, z_sim)
    loss.backward()
    optimizer.step()

    # Eval the fit every 10 epochs
    if epoch % 10 == 0:
        with torch.no_grad():
            eval = eval_fn(y_true, target_true)
        print("Epoch {} Loss {:.2f}".format(epoch, eval))

Epoch 10 Loss 6.32
Epoch 20 Loss 6.23
Epoch 30 Loss 6.15
Epoch 40 Loss 6.07
Epoch 50 Loss 6.02
Epoch 60 Loss 5.98
Epoch 70 Loss 5.94
Epoch 80 Loss 5.91
Epoch 90 Loss 5.90
Epoch 100 Loss 5.90
Epoch 110 Loss 5.90
Epoch 120 Loss 5.91
Epoch 130 Loss 5.98
Epoch 140 Loss 6.09
Epoch 150 Loss 6.30
Epoch 160 Loss 6.63
Epoch 170 Loss 7.13
Epoch 180 Loss 7.86
Epoch 190 Loss 8.88
Epoch 200 Loss 10.25
Epoch 210 Loss 12.03
Epoch 220 Loss 14.34
Epoch 230 Loss 17.25
Epoch 240 Loss 20.88
Epoch 250 Loss 25.35
Epoch 260 Loss 30.88
Epoch 270 Loss 37.91
Epoch 280 Loss 46.44
Epoch 290 Loss 56.57
Epoch 300 Loss 69.13
Epoch 310 Loss 84.20
Epoch 320 Loss 102.34
Epoch 330 Loss 125.58
Epoch 340 Loss 153.76
Epoch 350 Loss 186.35
Epoch 360 Loss 227.56
Epoch 370 Loss 275.68
Epoch 380 Loss 335.69
Epoch 390 Loss 407.63
Epoch 400 Loss 493.31
Epoch 410 Loss 592.38
Epoch 420 Loss 717.10
Epoch 430 Loss 860.74
Epoch 440 Loss 1016.44
Epoch 450 Loss 1200.59
Epoch 460 Loss 1418.50
Epoch 470 Loss 1664.98
Epoch 480 Loss 1947.4