# Toxoplasma gondii Virome graph investigation

### Graph structure

#### Virus nodes

- Only use sOTUs
- All non-centroid palmprints are collapsed into their associated centroid. Centroids aggregate relationships associated to non-centroid palmprints

#### Host nodes

- Only include Taxons with 'family' rank (+ some exceptions)
- All taxons with rank = 'family or rank more specific than family are collapsed to ancestor taxon with a family rank. Any taxons with a missing family ancestor are added directly with display label `Unknown`


#### HAS_HOST edges 

- Edges include SOTUs and Taxons directly associated with SRAs from Toxo dataset
- Edges aggregate weight by averaging `percentIdentity` from non-centroid Palmprints and non-family ranked Taxons
- (TODO) add indirect HAS_HOST edges or edge attributes, i.e. HAS_HOST relationships from associated SOTUs that are not associated to SRAs in Toxo dataset but exist in Serratus database
- (TODO) add weight property to HAS_HOST edges using Fisher's exact test: Toxo dataset count vs. non-toxo dataset counts (consider incorporating HAS_PALMPRINT avg percentIdentity) 

#### Nodes of interest
- Toxoplasma gondii has family `Sarcocystidae` with taxId `5809`
- Palmprints associated to Ruby and Cougar strains can be filtered by using node property `point:isRubyOrCougar` or by palmId `u658323` or `u380516`


### Graph Analysis

#### Community Detection
- Label propagation Algorithm (LPA) nicely clusters communities of Taxons and associated palmprints (`point:communityId`) and can be used to color the Virome
- Transitional palmprints with edges between diffent communities are of interest

#### Centrality
- Centrality algorithms measure the spread of influence of nodes in the network
- Page Rank highlights nodes with high importantance (`point:pageRank`)
- Degree is similarly useful (`point:degree`)
- CELF allows specifying count (seed) of k most important nodes in the network, which typically highlights high degree Taxons (`point:celf`)


#### Transitionary nodes as potential parasites

- Ruby and cougar have Palmprints that are transitionary nodes between clusters, i.e. they have edges between two Taxon families Sarcocystidae, the parent of T Gondii (5809) and Hominidae (9604)
- It may be of interest to analyze other transitionary Palmprint nodes with edges between these and other tax families to understand if they have similar vector/parasitic qualities


## Setup

### Imports and config

In [1]:
# Notebook config
import sys
if '../' not in sys.path:
    sys.path.append("../")
%load_ext dotenv
%dotenv

import collections
import os

import pandas as pd
from datasources.neo4j import gds
from queries import utils
import matplotlib.pyplot as plt
import graphistry

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
graphistry.register(
    api=3,
    username=os.getenv('GRAPHISTRY_USERNAME'),
    password=os.getenv('GRAPHISTRY_PASSWORD'),
)


data_dir_path = '/mnt/graphdata/tgav-data/'


### Parse SRA csvs

In [3]:
df1 = pd.read_csv(data_dir_path + 'toxo_SraRunInfo.csv')
df2 = pd.read_csv(data_dir_path + 'txid5810_SraRunInfo.csv')
df3 = pd.read_csv(data_dir_path + 'txid5810_statbigquery.csv')
df3 = df3.rename(columns={'tax_id': 'TaxID', 'acc': 'Run'})

sra_all = pd.concat([df1[['Run', 'TaxID']], df2[['Run', 'TaxID']], df3[['Run', 'TaxID']]], axis=0)
sra_all = sra_all.drop_duplicates()
sra_all = sra_all.astype({"Run": str, "TaxID": int})
print(sra_all.head())
print(sra_all.shape)

sra_intersection = df1[['Run', 'TaxID']].merge(
    df2[['Run', 'TaxID']],
    left_on='Run',
    right_on='Run',
    how='left',
).dropna()
sra_intersection = sra_intersection.merge(
    df3[['Run', 'TaxID']],
    left_on='Run',
    right_on='Run',
    how='left',
).dropna()
sra_intersection = sra_intersection.astype({"Run": str, "TaxID": int, "TaxID_x": int, "TaxID_y": int })
print(sra_intersection.head())
print(sra_intersection.shape)

           Run  TaxID
