# Gradients of the matrices A and B
> Updating gradients for matrix version of solver

In [1]:
import sys
sys.path.append('/home/phil/aptr')
%load_ext autoreload
%autoreload 2

In [4]:
import torch
import numpy as np

from src.simulation_new import make_tables
from src.database import RnaDB
from src.torch_solver import TorchSolver


In [5]:
rnadb = RnaDB()

np.random.seed(42)
abundances, ptrs, otus = make_tables(n_genomes=4, n_samples=3)

In [6]:
solver = TorchSolver(
    genomes = rnadb.generate_genome_objects(list(abundances.index))[0],
    coverages = otus
)

Old gradients:
* L = MSE(f_predicted, f_observed)
* dL/dg = 2/k exp(g) E(f_predicted - f_observed)
* dL/da = CdL/dg
* dL/db = -Ddl/dg

In [10]:
torch.manual_seed(42)
solver.A_hat = torch.rand(size=abundances.shape, requires_grad=True)
solver.B_hat = torch.rand(size=ptrs.shape, requires_grad=True)
f_hat = solver(solver.A_hat, solver.B_hat)
loss = torch.sum((f_hat - solver.coverages)**2)
loss.backward()

print("A")
print(solver.A_hat.grad)
print("")
print("B")
print(solver.B_hat.grad)

A
tensor([[  87.9920,  110.8270,   35.3428],
        [ 378.0231,   91.6746,  168.5249],
        [1467.0398, 3741.2458, 5757.2402],
        [  66.1114,  370.0370,  118.5928]])

B
tensor([[  -22.4380,   -36.3739,   -10.0928],
        [ -120.7763,   -26.8508,   -52.3566],
        [ -285.2130,  -703.0383, -1118.9692],
        [  -26.8787,  -155.5475,   -43.9320]])


In [11]:
n = solver.n
s = solver.s
m = solver.m
k = solver.k
print(f"n: {n}, s: {s}, m: {m}, k: {k}")

A = solver.A_hat
B = solver.B_hat
C = solver.members
D = solver.dists
E = solver.gene_to_seq
G = C @ A + 1 - D @ B
F = solver.coverages
F_hat = f_hat

n: 4, s: 3, m: 20, k: 12


In [12]:
dL_dA = 2 * C.T @ (torch.exp(G) * (E.T @ (F_hat - F)))
dL_dA / solver.A_hat.grad

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], grad_fn=<DivBackward0>)

In [13]:
dL_dB = -D.T @ (torch.exp(G) * (E.T @ (F_hat - F))) * 2
dL_dB / solver.B_hat.grad


tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], grad_fn=<DivBackward0>)