# Verify solver works as intended

In [1]:
import sys
sys.path.append('/home/phil/aptr')
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from src.simulation_new import simulate_sample
from src.database import RnaDB
from src.torch_solver import TorchSolver, solve_table


In [3]:
rnadb = RnaDB()

In [86]:
samples = []
genomes = []
ptrs = pd.DataFrame()
for i in range(10):
    genome = np.random.choice(rnadb.complete_genomes)
    genomes.append(genome)
    log_ptr = np.random.rand()
    sample = simulate_sample(
        genome = genome,
        log_ptr=log_ptr,
        db = rnadb,
    )
    samples.append(sample)
    ptrs = ptrs.append(
        {"genome": genome, "ptr": log_ptr, "sample": i}, ignore_index=True
    )
otu_matrix = pd.DataFrame(samples).T
otu_matrix

ptrs = ptrs.pivot(index="genome", columns="sample", values="ptr")

In [87]:
solver = TorchSolver(
    genomes=rnadb.generate_genome_objects(genomes)[0],
    coverages=otu_matrix.values
)
_ = solver.train(epochs=10)

Epoch 0:	 1.767502908478491e-05
Epoch 1:	 3.975238087150501e-06
Epoch 2:	 1.585175709806208e-06
Epoch 3:	 7.739557759123272e-07
Epoch 4:	 4.178955919087457e-07
Epoch 5:	 2.4362910266972904e-07
Epoch 6:	 1.5554380183857575e-07
Epoch 7:	 1.0822384410857921e-07
Epoch 8:	 8.115236482808541e-08
Epoch 9:	 6.498063243043362e-08


In [90]:
a = solver.A_hat.detach().numpy()
b = solver.B_hat.detach().numpy()

b[a < 0] = np.nan

inferred_ptrs = pd.DataFrame(b, index=genomes, columns=otu_matrix.columns)
inferred_ptrs

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
744872.3,0.227501,,,,,,,,,
591001.3,,0.719544,,,,,,,,
1250006.5,,,0.20009,,,,,,,
1150469.3,,,,0.301373,,,,,,
632348.3,,,,,0.74688,,,,,
350058.8,,,,,,0.569775,,,,
1223515.3,,,,,,,1.0,,,
457425.27,,,,,,,,0.471837,,
880070.3,,,,,,,,,0.0,
768490.3,,,,,,,,,,0.442


In [93]:
np.exp(ptrs[inferred_ptrs.columns]) - np.exp(inferred_ptrs)

sample,0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0
1150469.3,,,,-0.052962,,,,,,
1223515.3,,,,,,,-0.150685,,,
1250006.5,,,0.004334,,,,,,,
350058.8,,,,,,0.400388,,,,
457425.27,,,,,,,,0.152423,,
591001.3,,-0.167627,,,,,,,,
632348.3,,,,,-0.224506,,,,,
744872.3,0.072839,,,,,,,,,
768490.3,,,,,,,,,,-0.180615
880070.3,,,,,,,,,0.025459,
