### Model assessment
Here we compared the weights learned by several model iterations, to highlight
- Performance user parameter combinations
- Interpretation based on TF regulators
- Coherence or learned graph-weights when assisting RNA-weights, or only using ATAC-weights

In [None]:
%load_ext autoreload
%autoreload 2

: 

In [None]:
cd ~/workspace/theislab/mubind/docs/notebooks/scatac

: 

In [None]:
import torch
import mubind as mb
import scanpy as sc

: 

In [None]:
# load models
model_by_logdynamic = {}
for use_logdynamic in [False, True]:
    p = 'pancreas_multiome_use_logdynamic_%i.pth' % use_logdynamic
    print(p)
    model_by_logdynamic[use_logdynamic] = torch.load(p)

: 

In [None]:
ad = sc.read_h5ad('atac_train.h5ad')
rna_sample = sc.read_h5ad('rna_sample_train.h5ad')

: 

In [None]:
import pickle
train = pickle.load(open('train_dataloader.pkl', 'rb'))

: 

In [None]:
%load_ext line_profiler

: 

In [None]:
# load the pancreas multiome dataset
rna, atac = mb.datasets.pancreas_multiome() # data_directory='../../../annotations/scatac')


: 

In [None]:
# %lprun -f model.forward model.optimize_iterative(train, n_epochs=10, skip_kernels=list([0]) + list(range(2, 500)), opt_kernel_shift=[0, 0] + [0] * (n_kernels), opt_kernel_length=[0, 0] + [0] * (n_kernels))

: 

In [None]:
# %lprun -f model.binding_modes.forward model.optimize_iterative(train, n_epochs=10, skip_kernels=list([0]) + list(range(2, 500)), opt_kernel_shift=[0, 0] + [0] * (n_kernels), opt_kernel_length=[0, 0] + [0] * (n_kernels))

: 

In [None]:
import matplotlib.pyplot as plt

: 

In [None]:

for optimize_log_dynamic in model_by_logdynamic:
    model = model_by_logdynamic[optimize_log_dynamic]
    print(optimize_log_dynamic)
    from matplotlib import rcParams
    rcParams['figure.figsize'] = 20, 5
    rcParams['figure.dpi'] = 100
    mb.pl.logo(model, n_cols=3, show=True, n_rows=6, stop_at=4) #  log=True)
    plt.show()


: 

In [None]:
for optimize_log_dynamic in model_by_logdynamic:
    if not optimize_log_dynamic:
        continue
    model = model_by_logdynamic[optimize_log_dynamic]
    print(optimize_log_dynamic)

    tsum = torch.sum
    texp = torch.exp
    tspa = torch.sparse_coo_tensor
    tsmm = torch.sparse.mm
    t = torch.transpose

    # connectivities
    C = model.graph_module.conn_sparse
    a_ind = C.indices()

    log_dynamic = model.graph_module.log_dynamic
    D = model.graph_module.log_dynamic
    D_tril = tspa(a_ind, D, C.shape)  # .requires_grad_(True).cuda()
    D_triu = tspa(a_ind, -D, C.shape)  # .requires_grad_(True).cuda()
    D = D_tril + t(D_triu, 0, 1)
    # log_dynamic = log_dynamic + -torch.transpose(log_dynamic, 0, 1)
    # triu_indices = torch.triu_indices(row=n_rounds, col=n_rounds, offset=1)
    D

    import seaborn as sns
    mb.pl.set_rcParams({'figure.figsize': [3, 3]})
    sns.heatmap(D.to_dense().detach().cpu(), cmap='RdBu_r')
    plt.show()

: 

In [None]:
model = model_by_logdynamic[1]

: 

In [None]:
mb.pl.set_rcParams({'figure.figsize': [12, 3], 'figure.dpi': 110})
plt.subplot(1, 4, 1)
plt.plot(model.loss_history_log_dynamic)
plt.ylabel('log dynamic loss')
plt.subplot(1, 4, 2)
plt.plot(model.loss_history)
plt.ylabel('overall loss')
plt.subplot(1, 4, 3)
plt.plot(model.loss_history_sym_weights)
plt.ylabel('similar weights loss')
plt.tight_layout()
plt.show()

: 

In [None]:
import pandas as pd
import numpy as np

: 

In [None]:
rcParams['figure.figsize'] = 3, 5
r2_all = []
for optimize_log_dynamic in model_by_logdynamic:
    print(optimize_log_dynamic)
    model = model_by_logdynamic[optimize_log_dynamic]
    # contributions per newly added kernel
    import seaborn as sns
    if len(model.best_r2_by_new_filter) != 0:
        r2 = pd.DataFrame(model.best_r2_by_new_filter, columns=['r2']).reset_index()
        r2['opt_log_dynamic'] = optimize_log_dynamic
        r2_all.append(r2)

if len(r2_all) > 0:
    r2_all = pd.concat(r2_all)
    rcParams['figure.figsize'] = 3, 3
    rcParams['figure.dpi'] = 80
    ax = sns.barplot(data=r2_all, x='index', y='r2', hue='opt_log_dynamic', )
    sns.move_legend(ax, "lower center", bbox_to_anchor=(.4, 1), ncol=3, title=None, frameon=False)

    plt.xlabel('number of filters in model')
    plt.show()

: 

In [None]:

