## NCBI STAT Host associations

#### Plan for adding missing runs from toxo dataset

1. Get list of runs in dataset that are missing from sra_stat
1. Fetch missing runs from bigquery STAT
1. Fetch sra_stat_group to dataframe
1. Compute additional values for new rows, (groupby total kmers, percent, label)
1. Write missing values to psql
1. Write all STAT data from toxo runs to csv
1. Import, visualize and compare data in seperate graphistry notebook

#### Plan for full sra_stat ETL

1. Set up notebook in instance with more memory
1. Download all stat rows, excluding existing runs using last updated date 
    - last inserted in sra_stat: (row_id: 85551530, run: "SRR9999999")
1. Use dask dataframe to compute additional values for new rows (groupby total kmers, percent, label)
1. Write new STAT tax values to psql
1. Write new edge values to neo4j using containerized ETL job

### Imports and configs

In [1]:
import sys
if '../' not in sys.path:
    sys.path.append("../")
%reload_ext dotenv
%dotenv

import os

from queries import serratus_queries, graph_queries
from datasources import psql


import psycopg2
import pandas as pd
from google.cloud import bigquery
import dask.dataframe as dd

In [2]:
base_data_path = '../../graph_learning/notebooks/tgav_data/'
toxo_data_path = base_data_path + 'toxo/'

def fetch_one(query, params={}):
    conn = psql.get_connection()
    cursor = conn.cursor()
    resp = cursor.execute(query, params)
    resp = cursor.fetchone()
    cursor.close()
    conn.close()
    return resp

def fetch_count(query, params={}):
    resp = fetch_one(query, params)
    return int(resp[0])

def fetch_all(query, params={}):
    conn = psql.get_connection()
    cursor = conn.cursor()
    resp = cursor.execute(query, params)
    resp = cursor.fetchall()
    cursor.close()
    conn.close()
    return resp

### Inspect missing STAT runs in Toxo dataset

In [21]:


def get_toxo_dfs():
    # Get RunID, TaxID pairs from original datasets
    df1 = pd.read_csv(toxo_data_path + 'toxo_SraRunInfo.csv')
    df2 = pd.read_csv(toxo_data_path + 'txid5810_SraRunInfo.csv')
    df3 = pd.read_csv(toxo_data_path + 'txid5810_statbigquery.csv')
    df3 = df3.rename(columns={'tax_id': 'TaxID', 'acc': 'Run'})
    sra_union = pd.concat([df1[['Run', 'TaxID']], df2[['Run', 'TaxID']], df3[['Run', 'TaxID']]], axis=0)
    sra_union = sra_union.drop_duplicates(subset=['Run'])
    sra_union = sra_union.astype({"Run": str, "TaxID": int})

    # Get additional biosample metadata (missing TaxId)
    sra_union_metadata = pd.read_csv(toxo_data_path + 'tg_set_all_metadata_additional.csv')
    sra_union_metadata = sra_union_metadata.rename(columns={'acc': 'Run'})
    sra_union = sra_union_metadata.merge(
        sra_union,
        left_on='Run',
        right_on='Run',
        how='left',
    ).drop_duplicates(subset=['Run'])

    # Get intersection of original datasets
    sra_intersection = df1[['Run', 'TaxID']].merge(
        df2[['Run', 'TaxID']],
        left_on='Run',
        right_on='Run',
        how='left',
    ).dropna()
    sra_intersection = sra_intersection.merge(
        df3[['Run', 'TaxID']],
        left_on='Run',
        right_on='Run',
        how='left',
    ).dropna()
    sra_intersection = sra_intersection.astype({"Run": str, "TaxID": int, "TaxID_x": int, "TaxID_y": int })

    return sra_union, sra_intersection


sra_union, sra_intersection = get_toxo_dfs()

  df1 = pd.read_csv(toxo_data_path + 'toxo_SraRunInfo.csv')
  sra_union_metadata = pd.read_csv(toxo_data_path + 'tg_set_all_metadata_additional.csv')


In [4]:
unique_runs = list(sra_union.Run.unique())

