In [79]:
import torch
import requests
import pandas as pd

In [80]:
from tqdm import tqdm
tqdm.pandas()

In [134]:
model = torch.load('/opt/gpudata/aartiv/qag/query_intent_model.pt')
num_params = sum(p.numel() for p in model.parameters())
param_size = sum(p.element_size() * p.numel() for p in model.parameters()) / (1024 ** 2)

print(f"Number of parameters: {num_params:,}")
print(f"Approx size in memory: {param_size:.2f} MB")

  model = torch.load('/opt/gpudata/aartiv/qag/query_intent_model.pt')


Number of parameters: 109,486,085
Approx size in memory: 417.66 MB


In [135]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [136]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [140]:
def get_intent(query):    
  # print('testing query: {}'.format(query))
  inputs = tokenizer(query, return_tensors="pt", truncation=True, padding=True)
  inputs = {k:v.to(device) for k, v in inputs.items()}
  # pass tokenized input through the model
  outputs = model(**inputs)
  # print('output logits {}'.format(outputs))
  #outputs are logits, need to apply softmax to convert to probs
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
  # print('probs: {}'.format(probs))
  predicted_label = torch.argmax(probs, dim=1).item()
  return predicted_label


### create external synthetic dataset

In [57]:
curl_cmd = "https://api.gdc.cancer.gov/ssms?expand=occurrence.case.project&from=1001&size=1000&pretty=true"

In [58]:
response = requests.get(curl_cmd)

In [59]:
response

<Response [200]>

In [60]:
result = response.json()

In [61]:
result['data']['hits'][:1]

[{'id': 'fb33682a-d2f0-554f-bc50-9b73b45de914',
  'start_position': 161391148,
  'gene_aa_change': ['PSMD14 I205M'],
  'cosmic_id': ['COSM257825'],
  'chromosome': 'chr2',
  'ssm_id': 'fb33682a-d2f0-554f-bc50-9b73b45de914',
  'occurrence': [{'case': {'project': {'primary_site': ['Unknown',
       'Connective, subcutaneous and other soft tissues',
       'Rectosigmoid junction',
       'Rectum',
       'Colon'],
      'disease_type': ['Cystic, Mucinous and Serous Neoplasms',
       'Adenomas and Adenocarcinomas'],
      'project_id': 'TCGA-READ',
      'name': 'Rectum Adenocarcinoma'}}}],
  'end_position': 161391148,
  'reference_allele': 'T',
  'ncbi_build': 'GRCh38',
  'mutation_subtype': 'Single base substitution',
  'mutation_type': 'Simple Somatic Mutation',
  'genomic_dna_change': 'chr2:g.161391148T>G',
  'tumor_allele': 'G'}]

In [62]:
print(len(result['data']['hits']))

1000


In [63]:
def write_genes_mutations_cancers_output(result):
  with open('/opt/gpudata/aartiv/qag/data_for_intent_model/genes_mutations_cancers_eval.tsv', 'w') as out_file:
    out_file.write('cosmic_id\tgene\tmutation\tcancer\n')
    for info in result['data']['hits']:
      try:
        cosmic_ids = info['cosmic_id']
        aa_changes = info['gene_aa_change']
        for occurrence in info['occurrence']:
          projects = occurrence['case']['project']['name']
          for c_id in cosmic_ids:
            for mut in aa_changes:
              gene, m = mut.split(' ')
              out_file.write('{}\t{}\t{}\t{}\n'.format(c_id, gene, m, projects ))
      except:
        pass

In [64]:
write_genes_mutations_cancers_output(result)

### define templates