model = model_by_logdynamic[True]

torch.set_printoptions(precision=2)
dynamic_score = D.to_dense().detach().cpu().sum(axis=0)
# dyn_score
dynamic_score = dynamic_score
dynamic_score = (dynamic_score - dynamic_score.min()) / (dynamic_score.max() - dynamic_score.min())
ad.obs['dynamic_score'] = dynamic_score

ad.obs['dynamic_score_cluster'] = np.where(dynamic_score > dynamic_score.mean(), 'dynamic', 'static')
z1 = np.where(((dynamic_score - dynamic_score.mean()) / dynamic_score.std()) > 1, 'dynamic', 'static')
z2 = np.where(((dynamic_score - dynamic_score.mean()) / dynamic_score.std()) > 2, 'dynamic', 'static')

ad.obs['dynamic_score_z1'] = z1
ad.obs['dynamic_score_z2'] = z2


: 

In [None]:
ad.obs['dynamic_score'].describe()

: 

In [None]:
ad.obs['dynamic_score_abs'] = ad.obs['dynamic_score'].abs()
sc.pl.umap(ad, color='dynamic_score_abs', color_map='Reds', vmin=.45)

: 

In [None]:


# contributions per newly added kernel
mb.pl.set_rcParams({'figure.figsize': [5, 5], 'figure.dpi': 90})
sc.pl.umap(ad, color=['dynamic_score'], cmap='RdBu_r', sort_order=True)
sc.pl.umap(ad, color=['dynamic_score_z1'], cmap='RdBu_r', sort_order=True)

sc.tl.embedding_density(ad, basis='umap', groupby='dynamic_score_z1')
sc.pl.embedding_density(ad, basis='umap', key='umap_density_dynamic_score_z1', group='dynamic') # basis='umap', groupby='dynamic_score_cluster')
sc.tl.embedding_density(ad, basis='umap', groupby='dynamic_score_z2')
sc.pl.embedding_density(ad, basis='umap', key='umap_density_dynamic_score_z2', group='dynamic', color_map='viridis') # basis='umap', groupby='dynamic_score_cluster')

: 

In [None]:
import seaborn as sns
umap = ad.obsm['X_umap']
sns.histplot(x=umap[:, 0], y=umap[:, 1], bins=50, cmap='PiYG')

: 

In [None]:
plt.pcolormesh(
    np.histogram2d(umap[:, 0], umap[:, 1], bins=50)[0]
)

: 

In [None]:
x, y = np.meshgrid(umap[:, 0], umap[:, 1])

: 

In [None]:
x = umap[:,1] # array_txt[:,0]
y = umap[:,1] # array_txt[:,1]
z = ad.obs['dynamic_score'].values # array_txt[:,2]


: 

In [None]:
sc.pl.umap(ad, color='dynamic_score')

: 

In [None]:
import matplotlib.pyplot as plt
import numpy as np
rcParams['figure.figsize'] = 5, 3

# generate 2 2d grids for the x & y bounds
y, x = np.meshgrid(np.linspace(-3, 3, 100), np.linspace(-3, 3, 100))
z = (1 - x / 2. + x ** 5 + y ** 3) * np.exp(-x ** 2 - y ** 2)
z = z[:-1, :-1]
z_min, z_max = -np.abs(z).max(), np.abs(z).max()
fig, ax = plt.subplots()
c = ax.pcolormesh(x, y, z, cmap='RdBu', vmin=z_min, vmax=z_max)
ax.set_title('pcolormesh')
# set the limits of the plot to the limits of the data
ax.axis([x.min(), x.max(), y.min(), y.max()])
fig.colorbar(c, ax=ax)

plt.show()

: 

In [None]:
import itertools
import numpy as np

def grid(x, y, z, size_x=1, size_y=1):

    def pairwise(iterable):
        "s -> (s0,s1), (s1,s2), (s2, s3), ..."
        a, b = itertools.tee(iterable)
        next(b, None)
        return zip(a, b)

    minx, maxx = int(min(x)), int(max(x)) + 1
    miny, maxy = int(min(y)), int(max(y)) + 1

    result = []
    x_edges = pairwise(np.arange(minx, maxx + 1, size_x))
    for xleft, xright in x_edges:
        xmask = np.logical_and(x >= xleft, x < xright)
        y_edges = pairwise(np.arange(miny, maxy + 1, size_y))
        for yleft, yright in y_edges:
            ymask = np.logical_and(y >= yleft, y < yright)
            cell = z[np.logical_and(xmask, ymask)]
            result.append(cell.sum())

    result = np.array(result).reshape((maxx - minx, maxy - miny))
    return np.flip(result.T, 0)


: 

In [None]:
grid_dyn_score = grid(umap[:,0], umap[:,1], ad.obs['dynamic_score'], size_x=1, size_y=1)
grid_counts = grid(umap[:,0], umap[:,1], ad.obs['celltype'].cat.codes.values, size_x=1, size_y=1)

: 

In [None]:

sns.heatmap(grid_dyn_score, cmap='Reds')
plt.show()
sns.heatmap(grid_counts, cmap='Reds')
plt.show()

sc.pl.umap(ad, color='celltype')

: 