0  SRR23381539   5811
1  SRR23381538   5811
2  SRR23381537   5811
3  SRR23381536   5811
4  SRR23381535   5811
(43726, 2)
           Run  TaxID_x  TaxID_y  TaxID
0  SRR23381539     5811     5811   5810
1  SRR23381538     5811     5811   5810
2  SRR23381537     5811     5811   5810
3  SRR23381536     5811     5811   5810
4  SRR23381535     5811     5811   5810
(845, 4)


  df1 = pd.read_csv(data_dir_path + 'toxo_SraRunInfo.csv')


### Neo4j Queries

In [4]:
query_sra_palmprint_taxon_counts = """
    MATCH (a:Palmprint)<-[b:HAS_PALMPRINT]-(c:SRA)
        -[d:HAS_HOST]->(e:Taxon)
    WHERE c.runId in $run_ids
    AND NOT (e)-[:HAS_PARENT*]->(:Taxon {taxId: '12908'})
    RETURN COUNT(DISTINCT c) as num_sras, COUNT(DISTINCT a) as num_palmprints, COUNT(DISTINCT e) as num_taxons
"""

query_sra_palmprints_taxons_collections = """
    MATCH (a:Palmprint)<-[b:HAS_PALMPRINT]-(c:SRA)
        -[d:HAS_HOST]->(e:Taxon)
    WHERE c.runId in $run_ids
    AND NOT (e)-[:HAS_PARENT*]->(:Taxon {taxId: '12908'})
    RETURN COLLECT(DISTINCT c.runId) as run_ids, COLLECT(DISTINCT a.palmId) as palm_ids, COLLECT(DISTINCT e.taxId) as tax_ids
"""

query_palmprints = """
    MATCH (n:Palmprint)
    WHERE n.palmId in $palm_ids
    RETURN
        id(n) as nodeId,
        n.palmId as appId,
        n.palmId as palmId,
        labels(n) as labels,
        n.centroid as centroid,
        n.sotu as sotu,
        CASE WHEN n.palmId in ['u658323', 'u380516'] THEN True ELSE False END AS isRubyOrCougar
        //count(r) as numPalmprints
"""

query_sotus = """
    MATCH (n:SOTU)<-[r:HAS_SOTU*0..1]-(p:Palmprint)
    WHERE n.palmId in $palm_ids OR p.palmId in $palm_ids
    RETURN
        id(n) as nodeId,
        n.palmId as appId,
        n.palmId as palmId,
        labels(n) as labels,
        n.centroid as centroid,
        CASE WHEN n.palmId in ['u658323', 'u380516'] THEN True ELSE False END AS isRubyOrCougar,
        count(r) as numPalmprints
"""

query_taxons = """
    MATCH (n:Taxon)
    WHERE n.taxId in $tax_ids
    RETURN
        id(n) as nodeId,
        n.taxId as appId,
        n.taxId as taxId,
        labels(n) as labels,
        n.rank as rank,
        n.taxFamily as taxFamily
"""

query_taxons_family = """
    CALL {
        MATCH (t:Taxon)-[:HAS_PARENT*]->(n:Taxon)
        WHERE t.taxId in $tax_ids
        AND n.rank = 'family'
        RETURN n
        UNION
        MATCH (n:Taxon)
        WHERE n.taxId in $tax_ids
        AND n.rank = 'family'
        RETURN n
    }
    WITH n
    RETURN
        id(n) as nodeId,
        n.taxId as appId,
        n.taxId as taxId,
        labels(n) as labels,
        n.rank as rank,
        n.taxFamily as taxFamily
"""