In [17]:
template1_questions = [
    'What percentage of [cancer] cases have the [gene_x] [mutation_x] mutation?',
    'How widespread are [gene_x] [mutation_x] mutations in [cancer]?',
    'How often is the [gene_x] [mutation_x] found in [cancer]?',
    'What is the occurrence rate of [gene_x] [mutation_x] in [cancer] patients?',
    'How prevalent is the [gene_x] [mutation_x] mutation in cases of [cancer]?',
    'What is the incidence of the [gene_x] [mutation_x] mutation in [cancer]?',
    'What is the frequency of [gene_x] [mutation_x] mutation in [cancer]?',
    'How common is the [gene_x] [mutation_x] mutation in [cancer]?',
    'What fraction of [cancer] patients have the [gene_x] [mutation_x] mutation?',
    'How widespread is the [gene_x] [mutation_x] mutation in [cancer]?',
    'What is the rate of occurrence of the [gene_x] [mutation_x] mutation in [cancer]?',
    'How frequently is the [gene_x] [mutation_x] mutation detected in [cancer]?',
    'what percentage of cases have mutations in both [gene_x] [mutation_x] and [gene_y] [mutation_y] mutations in [cancer]?',
    'what proportion of [cancer] patients have mutations in both [gene_x] [mutation_x] and [gene_y] [mutation_y]?',
    'What proportion of [cancer] patients exhibit the [gene_x] [mutation_x] mutation?',
    'Can you tell me the rate of [gene_x] [mutation_x] mutation in [cancer]?',
    'What is the percentage of [cancer] tumors with the [gene_x] [mutation_x] mutation?',
    'What is the prevalence rate of [gene_x] [mutation_x] mutation in [cancer]?',
    'How many [cancer] cases have the [gene_x] [mutation_x] mutation?',
    'What is the prevalence rate of the [gene_x] [mutation_x] mutation in [cancer]?',
    'What is the distribution of the [gene_x] [mutation_x] mutation in [cancer]?',
    'How high is the prevalence of the [gene_x] [mutation_x] mutation in [cancer]?',
    'What is the presence rate of the [gene_x] [mutation_x] mutation in [cancer]?',
    'How pervasively is the [gene_x] [mutation_x] mutation present in [cancer]?',
    'What is the occurrence of [gene_x] [mutation_x] mutation in [cancer]?',
    'What is the [gene_x] [mutation_x] mutation rate in [cancer]?',
    'How frequent is the [gene_x] [mutation_x] mutation in [cancer] cases?'
]
template1_answers = ['ssm_frequency'] * len(template1_questions)

In [18]:
# template 2 = 'what is the prevalence of MSI High in cancer x?'
template2_questions = [
  'What is the frequency of microsatellite instability-high in [cancer]?',
  'How common is MSI-H in [cancer] cases?',
  'Can you provide the prevalence of microsatellite instability-high in [cancer] patients?',
  'What percentage of [cancer] patients have MSI-H?',
  'How often is microsatellite instability-high observed in [cancer]?',
  'In [cancer], what is the occurrence rate of MSI-H?',
  'What is the incidence of MSI-H in [cancer]?',
  'How prevalent is microsatellite instability-high in [cancer]?',
  'What proportion of [cancer] cases exhibit MSI-H?',
  'How frequently is MSI High found in [cancer]?',
  'What is the distribution rate of microsatellite instability-high among [cancer] patients?',
  'How widespread is microsatellite instability-high in [cancer]?',
  'Can you tell me how prevalent MSI High is in [cancer]?',
  'What percentage of [cancer] cases are characterized by microsatellite instability-high (MSI-H)?',
  'How common is MSI High status in cases of [cancer]?',
  'What is the rate of microsatellite instability-high in [cancer] diagnoses?',
  'Among [cancer] patients, what is the frequency of msi-h?',
  'How often does msi-h occur in [cancer] patients?',
  'What is the detection rate of msi-h in [cancer]?',
  'How many [cancer] patients exhibit microsatellite instability-high?',
  'What is the percentage occurrence of MSI High in [cancer]?',
  'How frequently do [cancer] patients present with microsatellite instability-high?',
  'What is the proportion of MSI High tumors in [cancer]?',
  'Can you share the incidence rate of msi-high in [cancer]?',
  'What fraction of [cancer] cases involve msi-high?'
]
template2_answers = ['msi_h_frequency'] * len(template2_questions)


