In [8]:
import pandas as pd
import re
import os
import logging
logging.basicConfig(level=logging.INFO)
import json
import requests

In [9]:
def get_hgnc_complete_list(symbol_json_file='./hgnc_complete_set_2020-10-01.json'):
  # read json from url 
  if not os.path.exists(symbol_json_file):
    logging.info('Downloading HGNC complete list')
    url = "https://ftp.ebi.ac.uk/pub/databases/genenames/hgnc/archive/quarterly/json/hgnc_complete_set_2020-10-01.json"
    gene_list_json = requests.get(url).json()
    symbol_list = [doc['symbol'] for doc in gene_list_json['response']['docs']]
    logging.info('HGNC complete list downloaded')
    logging.info('length of HGNC complete list: {}'.format(len(symbol_list)))
    with open(symbol_json_file, 'w') as f:
      json.dump(symbol_list, f)
  else:
    with open(symbol_json_file, 'r') as f:
      logging.info('Reading HGNC complete list from local file')
      symbol_list = json.load(f)
      logging.info('length of HGNC complete list: {}'.format(len(symbol_list)))
  return symbol_list

def get_gene_list(gpt_response, hgnc_complete_list):
  '''
  Try to grep a GENE SYMBOL (using regular expression)
  [A-Z0-9]+
  '''
  logging.debug('gpt_response: {}'.format(gpt_response))
  pattern = r'[ ,.!?\n]+'
  tokens = re.split(pattern, gpt_response)
  tokens = [token.strip() for token in tokens if len(token) > 1]
  logging.debug('tokens: {}'.format(tokens))
  overlapped_genes = set(tokens) & set(hgnc_complete_list)
  logging.debug('overlapped_genes: {}'.format(overlapped_genes))
  
  return overlapped_genes

def get_gpt_response(file):
  with open(file, 'r') as f:
    text = f.read() 
    return text
  
def update_count_dict(overlapped_genes, count_dict):
    for gene in overlapped_genes:
        if gene in count_dict:
            count_dict[gene] += 1
        else:
            count_dict[gene] = 1
    return count_dict

In [14]:
# create a count dictionary

output_dir = './Experiment_003subset'
hgnc_complete_list = get_hgnc_complete_list()
count_top10_dict = {}
count_top50_dict = {}
count_true_gene_dict = {}
for file in os.listdir(output_dir):
    m = re.match(r'(.+?).gpt.response*', file)
    sample_id, true_gene, top_n, prompt, gpt_version, input_type, iteration = m.group(1).split('__')
    count_true_gene_dict = update_count_dict([true_gene], count_true_gene_dict)
    gpt_response = get_gpt_response(os.path.join(output_dir,file))
    overlapped_genes = get_gene_list(gpt_response, hgnc_complete_list)
    if '__10__' in file:
        count_top10_dict = update_count_dict(overlapped_genes, count_top10_dict)
    else:
        count_top50_dict = update_count_dict(overlapped_genes, count_top50_dict)

INFO:root:Reading HGNC complete list from local file
INFO:root:length of HGNC complete list: 43984


In [19]:
top10_df = pd.DataFrame.from_dict(count_top10_dict, orient='index',columns=['top10_count']).reset_index().rename(columns={'index':'gene'})
top50_df = pd.DataFrame.from_dict(count_top50_dict, orient='index',columns=['top50_count']).reset_index().rename(columns={'index':'gene'})
true_df = pd.DataFrame.from_dict(count_true_gene_dict, orient='index',columns=['true_gene_count']).reset_index().rename(columns={'index':'gene'})
merged_df = top50_df.merge(top10_df, on='gene', how='outer').merge(true_df, on='gene', how='outer').fillna(0)

In [22]:
merged_df.sort_values(by=['top50_count'], ascending=False).to_csv('count_of_gpt_prediction.csv', index=False)