In [None]:
# for optimize_log_dynamic in model_by_logdynamic:
#     mb.pl.set_rcParams({'figure.figsize': [3, 3], 'figure.dpi': 90})
#     print(optimize_log_dynamic)
#     model = model_by_logdynamic[optimize_log_dynamic]
#     mb.pl.kmer_enrichment(model, train, log_scale=False, style='scatter', ylab='t1', xlab='p1', k=8)
#     plt.show()

#     mb.pl.set_rcParams({'figure.figsize': [10, 7], 'figure.dpi': 90})
#     mb.pl.logo(model,
#                title=False,
#                xticks=False,
#                rowspan_dinuc=0,
#                rowspan_mono=1,
#                n_rows=12,
#                n_cols=3,
#                stop_at=20) # n_cols=len(reduced_groups))
#     plt.show()


: 

In [None]:
model = model_by_logdynamic[True]

: 

In [None]:
G = model.graph_module.conn_sparse.detach().cpu().to_dense() # (C, C)

: 

In [None]:
# number of non_zero weights
len(G[G != 0])

: 

In [None]:
# output = model(**inputs, use_conn=False, return_binding_scores=True)

: 

In [None]:
print('here...')

: 

In [None]:
ad

: 

In [None]:
model = model.cuda()

: 

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device

: 

In [None]:
for optimize_log_dynamic in model_by_logdynamic:
    print(optimize_log_dynamic)
    if not optimize_log_dynamic:
        continue
    model = model_by_logdynamic[optimize_log_dynamic].cuda()

    umap = ad.obsm['X_umap'].copy()
    umap = np.sort(umap, 0)

    x = umap[:,0]
    y = umap[:,1]
    X, Y = np.meshgrid(x, y)

    n_points = x.shape[0]
    # x-component to the right
    u = np.ones((n_points, n_points))
    # y-component zero
    v = np.zeros((n_points, n_points))

    pred = []
    for i, batch in enumerate(train):
        # Get a batch and potentially send it to GPU memory.
        mononuc = batch["mononuc"].to(device)
        # print(i, mononuc.shape)
        b = batch["batch"].to(device) if "batch" in batch else None
        rounds = batch["rounds"].to(device) if "rounds" in batch else None
        countsum = batch["countsum"].to(device) if "countsum" in batch else None
        seq = batch["seq"] if "seq" in batch else None
        residues = batch["residues"].to(device) if "residues" in batch else None
        if residues is not None and train.dataset.store_rev:
            mononuc_rev = batch["mononuc_rev"].to(device)
            inputs = {"mono": mononuc, "mono_rev": mononuc_rev, "batch": b, "countsum": countsum,
                      "residues": residues}
        elif residues is not None:
            inputs = {"mono": mononuc, "batch": b, "countsum": countsum, "residues": residues}
        elif train.dataset.store_rev:
            mononuc_rev = batch["mononuc_rev"].to(device)
            inputs = {"mono": mononuc, "mono_rev": mononuc_rev, "batch": b, "countsum": countsum}
        else:
            inputs = {"mono": mononuc, "batch": b, "countsum": countsum}

        inputs['scale_countsum'] = model.datatype == 'selex'
        output = model(**inputs, use_conn=False, return_binding_scores=True)

        output = output.cpu().detach().numpy()
        print('here...')
        print(output.shape)

        print(output.sum())
        pred.append(output)

    # pred = np.concatenate(pred).T

    binding_scores = np.concatenate(pred).T


    # ad.layers['velocity'] = pred
    
    # conn = model.graph_module.conn_sparse.detach().cpu()
    # conn = model.graph_module.conn_sparse.detach().cpu().to_dense()
    # v = conn.sum(axis=1)
    # ad.layers['velocity'] = torch.stack([v,] * ad.shape[1], axis=1).numpy()
    
    # ad.layers['counts'] = ad.X

    # mb.pl.set_rcParams({'figure.figsize': [5, 4], 'figure.dpi': 90})
    # plt.hist(model.graph_module.conn_sparse.values().detach().cpu().numpy())
    # plt.show()

    # import scvelo as scv

    # sc.pp.neighbors(ad)

    # # scv.tl.velocity_graph(ad, vkey='velocity', xkey='counts')
    # # ad.layers['velocity'] = ad.obs['dynamic_score']
    
    # scv.tl.velocity_graph(ad, vkey='velocity', xkey='counts')
    # ax = scv.pl.velocity_embedding_stream(ad, color='celltype', show=False) #  X_grid='X_umap', V=V)


X = ad.X.A
G @ binding_scores
np.random.shuffle(binding_scores)

: 

In [None]:
import scvelo as scv

: 

In [None]:
ad
ad.layers['velocity'] = binding_scores
# scv.tl.velocity_graph(ad, vkey='velocity', xkey='counts')
# ax = scv.pl.velocity_embedding_stream(ad, color='celltype', show=False) #  X_grid='X_umap', V=V)

: 

In [None]:
np.random.shuffle(binding_scores)
binding_scores

: 

In [None]:
try:
    scv.pl.velocity_embedding_stream(rna_sample, color='celltype')
except Exception:
    print("sample too small.")

: 

In [None]:
# np.random.shuffle(binding_scores)
# ad.layers['velocity'] = binding_scores
# scv.tl.velocity_graph(ad, vkey='velocity', xkey='counts')
# ax = scv.pl.velocity_embedding_stream(ad, color='celltype', show=False) #  X_grid='X_umap', V=V)