In [19]:
# template 3 what is the freq of 1p19q or gene loss in GBM?
template3_questions = [
    'What is the incidence of [gene_x] loss in [cancer]?',
    'What is the frequency of somatic [gene_x] heterozygous deletion in [cancer]?',
    'What is the frequency of somatic [gene_x] homozygous deletion in [cancer]?',
    'How common is [gene_x] and [gene_y] codeletion in [cancer] cases?',
    'What percentage of [cancer] patients have somatic [gene_x] LOH?',
    'How often is [gene_x] LOH observed in [cancer]?',
    'In [cancer], what is the occurrence rate of [gene_x] loss?',
    'What is the prevalence of [gene_x] and [gene_y] loss of heterozygosity in [cancer]?',
    'How prevalent is [gene_x] gain in [cancer] patients?',
    'How prevalent is somatic [gene_x] amplification in [cancer] patients?',
    'What proportion of [cancer] cases exhibit [gene_x] and [gene_y] codeletion?',
    'What is the joint frequency of somatic [gene_x] and [gene_y] homozygous deletions in [cancer]?',
    'What is the frequency of somatic [gene_x] and [gene_y] heterozygous co-deletion in [cancer]?',
    'How frequently is [gene_x] LOH found in [cancer]?',
    'What is the distribution rate of [gene_x] and [gene_y] gain among [cancer] patients?',
    'How widespread is [gene_x] LOH in [cancer]?',
    'Can you tell me how prevalent [gene_x] and [gene_y] loss is in [cancer]?',
    'What percentage of [cancer] are characterized by codeletion of [gene_x] and [gene_y]?',
    'How common is [gene_y] LOH in [cancer] cases?',
    'What is the rate of [gene_x] LOH in [cancer] diagnoses?',
    'Among [cancer] patients, what is the frequency of [gene_y] gain?',
    'How often does [gene_x] loss occur in [cancer] patients?',
    'What is the detection rate of [gene_x] LOH in [cancer]?',
    'How many [cancer] patients exhibit cogains in [gene_x] and [gene_y]?',
    'What is the percentage occurrence of [gene_y] loss of heterozygosity in [cancer]?',
    'How frequently do [cancer] patients present with [gene_y] loss of heterozygosity?',
    'What is the proportion of [gene_y] and [gene_x] codeletion in [cancer]?',
    'Can you share the incidence rate of [gene_x] loss in [cancer]?',
    'What fraction of [cancer] cases involve [gene_x] loss?'
]
template3_answers = ['freq_cnv_loss_or_gain'] * len(template3_questions)


In [20]:
# template4 /analysis/top_cases_counts_by_genes used for cohort comparisons/survival analysis
# number of cases w/ mutations in a gene or genes, and number of cases w/o mutations in genes
# this template refers to any mutation in the gene, no specific ssm

template4_questions = [
    'What is the incidence of [gene_x] or [gene_y] mutations in cases of [cancer] in the genomic data commons?',
    'How frequently are [gene_x] or [gene_y] mutations observed in [cancer]?',
    'What percentage of [cancer] cases exhibit mutations in [gene_x] genes?',
    'What proportion of [cancer] cases have mutated [gene_x] or [gene_y]?',
    'How many patients with [cancer] carry mutations in either [gene_x] or [gene_y]?',
    'What is the prevalence of mutations in [gene_x] or [gene_y] among [cancer] cases?',
    'In cases of [cancer], how many exhibit [gene_x] mutations in the genomic data commons?',
    'Are [gene_x] or [gene_y] mutations present in many [cancer] cases?',
    'What is the rate of [gene_x] mutations in [cancer] tumors?',
    'What is the number of [cancer] cases without [gene_x] and [gene_y] mutations?',
    'How many [cancer] patients lack mutations in [gene_y] in the genomic data commons?',
    'What proportion of [cancer] do not exhibit mutations in [gene_x] and [gene_y]?',
    'How frequently are [cancer] without [gene_x] mutations observed?',
    'What percentage of [cancer] cases do not have [gene_y] mutations?',
    'How rare is it to find [cancer] without mutations in [gene_x] and [gene_y]?',
    'How many [cancer] cases lack [gene_x] and [gene_y] gene mutations?',
    'What fraction of [cancer] cases do not possess mutations in [gene_x] and [gene_y]?',
    'Do a substantial number of [cancer] not have [gene_x] and [gene_y] mutations?',
    'What share of [cancer] cases is without [gene_x] and [gene_y] mutations?'
]
template4_answers = ['top_cases_counts_by_genes'] * len(template4_questions)