def get_matching_sra_counts(runs):
    query = """
        SELECT COUNT(distinct run) 
        FROM public.sra_stat 
        WHERE run IN %(runs)s;
    """ 
    params = {
        'runs': tuple(runs),
    }
    return fetch_count(query, params)

counts = get_matching_sra_counts(unique_runs)
print(counts)
print(counts/len(unique_runs))

29039
0.8942506081975795


In [23]:
def get_matching_sra_runs(runs):
    query = """
        SELECT run 
        FROM public.sra_stat 
        WHERE run IN %(runs)s;
    """  
    params = {
        'runs': tuple(runs),
    }
    return fetch_all(query, params)

matched_runs = get_matching_sra_runs(unique_runs)
matched_runs = [x[0] for x in matched_runs]
print(len(matched_runs))
print(matched_runs[:10])

missing_runs = list(set(unique_runs) - set(matched_runs))
print(len(missing_runs))

728819
['DRR001705', 'DRR001705', 'DRR001705', 'DRR001705', 'DRR001705', 'DRR001705', 'DRR001706', 'DRR001706', 'DRR001706', 'DRR001706']
1565


In [30]:
def get_existing_stat_runs():
    query = """
        SELECT distinct run 
        FROM public.sra_stat 
    """
    resp = fetch_all(query)
    return resp

# existing_stat_runs = get_existing_stat_runs()
# print(len(existing_stat_runs))

12483673


### Update sra_stat table with missing Toxo runs

In [7]:
# Authenticate GCP
!gcloud auth application-default login

Your browser has been opened to visit:

    https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=764086051850-6qr4p6gpi6hn506pt8ejuq83di341hur.apps.googleusercontent.com&redirect_uri=http%3A%2F%2Flocalhost%3A8085%2F&scope=openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fsqlservice.login+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Faccounts.reauth&state=Xz054jgYLJylWcx95CCKPKqI4uy3ti&access_type=offline&code_challenge=qzYZuRWYBxNGQJSQe_BlXQ2_j8yEC-jNAmlxzqnDCgE&code_challenge_method=S256


Credentials saved to file: [/Users/lukepereira/.config/gcloud/application_default_credentials.json]

These credentials will be used by any library that requests Application Default Credentials (ADC).
Cannot find a quota project to add to ADC. You might receive a "quota exceeded" or "API not enabled" error. Run $ gcloud auth application-default set-quota-project to add a quot

In [24]:
project_id = 'rnalab-393418'
client = bigquery.Client(project_id)



In [25]:
query = """
    SELECT *
    FROM nih-sra-datastore.sra_tax_analysis_tool.tax_analysis
    WHERE rank = 'order'
    AND total_count >= 100
    AND acc in UNNEST(@missing_runs)
    ORDER BY acc
"""

job_config = bigquery.QueryJobConfig(
    query_parameters=[
        bigquery.ArrayQueryParameter("missing_runs", "STRING", missing_runs),
    ]
)

query_job = client.query(query, job_config=job_config) 

In [26]:
stats_data = []
for row in query_job:
    # Row values can be accessed by field name or index.
    stats_data.append({
        'run': row[0],
        'taxid': row[1],
        'rank': row[2],
        'name': row[3],
        'kmer': row[4],
    })

stats_data = pd.DataFrame(stats_data)
print(len(stats_data))
print(stats_data.head())

685
          run   taxid   rank                name     kmer
0  SRR8144089  213115  order  Desulfovibrionales     1796
1  SRR8144089  162474  order       Malasseziales      866
2  SRR8144089  186826  order     Lactobacillales  3669547
3  SRR8144089    4892  order   Saccharomycetales     1684
4  SRR8144089   73020  order      Isochrysidales     5637


In [27]:
total_kmer = stats_data.groupby('run').agg({'kmer': 'sum'}).reset_index()

assert stats_data.run.nunique() == total_kmer.shape[0]

total_kmer = total_kmer.rename(columns={'kmer': 'total_kmers'})