: 

In [None]:
import seaborn as sns
act = model.get_log_activities().detach().cpu().squeeze(0)
sns.heatmap(act, cmap='RdBu_r', cbar_kws={'label': 'activities'})

: 

In [None]:
scv.pl.velocity_graph(rna)

: 

In [None]:

# ax = scv.pl.velocity_embedding_stream(ad,
#                                       color='celltype',
#                                       # density=2,
#                                       arrow_color='black',
#                                       n_neighbors=15) # show=False) #  X_grid='X_umap', V=V)

: 

In [None]:
# ax = scv.pl.velocity_embedding_stream(ad, color='celltype', density=2, arrow_color='black', n_neighbors=15) # show=False) #  X_grid='X_umap', V=V)

: 

In [None]:
# scv.pl.velocity_embedding_stream(ad, color='celltype', n_neighbors=15) #  X_grid='X_umap', V=V)

: 

## Study the asssociations betweeen obtained weights and cluster-specific transcription factors

Load information from archetypes DB (Vierstra et al 2020)

In [None]:
rna_sample, ad.shape

: 

In [None]:
rna_sel = rna_sample # rna[rna.obs_names.isin(ad.obs_names),:].copy()
rna_sel.shape

: 

In [None]:
pwd

: 

In [None]:
import bindome as bd
bd.constants.ANNOTATIONS_DIRECTORY = 'annotations'

anno = mb.datasets.archetypes_anno()

: 

In [None]:
rna_sel.shape
anno.sort_values('Cluster_ID')

: 

In [None]:
for optimize_log_dynamic in model_by_logdynamic:
    print(optimize_log_dynamic)
    model = model_by_logdynamic[optimize_log_dynamic]
    log_act = torch.stack(list(model.activities.log_activities), dim=1).squeeze(0).T
    log_act = pd.DataFrame(log_act.detach().cpu().numpy())
    # log_act.columns = anno['Seed_motif'][2]
    # log_act.columns = ['intercept', 'dinuc_bias'] + list(anno['Seed_motif'].values)
    log_act.index = ad.obs_names
    ad.obsm['mubind_activities'] = log_act

    mb.pl.set_rcParams({'figure.figsize': [5, 3], 'figure.dpi': 110})
    delta = (log_act.max(axis=0) - log_act.min(axis=0))
    var = log_act.var(axis=0)
    plt.scatter(delta, var, color='gray', edgecolors='black')
    plt.xlabel('effect size')
    plt.ylabel('variability')
    plt.title('TF modules (by score) | GraphLayer = %i' % optimize_log_dynamic )
    plt.show()

: 

In [None]:
# unique names for annotation
names = anno['Name'] # .sort_values('Name')
added = dict()
new_name = []
for name in names:
    if not name in added:
        new_name.append(name)
        added[name] = 0
    else:
        new_name.append(name + '_%i' % added[name])
        added[name] += 1
anno['Name_unique'] = new_name


: 

In [None]:
from scipy.stats import spearmanr
res = []
for optimize_log_dynamic in model_by_logdynamic:
    if not optimize_log_dynamic:
        continue

    model = model_by_logdynamic[optimize_log_dynamic]
    log_act = torch.stack(list(model.activities.log_activities), dim=1).squeeze(0).T
    log_act = pd.DataFrame(log_act.detach().cpu().numpy())
    # log_act.columns = anno['Seed_motif'][2]
    log_act.columns = ['intercept', 'dinuc_bias'] + list(range(1, 287))
    log_act.index = ad.obs_names
    ad.obsm['mubind_activities'] = log_act

    mb.pl.set_rcParams({'figure.figsize': [5, 3], 'figure.dpi': 90})
    delta = (log_act.max(axis=0) - log_act.min(axis=0))
    var = log_act.var(axis=0)
    plt.scatter(delta, var)
    plt.xlabel('min-max range')
    plt.ylabel('variability')
    plt.title('TF modules (by score)')
    plt.show()

    for c in log_act:
        a = log_act[c]
        b = ad.obs['dynamic_score'].values
        # print(a.shape, b.shape)
        res.append([optimize_log_dynamic, c] + list(spearmanr(a, b)))

res = pd.DataFrame(res, columns=['opt_log_dynamic', 'archetype_id', 'spearman', 'p_val'])



: 

In [None]:

# add archetypes name
meta = pd.DataFrame(pd.concat([delta, var], axis=1))
meta.columns = ['max_effect', 'variability']
meta['name'] = ['intercept', 'dinuc_bias'] + list(range(1, 287))
clu = mb.datasets.archetypes_clu()
meta['archetypes_name'] = meta['name'].map(anno.set_index('Cluster_ID')['Name_unique'])
meta['archetypes_name'] = np.where(pd.isnull(meta['archetypes_name']), meta['name'], meta['archetypes_name'])

meta['archetypes_seed'] = meta['name'].map(anno.set_index('Cluster_ID')['Seed_motif'])
meta = meta.sort_values('max_effect', ascending=0)
meta

res = res.merge(meta, left_on='archetype_id', right_on='name')
res = res.sort_values('p_val', ascending=True)

: 

In [None]:
name_by_filter_id = meta['archetypes_name'].to_dict()
# name_by_filter_id

: 