In [21]:
# cnv and ssm combination
template5_questions = [
    'what is the frequency of cases with somatic mutations in [gene_x] and amplifications in [gene_y] in [cancer] in the genomic data commons?',
    'what is the frequency of cases with amplifications in [gene_x] and mutations in [gene_y] in [cancer] in the genomic data commons?',
    'How often do cases of [cancer] exhibit mutations in [gene_x] and homozyous deletions in [gene_y]?',
    'What is the prevalence of [gene_x] somatic mutations and [gene_y] heterozygous deletions in [cancer] cases?',
    'Can you tell me the frequency of [gene_x] mutations co-occurring with [gene_y] amplifications in [cancer]?',
    'How common are mutations in [gene_x] and amplifications in [gene_y] among patients with [cancer]?',
    'What percentage of [cancer] cases present with mutations in [gene_x] and somatic homozygous deletions in [gene_y]?',
    'To what extent do [cancer] cases show somatic mutations in [gene_x] along with somatic amplifications in [gene_y]?',
    'How frequently are [gene_x] somatic mutations and somatic [gene_y] amplifications found in [cancer] patients?',
    'What is the incidence rate of mutations in [gene_x] paired with somatic heterozygous deletions in [gene_y] in [cancer]?',
    'What portion of [cancer] cases have both [gene_x] mutations and [gene_y] gains?',
    'How prevalent are simultaneous mutations in [gene_x] and gains in [gene_y] in [cancer]?',
    'Could you provide the rate of [gene_x] mutations and [gene_y] deletions among [cancer] cases?',
    'What is the occurrence of [gene_x] mutations alongside [gene_y] amplifications in [cancer]?',
    'How often are concurrent somatic mutations in [gene_x] and deletions in [gene_y] observed in [cancer]?',
    'What is the proportion of [cancer] cases with both [gene_x] mutations and [gene_y] amplifications?',
    'In cases of [cancer], how common are mutations in [gene_x] and amplifications in [gene_y]?',
    'What is the rate at which [gene_x] somatic mutations and somatic [gene_y] homozygous deletions coincide in [cancer]?',
    'How frequently do [gene_x] mutations and [gene_y] amplifications occur together in [cancer] cases?',
    'What share of [cancer] cases display both somatic [gene_x] mutations and somatic [gene_y] losses?',
    'What is the joint frequency of somatic [gene_x] mutations and [gene_y] losses in [cancer]?',
    'How prevalent is the co-occurrence of [gene_x] mutations and [gene_y] amplifications in [cancer]?',
    'What is the frequency of mutations in [gene_x] coupled with gains in [gene_y] in [cancer]?'
    'How often do mutations in [gene_x] and amplifications in [gene_y] happen in [cancer]?',
    'What is the incidence of [gene_x] mutations along with [gene_y] losses in [cancer]?',
    'How common are both [gene_x] mutations and [gene_y] amplifications within [cancer] cases?',
    'Can you specify the frequency at which [gene_x] mutations and [gene_y] gains are seen in [cancer]?'
]
template5_answers = ['cnv_and_ssm'] * len(template5_questions)



In [22]:
all_template_questions = template1_questions + template2_questions + template3_questions + template4_questions + template5_questions
all_template_answers = template1_answers + template2_answers + template3_answers + template4_answers + template5_answers

In [23]:
len(all_template_questions), len(all_template_answers)

(126, 126)

In [84]:
# Eval Dataset
gdc_data_eval = pd.read_csv('/opt/gpudata/aartiv/qag/data_for_intent_model/genes_mutations_cancers_eval.tsv', sep='\t')


In [85]:
gdc_data_eval.shape

(1456, 4)

In [86]:
gdc_data_eval.head()

Unnamed: 0,cosmic_id,gene,mutation,cancer
0,COSM257825,PSMD14,I205M,Rectum Adenocarcinoma
1,COSM3958761,ACOX1,A127S,Lung Adenocarcinoma
2,COSM3958761,ACOX1,P180=,Lung Adenocarcinoma
3,COSM3991550,NDUFAF5,E143D,Kidney Renal Papillary Cell Carcinoma
4,COSM4075175,CHERP,L124=,Stomach Adenocarcinoma


In [87]:
gdc_data_eval['simple_mutation'] = gdc_data_eval['mutation'].apply(lambda x: x.isalnum())

In [88]:
dataset1 = gdc_data_eval[gdc_data_eval['simple_mutation']].copy()