query_direct_has_host_edges = '''
    CALL {
        MATCH (p:SOTU)<-[:HAS_SOTU]-(:Palmprint)<-[r:HAS_PALMPRINT]-(s:SRA)
            -[:HAS_HOST]->()-[:HAS_PARENT*0..]->(t:Taxon {rank: 'family'})
        WHERE s.runId in $run_ids
        AND p.palmId in $sotus
        AND t.taxId in $family_tax_ids
        RETURN p, t, r
        UNION
        MATCH (p:SOTU)<-[r:HAS_PALMPRINT]-(s:SRA)
            -[:HAS_HOST]->()-[:HAS_PARENT*0..]->(t:Taxon {rank: 'family'})
        WHERE s.runId in $run_ids
        AND p.palmId in $sotus
        AND t.taxId in $family_tax_ids
        RETURN p, t, r
    }
    WITH p, t, r
    RETURN
        id(p) as sourceNodeId,
        p.palmId as sourceAppId,
        id(t) as targetNodeId,
        t.taxId as targetAppId,
        'HAS_HOST' as relationshipType,
        count(*) AS count,
        avg(r.percentIdentity) as avgPercentIdentity,
        avg(r.percentIdentity) as weight,
        'True' as isDirect
'''

query_indirect_has_host_edges = '''
    MATCH (p:Palmprint)<-[r:HAS_PALMPRINT]-(s:SRA)
        -[:HAS_HOST]->(t:Taxon)
    WHERE not (t)-[:HAS_PARENT*]->(:Taxon {taxId: '12908'})
    AND p.palmId in $palm_ids
    RETURN
        id(p) as sourceNodeId,
        p.palmId as sourceAppId,
        id(t) as targetNodeId,
        t.taxId as targetAppId,
        'HAS_HOST' as relationshipType,
        count(*) AS count,
        avg(r.percentIdentity) as avgPercentIdentity,
        avg(r.percentIdentity) as weight,
        'False' as isDirect
'''


In [5]:
def get_run_ids(sra_df):
    return set(sra_df.Run.unique())


def get_tax_ids(sra_df):
    tax_ids = set()
    tax_ids.update(list(sra_df.TaxID_x.unique().astype(int)))
    tax_ids.update(list(sra_df.TaxID_y.unique().astype(int)))
    tax_ids.update(list(sra_df.TaxID.unique().astype(int)))
    return tax_ids


## All SRA matches

### Create dataframes

In [6]:
sra_df = sra_all
run_ids = get_run_ids(sra_df)

def _log_df(df):
    namespace = globals()
    var_name = [name for name in namespace if namespace[name] is df]
    print(var_name, df.shape)
    print(df.head())


result_counts = gds.run_cypher(
    query_sra_palmprint_taxon_counts,
    params={'run_ids': list(run_ids)}
)
_log_df(result_counts)


result_collections = gds.run_cypher(
    query_sra_palmprints_taxons_collections,
    params={'run_ids': list(run_ids)}
)
_log_df(result_collections)


palmprint_nodes = gds.run_cypher(
    query_palmprints,
    params={'palm_ids': result_collections['palm_ids'][0]},
)
_log_df(palmprint_nodes)


sotu_nodes = gds.run_cypher(
    query_sotus,
    params={'palm_ids': result_collections['palm_ids'][0]},
)
_log_df(sotu_nodes)


taxon_nodes = gds.run_cypher(
    query_taxons,
    params={'tax_ids': result_collections['tax_ids'][0]},
)
_log_df(taxon_nodes)


taxon_family_nodes = gds.run_cypher(
    query_taxons_family,
    params={'tax_ids': result_collections['tax_ids'][0]},
)
_log_df(taxon_family_nodes)


# Some taxons are missing family information and have no ancestors with family information
taxon_nodes_missing_family = taxon_nodes[taxon_nodes['taxFamily'].isna()]
_log_df(taxon_nodes_missing_family)


has_host_edges = gds.run_cypher(
    query_direct_has_host_edges,
    params={
        'run_ids': result_collections['run_ids'][0],
        'sotus': sotu_nodes['appId'].tolist(),
        'family_tax_ids': taxon_family_nodes['appId'].tolist(),
    },
)
_log_df(has_host_edges)


indirect_has_host_edges = gds.run_cypher(
    query_indirect_has_host_edges,
    params={'palm_ids': result_collections['palm_ids'][0]},
)
_log_df(indirect_has_host_edges)

['result_counts'] (1, 3)
   num_sras  num_palmprints  num_taxons
0      1727            4864         226
['result_collections'] (1, 3)
                                             run_ids  \
