## NCBI STAT Host associations

1. Fetch list of existing run_ids in sra_stat table in Serratus SQL DB
1. Create BigQuery dataset and table with single existing run_ids columns
1. Use complement of join on STAT table and existing run_ids to find all missing STAT rows
1. Download missing rows to disk to limit API usage
1. Use dask dataframe to compute additional/enriched values for new rows (groupby total kmers, percent, label)
1. Write new STAT tax values to psql
1. Write new edge values to neo4j using containerized ETL job

### Imports, configs, global helpers

In [1]:
import sys
if '../' not in sys.path:
    sys.path.append("../")
%reload_ext dotenv
%dotenv

import os
import csv
import glob
import gc

from queries import serratus_queries, graph_queries
from datasources import psql


import psycopg2
import psycopg2.extras as extras
import pandas as pd
from google.cloud import bigquery
import dask.dataframe as dd
from dask.diagnostics import ProgressBar

In [2]:
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
ProgressBar().register()

# Authenticate GCP
# !gcloud auth application-default login
project_id = 'rnalab-393418'
client = bigquery.Client(project_id)



In [3]:
def fetch_one(query, params={}):
    conn = psql.get_serratus_connection()
    cursor = conn.cursor()
    resp = cursor.execute(query, params)
    resp = cursor.fetchone()
    cursor.close()
    conn.close()
    return resp

def fetch_count(query, params={}):
    resp = fetch_one(query, params)
    return int(resp[0])

def fetch_all(query, params={}):
    conn = psql.get_serratus_connection()
    cursor = conn.cursor()
    resp = cursor.execute(query, params)
    resp = cursor.fetchall()
    cursor.close()
    conn.close()
    return resp

def get_write_connection():
    return psycopg2.connect(
        database="summary",
        host="serratus-aurora-20210406.cluster-ro-ccz9y6yshbls.us-east-1.rds.amazonaws.com",
        user=os.environ.get('SQL_WRITE_USER'),
        password=os.environ.get('SQL_WRITE_PASSWORD'),
        port="5432")

def get_max_row_id():
    conn = psql.get_serratus_connection()
    cursor = conn.cursor()
    query = """
        SELECT max(CAST(row_id as Int)) FROM sra_stat;
    """
    cursor.execute(query)
    out = cursor.fetchone()[0]
    cursor.close()
    conn.close()
    return int(out)

def get_existing_stat_runs():
    query = """
        SELECT distinct run 
        FROM public.sra_stat 
    """
    resp = fetch_all(query)
    return resp

### Enrich stats data with labels and kmer_perc

In [4]:
def compute_kmer_perc(stats_data):
    total_kmer = stats_data.groupby('run').agg({'kmer': 'sum'}).reset_index()
    # assert stats_data.run.nunique() == total_kmer.shape[0]
    total_kmer = total_kmer.rename(columns={'kmer': 'total_kmers'})

    stats_data_totals = stats_data.merge(
        total_kmer[['run', 'total_kmers']],
        left_on='run',
        right_on='run',
    )
    stats_data_totals['kmer_perc'] = (stats_data_totals['kmer'] / stats_data_totals['total_kmers']) * 100
    stats_data_totals['kmer_perc'] = stats_data_totals['kmer_perc'].round(2)
    # Note, existing table contains 15,552,449/85,551,530 (20%) that have kmer_perc == 0.0
    return stats_data_totals


In [5]:
def get_stats_tax_labels():
    query = """
        SELECT * FROM public.sra_stat_group
        ORDER BY taxid ASC
    """
    return fetch_all(query)

stats_tax_labels = get_stats_tax_labels()
stats_tax_labels = pd.DataFrame(stats_tax_labels, columns=['taxid', 'tax_rank', 'tax_name', 'tax_label'])

In [6]:
def enrich_stats_data(stats_data):
    stats_data_totals = compute_kmer_perc(stats_data)
    stats_data_totals_labels = stats_data_totals.merge(
        stats_tax_labels[['taxid', 'tax_label']],
        left_on='taxid',
        right_on='taxid',
        how='left',
    )
    return stats_data_totals_labels


### Fetch all missing runs from BigQuery

In [None]:
# Create BigQuery dataset and table
dataset_id = "{}.stat_accessions".format(client.project)
dataset = bigquery.Dataset(dataset_id)
dataset.location = "US"
dataset = client.create_dataset(dataset, timeout=30)
print("Created dataset {}.{}".format(client.project, dataset.dataset_id))