In [89]:
dataset1.drop_duplicates(['gene', 'mutation'], inplace=True)

In [90]:
dataset1.head()

Unnamed: 0,cosmic_id,gene,mutation,cancer,simple_mutation
0,COSM257825,PSMD14,I205M,Rectum Adenocarcinoma,True
1,COSM3958761,ACOX1,A127S,Lung Adenocarcinoma,True
3,COSM3991550,NDUFAF5,E143D,Kidney Renal Papillary Cell Carcinoma,True
14,COSM1413622,KRTAP13-1,R82S,Colon Adenocarcinoma,True
15,COSM1558784,RPS6KA6,R310S,Lung Adenocarcinoma,True


In [91]:
dataset1.shape

(523, 5)

In [92]:
# total number of question/intent pairs
len(all_template_questions)*dataset1.shape[0]

65898

In [94]:
from numpy import random
def generate_q_a(row, gene_list, mutation_list, template_index):
  # generate question
  gene_y = random.choice(gene_list)
  mutation_y = random.choice(mutation_list)

  try:
    question = all_template_questions[template_index]\
      .replace("[mutation_x]", row['mutation'])\
      .replace("[cancer]", row['cancer']) \
      .replace("[gene_x]", row['gene']) \
      .replace("[gene_y]", gene_y)\
      .replace("[mutation_y]", mutation_y)
    intent = all_template_answers[template_index]

  except Exception as e:
    print('unable to generate question {}'.format(str(e)))

  return question, intent


In [95]:
# for double gene/mutation questions, get the list to randomly sample second one from
gene_list = list(dataset1['gene'])
mutation_list = list(dataset1['mutation'])

In [96]:
questions_intent_df = pd.DataFrame(columns=['gene', 'mutation', 
                                            'cancer', 'questions', 
                                            'intent'])

In [97]:
# generate a question for every row in dataset1
dataset1_copy = dataset1.copy()

for template_index in range(len(all_template_questions)):
  dataset1_copy[['questions', 'intent']] = dataset1_copy.apply(
    lambda row: generate_q_a(row, gene_list, mutation_list, template_index),
    axis=1,result_type='expand'
    )
  questions_intent_df = pd.concat(
    [questions_intent_df, 
     dataset1_copy[['gene', 'mutation', 'cancer', 'questions', 'intent']]
    ]
  )

In [98]:
questions_intent_df.shape

(65898, 5)

In [99]:
questions_intent_df.head()

Unnamed: 0,gene,mutation,cancer,questions,intent
0,PSMD14,I205M,Rectum Adenocarcinoma,What percentage of Rectum Adenocarcinoma cases...,ssm_frequency
1,ACOX1,A127S,Lung Adenocarcinoma,What percentage of Lung Adenocarcinoma cases h...,ssm_frequency
3,NDUFAF5,E143D,Kidney Renal Papillary Cell Carcinoma,What percentage of Kidney Renal Papillary Cell...,ssm_frequency
14,KRTAP13-1,R82S,Colon Adenocarcinoma,What percentage of Colon Adenocarcinoma cases ...,ssm_frequency
15,RPS6KA6,R310S,Lung Adenocarcinoma,What percentage of Lung Adenocarcinoma cases h...,ssm_frequency


In [104]:
0.31*65898

20428.38

In [113]:
from sklearn.model_selection import train_test_split
stratified_sample, _ = train_test_split(
    questions_intent_df,
    test_size=0.69,               # Keep ~30%
    stratify=questions_intent_df['intent'],        # stratify on this column
    random_state=42              # For reproducibility
)

In [114]:
questions_intent_df['intent'].value_counts()

intent
freq_cnv_loss_or_gain        15167
ssm_frequency                14121
cnv_and_ssm                  13598
msi_h_frequency              13075
top_cases_counts_by_genes     9937
Name: count, dtype: int64

In [115]:
questions_intent_df['intent'].value_counts(normalize=True)

intent
freq_cnv_loss_or_gain        0.230159
ssm_frequency                0.214286
cnv_and_ssm                  0.206349
msi_h_frequency              0.198413
top_cases_counts_by_genes    0.150794
Name: proportion, dtype: float64

In [116]:
stratified_sample['intent'].value_counts()