## Observe general scores per case

: 

In [None]:
rcParams['figure.figsize'] =3, 5
sns.barplot(data=res.sort_values('max_effect', ascending=False).head(25), x='max_effect', y='archetypes_name', color='orange')

: 

: 

In [None]:
res

: 

In [None]:
# visualize the logos as obtained by the model in each step
mb.pl.set_rcParams({'figure.figsize': [5, 20], 'figure.dpi': 90})
mb.pl.logo(model, title=False, xticks=False, rowspan_dinuc=0, rowspan_mono=1, n_rows=40, n_cols=1, stop_at=20)
           # n_rows=len(res.head(20).index),

: 

In [None]:
mb.pl.set_rcParams({'figure.figsize': [2, 20], 'figure.dpi': 90})
mb.pl.logo(model, title=False, xticks=False, rowspan_dinuc=0, rowspan_mono=1, n_rows=40,
           # n_rows=len(res.head(20).index),
           n_cols=1, order=res.head(20).index) # n_cols=len(reduced_groups))
plt.tight_layout()
plt.show()

: 

In [None]:
import resource
print('total GB used:', resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1e6)


: 

In [None]:
# for k in ad.obsm['log_activities'].iloc[:,2:]:
#     ad.obs[str(k)] = ad.obsm['log_activities'][k]
# sc.pl.umap(ad, color=map(str, ad.obsm['log_activities'].iloc[:,2:]), cmap='Reds')


: 

In [None]:
rna_sel.obsm['X_umap'] = ad.obsm['X_umap']

: 

In [None]:
def find_varname(ad, k):
    return ad.var_names[ad.var_names.str.upper().str.contains(k.upper())]

: 

In [None]:
from scipy.stats import spearmanr, pearsonr

: 

Calculate global correlations between the activities obtained per motif and gene-specific expression

In [None]:

all_targets = set()
for optimize_log_dynamic in model_by_logdynamic:
    print(optimize_log_dynamic)
    model = model_by_logdynamic[optimize_log_dynamic]

    log_act = torch.stack(list(model.activities.log_activities), dim=1).squeeze(0).T
    log_act = pd.DataFrame(log_act.detach().cpu().numpy())
    log_act.index = ad.obs_names

    ad.obsm['log_activities'] = log_act
    ad.obsm['log_activities'].columns = ['intercept', 'dinuc_bias'] + list(range(1, 287))

    # collect all targets
    for k in ad.obsm['log_activities'].iloc[:,2:]:
        log_act = ad.obsm['log_activities'][k].values
        names = set()
        clu_sel = clu[clu['Cluster_ID'] == k]['Motif']
        for g in clu_sel:
            names.add(g.split('_')[0].split('.')[0].split('+')[0].upper())
        for g in anno[anno['Cluster_ID'] == k]['Seed_motif']:
            names.add(g.split('_')[0].split('.')[0])
        # print(k, names)
        targets = set()
        for name in names:
            target = find_varname(rna_sel, name)
            for t in target:
                all_targets.add(t)
        if len(targets) > 0 and False:
            sc.pl.umap(rna_sel, color=targets, cmap='Reds')


: 

In [None]:


print('association between motif activities and related TF targets ')
res = []
print(len(all_targets))

rna_sel_df = rna_sel.to_df()

for optimize_log_dynamic in model_by_logdynamic:
    print('use GraphLayer = %i' % optimize_log_dynamic)
    model = model_by_logdynamic[optimize_log_dynamic]

    log_act = torch.stack(list(model.activities.log_activities), dim=1).squeeze(0).T
    log_act = pd.DataFrame(log_act.detach().cpu().numpy())
    log_act.index = ad.obs_names

    ad.obsm['log_activities'] = log_act
    ad.obsm['log_activities'].columns = ['intercept', 'dinuc_bias'] + list(range(1, 287))

    log_act = ad.obsm['log_activities'][k].values
    for ki, k in enumerate(ad.obsm['log_activities'].iloc[:,2:]):
        # print(ki)
        if ki % 30 == 0:
            print(ki)
        names = set()
        clu_sel = clu[clu['Cluster_ID'] == k]['Motif']
        for g in clu_sel:
            names.add(g.split('_')[0].split('.')[0].split('+')[0].upper())
        for g in anno[anno['Cluster_ID'] == k]['Seed_motif']:
            names.add(g.split('_')[0].split('.')[0])
        # print(k, names)
        next_targets = set()
        for name in names:
            target = find_varname(rna_sel, name)
            # print(name, target)
            for t in target:
                next_targets.add(t)

        # for t in all_targets:
        for t in set(all_targets).intersection(next_targets):
            gex = rna_sel_df[[t]].to_numpy() # rna_sel_df[t].A
            assert gex.shape[1] == 1
            gex = gex.flatten()
            # print(log_act.shape, gex.shape)
            # print(t, pearsonr(log_act, gex))
            res.append([ki, optimize_log_dynamic, k, t, t in next_targets] +
                       list(spearmanr(log_act, gex)))
            
res = pd.DataFrame(res, columns=['filter_id', 'opt_log_dynamic', 'archetype_id', 'gene_name', 'matched', 'spearman', 'p_val'])

# p-values
res['module_name'] = res['archetype_id'].map(anno.set_index('Cluster_ID')['Name'].to_dict())
res['p_val'] = np.where(pd.isnull(res['p_val']), 1.0, res['p_val'])

