# Simulated one-dimensional latent space

In [None]:
import string
import pandas as pd

from itertools import combinations
from gpytorch.kernels import RQKernel
import torch
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

## Simulate data

In [None]:
def sim(seed, p=5):

    N = 2 ** p
    torch.random.manual_seed(seed)
    W = torch.randn(p, 1) * np.sqrt(2) - .1

    X = torch.zeros(N, p)
    ind = 1

    # for all # of mutations
    for mutations in range(1, p + 1):

        # for selected combination of mutations for a variant
        for variant in combinations(range(p), mutations):

            # for each selected
            for s in variant:
                X[ind, s] = 1

            # update after variant
            ind += 1

    z = torch.mm(X, W)
    Z = torch.linspace(z.min(), z.max(), 100)[:, None]
    z_samp = torch.cat((z, Z), 0)

    kernel = RQKernel()
    with torch.no_grad():
        K = kernel(z_samp).evaluate()
        f = torch.distributions.MultivariateNormal(
            torch.zeros(N + 100), 0.005 * K + torch.eye(N + 100) * 1e-7
        ).rsample() + torch.sigmoid(z_samp[:, 0])
        
    y = f[:N] + torch.randn(N) * 0.15

    return W, X, z, y, Z, f[N:]

p = 5
W, X, z, y, Z, f = sim(100, p=p)

plt.figure(figsize=(4, 3), dpi=300)
plt.plot(Z, f)
plt.scatter(z, y, c="C2", alpha=0.8)
plt.axvline(0, c="k", ls="--")

for i in range(p):
    plt.arrow(0, -.05*i, W[i].item(), 0, color=f"C{3+i}", width=0.01)
    
plt.ylabel("phenotype")
plt.xlabel("$z_1$")
    
None

### Convert to dataframe

In [None]:
df = pd.DataFrame(
    {
        "substitutions": [
            ":".join(
                [
                    "+{}".format(string.ascii_lowercase[i])
                    for i in np.where(X[j, :].numpy())[0]
                ]
            )
            for j in range(X.shape[0])
        ],
        "phenotype": y,
    },
)

df.head()

### Build LANTERN dataset

In [None]:
from lantern.dataset import Dataset
ds = Dataset(df)
ds

In [None]:
# 32 observations
len(ds)

In [None]:
# get the first element (a tuple of x_0, y_0)
ds[0]

## Build model

In [None]:
from lantern.model.basis import VariationalBasis

basis = VariationalBasis.fromDataset(ds, K=8, meanEffectsInit=True)

In [None]:
from lantern.model.surface import Phenotype

surface = Phenotype.fromDataset(ds, K=8)

In [None]:
from lantern.model import Model

model = Model(basis, surface)

## Train model

In [None]:
from torch.optim import Adam

loss = model.loss(N=len(ds))
X, y = ds[: len(ds)]

optimizer = Adam(loss.parameters(), lr=0.01)
hist = []
for i in range(500):
    optimizer.zero_grad()
    yhat = model(X)
    lss = loss(yhat, y)
    total = sum(lss.values())
    total.backward()
    optimizer.step()
    hist.append(total.item())

plt.figure(figsize=(4, 3), dpi=300)
plt.plot(hist)
plt.xlabel("epoch")
plt.ylabel("loss")