intent
freq_cnv_loss_or_gain        4702
ssm_frequency                4378
cnv_and_ssm                  4215
msi_h_frequency              4053
top_cases_counts_by_genes    3080
Name: count, dtype: int64

In [117]:
stratified_sample['intent'].value_counts(normalize=True)

intent
freq_cnv_loss_or_gain        0.230174
ssm_frequency                0.214314
cnv_and_ssm                  0.206334
msi_h_frequency              0.198404
top_cases_counts_by_genes    0.150773
Name: proportion, dtype: float64

In [118]:
stratified_sample.shape

(20428, 5)

In [129]:
intent_labels = {
    'msi_h_frequency': 1.0,
    'ssm_frequency': 0.0,
    'freq_cnv_loss_or_gain': 2.0,
    'top_cases_counts_by_genes': 3.0,
    'cnv_and_ssm': 4.0
}

In [130]:
def attach_labels(intent):
    return intent_labels[intent]

In [131]:
stratified_sample['true_label'] = stratified_sample['intent'].apply(lambda x: attach_labels(x))

In [132]:
stratified_sample.head()

Unnamed: 0,gene,mutation,cancer,questions,intent,true_label
1292,BRF1,Q517K,Uterine Corpus Endometrial Carcinoma,"In Uterine Corpus Endometrial Carcinoma, what ...",freq_cnv_loss_or_gain,2.0
477,RAD54L2,R706C,Uterine Corpus Endometrial Carcinoma,"In Uterine Corpus Endometrial Carcinoma, what ...",msi_h_frequency,1.0
704,CDH16,P224H,Uterine Corpus Endometrial Carcinoma,Can you tell me the frequency of CDH16 mutatio...,cnv_and_ssm,4.0
293,ADSS1,R79K,Breast Invasive Carcinoma,Can you tell me the frequency of ADSS1 mutatio...,cnv_and_ssm,4.0
1048,SLC4A5,R634H,Bladder Urothelial Carcinoma,How many patients with Bladder Urothelial Carc...,top_cases_counts_by_genes,3.0


In [141]:
stratified_sample['predicted_label'] = stratified_sample['questions'].progress_apply(
    lambda x: get_intent(x)
)

100%|██████████| 20428/20428 [03:29<00:00, 97.73it/s] 


In [142]:
stratified_sample.head()

Unnamed: 0,gene,mutation,cancer,questions,intent,true_label,predicted_label
1292,BRF1,Q517K,Uterine Corpus Endometrial Carcinoma,"In Uterine Corpus Endometrial Carcinoma, what ...",freq_cnv_loss_or_gain,2.0,2
477,RAD54L2,R706C,Uterine Corpus Endometrial Carcinoma,"In Uterine Corpus Endometrial Carcinoma, what ...",msi_h_frequency,1.0,1
704,CDH16,P224H,Uterine Corpus Endometrial Carcinoma,Can you tell me the frequency of CDH16 mutatio...,cnv_and_ssm,4.0,4
293,ADSS1,R79K,Breast Invasive Carcinoma,Can you tell me the frequency of ADSS1 mutatio...,cnv_and_ssm,4.0,4
1048,SLC4A5,R634H,Bladder Urothelial Carcinoma,How many patients with Bladder Urothelial Carc...,top_cases_counts_by_genes,3.0,3


In [143]:
model_predictions = stratified_sample['predicted_label']
true_labels = stratified_sample['true_label']

In [144]:
len(model_predictions), len(true_labels)

(20428, 20428)

In [145]:
from sklearn.metrics import classification_report, confusion_matrix

In [146]:
print(classification_report(true_labels, model_predictions))

              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00      4378
         1.0       1.00      1.00      1.00      4053
         2.0       1.00      1.00      1.00      4702
         3.0       1.00      1.00      1.00      3080
         4.0       1.00      1.00      1.00      4215

    accuracy                           1.00     20428
   macro avg       1.00      1.00      1.00     20428
weighted avg       1.00      1.00      1.00     20428



In [147]:
print(confusion_matrix(true_labels, model_predictions))


[[4378    0    0    0    0]
 [   0 4053    0    0    0]
 [   0    0 4702    0    0]
 [   0    0    0 3080    0]
 [   0    0    0    0 4215]]