tables = client.list_tables(project_id + ".stat_accessions")
for table in tables:
    print("{}.{}.{}".format(table.project, table.dataset_id, table.table_id))

schema = [
    bigquery.SchemaField("acc", "STRING", mode="REQUIRED"),
]
table_id = project_id + ".stat_accessions.existing_accessions"
table = bigquery.Table(table_id, schema=schema)
table = client.create_table(table)

Created dataset rnalab-393418.stat_accessions


In [8]:
existing_stat_runs = get_existing_stat_runs()
df = pd.DataFrame(existing_stat_runs, columns=['acc'])
df.to_csv('stat_tax_host.csv', index=False)
# Manually uploade csv to google cloud storage:

NameError: name 'existing_stat_runs' is not defined

In [21]:
job_config = bigquery.LoadJobConfig(
    schema=[
        bigquery.SchemaField("acc", "STRING", mode="REQUIRED"),
    ],
    skip_leading_rows=1,
    source_format=bigquery.SourceFormat.CSV,
)
uri = "gs://rnalab/stat_tax_host.csv"
load_job = client.load_table_from_uri(
    uri, table_id, job_config=job_config
)  
load_job.result()  
destination_table = client.get_table(table_id) 
print("Loaded {} rows.".format(destination_table.num_rows))

Loaded 12487486 rows.


In [9]:
query = """
    SELECT acc, tax_id, rank, name, total_count
    FROM nih-sra-datastore.sra_tax_analysis_tool.tax_analysis
    WHERE rank = 'order'
    AND total_count >= 100
    AND acc NOT IN (
        SELECT DISTINCT acc from nih-sra-datastore.sra_tax_analysis_tool.tax_analysis
        JOIN rnalab-393418.stat_accessions.existing_accessions 
        USING(acc)
    )
    ORDER BY acc
"""

query_job = client.query(query) 


cur_row_id = get_max_row_id() + 1
stats_data = []
batch = 0

def write_stat_data_to_csv(stats_data, batch):
    df = pd.DataFrame(stats_data)
    filename = f"stat_download/stat_missing_{batch}.csv"
    df.to_csv(filename, index=False)
    df = None
    gc.collect()

for row in query_job:
    stats_data.append({
        'row_id': cur_row_id,
        'run': row['acc'],
        'taxid': row['tax_id'],
        'rank': row['rank'],
        'name': row['name'],
        'kmer': row['total_count'],
    })
    cur_row_id += 1
    if len(stats_data) % 1000000 == 0:
        print(batch)
        write_stat_data_to_csv(stats_data, batch)
        batch += 1
        stats_data = []
        gc.collect()

write_stat_data_to_csv(stats_data, batch)

### Update sra_stat table with all missing runs

In [8]:
filenames = glob.glob(os.path.join('stat_download', '*.csv'))
ddf_stat = dd.read_csv(filenames)
ddf_stat = ddf_stat.set_index('run')