0  [SRR10384547, ERR3588793, SRR10070156, SRR8654...   

                                            palm_ids  \
0  [u373632, u771732, u671355, u427546, u428730, ...   

                                             tax_ids  
0  [9606, 1561973, 29656, 6689, 289980, 10090, 96...  
['palmprint_nodes'] (4864, 7)
    nodeId     appId    palmId             labels  centroid      sotu  \
0  8106558   u201556   u201556        [Palmprint]     False   u247271   
1  8747199  u1063335  u1063335  [Palmprint, SOTU]      True  u1063335   
2  7753102    u29671    u29671  [Palmprint, SOTU]      True    u29671   
3  8573073  u1070390  u1070390        [Palmprint]     False   u843755   
4  8603781   u414454   u414454        [Palmprint]     False   u874078   

   isRubyOrCougar  
0           False  
1           False  


In [7]:
graph_name = 'tgav'

nodes = pd.concat([
    sotu_nodes[['nodeId', 'labels']],
    taxon_family_nodes[['nodeId', 'labels']],
    taxon_nodes_missing_family[['nodeId', 'labels']],
])
print(nodes)

relationships = pd.concat([
    has_host_edges[['sourceNodeId', 'targetNodeId',  'relationshipType', 'weight']],
])
print(relationships)

if gds.graph.exists(graph_name)['exists']:
    gds.graph.drop(gds.graph.get(graph_name))

G = gds.alpha.graph.construct(
    graph_name=graph_name,
    nodes=nodes,
    relationships=relationships,
    concurrency=4,
    undirected_relationship_types=['HAS_HOST'],
)

       nodeId             labels
0     7680426  [Palmprint, SOTU]
1     7683241  [Palmprint, SOTU]
2     7685828  [Palmprint, SOTU]
3     7686547  [Palmprint, SOTU]
4     7687562  [Palmprint, SOTU]
..        ...                ...
4    10964222      [Taxon, Host]
49    8759892      [Taxon, Host]
70    9999282      [Taxon, Host]
105   8875432      [Taxon, Host]
157  10964225      [Taxon, Host]

[4405 rows x 2 columns]
      sourceNodeId  targetNodeId relationshipType  weight
0          8738206       8772220         HAS_HOST    99.0
1          8148295       8772220         HAS_HOST   100.0
2          8532699       8772220         HAS_HOST    45.0
3          8148467       8772220         HAS_HOST   100.0
4          7991777       8772220         HAS_HOST   100.0
...            ...           ...              ...     ...
5198       7781781       8809302         HAS_HOST    42.0
5199       8041022       8809302         HAS_HOST    84.0
5200       8485750       8809302         HAS_HOST   100.0

### Community detection and node centrality

In [8]:
communities = gds.labelPropagation.stream(
  G,
  nodeLabels=['Taxon', 'Palmprint'],
  relationshipWeightProperty='weight',
  maxIterations=30,
)
unique_communities = communities.communityId.unique()
community_counter = collections.Counter(communities.communityId)
print(len(unique_communities))
print(community_counter.most_common(10))

page_ranks = gds.pageRank.stream(
    G,
  nodeLabels=['Taxon', 'Palmprint'],
  relationshipWeightProperty='weight',
  maxIterations=100,
)
page_ranks = page_ranks.rename(columns={'score': 'pageRank'})
page_ranks['pageRank'] = page_ranks['pageRank'].round(0)


celf = gds.beta.influenceMaximization.celf.stream(
    G,
    seedSetSize=25,
)
celf = celf.rename(columns={'spread': 'celf'})
celf['celf'] = celf['celf'].round(0)

G.drop()

127
[(11156590, 1862), (8763616, 389), (8764798, 313), (8759540, 148), (8764001, 139), (8760548, 113), (8758893, 79), (8763875, 77), (8761256, 75), (8759772, 60)]


CELF: 100%|██████████| 100.0/100 [00:01<00:00, 65.84%/s] 


graphName                                                             tgav
database                                                             neo4j
memoryUsage                                                               
sizeInBytes                                                             -1
nodeCount                                                             4405
relationshipCount                                                    10406
configuration            {'jobId': 'ea3b6ac3-e585-4643-b8a6-3fd7338c5a4...
density                                                           0.000536
creationTime                           2023-07-14T00:19:32.793362455+00:00
modificationTime                       2023-07-14T00:19:32.833848776+00:00
schema                   {'graphProperties': {}, 'relationships': {'HAS...
schemaWithOrientation    {'graphProperties': {}, 'relationships': {'HAS...
Name: 0, dtype: object

### Graphistry visualization

In [9]:
# Create node and relationship dataframes with full information

sotu_nodes['displayLabel'] = sotu_nodes['appId']
taxon_family_nodes['displayLabel'] = taxon_family_nodes['taxFamily']
taxon_nodes_missing_family['displayLabel'] = 'Unknown'

nodes = pd.concat([
    sotu_nodes,
    taxon_family_nodes,
    taxon_nodes_missing_family
])

nodes = nodes.merge(
    page_ranks,
    left_on='nodeId',
    right_on='nodeId',
    how='left',
)
nodes = nodes.merge(
    communities,
    left_on='nodeId',
    right_on='nodeId',
    how='left',
)
nodes = nodes.merge(
    celf,
    left_on='nodeId',
    right_on='nodeId',
    how='left',
)

nodes['type'] = nodes['labels']

nodes = nodes[[
    'appId', 'labels', 'type', 'pageRank', 
    'centroid', 'rank', 'communityId',
    'taxFamily', 'isRubyOrCougar', 'celf',
    'displayLabel',
]].astype(str)


relationships = pd.concat([
    has_host_edges,
])
relationships = relationships[[
    'sourceAppId', 'targetAppId', 'relationshipType', 'weight'
]].astype(str)


nodes['communityId'] = nodes['communityId'].astype('int32')
labels = nodes['communityId'].unique()
mapping = {label: i for i, label in enumerate(labels)}
nodes['communityId'] =  nodes['communityId'].replace(mapping)
nodes['communityColorCodes'] = nodes['communityId'].mod(12)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  taxon_nodes_missing_family['displayLabel'] = 'Uknown'


In [10]:
g = graphistry.bind()

g = g.bind(
    source='sourceAppId',
    destination='targetAppId',
).edges(relationships)

g = g.bind(node='appId', point_label='displayLabel').nodes(nodes)

g = g.settings(url_params={
        'play': 2000,
        'menu': True, 
        'info': True,
        'showArrows': True,
        # 'pointSize': 2.0, 
        # 'edgeCurvature': 0.5,
        'edgeOpacity': 0.25, 
        'pointOpacity': 1.0,
        # 'lockedX': False, 'lockedY': False, 'lockedR': False,
        'linLog': True, 
        'compactLayout': True,
        'strongGravity': True,
        'dissuadeHubs': True,
        'edgeInfluence': 0.95,
        # 'precisionVsSpeed': 0, 'gravity': 1.0, 'scalingRatio': 1.0,
        # 'showLabels': True, 'showLabelOnHover': True,
        # 'showPointsOfInterest': True, 'showPointsOfInterestLabel': True, 
        'showLabelPropertiesOnHover': True,
        'pointsOfInterestMax': 15,
      })

g = g.encode_point_color(
    'communityColorCodes',
)

g.plot()

## Indirect associated hosts

### Co-occurrence Triangles
- Viral co-occurrence relates taxons by associating a common virus found in two different taxon families
- Host Co-occurrence relates viruses by associating a common host to two different sOTU
- Viral co-ocurrence across hosts shows indication of zoonotic potential, either through parasites or other viral vectors and causal interactions


Viral co-occurrence

|              | (Virus_1, Host_1) | (Virus_1, Host_2) | Row total          |
|--------------|-------------------|-------------------|--------------------|
| T. Gondi     | a                 | b                 | a + b              |
| Not T. Gondi | c                 | d                 | c + d              |
| Column Total | a + c             | b + d             | a + b + c + d (=n) |


Host Co-occurrence

|              | (Virus_1, Host_1) | (Virus_2, Host_1) | Row total          |
|--------------|-------------------|-------------------|--------------------|
| T. Gondi     | a                 | b                 | a + b              |
| Not T. Gondi | c                 | d                 | c + d              |
| Column Total | a + c             | b + d             | a + b + c + d (=n) |



### Local Clustering Coeffecient
- Co-occurrence edges form triangles in the network. These triangles can be used in clustering algorithms, i.e. [Local Clustering Coeffecient](https://neo4j.com/docs/graph-data-science/current/algorithms/local-clustering-coefficient/) which describes the likelihood that the neighbours of node $n$ are also connected.
- To compute $C_n$ we use the number of triangles a node is a part of $T_n$, and the degree of the node $d_n$. The formula to compute the local clustering coefficient is as follows:
$
{\displaystyle
C_n = \frac{2 T_n}{d_n(d_n - 1)}
}
$


### Fisher's exact test 
- Fisher's exact test computes the significance or probability of obtaining a set of values in terms of the deviation from a null hypothesis (e.g., p-value) of the correlation of random variables
- The probability of obtaining a given set of values in the above tables is given by: 
$
{\displaystyle p={\frac {\displaystyle {{a+b} \choose {a}}\displaystyle {{c+d} \choose {c}}}{\displaystyle {{n} \choose {a+c}}}}={\frac {\displaystyle {{a+b} \choose {b}}\displaystyle {{c+d} \choose {d}}}{\displaystyle {{n} \choose {b+d}}}}={\frac {(a+b)!~(c+d)!~(a+c)!~(b+d)!}{a!~~b!~~c!~~d!~~n!}}}
$
- If the marginal totals of the above tables (i.e. $a + b$, $c + d$, $a + c$, and $b + d$) are known, only a single degree of freedom is left: the value e.g. of $a$ suffices to deduce the other values. 
- Now, $p = p(a)$ is the probability that $a$ elements are positive in a random selection (without replacement) of $a + c$ elements from a larger set containing n elements in total out of which $a + b$ are positive. This is precisely the definition of the hypergeometric distribution.


### Hypergeometric series as Inductive Bias
- Hypergeometric series can be interpreted as a generalization of spherical harmonics (which parameterize SO(3) rotations) to other Lie Groups. SO(3) rotations are commonly used in E(3)-equivariant Graph Neural Networks as an inductive bias for graph representation learning
- We can create a hypergeometric distribution with Fisher's exact test by comparing virus-host co-occurences, i.e. from T. Gondii SRA runs and unassociated T.Gondii SRA runs with shared virus-host co-occurence 
- These probability values can be stored as edge weights and used to learn equivariant representations similar to E(3)-networks from angle and torsion 


### Graphistry visualization

In [57]:

# Create node and relationship dataframes with full information and indirect host associations
nodes = pd.concat([
    palmprint_nodes,
    taxon_nodes,
])
nodes = nodes.merge(
    page_ranks,
    left_on='nodeId',
    right_on='nodeId',
    how='left',
)
nodes = nodes.merge(
    communities,
    left_on='nodeId',
    right_on='nodeId',
    how='left',
)
nodes['type'] = nodes['labels']

relationships = pd.concat([
    has_host_edges,
    indirect_has_host_edges,
])
nodes = nodes[['appId', 'labels', 'score', 'centroid', 'rank', 'communityId', 'type']].astype(str)
relationships = relationships[['sourceAppId', 'targetAppId', 'relationshipType', 'weight']].astype(str)

In [20]:
g = graphistry.bind()

g = g.bind(
    source='sourceAppId',
    destination='targetAppId',
).edges(relationships)

g = g.bind(node='appId', point_label='appId').nodes(nodes)
pallete = [
    'red', 'green', 'blue', 'orange',
    # "#32964d", "#a3d2a0", "#155126", 
    # "#64d4fd", "#378dae", "#374475",
]
g = g.encode_point_color(
    'type',
    palette=pallete,
    categorical_mapping={
        "['Palmprint', 'SOTU']": 'blue',
        "['Palmprint']": 'purple',
        "['Taxon', 'Host']": 'green',
    },
    as_categorical=True
)

g.plot()