stats_data_totals = stats_data.merge(
    total_kmer[['run', 'total_kmers']],
    left_on='run',
    right_on='run',
)
stats_data_totals['kmer_perc'] = (stats_data_totals['kmer'] / stats_data_totals['total_kmers']) * 100
stats_data_totals['kmer_perc'] = stats_data_totals['kmer_perc'].round(2)
# Note, existing table contains 15,552,449/85,551,530 (20%) that have kmer_perc == 0.0

print(stats_data_totals.head())

          run   taxid   rank                name     kmer  total_kmers  \
0  SRR8144089  213115  order  Desulfovibrionales     1796      7259892   
1  SRR8144089  162474  order       Malasseziales      866      7259892   
2  SRR8144089  186826  order     Lactobacillales  3669547      7259892   
3  SRR8144089    4892  order   Saccharomycetales     1684      7259892   
4  SRR8144089   73020  order      Isochrysidales     5637      7259892   

   kmer_perc  
0       0.02  
1       0.01  
2      50.55  
3       0.02  
4       0.08  


In [28]:
def get_stats_tax_labels():
    query = """
        SELECT * FROM public.sra_stat_group
        ORDER BY taxid ASC
    """
    return fetch_all(query)

stats_tax_labels = get_stats_tax_labels()
print(len(stats_tax_labels))
print(stats_tax_labels[:10])

1184
[(29, 'order', 'Myxococcales', 'Bacteria'), (112, 'order', 'Planctomycetales', 'Bacteria'), (136, 'order', 'Spirochaetales', 'Bacteria'), (356, 'order', 'Hyphomicrobiales', 'Bacteria'), (766, 'order', 'Rickettsiales', 'Bacteria'), (1118, 'order', 'Chroococcales', 'Bacteria'), (1150, 'order', 'Oscillatoriales', 'Bacteria'), (1161, 'order', 'Nostocales', 'Bacteria'), (1189, 'order', 'Stigonemataceae', 'Bacteria'), (1212, 'order', 'NA', 'Unclassified')]


In [37]:
stats_tax_labels = pd.DataFrame(stats_tax_labels, columns=['taxid', 'tax_rank', 'tax_name', 'tax_label'])

stats_data_totals_labels = stats_data_totals.merge(
    stats_tax_labels[['taxid', 'tax_label']],
    left_on='taxid',
    right_on='taxid',
    how='left',
)

assert stats_data_totals_labels.shape[0] == stats_data_totals.shape[0]
print(stats_data_totals_labels.head())

          run   taxid   rank                name     kmer  total_kmers  \
0  SRR8144089  213115  order  Desulfovibrionales     1796      7259892   
1  SRR8144089  162474  order       Malasseziales      866      7259892   
2  SRR8144089  186826  order     Lactobacillales  3669547      7259892   
3  SRR8144089    4892  order   Saccharomycetales     1684      7259892   
4  SRR8144089   73020  order      Isochrysidales     5637      7259892   

   kmer_perc  tax_label  
0       0.02   Bacteria  
1       0.01      Fungi  
2      50.55   Bacteria  
3       0.02      Fungi  
4       0.08  Eukaryota  


In [39]:
stats_data_totals_labels.to_csv('tmp.csv', index=False)

In [5]:
stats_data_totals_labels = pd.read_csv('tmp.csv')
print(len(stats_data_totals_labels))

In [7]:
def get_write_connection():
    return psycopg2.connect(
        database="summary",
        host="serratus-aurora-20210406.cluster-ro-ccz9y6yshbls.us-east-1.rds.amazonaws.com",
        user=os.environ.get('SQL_WRITE_USER'),
        password=os.environ.get('SQL_WRITE_PASSWORD'),
        port="5432")


In [8]:
def get_max_row_id():
    conn = psql.get_connection()
    cursor = conn.cursor()
    query = """
        SELECT max(CAST(row_id as Int)) FROM sra_stat;
    """
    cursor.execute(query)
    out = cursor.fetchone()[0]
    cursor.close()
    conn.close()
    return int(out)

