In [None]:
import treescope
import torch
from pathlib import Path
from src.model import EvolutionOperator
from src.data import DESRESDataset
from src.configs import ModelArgs # noqa
from torch_geometric.loader import DataLoader
from loguru import logger
from dataclasses import asdict
import linear_operator_learning as lol
import matplotlib.pyplot as plt
import numpy as np
import pickle


#treescope.basic_interactive_setup()

In [None]:
# Write new configs on wandb.
import wandb

protein_id = "CLN025"
traj_id = 0
epoch = 249
api = wandb.Api()

for run in api.runs(f"csml/encoderops-{protein_id}-{traj_id}"):
    updated: bool = False
    if run.state == "finished":
        data_path = Path(f"logs/encoderops-{protein_id}-{traj_id}/{run.id}")
        model = EvolutionOperator.load_from_checkpoint(next(data_path.glob(f"**/epoch={epoch}*.ckpt")))
        model_args = asdict(model.model_args)
        for k, v in model_args.items():
            if k not in run.config:
                run.config[k] = v
                updated = True
        run.update()
        if updated:
            logger.info(f"Updated run {run.id}")
        else:
            logger.info(f"Run {run.id} was already up to date")

In [None]:
run_id = "yl7v1o69"

In [None]:
def load_run(run_id: str, protein_id: str, traj_id:int = 0, epoch: int = 249, center_covariance: bool = False, progress: bool = True):
    reg=1e-4
    data_path = Path(f"logs/encoderops-{protein_id}-{traj_id}/{run_id}")
    model = EvolutionOperator.load_from_checkpoint(next(data_path.glob(f"**/epoch={epoch}*.ckpt")))
    dataset = DESRESDataset(model.data_args.protein_id, lagtime=model.data_args.lagtime)
    embeddings = {"t": [], "lag": []}
    model = model.eval()

    dataloader = DataLoader(
        dataset, batch_size=128, shuffle=False
    )

    with torch.no_grad():
        if progress:
            from tqdm.auto import tqdm
            dataloader = tqdm(dataloader)
        for batch in dataloader:
            for k, v in batch.items():
                batch[k] = v.to(model.device)
            # data
            x_t = model._setup_graph_data(batch)
            x_lag = model._setup_graph_data(batch, key="item_lag")
            # forward
            f_t = model.forward_nn(x_t)
            f_lag = model.forward_nn(x_lag)
            embeddings["t"].append((f_t.detach().cpu()))
            embeddings["lag"].append((f_lag.detach().cpu()))
    phi_t = torch.cat(embeddings["t"])
    phi_lag = torch.cat(embeddings["lag"])
    cov_X = lol.nn.stats.covariance(phi_t, center=center_covariance)
    cov_X += torch.eye(cov_X.shape[0], device=cov_X.device) * reg
    cov_XY = lol.nn.stats.covariance(phi_t, phi_lag, center=center_covariance)
    G = torch.linalg.solve(cov_X, cov_XY)
    l, Q = torch.linalg.eig(G) # Q @ torch.diag(l) @ Q^-1 = G
    r_fun = phi_t.to(Q.dtype) @ Q 
    results = {
        "embedding_t": phi_t,
        "embedding_lag": phi_lag,
        "eigenvalues": l,
        "eigenvectors": Q,
        "r_fun": r_fun
    }
    return model, results

In [None]:
model, results = load_run(run_id, protein_id)

In [None]:
reg = 1e-4
if Path(f"analysis/results-{protein_id}-{traj_id}.pkl").exists():
    results = pickle.load(open(f"analysis/results-{protein_id}-{traj_id}.pkl", "rb"))
else:
    results = {}
    for run in api.runs(f"csml/encoderops-{protein_id}-{traj_id}"):
        if run.state == "finished":
            logger.info(f"Run {run.id}")
            model, run_results = load_run(run.id, protein_id, progress=False)
            results[run.id] = run_results
            pickle.dump(results, open(f"analysis/results-{protein_id}-{traj_id}.pkl", "wb"))

In [None]:
def implied_timescale(ev, lagtime_ns: float = 5.0):
    if torch.is_tensor(ev):
        ev = ev.numpy(force=True)
    return np.sort((1/-np.log(np.abs(ev)))*lagtime_ns)[::-1]

In [None]:
print(results['eigenvalues'])

In [None]:
print(implied_timescale(results['eigenvalues']))

In [None]:
print(model.get_transfer_operator())

In [None]:
print(model.cov.cpu()/cov_X)

In [None]:
from linear_operator_learning.nn.stats import covariance
phi_t = results['embedding_t']
phi_lag = results['embedding_lag']
center_covariance=False
cov_X = covariance(phi_t, center=center_covariance)
cov_X += torch.eye(cov_X.shape[0], device=cov_X.device) * 1e-4
cov_XY = covariance(phi_t, phi_lag, center=center_covariance)
G = torch.linalg.solve(cov_X, cov_XY)

In [None]:
p

In [None]:
fig, ax = plt.subplots(ncols=2, figsize=(10, 5))
for run_id, values in results.items():
    data_path = Path(f"/home/novelli/encoderops/encoderops_chignolin/{run_id}")
    epoch = 249
    model = EvolutionOperator.load_from_checkpoint(next(data_path.glob(f"**/epoch={epoch}*.ckpt")))
    color = 'r'
    if model.model_args.regularization == 0.0001:
        if model.model_args.min_encoder_lr is not None:
            if model.model_args.normalize_lin:
            #if model.model_args.max_grad_norm is None:
                color ='b'
                print(run_id, model.model_args)
    
    #color = 'b' if model.model_args.normalize_lin else 'r'
    # color = 'b' if model.model_args.max_grad_norm is None else 'r'
    # color = 'b' if model.model_args.regularization == 0.0001 else 'r'

    ts = implied_timescale(values['eigenvalues'])
    ax[0].plot(range(1, len(ts) + 1), ts, 'x-', label=run_id, color= color)
    ax[1].plot(np.sort(np.abs(values['eigenvalues'][:4]))[::-1], 'x-', color=color)
#plt.yscale('log')
ax[0].set_xscale('log')
ax[0].set_yscale('log')
ax[0].axhline(5)
#ax[0].set_ylim(2, 100)
ax[1].set_ylim(0.8, 1)

In [None]:
results['p52qnu9d']['eigenvalues']

In [None]:
N = phi_t.shape[0]
plt.scatter(results['p52qnu9d']['eigenvalues'].real, results['p52qnu9d']['eigenvalues'].imag, marker='x', label="Ridge Regression new")
#plt.scatter(e1.real, e1.imag, label="end2end")
plt.grid(alpha=0.2)
plt.legend()
# Draw unit circle
t = np.linspace(0, 2 * np.pi, 100)
r = 1
x = r * np.cos(t)
y = r * np.sin(t)
plt.plot(x, y, color='k', lw=0.5)
plt.axis('equal')

In [None]:
r_fun = results['p52qnu9d']['r_fun']

In [None]:
from mlcolvar.utils.fes import compute_fes

In [None]:
plt.plot(r_fun[:, 0].real, '.', color='k', markersize = 0.5)
plt.axhline(1.55, color='r')

In [None]:
plt.scatter(r_fun[:, 0].real, r_fun[:, 1].real, color='k', s = 0.5)
plt.axvline(1.55, color='r')
plt.title("They look like they are concentrated on the vertices of a simplex!!!")