# p-val adjust
from statsmodels.stats.multitest import fdrcorrection
res['p_adj'] = fdrcorrection(res['p_val'])[1]
res[res['p_adj'] < 0.1]

: 

In [None]:
genes_by_module_name = res.groupby(['module_name'])['gene_name'].apply(lambda grp: list(grp.value_counts().index)).to_dict()
# genes_by_module_name

: 

In [None]:
res.sort_values('p_adj')

: 

In [None]:
res['k'] = res['gene_name'] + '_' + res['archetype_id'].astype(str)
df2 = res.pivot(index='k', columns='opt_log_dynamic', values='spearman')
# df2 = res # .pivot(index='k', columns='opt_log_dynamic', values='spearman')

: 

In [None]:
df2

: 

In [None]:
mb.pl.set_rcParams({'figure.figsize': [5, 4], 'figure.dpi': 120})
# df2 = df2.sort_values('matched', ascending=True)
# plt.scatter(df2[True], df2[True],
#             color=np.where(df2['matched'], 'blue', 'gray'),
#             s=np.where(df2['matched'], 30, 5))
# plt.xlabel('TF activity (graph = off)')
# plt.ylabel('TF activity (graph = on)')
# plt.axhline(0, color='gray', ls='--', zorder=0)
# plt.axvline(0, color='gray', ls='--', zorder=0)


: 

In [None]:
# df2[df2['matched'] == True].sort_values(True, ascending=False)

: 

In [None]:
# res['arch_name'] = name_by_filter_id

: 

In [None]:
res

: 

In [None]:
rcParams['figure.figsize'] = 4, 4
rcParams['figure.dpi'] = 90

for optimize_log_dynamic, grp in res.groupby('opt_log_dynamic'):
    grp['minus_log10_pval'] = -np.log10(grp['p_val'])
    grp = grp.sort_values('matched')
    plt.scatter(grp['spearman'], grp['minus_log10_pval'],
                s=np.power(grp['minus_log10_pval'], 2), color=np.where(grp['matched'], 'red', 'blue'))
    plt.ylabel('-log(p-adj)')
    plt.xlabel('spearman')
    plt.title('corr(filter, GEX) | GraphLayer = %i' % optimize_log_dynamic)
    plt.axhline(1, ls='--', color='red', lw=0.6)
    plt.show()

: 

In [None]:
# sc.pl.umap(ad, color=[96], cmap='RdBu_r')
# sc.pl.umap(rna_sel, color=['Ehf', 'Ergic2'], cmap='plasma')

: 

In [None]:
rcParams['figure.figsize'] = 3, 3
rcParams['figure.dpi'] = 90
plt.hist(res['p_val'], color='gray', bins=20, label='raw', alpha=.5, edgecolor = 'black')
plt.hist(res['p_adj'], color='red', bins=20, label='adjusted (BH)', alpha=.5, edgecolor = 'black')
plt.xlabel('p-value')
plt.legend()
plt.ylabel('# associations')

: 

In [None]:
res[res['p_adj'] < 0.05]

: 

In [None]:
pval_thr = 1e-5
sel_genes = set(list(res[res['p_adj'] < pval_thr]['gene_name']))

: 

In [None]:
log_act = ad.obsm['log_activities'].copy()

: 

In [None]:
cols_act = ['intercept', 'dinuc_bias'] + [name_by_filter_id[k] for k in log_act.columns[2:]]
log_act.columns = cols_act

: 

In [None]:
import anndata
ad_act = anndata.AnnData(log_act)
ad_act.obsm['X_umap'] = ad.obsm['X_umap']
ad_act.obs = ad.obs

: 

In [None]:
sc.pl.umap(ad_act, color='celltype')


: 

Rank genes groups using the annotation

In [None]:
sc.tl.rank_genes_groups(ad_act, 'celltype')
rkg_df = []
for ct in ad_act.obs['celltype'].values.unique():
    print(ct)
    rkg_df2 = sc.get.rank_genes_groups_df(ad_act, ct)
    rkg_df2['celltype'] = ct
    rkg_df.append(rkg_df2)
rkg_df = pd.concat(rkg_df)
rkg_df['module_name'] = rkg_df['names'].map(anno.set_index('Cluster_ID')['Name'].to_dict())
rkg_df['module_name'] = np.where(~pd.isnull(rkg_df['module_name']), rkg_df['module_name'], rkg_df['names'])
rkg_df.head()

: 

Get top modules

In [None]:
ad_act.var_names = ad_act.var_names.map(rkg_df.set_index('names')['module_name'].to_dict())

: 

In [None]:
sc.tl.rank_genes_groups(ad_act, 'celltype')

: 

In [None]:
rcParams['figure.figsize'] = 3.5, 3.5
rcParams['figure.dpi'] = 80
sc.pl.rank_genes_groups(ad_act)


: 

In [None]:
set(res[(res['p_adj'] < 1e-5)]['k'])

: 

In [None]:
res[res['module_name'].str.contains('HD')].sort_values('p_adj')

: 