In [31]:
def update_stats_table(df):
    cur_row_id = get_max_row_id() + 1
    conn = get_write_connection()
    cursor = conn.cursor()
    errors = []
    
    for row in df.reset_index().to_dict('rows'):
        try:
            query = """
                INSERT into public.sra_stat(row_id, run, taxid, 
                    rank, name, kmer, 
                    total_kmers, kmer_perc, tax_label
                )
                VALUES(%s, %s, %s, %s, %s, %s, %s, %s, %s)
                ON CONFLICT DO NOTHING;
            """ 
            args = (
                cur_row_id, row['run'], row['taxid'],
                row['rank'], row['name'], row['kmer'], 
                row['total_kmers'], row['kmer_perc'], row['tax_label'],
            )
            cursor.execute(query, args)
            if cur_row_id % 1000 == 0:
                print(cur_row_id)
                conn.commit()
            cur_row_id += 1
        except Exception as e:
            errors.append(row)
            print(e)
            conn.rollback()
            break

    cursor.close()
    conn.close()
    return errors


In [32]:
errs = update_stats_table(stats_data_totals_labels)
print(errs)

  for row in df.reset_index().to_dict('rows'):


[]


In [33]:
# Fetch all toxo runs from sra_stat
def get_toxo_runs():
    query = """
        SELECT DISTINCT ON (run)
            run,
            taxid,
            rank,
            name,
            kmer,
            total_kmers,
            kmer_perc,
            tax_label
        FROM public.sra_stat
        WHERE run in %(toxo_runs)s
        ORDER BY run, kmer_perc DESC;
    """
    params = {
        'toxo_runs': tuple(unique_runs),
    }
    return fetch_all(query, params)

toxo_sra_stat = get_toxo_runs()
print(len(toxo_sra_stat))

30908


In [34]:
print(toxo_sra_stat[:10])

