# Verify solver works as intended

In [1]:
import sys
sys.path.append('/home/phil/aptr')
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from src.simulation_new import simulate_sample
from src.database import RnaDB
from src.torch_solver import TorchSolver, solve_table


In [3]:
rnadb = RnaDB()

In [61]:
samples = []
genomes = []
ptrs = pd.DataFrame()
for i in range(10):
    genome = np.random.choice(rnadb.complete_genomes)
    genomes.append(genome)
    log_ptr = np.random.rand()
    sample = simulate_sample(
        genome = genome,
        log_ptr=log_ptr,
        db = rnadb,
    )
    samples.append(sample)
    ptrs = ptrs.append(
        {"genome": genome, "ptr": log_ptr, "sample": i}, ignore_index=True
    )
otu_matrix = pd.DataFrame(samples).T
otu_matrix

ptrs = ptrs.pivot(columns="genome", index="sample", values="ptr")

In [69]:
solver = TorchSolver(
    genomes=rnadb.generate_genome_objects(genomes)[0],
    coverages=otu_matrix.T.values
)
_ = solver.train(epochs=10)

Epoch 0:	 2.4934621251304634e-05
Epoch 1:	 5.502673957380466e-06
Epoch 2:	 2.07315724765067e-06
Epoch 3:	 9.856578344624722e-07
Epoch 4:	 5.195817607273057e-07
Epoch 5:	 2.92365143650386e-07
Epoch 6:	 1.7785500006084476e-07
Epoch 7:	 1.175301989064792e-07
Epoch 8:	 8.266682272051185e-08
Epoch 9:	 6.102203542468487e-08


(array([[ 3.2394073, -2.9004035, -2.5778015, -2.664317 , -2.640255 ,
         -2.7309008, -3.0073214, -2.870785 , -2.8461971, -2.7706692],
        [-2.5813768,  3.1996098, -2.5778015, -2.664317 , -2.640255 ,
         -2.7309008, -3.0073214, -2.870785 , -2.8461971, -2.7706692],
        [-2.5813768, -2.9004035,  3.1557233, -2.664317 , -2.640255 ,
         -2.7309008, -3.0073214, -2.870785 , -2.8461971, -2.7706692],
        [-2.5813768, -2.9004035, -2.5778015,  3.2261872, -2.640255 ,
         -2.7309008, -3.0073214, -2.870785 , -2.8461971, -2.7706692],
        [-2.5813768, -2.9004035, -2.5778015, -2.664317 ,  3.1653726,
         -2.7309008, -3.0073214, -2.870785 , -2.8461971, -2.7706692],
        [-2.5813768, -2.9004035, -2.5778015, -2.664317 , -2.640255 ,
          2.9930618, -3.0073214, -2.870785 , -2.8461971, -2.7706692],
        [-2.5813768, -2.9004035, -2.5778015, -2.664317 , -2.640255 ,
         -2.7309008,  3.168074 , -2.870785 , -2.8461971, -2.7706692],
        [-2.5813768, -2.900

In [55]:
a = solver.a_hat.detach().numpy()
b = solver.b_hat.detach().numpy()

b[a < 0] = np.nan

In [65]:
inferred_ptrs = pd.DataFrame(b, columns=genomes, index=otu_matrix.columns)

In [68]:
np.exp(inferred_ptrs) - np.exp(ptrs)

Unnamed: 0,1076934.5,1131462.4,1196835.3,136857.5,218491.5,396513.4,491915.6,529120.14,550540.3,771875.3
0,,,-0.797464,,,,,,,
1,,,,,,,,,-0.881902,
2,,,,,,,,,,-0.23657
3,,,,0.199424,,,,,,
4,,-0.325927,,,,,,,,
5,-0.979878,,,,,,,,,
6,,,,,,,-0.932963,,,
7,,,,,,,,0.111862,,
8,,,,,-1.418324,,,,,
9,,,,,,0.113326,,,,