In [None]:
mod_names_best = set(rkg_df.sort_values('scores', ascending=False).groupby('celltype').head(5)['module_name'])
best = rkg_df[rkg_df['module_name'].isin(mod_names_best)]
rcParams['figure.dpi'] = 130
sns.clustermap(best.pivot(index='celltype', columns='module_name', values='scores'),
               cbar_kws={'label': 'activity'}, cmap='RdBu_r',
               vmin=-5, vmax=5,
               figsize=[6.2, 5],
               # dpi=100,
               xticklabels=True)


: 

In [None]:
rna_tfs = rna_sel.to_df()[list(set(res['gene_name']))]
rna_tfs['celltype'] = rna_sel.obs['celltype']
mean_tfs = rna_tfs.groupby('celltype').mean()

act_tfs_df = ad_act.to_df()
act_tfs_df['celltype'] = ad_act.obs['celltype']
mean_act_tf = act_tfs_df.groupby('celltype').mean()

: 

In [None]:
corr_celltype = []
for i, c1 in enumerate(mean_act_tf):
    if i % 50 == 0:
        print(i, mean_act_tf.shape[1])
    for j, c2 in enumerate(mean_tfs):
        if not c1 in genes_by_module_name or not c2 in genes_by_module_name[c1]:
            continue
        a = mean_act_tf[c1]
        b = mean_tfs[c2]
        corr_celltype.append([c1, c2, mean_act_tf.index[np.argmax(mean_act_tf[c1])]] + list(pearsonr(a, b)))

corr = pd.DataFrame(corr_celltype,
                    columns=['module_name', 'gene_name', 'cell_type', 'pearsonr', 'p_val'])
corr = corr.sort_values('pearsonr', ascending=False)


: 

: 

In [None]:
# sc.pl.dotplot(rna_sel, groupby='celltype', var_names=list(set(res['gene_name'])))

: 

In [None]:

for ri, r in corr.sort_values('p_val').groupby('cell_type').head(3).iterrows():
    # ad_act.obs['HD/2'] = log_act['HD/2']
    sc.pl.dotplot(ad_act,
                groupby='celltype',
                cmap='Blues',
                var_names=r['module_name'],
                figsize=[2, 3.2],
                colorbar_title='mean activity in group')
    sc.pl.dotplot(rna_sel,
                  groupby='celltype',
                  var_names=r['gene_name'],
                  figsize=[2, 3.2])

: 

In [None]:
### attempt to show together