[('DRR001705', 9989, 'order', 'Rodentia', 988967, Decimal('994240'), Decimal('99.47'), 'Mammalia'), ('DRR001706', 9989, 'order', 'Rodentia', 1438884, Decimal('1447150'), Decimal('99.43'), 'Mammalia'), ('DRR002461', 75739, 'order', 'Eucoccidiorida', 8963466, Decimal('11776961'), Decimal('76.11'), 'Eukaryota'), ('DRR002462', 75739, 'order', 'Eucoccidiorida', 5861322, Decimal('10788973'), Decimal('54.33'), 'Eukaryota'), ('DRR002463', 75739, 'order', 'Eucoccidiorida', 8563615, Decimal('10225673'), Decimal('83.75'), 'Eukaryota'), ('DRR002464', 75739, 'order', 'Eucoccidiorida', 8824039, Decimal('11567974'), Decimal('76.28'), 'Eukaryota'), ('DRR002465', 75739, 'order', 'Eucoccidiorida', 5740362, Decimal('10530519'), Decimal('54.51'), 'Eukaryota'), ('DRR002466', 75739, 'order', 'Eucoccidiorida', 8423467, Decimal('10036212'), Decimal('83.93'), 'Eukaryota'), ('DRR014676', 9443, 'order', 'Primates', 21622218, Decimal('21654403'), Decimal('99.85'), 'Primates'), ('DRR022972', 9443, 'order', 'Primat

In [45]:
matched_runs_v2 = get_matching_sra_runs(unique_runs)
matched_runs_v2 = [x[0] for x in matched_runs_v2]
print(len(matched_runs_v2))

missing_runs_v2 = list(set(unique_runs) - set(matched_runs_v2))
print(len(missing_runs_v2))

print(missing_runs_v2[:10])

# These runs are missing from STAT bigquery table

728819
1565
['SRR18548922', 'SRR11063004', 'SRR18532781', 'SRR19329101', 'ERR5466811', 'SRR19329048', 'SRR20330572', 'SRR11156255', 'SRR20330436', 'SRR19325907']


In [47]:
# write to csv in tgav_data
toxo_sra_stat = pd.DataFrame(toxo_sra_stat, columns=['run', 'taxid', 'rank', 'name', 'kmer', 'total_kmers', 'kmer_perc', 'tax_label'])
toxo_sra_stat.to_csv(toxo_data_path + 'tg_tax_host_stat.csv', index=False)

### Update Neo4j with STAT HAS_HOST edges

In [8]:
toxo_sra_stat = pd.read_csv(toxo_data_path + 'tg_tax_host_stat.csv')
toxo_sra_stat['taxid'] = toxo_sra_stat['taxid'].astype(str)
toxo_sra_stat_ddf = dd.from_pandas(toxo_sra_stat, chunksize=1000)
print(len(toxo_sra_stat))

30908


In [31]:
from datasources.neo4j import get_connection

conn = get_connection()

def sanity_check_sras(run_ids):
    query = '''
            MATCH (s:SRA)
            WHERE s.runId in $run_ids
            RETURN COLLECT(DISTINCT s.runId) as run_ids
            '''
    return conn.query(
            query,
            parameters={
                'run_ids': run_ids,
            }
        )

def sanity_check_taxons(tax_ids):
    query = '''
            MATCH (t:Taxon)
            WHERE t.taxId in $tax_ids
            RETURN Collect(DISTINCT t.taxId) as tax_ids
            '''
    return conn.query(
            query,
            parameters={
                'tax_ids': tax_ids,
            }
        )

out = sanity_check_sras(toxo_sra_stat.run.unique())
print(len(toxo_sra_stat.run.unique()))
print(len(out[0]['run_ids']))
missing_runs = set(toxo_sra_stat.run.unique()) - set(out[0]['run_ids'])

taxons = toxo_sra_stat.taxid.unique()
taxons = [str(x) for x in taxons]
out = sanity_check_taxons(taxons)
missing_taxons = set(taxons) - set(out[0]['tax_ids'])
print(len(taxons))
print(len(out[0]['tax_ids']))
print(missing_taxons)

# 28883 was merged to 2731619
# 206350 was merged to 32003
# 1212 was merged to 1213 
# 40677 was merged to 6132

30908
22942
237
233
{'28883', '206350', '1212', '40677'}


In [7]:
print(toxo_sra_stat_ddf.head())

         run  taxid   rank            name     kmer  total_kmers  kmer_perc  \
0  DRR001705   9989  order        Rodentia   988967       994240      99.47   
1  DRR001706   9989  order        Rodentia  1438884      1447150      99.43   
2  DRR002461  75739  order  Eucoccidiorida  8963466     11776961      76.11   
3  DRR002462  75739  order  Eucoccidiorida  5861322     10788973      54.33   
4  DRR002463  75739  order  Eucoccidiorida  8563615     10225673      83.75   

   tax_label  
0   Mammalia  
1   Mammalia  
2  Eukaryota  
3  Eukaryota  
4  Eukaryota  


In [10]:
from datasources.neo4j import get_connection

conn = get_connection()

def debug_missing_spots(run_ids):
    query = '''
            MATCH (s:SRA)
            WHERE s.runId in $run_ids
            AND s.spots = 0
            RETURN Collect(DISTINCT s.runId) as run_ids
            '''
    return conn.query(
            query,
            parameters={
                'run_ids': run_ids,
            }
        )

out = debug_missing_spots(toxo_sra_stat.run.unique())
print(len(out[0]['run_ids']))
print(len(toxo_sra_stat.run.unique()))

36
30908


In [7]:
# Temp fix in graph to add missing spots
# Long-term fix involves updating sra tables to include missing spots
missing_spots = {
      'ERR3415775': 27250560,
      'ERR4352771': 0,
      'ERR1994964': 23972947,
      'ERR5384467': 8430905,
      'ERR1726732': 0,
      'ERR3415759': 27860746,
      'ERR3274949': 865925712,
      'ERR538188': 0,
      'ERR2003549': 23972947,
      'ERR3415758': 41826118,
      'ERR3415760': 31480217,
      'ERR1726762': 0,
      'ERR1726702': 0,
      'ERR1994972': 23231253,
      'ERR1726688': 0,
      'ERR3415762': 20223464,
      'ERR2003534': 27306569,
      'ERR3978063': 60335336,
      'ERR3806953': 0,
      'ERR3415774': 35643841,
      'ERR2003542': 36240950,
      'ERR3415773': 39138550,
      'ERR538183': 0,
      'ERR3415761': 24380142,
      'ERR2003547': 23729256,
      'ERR1726916': 0,
      'ERR1726629': 0,
      'ERR1726949': 0,
      'ERR2003527': 35611500,
      'ERR1994960': 36240950,
      'ERR1726891': 0,
      'ERR1726944': 0,
      'ERR1726740': 0,
      'ERR1726724': 0,
      'ERR3806950': 0,
      'ERR4352608': 0,
}

missing_spots_df = pd.DataFrame(missing_spots.items(), columns=['runId', 'spots'])

missing_spots_ddf = dd.from_pandas(missing_spots_df, chunksize=1000)


def add_missing_spots_to_sras(rows):
    query = '''
            UNWIND $rows as row
            MATCH (s:SRA)
            WHERE s.runId = row.run 
            SET s += {
                spots: toInteger(row.spots),
                spotsWithMates: toInteger(row.spots)
            }
            '''
    return graph_queries.batch_insert_data(query, rows)

add_missing_spots_to_sras(missing_spots_ddf)

[[]]

In [25]:
def add_sra_stat_taxon_edges(rows):
    query = '''
            UNWIND $rows as row
            MATCH (s:SRA), (t:Taxon)
            WHERE s.runId = row.run AND t.taxId = row.taxid
            MERGE (s)-[r:HAS_HOST_STAT]->(t)
            SET r += {
                percentIdentity: round(toFloat(row.kmer_perc / 100), 4),
                percentIdentityFull: CASE WHEN s.spots > 0 
                    THEN round(toFloat(row.kmer) / toFloat(s.spots), 4)
                    ELSE 0.0 END,
                kmer: row.kmer,
                totalKmers: row.total_kmers,
                totalSpots: s.spots
            }
            

            '''
    return graph_queries.batch_insert_data(query, rows)

add_sra_stat_taxon_edges(toxo_sra_stat_ddf)

[[],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 []]

### Update sra_stat table with all missing runs

In [36]:
query = """
    SELECT MAX(updated)
    FROM nih-sra-datastore.sra_tax_analysis_tool.tax_analysis
    WHERE acc in @existing_runs
    ORDER BY acc
"""

job_config = bigquery.QueryJobConfig(
    query_parameters=[
        bigquery.ArrayQueryParameter("existing_runs", "STRING", existing_stat_runs),
    ]
)

query_job = client.query(query, job_config=job_config) 


KeyboardInterrupt: 

In [None]:
query = """
    SELECT * FROM nih-sra-datastore.sra_tax_analysis_tool.tax_analysis AS tax
    WHERE rank = 'order'
    AND total_count >= 100
    ORDER BY acc
    LIMIT 10
"""

query_job = client.query(query)
print("The query data:")
for row in query_job:
    # Row values can be accessed by field name or index.
    print("name={}, count={}".format(row[0], row))

The query data:
name=DRR000013, count=Row(('DRR000013', 28734, 'order', 'Macroscelidea', 405, 0, 22, 97, 101), {'acc': 0, 'tax_id': 1, 'rank': 2, 'name': 3, 'total_count': 4, 'self_count': 5, 'ilevel': 6, 'ileft': 7, 'iright': 8})
name=DRR000013, count=Row(('DRR000013', 9989, 'order', 'Rodentia', 527, 0, 24, 60, 77), {'acc': 0, 'tax_id': 1, 'rank': 2, 'name': 3, 'total_count': 4, 'self_count': 5, 'ilevel': 6, 'ileft': 7, 'iright': 8})
name=DRR000013, count=Row(('DRR000013', 9443, 'order', 'Primates', 719326, 29171, 23, 0, 60), {'acc': 0, 'tax_id': 1, 'rank': 2, 'name': 3, 'total_count': 4, 'self_count': 5, 'ilevel': 6, 'ileft': 7, 'iright': 8})
