In [None]:
ENV["PYTHON"] = "/home/stephenz/base_env/bin/python"
using locaTE
using OptimalTransport
using NPZ
using StatsBase
using SparseArrays
using ProgressMeter
import locaTE as scN
using NearestNeighbors
using Graphs
using GraphSignals
using Printf
using Base.Threads
using LinearAlgebra
using DataFrames
using CSV
using Distances
using NNlib
using Discretizers
using EvalMetrics
using Random
using BenchmarkTools

In [3]:
DATA_PATH="../data/"
cd(DATA_PATH)
X = npzread("X.npy")
X = relu.(X .- 1e-2);

In [4]:
try
    global genes = Array(CSV.read("genes.txt", DataFrame)[:, 2])
catch e
    @info "Exception: $e"
    global genes = ["gene_$i" for i = 1:size(X, 2)];
end
X_pca = npzread("X_pca.npy")
X_umap = npzread("X_umap.npy")
X_fle = npzread("X_fle.npy")
P = npzread("P_statot.npy")
C = npzread("C.npy");
dpt = npzread("dpt.npy");
J = npzread("J.npy");
J[diagind(J)] .= 0;
J_escape = npzread("J_ESCAPE.npy");
J_escape[diagind(J_escape)] .= 0;

In [5]:
# select gene subset
id = Colon()
X = X[:, id]
J = J[id, :][:, id]
J_escape = J_escape[id, :][:, id]
genes = genes[id];

In [6]:
R = quadreg(ones(size(X, 1)), ones(size(X, 1)), C, 2.5*mean(C));
gene_idxs = vcat([[j, i]' for i = 1:size(X, 2) for j = 1:size(X, 2)]...);

k = 1
π_unif = fill(1/size(P, 1), size(P, 1))'
Q = (P' .* π_unif)./(π_unif * P)';
P_sp = sparse((P^k))
QT_sp = sparse((Q^k)')
R_sp = sparse(R);

In [7]:
# construct kNN and Laplacian
kdtree = KDTree(X_pca')
idxs, dists = knn(kdtree, X_pca', 25);
A = spzeros(size(X_pca, 1), size(X_pca, 1));
for (i, j) in enumerate(idxs)
    A[i, j] .= 1.0
end
L = sparse(normalized_laplacian(max.(A, A'), Float64));

In [8]:
alg = DiscretizeBayesianBlocks()
disc = scN.discretizations_bulk(X; alg = alg);

In [9]:
@benchmark get_MI($X, $(compute_coupling(X, 1, P_sp, QT_sp, R_sp)), $gene_idxs[:, 1], $gene_idxs[:, 2]; disc = $disc, alg = $alg)

BenchmarkTools.Trial: 6 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m823.366 ms[22m[39m … [35m857.810 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m3.86% … 5.11%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m847.350 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m4.74%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m845.058 ms[22m[39m ± [32m 12.784 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m4.66% ± 0.43%

  [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m [39m [39m [39m [39m [39m [39m█[34m [39m[39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m [39m [39m [39m [39m [39m█[39m [39m [39m [39m█[39m [39m 
  [39m█[39m▁[39m▁[39m▁

In [10]:
# directed inference using GPU
using CUDA
disc_max_size = maximum(map(x -> length(x[1])-1, disc))
N_blocks = 1
joint_cache = get_joint_cache(size(X, 2) ÷ N_blocks, disc_max_size);
ids_cu = hcat(map(x -> x[2], disc) ...) |> cu;

In [11]:
P_cu = cu(Array(P_sp))
QT_cu = cu(Array(QT_sp))
R_cu = cu(Array(R_sp));

In [15]:
mi_all_gpu = zeros(Float32, size(X, 1), size(X, 2), size(X, 2)) |> cu
i = 1
gamma, idx0, idx1 = scN.getcoupling_dense_trimmed(i, P_cu, QT_cu, R_cu) 
(N_x, N_y), (offset_x, offset_y) = first(scN.getblocks(size(X, 2), N_blocks, N_blocks))

@benchmark get_MI!($(view(mi_all_gpu, i, :, :)), $joint_cache, $gamma, $size(X, 2), 
                    $ids_cu[idx0, :], $ids_cu[idx1, :]; 
                    offset_x = $offset_x, N_x = $N_x, offset_y = $offset_y, N_y = $N_y)

BenchmarkTools.Trial: 858 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m4.475 ms[22m[39m … [35m51.671 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 49.46%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m5.771 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m5.820 ms[22m[39m ± [32m 1.633 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m0.51% ±  1.69%

  [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m [39m▂[39m▁[39m▃[39m▂[39m [39m [39m▁[39m▄[39m█[39m▄[34m▇[39m[39m▄[32m▆[39m[39m▅[39m▆[39m▅[39m▆[39m▅[39m▂[39m [39m [39m [39m [39m [39m 
  [39m▂[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