In [148]:
# write to CSV
stratified_sample.to_csv('/opt/gpudata/aartiv/qag/data_for_intent_model/BERT_model_eval.csv', index=0)

### validation dataset
- used during training to monitor overfitting

In [149]:
# load validation dataset
validation_dataset = pd.read_csv('/opt/gpudata/aartiv/qag/eval_query_intent/val.csv', index_col=0)

In [122]:
validation_dataset.head()

Unnamed: 0,gene,mutation,cancer,questions,intent,label
127,HTR3B,S206I,Uterine Corpus Endometrial Carcinoma,"In Uterine Corpus Endometrial Carcinoma, what ...",msi_h_frequency,1.0
72,PAK3,E334K,Skin Cutaneous Melanoma,How many Skin Cutaneous Melanoma cases have th...,ssm_frequency,0.0
1348,MAST4,P1991S,Uterine Corpus Endometrial Carcinoma,How widespread is MAST4 LOH in Uterine Corpus ...,freq_cnv_loss_or_gain,2.0
999,PTPRO,E67K,Pancreatic Adenocarcinoma,What percentage of Pancreatic Adenocarcinoma c...,top_cases_counts_by_genes,3.0
1326,NRAP,S808L,Uterine Corpus Endometrial Carcinoma,What fraction of Uterine Corpus Endometrial Ca...,top_cases_counts_by_genes,3.0


In [128]:
validation_dataset[['intent', 'label']].drop_duplicates().set_index(['intent']).to_dict()

{'label': {'msi_h_frequency': 1.0,
  'ssm_frequency': 0.0,
  'freq_cnv_loss_or_gain': 2.0,
  'top_cases_counts_by_genes': 3.0,
  'cnv_and_ssm': 4.0}}

In [None]:
validation_dataset['predicted_label'] = validation_dataset['questions'].progress_apply(
    lambda x: get_intent(x)
)

In [17]:
model_predictions = validation_dataset['predicted_label']
true_labels = validation_dataset['label']

In [18]:
len(model_predictions), len(true_labels)

(19127, 19127)

In [21]:
print(classification_report(true_labels, model_predictions))


              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00      4108
         1.0       1.00      1.00      1.00      3756
         2.0       1.00      1.00      1.00      4442
         3.0       1.00      1.00      1.00      2888
         4.0       1.00      1.00      1.00      3933

    accuracy                           1.00     19127
   macro avg       1.00      1.00      1.00     19127
weighted avg       1.00      1.00      1.00     19127



In [22]:
print(confusion_matrix(true_labels, model_predictions))


[[4108    0    0    0    0]
 [   0 3756    0    0    0]
 [   0    0 4442    0    0]
 [   0    0    0 2888    0]
 [   0    0    0    0 3933]]


In [150]:
train_dataset = pd.read_csv('/opt/gpudata/aartiv/qag/eval_query_intent/train.csv', index_col=0)

In [151]:
train_dataset.head()

Unnamed: 0,gene,mutation,cancer,questions,intent,label
694,RIMS2,R619W,Skin Cutaneous Melanoma,How many Skin Cutaneous Melanoma cases have th...,ssm_frequency,0.0
1090,OR13C3,M147I,Skin Cutaneous Melanoma,Can you provide the prevalence of microsatelli...,msi_h_frequency,1.0
341,TENT5A,S401F,Skin Cutaneous Melanoma,What is the occurrence rate of TENT5A S401F in...,ssm_frequency,0.0
996,PTPRO,E249K,Pancreatic Adenocarcinoma,What percentage of Pancreatic Adenocarcinoma c...,cnv_and_ssm,4.0
201,GRTP1,A110V,Skin Cutaneous Melanoma,What percentage of Skin Cutaneous Melanoma cas...,top_cases_counts_by_genes,3.0


In [None]:
### check for overlap between train_dataset and test dataset and ensure none
overlap = pd.merge(
    train_dataset,
    stratified_sample,
    on=['gene', 'mutation'],
    how='inner'
)


In [153]:
overlap.shape

(0, 11)

In [154]:
overlap = pd.merge(
    validation_dataset,
    stratified_sample,
    on=['gene', 'mutation'],
    how='inner'
)

In [155]:
overlap.shape

(0, 11)