for key_interaction in varm_ligrec_by_k:
    # if not 'Resident' in key_interaction:
    #     continue
    
    viz = varm_ligrec_by_k[key_interaction]
        
    obs = viz[['group']]
    obs['k2'] = obs.index.str.replace('index_', '').str.replace('neighbor', '') + '_' + obs['group']
    viz = anndata.AnnData(viz[viz.columns[1:]], obs=obs)


    import matplotlib.pyplot as plt
    from matplotlib import rcParams
    from matplotlib import colors
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches

    plt.rcParams['figure.dpi'] = 150
    SMALL_SIZE = 14
    MEDIUM_SIZE = 19
    BIGGER_SIZE = 21
    plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
    plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
    plt.rc('axes', labelsize=SMALL_SIZE)    # fontsize of the x and y labels
    plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
    plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
    plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

    viz_df = viz.to_df()
    print(viz_df.shape)

    viz_df.index = viz.obs['k2']
    z = viz_df.copy()
    for c in z:
        z[c] = (z[c] - z[c].mean()) / z[c].std()

    tree_genes_2 = viz_df.idxmax(axis=1).to_dict()

    tree_genes = {}
    for k in tree_genes_2:
        tree_genes[k] = set([tree_genes_2[k]])

    z = z.T
    tree_genes = {}
    n = 3
    for c in z:
        tree_genes[c] = set(z.sort_values(c, ascending=False).index[:n])
    tree_genes

    diagonal_rectangle_heights = { k: 1 for k in tree_genes}
    #     'AT1'                                  :2,
    #     'AT2'                                  :2,
    #     'Club'                                 :2,
    #     'Multiciliated'                        :2,
    #     'Deuterosomal'                         :2,
    # #    'Basal'                                :2,
    # #    'Smooth muscle'                        :2,
    # #    'Pericytes'                            :2,
    # #    'Adventitial Fibroblast'               :1,
    # #    'Alveolar Fibroblast'                  :2,
    # #    '(Vascular) Endothelial cell'          :2,
    # #    'Cappillary (G) endothelial cell'      :2,
    # #    'Cappillary Aerocyte endothelial cell' :1,
    # #    'Lymphatic EC'                         :1,    
    #     'Lymphocyte'                           :2,
    # #    'DC'                                   :1,
    # #   'Monocytes'                            :1,
    # #    'Alv Macrophage'                       :2,
    #     'Interstitial macrophages'             :2,
    #     'Mast cells'                           :2,
    # }

    dot_color_df = viz_df # [list(set(g for k in tree_genes for g in tree_genes[k]))]
    # dot_color_df[dot_color_df < 0] = -dot_color_df[dot_color_df < 0]
    # dot_color_df.index = viz.obs_names

    sc_ct_list = [k for k in set(viz.obs['k2']) if 'immune' in k]
    ST_ct_list_with_space = [k for k in set(viz.obs['k2']) if 'tumor' in k]

    genes_by_group = {}
    for ri, r in cellphone_interactions.iterrows():
        # print(r)
        a, b = str(r['protein_name_a']).split('_')[0], str(r['protein_name_b']).split('_')[0]

        found_a = False
        found_b = False
        for ct in tree_genes:
            if a in tree_genes[ct]:
                found_a = True
            if b in tree_genes[ct]:
                found_b = True

        key_class = ' '.join(str(r['classification']).split(' ')[2:])
        if key_class == '':
            key_class = 'Others'
        # print(key_class)

        if a != 'nan' and found_a:
            if not key_class in genes_by_group:
                genes_by_group[key_class] = set()
            genes_by_group[key_class].add(a)
        if b != 'nan' and found_b:
            if not key_class in genes_by_group:
                genes_by_group[key_class] = set()
            genes_by_group[key_class].add(b)
        # if a in cpdb_names or b in cpdb_names:


    ## remove duplicates and empty
    order = [v for g in genes_by_group for v in genes_by_group[g]]
    order = list(dict.fromkeys(order))
    # order
    found_genes = set()
    for g in genes_by_group:
        values_group = [v for v in genes_by_group[g]]
        for v in values_group:
            if not v in found_genes:
                found_genes.add(v)
            else:
                genes_by_group[g].remove(v)

    # remove empty gruops
    key_groups = list(genes_by_group.keys())
    for g in key_groups:
        if len(genes_by_group[g]) == 0:
            del genes_by_group[g]

    fig = plt.figure(figsize=(10,5),
                    dpi=150)
    ax = plt.gca()

    reduced_tree_genes = {ct: tree_genes[ct] for ct in diagonal_rectangle_heights}

    print(viz.shape, dot_color_df.shape)
    # one can add arbitrary names
    # reduced_tree_genes['hello'] = reduced_tree_genes['AngiogenicTAMs:_MES2like2_immune']

    vmax = 0.2
    dp = sc.pl.DotPlot(viz,
                    genes_by_group, # reduced_tree_genes,
                    groupby="k2",
                    ax=ax,cmap="bwr",
                    vmin=-vmax,vmax=vmax,
                    var_group_rotation=90, # var_group_positions='ha',
                    dot_color_df=dot_color_df)
                    # standard_scale='var')

    # dp.dot_max = .9
    # dp.dot_min = 0.01
    # dp.smallest_dot=.15
    dp.make_figure()

    # assuming no duplicates, this is the matrix to inspect pairs.
    data = dot_color_df[[v for g in genes_by_group for v in genes_by_group[g]]]

    x_coord = 0

    thr_max = .1
    thr_diff = .1
    diffs = []
    for ci, c in enumerate(data):
        for i in range(0, len(data[c]), 2):
            a, b = data[c][i], data[c][i + 1]
            diffs.append([c, a, b, abs(a - b)])
            if abs(a - b) > thr_diff and abs(max(a, b)) > thr_max and (a > 0 or b > 0):
                print(c, a, b)
                rect = patches.Rectangle((ci, i), 1, 2,
                                        linewidth=2,
                                        linestyle='--',
                                        edgecolor='green',
                                        facecolor='none')#, zorder=1)
                dp.ax_dict["mainplot_ax"].add_patch(rect)

    for ct,height in diagonal_rectangle_heights.items():
        break
        print(ct, height)
        y = dot_color_df.index.get_loc(ct)
        width = len(tree_genes[ct])
        if height != 0:
            rect = patches.Rectangle((x_coord, y), width, height, linewidth=1, edgecolor='black', facecolor='none')#, zorder=1)
            dp.ax_dict["mainplot_ax"].add_patch(rect)
        x_coord += width
        

    old_ytick_labels = dp.ax_dict["mainplot_ax"].get_yticklabels()
    new_ticks = []
    yshift = 0
    for i, lab in enumerate(old_ytick_labels):
        # print(lab)
        x, y = lab.get_position()
        old_text = lab._text
        if lab._text in sc_ct_list:
            lab._text = '_'.join(lab._text.split('_')[:2]) + ' ' + lab._text.split('_')[-1]
            lab.set_color('purple')
            lab.set_x(x)
            lab.set_y(y + yshift)
            # print(x, y)
            new_ticks.append(y + yshift)
        elif lab._text in ST_ct_list_with_space:
            lab._text = old_text.split('_')[-1]
            lab.set_color('green')
            new_ticks.append(y + yshift)
        old_ytick_labels[i] = lab

    # print(old_ytick_labels)
    dp.ax_dict["mainplot_ax"].set_yticklabels(old_ytick_labels)
    print(dp.ax_dict["mainplot_ax"].get_yticklabels())
    print(dp.ax_dict["mainplot_ax"].set_yticks(new_ticks))
    print(dp.ax_dict["mainplot_ax"].get_yticks())


    # plt.show()
    # fig.savefig("./dotplot_sc_and_ST.png", bbox_inches = "tight")

    # break


: 

In [None]:
sc.pl.umap(rna_sel, color='Isl1', cmap='Blues')


: 

In [None]:
ad_act

: 

In [None]:
from matplotlib import rcParams, cm
cmap = cm.get_cmap('YlOrRd')
cmap.set_over('black')
cmap.set_under('lightgray')

: 

In [None]:
cmap

: 

In [None]:
sc.pl.embedding(rna_sel, basis='X_umap', color='Ehf', color_map=cmap)


: 