[########################################] | 100% Completed | 126.50 s


In [9]:
# enrich_stats_data
ddf_stat = compute_kmer_perc(ddf_stat)
ddf_stat = ddf_stat.merge(
    stats_tax_labels[['taxid', 'tax_label']],
    left_on='taxid',
    right_on='taxid',
    how='left',
)
ddf_stat = ddf_stat.persist()

[####                                    ] | 10% Completed | 25.14 ss


KeyboardInterrupt: 

In [11]:
ddf_stat.to_csv('stat_enriched/stat_missing_enriched_*.csv', index=False)

[########################################] | 100% Completed | 349.96 s


['/Users/lukepereira/workspace/rna-life/virus-host-graph-db/jobs/etl/notebooks/tmp_enriched/stat_missing_enriched_00.csv',
 '/Users/lukepereira/workspace/rna-life/virus-host-graph-db/jobs/etl/notebooks/tmp_enriched/stat_missing_enriched_01.csv',
 '/Users/lukepereira/workspace/rna-life/virus-host-graph-db/jobs/etl/notebooks/tmp_enriched/stat_missing_enriched_02.csv',
 '/Users/lukepereira/workspace/rna-life/virus-host-graph-db/jobs/etl/notebooks/tmp_enriched/stat_missing_enriched_03.csv',
 '/Users/lukepereira/workspace/rna-life/virus-host-graph-db/jobs/etl/notebooks/tmp_enriched/stat_missing_enriched_04.csv',
 '/Users/lukepereira/workspace/rna-life/virus-host-graph-db/jobs/etl/notebooks/tmp_enriched/stat_missing_enriched_05.csv',
 '/Users/lukepereira/workspace/rna-life/virus-host-graph-db/jobs/etl/notebooks/tmp_enriched/stat_missing_enriched_06.csv',
 '/Users/lukepereira/workspace/rna-life/virus-host-graph-db/jobs/etl/notebooks/tmp_enriched/stat_missing_enriched_07.csv',
 '/Users/lukeper

In [7]:
def copy_from_file(conn, filename):
    cursor = conn.cursor()
    copy_sql = """
        COPY sra_stat (run,row_id,taxid,rank,name,kmer,total_kmers,kmer_perc,tax_label)
        FROM stdin WITH CSV HEADER
        DELIMITER as ','
    """
    with open(filename, 'r') as f:
        try:
            cursor.copy_expert(sql=copy_sql, file=f)
            conn.commit()
        except (Exception, psycopg2.DatabaseError) as error:
            print("Error: %s" % error)
            conn.rollback()
            cursor.close()
            return error
    print("copy_from_file() done")
    cursor.close()

In [8]:
errs = []
conn = get_write_connection()

filelist = glob.glob(os.path.join('stat_enriched', '*.csv'))

for infile in sorted(filelist):
    f = str(infile)
    if not os.path.isfile(f):
        continue
    print(f)
    err = copy_from_file(conn, f)
    if err:
        errs.append(err)

print(errs)

tmp_enriched/stat_missing_enriched_00.csv
copy_from_file() done
tmp_enriched/stat_missing_enriched_01.csv
Error: COPY from stdin failed: error in .read() call
CONTEXT:  COPY sra_stat, line 774628

tmp_enriched/stat_missing_enriched_02.csv
copy_from_file() done
tmp_enriched/stat_missing_enriched_03.csv
copy_from_file() done
tmp_enriched/stat_missing_enriched_04.csv
copy_from_file() done
tmp_enriched/stat_missing_enriched_05.csv
copy_from_file() done
tmp_enriched/stat_missing_enriched_06.csv
copy_from_file() done
tmp_enriched/stat_missing_enriched_07.csv
copy_from_file() done
tmp_enriched/stat_missing_enriched_08.csv
copy_from_file() done
tmp_enriched/stat_missing_enriched_09.csv
copy_from_file() done
tmp_enriched/stat_missing_enriched_10.csv
copy_from_file() done
tmp_enriched/stat_missing_enriched_11.csv
copy_from_file() done
tmp_enriched/stat_missing_enriched_12.csv
copy_from_file() done
tmp_enriched/stat_missing_enriched_13.csv
copy_from_file() done
tmp_enriched/stat_missing_enriched_

In [None]:
# Clean up STAT download files and enriched csv files
os.rmdir('stat_download')
os.rmdir('stat_enriched')

### Update Neo4j with STAT HAS_HOST_METADATA edges

In [10]:
filelist = glob.glob(os.path.join('stat_enriched', '*.csv'))
ddf_stat = dd.read_csv(filelist) 

In [11]:
unique_runs = ddf_stat.run.unique().compute()
unique_taxons = ddf_stat.taxid.unique().compute()

[########################################] | 100% Completed | 38.45 s
[########################################] | 100% Completed | 16.49 s


In [12]:
print(len(unique_runs))
print(len(unique_taxons))

9251170
812


In [13]:
from datasources.neo4j import get_connection

conn = get_connection()

def sanity_check_sras(run_ids):
    query = '''
        MATCH (s:SRA)
        WHERE s.runId in $run_ids
        RETURN COLLECT(DISTINCT s.runId) as run_ids
    '''
    return conn.query(
        query,
        parameters={
            'run_ids': run_ids,
        }
    )

def sanity_check_taxons(tax_ids):
    query = '''
        MATCH (t:Taxon)
        WHERE t.taxId in $tax_ids
        RETURN Collect(DISTINCT t.taxId) as tax_ids
    '''
    return conn.query(
        query,
        parameters={
            'tax_ids': tax_ids,
        }
    )

out = sanity_check_sras(unique_runs)
print(out)
missing_runs = set(unique_runs) - set(out[0]['run_ids'])
print("Missing runs", missing_runs)

unique_taxons = [str(x) for x in unique_taxons]
out = sanity_check_taxons(unique_taxons)
missing_taxons = set(unique_taxons) - set(out[0]['tax_ids'])
print("Missing taxons", missing_taxons)

# 28883 was merged to 2731619
# 206350 was merged to 32003
# 1212 was merged to 1213
# 40677 was merged to 6132

KeyboardInterrupt: 

In [14]:
#22933
def add_sra_stat_taxon_edges(rows):
    query = '''
            UNWIND $rows as row
            MATCH (s:SRA), (t:Taxon)
            WHERE s.runId = toString(row.run)
            AND t.taxId = toString(row.taxid)
            AND round(toFloat(row.kmer_perc / 100), 4) > 0
            MERGE (s)-[r:HAS_HOST_STAT]->(t)
            SET r += {
                percentIdentity: round(toFloat(row.kmer_perc / 100), 4),
                percentIdentityFull: CASE WHEN s.spots > 0 
                    THEN round(toFloat(row.kmer) / toFloat(s.spots), 4)
                    ELSE 0.0 END,
                kmer: row.kmer,
                totalKmers: row.total_kmers,
                totalSpots: s.spots
            }
            '''
    return graph_queries.batch_insert_data(query, rows)

add_sra_stat_taxon_edges(ddf_stat)

[########################################] | 100% Completed | 1.33 ss
[########################################] | 100% Completed | 1.31 ss
[########################################] | 100% Completed | 1.34 ss
[########################################] | 100% Completed | 1.19 ss
[########################################] | 100% Completed | 1.72 ss
[########################################] | 100% Completed | 1.20 ss
[########################################] | 100% Completed | 1.18 ss
[########################################] | 100% Completed | 1.17 ss
[########################################] | 100% Completed | 1.23 ss
[########################################] | 100% Completed | 1.22 ss
[########################################] | 100% Completed | 1.23 ss
[########################################] | 100% Completed | 1.06 ss
[########################################] | 100% Completed | 1.15 ss
[########################################] | 100% Completed | 1.15 ss


In [10]:
from datasources.neo4j import get_connection

conn = get_connection()

def debug_missing_spots(run_ids):
    query = '''
            MATCH (s:SRA)
            WHERE s.runId in $run_ids
            AND s.spots = 0
            RETURN Collect(DISTINCT s.runId) as run_ids
            '''
    return conn.query(
            query,
            parameters={
                'run_ids': run_ids,
            }
        )

out = debug_missing_spots(ddf_stat.run.unique())
print(len(out[0]['run_ids']))
print(len(ddf_stat.run.unique()))

36
30908


In [7]:
# Temp fix in graph to add missing spots
# Long-term fix involves updating sra tables to include missing spots
missing_spots = {
      'ERR3415775': 27250560,
      'ERR4352771': 0,
      'ERR1994964': 23972947,
      'ERR5384467': 8430905,
      'ERR1726732': 0,
      'ERR3415759': 27860746,
      'ERR3274949': 865925712,
      'ERR538188': 0,
      'ERR2003549': 23972947,
      'ERR3415758': 41826118,
      'ERR3415760': 31480217,
      'ERR1726762': 0,
      'ERR1726702': 0,
      'ERR1994972': 23231253,
      'ERR1726688': 0,
      'ERR3415762': 20223464,
      'ERR2003534': 27306569,
      'ERR3978063': 60335336,
      'ERR3806953': 0,
      'ERR3415774': 35643841,
      'ERR2003542': 36240950,
      'ERR3415773': 39138550,
      'ERR538183': 0,
      'ERR3415761': 24380142,
      'ERR2003547': 23729256,
      'ERR1726916': 0,
      'ERR1726629': 0,
      'ERR1726949': 0,
      'ERR2003527': 35611500,
      'ERR1994960': 36240950,
      'ERR1726891': 0,
      'ERR1726944': 0,
      'ERR1726740': 0,
      'ERR1726724': 0,
      'ERR3806950': 0,
      'ERR4352608': 0,
}

missing_spots_df = pd.DataFrame(missing_spots.items(), columns=['runId', 'spots'])

missing_spots_ddf = dd.from_pandas(missing_spots_df, chunksize=1000)


def add_missing_spots_to_sras(rows):
    query = '''
            UNWIND $rows as row
            MATCH (s:SRA)
            WHERE s.runId = row.run 
            SET s += {
                spots: toInteger(row.spots),
                spotsWithMates: toInteger(row.spots)
            }
            '''
    return graph_queries.batch_insert_data(query, rows)

add_missing_spots_to_sras(missing_spots_ddf)

[[]]