# Torch sanity checks
> Verifying that the `TorchSolver` class is working as expected.

## Data generation

In [1]:
import sys
# sys.path.append('/home/phil/aptr')
# sys.path.append('~/aptr')
sys.path.append('/Users/phil/Columbia/aPTR')
%load_ext autoreload
%autoreload 2

In [91]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from src.torch_solver import TorchSolver
from src.database import RnaDB

from src.solve_table import solve_all, score_predictions
from src.simulation import simulate_from_ids

In [3]:
# Load a database object

db = RnaDB()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.obj[key] = value


In [74]:
# Get 10 complete genomes

genomes = db.db[db.db["n_contigs"] == 1]["genome"].unique()
genomes_to_use = np.random.choice(genomes, 10, replace=False)

# Simulate reads from those genomes
reads, ptrs, coverages, otus = simulate_from_ids(
    db=db.db,
    ids=genomes_to_use,
    fasta_path="/home/phil/aptr/data/seqs",
    n_samples=5,
    scale=1e4,
    shuffle=False,  # Suppress shuffling to conserve memory
    fastq=False
)


Generating sample 0 for organism 158190.3...
Generating sample 0 for organism 1538644.4...
Generating sample 0 for organism 655815.4...
Generating sample 0 for organism 713604.4...
Generating sample 0 for organism 1448139.3...
Generating sample 0 for organism 41514.7...
Generating sample 0 for organism 158822.7...
Generating sample 0 for organism 135487.3...
Generating sample 0 for organism 1042156.4...
Generating sample 0 for organism 467705.9...
Sample RNA:
[{'b7f41b025015db33fc3c70196e412905': 15, '370320f84fb907dc06dff584270a9ee5': 2}, {'c27fc07a69127cb5342a030dc64f6e78': 9, '83aa7ef534f9dd6e2f2559e2ad40b4bf': 5}, {'198bdbdc91dcecab92f812c764c142fb': 0, '2e6daf417dfd38a37f76ce679cf8366f': 1}, {'ba3405f355c4a84fa05463556ef463cb': 0, 'd29e6176bf0c86b43bc365e98acfc528': 0, '3c80218ac1294cf64d118bbc672c0595': 0}, {'f39f35965b1d20a9601584dfe085cc88': 9, '1f8c69aac679b5c6491970ab010f52e6': 1, 'b85b2777aac3f29822017387febe1f2e': 2}, {'556a8129572431ad0f63471d81f5c333': 12, 'ad1139d8d17b72

Now let's take a second to look at all of the simulation outputs:

In [76]:
print(reads) # Should be None, since we didn't ask for fastq output

None


In [77]:
print(ptrs) # A #{genomes} x #{samples} matrix of PTRs
# In this case, should be 10x5
print(ptrs.shape)

[[1.296476   1.22836694 1.69989888 1.37453742 1.28500123]
 [1.81078708 1.50818829 1.12666137 1.84335023 1.67940566]
 [1.17272828 1.32361942 1.73033557 1.28527683 1.9838819 ]
 [1.76765625 1.52760728 1.95606811 1.15563401 1.39603746]
 [1.13740399 1.81536428 1.50787392 1.95740882 1.1959742 ]
 [1.29915352 1.44219008 1.28869098 1.16721832 1.89333107]
 [1.56818202 1.90765779 1.06025922 1.62804339 1.50799788]
 [1.34288533 1.89917305 1.49072891 1.12754738 1.08295759]
 [1.07589977 1.99070998 1.69732122 1.28146579 1.54071823]
 [1.69656163 1.23173143 1.85870135 1.49720755 1.55908333]]
(10, 5)


In [78]:
print(coverages)
print(coverages.shape) # Should also be 10x5, positive integers

[[18053 12196  1756 15427 16842]
 [18605   728  3885 15899  6352]
 [ 5194 16834   976 14236 39945]
 [ 2648  8829 33764  6526  7175]
 [13334 37327 21571 16154  4521]
 [31733  3175  2411  5635 14281]
 [  982   924 16316  9404  1791]
 [ 3925  3015 30935 10028   854]
 [ 2856   611  2178 18642   279]
 [32485 31673  5689   618  8531]]
(10, 5)


In [79]:
print(otus)

                                     0     1     2     3     4
b7f41b025015db33fc3c70196e412905  15.0   5.0   0.0  11.0  11.0
370320f84fb907dc06dff584270a9ee5   2.0   4.0   0.0   1.0   3.0
c27fc07a69127cb5342a030dc64f6e78   9.0   1.0   3.0  11.0   9.0
83aa7ef534f9dd6e2f2559e2ad40b4bf   5.0   0.0   1.0   6.0   0.0
198bdbdc91dcecab92f812c764c142fb   0.0   1.0   0.0   3.0   8.0
2e6daf417dfd38a37f76ce679cf8366f   1.0   1.0   0.0   4.0   8.0
ba3405f355c4a84fa05463556ef463cb   0.0   1.0   1.0   0.0   0.0
d29e6176bf0c86b43bc365e98acfc528   0.0   1.0   1.0   2.0   1.0
3c80218ac1294cf64d118bbc672c0595   0.0   1.0   1.0   0.0   0.0
f39f35965b1d20a9601584dfe085cc88   9.0  33.0  21.0  18.0   3.0
1f8c69aac679b5c6491970ab010f52e6   1.0   4.0   3.0   1.0   0.0
b85b2777aac3f29822017387febe1f2e   2.0   2.0   4.0   4.0   0.0
556a8129572431ad0f63471d81f5c333  12.0   0.0   1.0   3.0   4.0
ad1139d8d17b7276132087bf4682045c   5.0   2.0   0.0   0.0   1.0
c6fe5d3d4af66c90ff7df941c23967f1   6.0   0.0   1.0   0.

## Torch checks

Remember, the TorchSolver object assumes you have a 'genomes' dict which has the keys:
* seqs -> indices for which sequence the genome has
* pos -> start position for RNA gene

In [80]:
solver_genomes, md5s = db.generate_genome_objects(genomes_to_use)
print(solver_genomes)
print(len(solver_genomes)) # Should be 10, since there are 10 genomes

[{'id': '158190.3', 'pos': array([0.36870599, 0.05386492, 0.05206312, 0.4038798 ]), 'seqs': [0, 0, 0, 1]}, {'id': '1538644.4', 'pos': array([0.03367286, 0.01410787, 0.03505736, 0.08001561, 0.42412868,
       0.13796746, 0.22024391, 0.19420117]), 'seqs': [2, 2, 2, 2, 3, 3, 2, 2]}, {'id': '655815.4', 'pos': array([0.3692223 , 0.60811316, 0.00403924]), 'seqs': [4, 5, 5]}, {'id': '713604.4', 'pos': array([0.12651841, 0.34689584, 0.22539473, 0.25593295]), 'seqs': [6, 7, 8, 7]}, {'id': '1448139.3', 'pos': array([0.16435212, 0.00907711, 0.02554103, 0.06524982, 0.01749113,
       0.19055063, 0.81160136, 0.07228189, 0.09243017, 0.03625559]), 'seqs': [9, 10, 9, 9, 9, 9, 9, 9, 9, 11]}, {'id': '41514.7', 'pos': array([0.07106634, 0.08936011, 0.75653234, 0.0306175 , 0.01029364,
       0.06107764, 0.2343136 ]), 'seqs': [12, 13, 12, 14, 15, 16, 17]}, {'id': '158822.7', 'pos': array([0.03045775, 0.1834427 , 0.20539586, 0.05412395, 0.05569944,
       0.07173006, 0.07988342]), 'seqs': [18, 19, 20, 21, 2

In [98]:
# Here we initialize a TorchSolver with the genomes and coverages from before

solver = TorchSolver()
solver.set_vals(
    genomes=solver_genomes,
    coverages=otus[0]
)

In [101]:
# Check that a_hat updates 

print(solver.a_hat)
solver.train()
print(solver.a_hat)

tensor([0.7543, 0.4381, 0.1231, 0.4828, 0.1528, 0.3969, 0.1638, 0.8205, 0.9877,
        0.4945], requires_grad=True)
tensor([0.7543, 0.4381, 0.1231, 0.4828, 0.1528, 0.3969, 0.1638, 0.8205, 0.9877,
        0.4945], requires_grad=True)


In [116]:
# Let's try to write out the torch model here
import torch

# Learnable parameters are a_hat and b_hat
a_hat = torch.rand(10, dtype=torch.float32)
b_hat = torch.log(1 + torch.rand(10, dtype=torch.float32)) # log-PTR

# Let's just make copies of the existing C, D, E matrices
C = torch.tensor(solver.members, dtype=torch.float32)
D = torch.tensor(solver.dists, dtype=torch.float32)
E = torch.tensor(solver.gene_to_seq, dtype=torch.float32)
f_true = torch.tensor(otus[0], dtype=torch.float32)

# Some counts from the model
n = solver.n
m = solver.m
k = solver.k

# Forward pass
g = a_hat @ C + 1 - b_hat @ D
f = torch.exp(g) @ E

# Loss
loss = torch.sum((f - f_true)**2)

# Gradients
dL_df = f - f_true
dL_dg = (2 / k) * torch.exp(g) * (E @ dL_df)
dL_da = C @ dL_dg
dL_db = -D @ dL_dg

# Update parameters
print(a_hat)
a_hat = a_hat - 0.1 * dL_da
print(a_hat)

print(b_hat)
b_hat = b_hat - 0.1 * dL_db
print(b_hat)


  C = torch.tensor(solver.members)
  D = torch.tensor(solver.dists)
  E = torch.tensor(solver.gene_to_seq)


RuntimeError: expected scalar type Float but found Double

In [115]:
b_hat.grad

  return self._grad


In [102]:
solver.members

tensor([[1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1.,
         1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0.,