# Gradients of the matrices A and B
> Updating gradients for matrix version of solver

In [2]:
import sys
sys.path.append('/home/phil/aptr')
%load_ext autoreload
%autoreload 2

In [12]:
import torch

from src.simulation_new import make_tables
from src.database import RnaDB
from src.torch_solver import TorchSolver


In [30]:
rnadb = RnaDB()

abundances, ptrs, otus = make_tables(n_genomes=4, n_samples=3)

In [31]:
solver = TorchSolver(
    genomes = rnadb.generate_genome_objects(list(abundances.index))[0],
    coverages = otus
)

Old gradients:
* L = MSE(f_predicted, f_observed)
* dL/dg = 2/k exp(g) E(f_predicted - f_observed)
* dL/da = CdL/dg
* dL/db = -Ddl/dg

In [134]:
solver.A_hat = torch.rand(size=abundances.shape, requires_grad=True)
solver.B_hat = torch.rand(size=ptrs.shape, requires_grad=True)
f_hat = solver(solver.A_hat, solver.B_hat)
loss = torch.sum((f_hat - solver.coverages)**2)
loss.backward()

print("A")
print(solver.A_hat.grad)
print("")
print("B")
print(solver.B_hat.grad)

A
tensor([[100.6545,  50.3325,  66.6091],
        [129.7197, 145.3194, 172.0185],
        [ 60.7192, 294.3960,  86.3913],
        [218.2977, 406.9335,  94.7381]])

B
tensor([[ -35.7292,  -18.9603,  -27.1860],
        [ -70.1071,  -79.7296,  -90.2138],
        [ -24.9466, -121.3563,  -37.9512],
        [ -68.4753, -118.7146,  -30.2694]])


In [190]:
from re import L

n = solver.n
s = solver.s
m = solver.m
k = solver.k
print(f"n: {n}, s: {s}, m: {m}, k: {k}")

A = solver.A_hat
B = solver.B_hat
C = solver.members
D = solver.dists
E = solver.gene_to_seq
G = C @ A + 1 - D @ B
F = solver.coverages
F_hat = f_hat

n: 4, s: 3, m: 14, k: 11


In [191]:
dL_dA = 2 * C.T @ (torch.exp(G) * (E.T @ (F_hat - F)))
dL_dA / solver.A_hat.grad

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], grad_fn=<DivBackward0>)

In [192]:
dL_dB = -D.T @ (torch.exp(G) * (E.T @ (F_hat - F))) * 2
dL_dB / solver.B_hat.grad


tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], grad_fn=<DivBackward0>)