# Train BERT model to identify user inent

### Install libraries

In [None]:
%pip install requests
%pip install pandas
%pip install numpy

In [None]:
import json
import os
import requests
import pandas as pd
import numpy as np
import random

### Collate some gene, mutation and disease data combo from DC
- Using just top 1k from API

In [None]:
# get data for random genes and dump to file
curl_cmd = "https://api.gdc.cancer.gov/ssms?expand=occurrence.case.project&size=1000&pretty=true"

In [None]:
response = requests.get(curl_cmd)

In [None]:
response

In [None]:
result = response.json()

In [None]:
result['data']['hits'][:1]

In [None]:
print(len(result['data']['hits']))

### Dataset1
- Dump cosmic_id, gene, mutation and cancer information from results to a file

In [None]:
# note: this is just looking at first 1000 mutations/genes/cancers and expanding on that info

def write_genes_mutations_cancers_output(result):
  with open('../csvs/genes_mutations_cancers.tsv', 'w') as out_file:
    out_file.write('cosmic_id\tgene\tmutation\tcancer\n')
    for info in result['data']['hits']:
      try:
        cosmic_ids = info['cosmic_id']
        aa_changes = info['gene_aa_change']
        for occurrence in info['occurrence']:
          projects = occurrence['case']['project']['name']
          for c_id in cosmic_ids:
            for mut in aa_changes:
              gene, m = mut.split(' ')
              out_file.write('{}\t{}\t{}\t{}\n'.format(c_id, gene, m, projects ))
      except:
        pass

In [None]:
write_genes_mutations_cancers_output(result)

### Dataset2 
- Dump driver genes, mutation and cancer info from a paper to a file
- driver gene info is taken from supp data table 1 from the following paper (full driver gene list in cancer_driver_genes.txt)
- https://pmc.ncbi.nlm.nih.gov/articles/PMC10406856/#MOESM2
- short version is fig 2g data, in cancer_driver_genes_short.txt file

In [None]:
# driver gene info is taken from supp data table 1 from the following paper (full driver gene list in cancer_driver_genes.txt)
# https://pmc.ncbi.nlm.nih.gov/articles/PMC10406856/#MOESM2
# short version is fig 2g data, in cancer_driver_genes_short.txt file
with open('/opt/gpudata/aartiv/rag_rig/cancer_driver_genes_short.txt') as gene_drivers:
    driver_gene_list = gene_drivers.read().splitlines()

In [None]:
def get_driver_gene_mutation_data(driver_gene_list):
  with open('../csvs/genes_mutations_cancers_driver_genes.tsv', 'w') as out_file:
    out_file.write('cosmic_id\tgene\tmutation\tcancer\n')
    for gene in driver_gene_list:
        endpoint = 'https://api.gdc.cancer.gov/ssms'

        expand = ["occurrence.case.project"]

        filters = {
            "op": "=",
            "content": {
                "field": "consequence.transcript.gene.symbol",
                "value": gene
            }
        }
        params = {
            "filters": json.dumps(filters),
            "expand": expand,
            "size": 10000
        }
        response = requests.get(endpoint, params=params)
        result = response.json()
        for info in result['data']['hits']:
            try:
                cosmic_ids = info['cosmic_id']
                aa_changes = info['gene_aa_change']
                for occurrence in info['occurrence']:
                    projects = occurrence['case']['project']['name']
                    for c_id in cosmic_ids:
                        for mut in aa_changes:
                            gene, m = mut.split(' ')
                            out_file.write('{}\t{}\t{}\t{}\n'.format(c_id, gene, m, projects ))
            except:
                pass

In [None]:
get_driver_gene_mutation_data(driver_gene_list)

## Define templates for question variety and BERT model training

In [None]:
# generate synthetic data

# asked phoenix AI to produce 25 different questions for each of the
# templates below

# template 1 = 'what is the frequency of mutation x in cancer y?
# template 2 = 'what is the prevalence of MSI High in cancer x?'
# template 3 = 'which gene is most frequently mutated in prostate cancer?'
# template 4 = 'what is the frequency of aneuploidy in breast cancers'?

# cnv question
# template 5 =  'what is the frequency of 1p19q loss in GBM'?
# template 6 = 'what are the top 15 mutated genes in breast cancer'?

template1_questions = [
    'What percentage of [cancer] cases have the [gene_x] [mutation_x] mutation?',
    'How widespread are [gene_x] [mutation_x] mutations in [cancer]?',
    'How often is the [gene_x] [mutation_x] found in [cancer]?',
    'What is the occurrence rate of [gene_x] [mutation_x] in [cancer] patients?',
    'How prevalent is the [gene_x] [mutation_x] mutation in cases of [cancer]?',
    'What is the incidence of the [gene_x] [mutation_x] mutation in [cancer]?',
    'What is the frequency of [gene_x] [mutation_x] mutation in [cancer]?',
    'How common is the [gene_x] [mutation_x] mutation in [cancer]?',
    'What fraction of [cancer] patients have the [gene_x] [mutation_x] mutation?',
    'How widespread is the [gene_x] [mutation_x] mutation in [cancer]?',
    'What is the rate of occurrence of the [gene_x] [mutation_x] mutation in [cancer]?',
    'How frequently is the [gene_x] [mutation_x] mutation detected in [cancer]?',
    'what percentage of cases have mutations in both [gene_x] [mutation_x] and [gene_y] [mutation_y] mutations in [cancer]?',
    'what proportion of [cancer] patients have mutations in both [gene_x] [mutation_x] and [gene_y] [mutation_y]?',
    'What proportion of [cancer] patients exhibit the [gene_x] [mutation_x] mutation?',
    'Can you tell me the rate of [gene_x] [mutation_x] mutation in [cancer]?',
    'What is the percentage of [cancer] tumors with the [gene_x] [mutation_x] mutation?',
    'What is the prevalence rate of [gene_x] [mutation_x] mutation in [cancer]?',
    'How many [cancer] cases have the [gene_x] [mutation_x] mutation?',
    'What is the prevalence rate of the [gene_x] [mutation_x] mutation in [cancer]?',
    'What is the distribution of the [gene_x] [mutation_x] mutation in [cancer]?',
    'How high is the prevalence of the [gene_x] [mutation_x] mutation in [cancer]?',
    'What is the presence rate of the [gene_x] [mutation_x] mutation in [cancer]?',
    'How pervasively is the [gene_x] [mutation_x] mutation present in [cancer]?',
    'What is the occurrence of [gene_x] [mutation_x] mutation in [cancer]?',
    'What is the [gene_x] [mutation_x] mutation rate in [cancer]?',
    'How frequent is the [gene_x] [mutation_x] mutation in [cancer] cases?'
]
template1_answers = ['ssm_frequency'] * len(template1_questions)


In [None]:
# template 2 = 'what is the prevalence of MSI High in cancer x?'
template2_questions = [
  'What is the frequency of microsatellite instability-high in [cancer]?',
  'How common is MSI-H in [cancer] cases?',
  'Can you provide the prevalence of microsatellite instability-high in [cancer] patients?',
  'What percentage of [cancer] patients have MSI-H?',
  'How often is microsatellite instability-high observed in [cancer]?',
  'In [cancer], what is the occurrence rate of MSI-H?',
  'What is the incidence of MSI-H in [cancer]?',
  'How prevalent is microsatellite instability-high in [cancer]?',
  'What proportion of [cancer] cases exhibit MSI-H?',
  'How frequently is MSI High found in [cancer]?',
  'What is the distribution rate of microsatellite instability-high among [cancer] patients?',
  'How widespread is microsatellite instability-high in [cancer]?',
  'Can you tell me how prevalent MSI High is in [cancer]?',
  'What percentage of [cancer] cases are characterized by microsatellite instability-high (MSI-H)?',
  'How common is MSI High status in cases of [cancer]?',
  'What is the rate of microsatellite instability-high in [cancer] diagnoses?',
  'Among [cancer] patients, what is the frequency of msi-h?',
  'How often does msi-h occur in [cancer] patients?',
  'What is the detection rate of msi-h in [cancer]?',
  'How many [cancer] patients exhibit microsatellite instability-high?',
  'What is the percentage occurrence of MSI High in [cancer]?',
  'How frequently do [cancer] patients present with microsatellite instability-high?',
  'What is the proportion of MSI High tumors in [cancer]?',
  'Can you share the incidence rate of msi-high in [cancer]?',
  'What fraction of [cancer] cases involve msi-high?'
]
template2_answers = ['msi_h_frequency'] * len(template2_questions)


In [None]:
"""Below not in V1

# template 3 = 'which gene is most frequently mutated in prostate cancer?'
template3_questions = [
    'What is the most frequently mutated gene in [cancer]?',
    'Can you tell me the gene that is most often mutated in [cancer]?',
    'Which gene is mutated most frequently in [cancer] cases?',
    'What is the predominant gene mutation in [cancer]?',
    'In [cancer], which gene shows the highest mutation frequency?',
    'What gene is most commonly altered in [cancer]?',
    'Which gene has the highest mutation rate in [cancer]?',
    'Can you provide the most frequently mutated gene in [cancer]?',
    'Which gene is mutated the most in [cancer] patients?',
    'What is the top mutated gene in [cancer]?',
    'Which gene experiences the most frequent mutations in [cancer]?',
    'What gene is most frequently changed in [cancer]?',
    'Can you tell me which gene is most frequently mutated in [cancer]?',
    'Which specific gene is most frequently mutated in [cancer]?',
    'What is the gene most often altered in [cancer] patients?',
    'What is the gene with the highest mutation occurrence in [cancer]?',
    'Can you identify the gene most frequently mutated in [cancer]?',
    'In [cancer], which gene has the highest incidence of mutation?',
    'What is the most commonly mutated gene in [cancer]?'
]
template3_answers = ['most_frequently_mutated_gene'] * len(template3_questions)
"""


In [None]:
# template 3 what is the freq of 1p19q or gene loss in GBM?
template3_questions = [
    'What is the incidence of [gene_x] loss in [cancer]?',
    'What is the frequency of somatic [gene_x] heterozygous deletion in [cancer]?',
    'What is the frequency of somatic [gene_x] homozygous deletion in [cancer]?',
    'How common is [gene_x] and [gene_y] codeletion in [cancer] cases?',
    'What percentage of [cancer] patients have somatic [gene_x] LOH?',
    'How often is [gene_x] LOH observed in [cancer]?',
    'In [cancer], what is the occurrence rate of [gene_x] loss?',
    'What is the prevalence of [gene_x] and [gene_y] loss of heterozygosity in [cancer]?',
    'How prevalent is [gene_x] gain in [cancer] patients?',
    'How prevalent is somatic [gene_x] amplification in [cancer] patients?',
    'What proportion of [cancer] cases exhibit [gene_x] and [gene_y] codeletion?',
    'What is the joint frequency of somatic [gene_x] and [gene_y] homozygous deletions in [cancer]?',
    'What is the frequency of somatic [gene_x] and [gene_y] heterozygous co-deletion in [cancer]?',
    'How frequently is [gene_x] LOH found in [cancer]?',
    'What is the distribution rate of [gene_x] and [gene_y] gain among [cancer] patients?',
    'How widespread is [gene_x] LOH in [cancer]?',
    'Can you tell me how prevalent [gene_x] and [gene_y] loss is in [cancer]?',
    'What percentage of [cancer] are characterized by codeletion of [gene_x] and [gene_y]?',
    'How common is [gene_y] LOH in [cancer] cases?',
    'What is the rate of [gene_x] LOH in [cancer] diagnoses?',
    'Among [cancer] patients, what is the frequency of [gene_y] gain?',
    'How often does [gene_x] loss occur in [cancer] patients?',
    'What is the detection rate of [gene_x] LOH in [cancer]?',
    'How many [cancer] patients exhibit cogains in [gene_x] and [gene_y]?',
    'What is the percentage occurrence of [gene_y] loss of heterozygosity in [cancer]?',
    'How frequently do [cancer] patients present with [gene_y] loss of heterozygosity?',
    'What is the proportion of [gene_y] and [gene_x] codeletion in [cancer]?',
    'Can you share the incidence rate of [gene_x] loss in [cancer]?',
    'What fraction of [cancer] cases involve [gene_x] loss?'
]
template3_answers = ['freq_cnv_loss_or_gain'] * len(template3_questions)


In [None]:
"""Not for V1

# template 5 = 'what are the top 15 mutated genes in breast cancer'?
template5_questions = [
    'What are the most commonly mutated 15 genes in [cancer]?',
    'Can you list the top 10 mutated genes in [cancer]?',
    'Which 5 genes are most frequently mutated in [cancer]?',
    'Which 7 genes have the highest mutation rates in [cancer]?'
    'Can you provide the 8 most frequently mutated genes in [cancer]?',
    'What 3 genes are most often mutated in [cancer] patients?',
    'What are the top 4 genes with mutations in [cancer]?',
    'Which 17 genes are most commonly altered in [cancer]?',
    'Can you identify the top 18 mutated genes in [cancer]?',
    'Which 9 genes exhibit the highest mutation frequency in [cancer]?',
    'What are the primary 3 genes mutated in [cancer] cases?',
    'Can you tell me the top 8 mutated genes in [cancer]?',
    'What are the major 7 gene mutations in [cancer]?',
    'Which 6 genes are most frequently altered in [cancer]?',
    'What are the top 12 genes with the highest mutation rates in [cancer]?',
    'What 14 genes are most frequently mutated in cases of [cancer]?',
    'Which 19 genes are most often found to be mutated in [cancer]?'
]
template5_answers = ['top_mutated_genes_by_project'] * len(template5_questions)
"""

In [None]:
"""Not for V1

# template 6 = /analysis/mutated_cases_count_by_project
# template 6 = 'Can you list the counts for the number of cases that have associated simple somatic mutation data in each cancer genome atlas project?'

template6_questions = [
    'Could you provide the counts of cases with simple somatic mutation data for each Cancer Genome Atlas project?',
    'What are the numbers of cases with associated simple somatic mutation data in each TCGA project?',
    'Can you list the number of cases that include simple somatic mutation data across all Cancer Genome Atlas projects?',
    'Please enumerate the case counts having simple somatic mutation data within each TCGA project.',
    'Could you detail the counts of cases that contain simple somatic mutation data for each project in the Cancer Genome Atlas?',
    'What is the number of cases with simple somatic mutation data in each of the TCGA projects?',
    'Can you provide the number of cases with simple somatic mutation data for each project in the Cancer Genome Atlas?',
    'Could you list how many cases involve simple somatic mutation data in each TCGA project?',
    'Please provide the counts of cases featuring simple somatic mutation data in each TCGA project.',
    'Can you give the counts of cases that have simple somatic mutation data within the Cancer Genome Atlas projects?',
    'What are the counts of cases with simple somatic mutation data in each of the Cancer Genome Atlas projects?',
    'Could you share the counts of cases having simple somatic mutation data across the various TCGA projects?',
    'Can you tell me how many cases with simple somatic mutation data exist in each TCGA project?',
    'Could you list the number of cases featuring simple somatic mutation data for each Cancer Genome Atlas project?',
    'What are the numbers of cases with simple somatic mutation data in different TCGA projects?',
    'Could you describe the case counts with simple somatic mutation data for each of the Cancer Genome Atlas projects?',
    'Can you list the counts of cases having simple somatic mutation data in each project under the TCGA?',
    'Could you enumerate how many cases include simple somatic mutation data in the various TCGA projects?',
    'What counts of cases with simple somatic mutation data are there for each TCGA project?',
    'Could you provide the tally of cases that contain simple somatic mutation data in each Cancer Genome Atlas project?'
]
template6_answers = ['mutated_cases_count_by_project'] * len(template6_questions)
"""



In [None]:
# template4 /analysis/top_cases_counts_by_genes used for cohort comparisons/survival analysis
# number of cases w/ mutations in a gene or genes, and number of cases w/o mutations in genes
# this template refers to any mutation in the gene, no specific ssm

template4_questions = [
    'What is the incidence of [gene_x] or [gene_y] mutations in cases of [cancer] in the genomic data commons?',
    'How frequently are [gene_x] or [gene_y] mutations observed in [cancer]?',
    'What percentage of [cancer] cases exhibit mutations in [gene_x] genes?',
    'What proportion of [cancer] cases have mutated [gene_x] or [gene_y]?',
    'How many patients with [cancer] carry mutations in either [gene_x] or [gene_y]?',
    'What is the prevalence of mutations in [gene_x] or [gene_y] among [cancer] cases?',
    'In cases of [cancer], how many exhibit [gene_x] mutations in the genomic data commons?',
    'Are [gene_x] or [gene_y] mutations present in many [cancer] cases?',
    'What is the rate of [gene_x] mutations in [cancer] tumors?',
    'What is the number of [cancer] cases without [gene_x] and [gene_y] mutations?',
    'How many [cancer] patients lack mutations in [gene_y] in the genomic data commons?',
    'What proportion of [cancer] do not exhibit mutations in [gene_x] and [gene_y]?',
    'How frequently are [cancer] without [gene_x] mutations observed?',
    'What percentage of [cancer] cases do not have [gene_y] mutations?',
    'How rare is it to find [cancer] without mutations in [gene_x] and [gene_y]?',
    'How many [cancer] cases lack [gene_x] and [gene_y] gene mutations?',
    'What fraction of [cancer] cases do not possess mutations in [gene_x] and [gene_y]?',
    'Do a substantial number of [cancer] not have [gene_x] and [gene_y] mutations?',
    'What share of [cancer] cases is without [gene_x] and [gene_y] mutations?'
]
template4_answers = ['top_cases_counts_by_genes'] * len(template4_questions)


In [None]:
"""Not for V1

# template8 questions: Can you give me the project summary for TCGA-BRCA project
template8_questions = [
    'Could you provide a brief overview of the [cancer] project?',
    'Can you summarize the data in the [cancer] project for me?',
    'Can you summarize the data in the [cancer] project for me, including data categories and experimental strategies?',
    'What is the project summary for [cancer]?',
    'Can you give me a summary of the [cancer] project, including data and experimental categories?',
    'What are the key summary details of the [cancer] project?',
    'Can you provide a high-level summary of the [cancer] project, in terms of number of files for each experiment and types of data collected?',
    'Can you offer a brief description of the [cancer] project?',
    'Could you give an overview of what the [cancer] project entails?',
    'Can you share a concise summary of the [cancer] project, including file counts and types of data?',
    'Can you break down the [cancer] project for me, in terms of types of data and experiments?',
    'Can you provide the number of cases and files in [cancer] project for different data categories and experiments?',
    'Can you provide the number of cases for biospecimens, clinical data, copy number variation, DNA Methylation in [cancer] project?',
    'Can you detail the number of cases for proteome profiling, sequencing reads, simple nucleotide variation, somatic structural variation, structural variation and transcriptome profiling for [cancer] project?',
    'Can you list the number of files for proteome profiling, sequencing reads, simple nucleotide variation, somatic structural variation, structural variation and transcriptome profiling for [cancer] project?',
    'Can you list cases and file counts by data category for [cancer] project?',
    'Can you provide a breakdown of number of cases for various experimental strategies such as ATAC-Seq, genotyping and methylation in [cancer] project?',
    'Can you provide a breakdown of number of cases for various experimental strategies such as miRNA, reverse phase protein array, RNA-Seq, WGS and WXS in [cancer] project?',
    'Can you give me a summary of number of cases for various experimental strategies such as miRNA, reverse phase protein array, transcriptomics, whole genome and whole exome sequencing in [cancer] project?',
    'How many cases and files are available for diagnostic and tissue slides in [cancer] project?'
]
template8_answers = ['project_summary'] * len(template8_questions)
"""


In [None]:
# cnv and ssm combination
template5_questions = [
    'what is the frequency of cases with somatic mutations in [gene_x] and amplifications in [gene_y] in [cancer] in the genomic data commons?',
    'what is the frequency of cases with amplifications in [gene_x] and mutations in [gene_y] in [cancer] in the genomic data commons?',
    'How often do cases of [cancer] exhibit mutations in [gene_x] and homozyous deletions in [gene_y]?',
    'What is the prevalence of [gene_x] somatic mutations and [gene_y] heterozygous deletions in [cancer] cases?',
    'Can you tell me the frequency of [gene_x] mutations co-occurring with [gene_y] amplifications in [cancer]?',
    'How common are mutations in [gene_x] and amplifications in [gene_y] among patients with [cancer]?',
    'What percentage of [cancer] cases present with mutations in [gene_x] and somatic homozygous deletions in [gene_y]?',
    'To what extent do [cancer] cases show somatic mutations in [gene_x] along with somatic amplifications in [gene_y]?',
    'How frequently are [gene_x] somatic mutations and somatic [gene_y] amplifications found in [cancer] patients?',
    'What is the incidence rate of mutations in [gene_x] paired with somatic heterozygous deletions in [gene_y] in [cancer]?',
    'What portion of [cancer] cases have both [gene_x] mutations and [gene_y] gains?',
    'How prevalent are simultaneous mutations in [gene_x] and gains in [gene_y] in [cancer]?',
    'Could you provide the rate of [gene_x] mutations and [gene_y] deletions among [cancer] cases?',
    'What is the occurrence of [gene_x] mutations alongside [gene_y] amplifications in [cancer]?',
    'How often are concurrent somatic mutations in [gene_x] and deletions in [gene_y] observed in [cancer]?',
    'What is the proportion of [cancer] cases with both [gene_x] mutations and [gene_y] amplifications?',
    'In cases of [cancer], how common are mutations in [gene_x] and amplifications in [gene_y]?',
    'What is the rate at which [gene_x] somatic mutations and somatic [gene_y] homozygous deletions coincide in [cancer]?',
    'How frequently do [gene_x] mutations and [gene_y] amplifications occur together in [cancer] cases?',
    'What share of [cancer] cases display both somatic [gene_x] mutations and somatic [gene_y] losses?',
    'What is the joint frequency of somatic [gene_x] mutations and [gene_y] losses in [cancer]?',
    'How prevalent is the co-occurrence of [gene_x] mutations and [gene_y] amplifications in [cancer]?',
    'What is the frequency of mutations in [gene_x] coupled with gains in [gene_y] in [cancer]?'
    'How often do mutations in [gene_x] and amplifications in [gene_y] happen in [cancer]?',
    'What is the incidence of [gene_x] mutations along with [gene_y] losses in [cancer]?',
    'How common are both [gene_x] mutations and [gene_y] amplifications within [cancer] cases?',
    'Can you specify the frequency at which [gene_x] mutations and [gene_y] gains are seen in [cancer]?'
]
template5_answers = ['cnv_and_ssm'] * len(template5_questions)



In [None]:
all_template_questions = template1_questions + template2_questions + template3_questions + template4_questions + template5_questions
all_template_answers = template1_answers + template2_answers + template3_answers + template4_answers + template5_answers

In [None]:
len(all_template_questions), len(all_template_answers)

# Process different types of data for query and intent

In [None]:
# Dataset1
gdc_data = pd.read_csv('../csvs/genes_mutations_cancers.tsv', sep='\t')

# Dataset2
# driver genes -- you could pare this down further perhaps to top 100?
# gdc_data = pd.read_csv('genes_mutations_cancers_driver_genes.tsv', sep='\t')


In [None]:
gdc_data.shape

In [None]:
len(template1_questions), len(template2_questions), len(template3_questions), len(template4_questions), len(template5_questions)

In [None]:
len(all_template_questions)

**Synthetic question answer generation**

**Dataset 1**

Generate a simpler and smaller df for first NER test
 - Remove = or * from mutation nomenclature
 - Remove gene/mutation duplicates


In [None]:
gdc_data['simple_mutation'] = gdc_data['mutation'].apply(lambda x: x.isalnum())

In [None]:
dataset1 = gdc_data[gdc_data['simple_mutation']].copy()

In [None]:
dataset1.drop_duplicates(['gene', 'mutation'], inplace=True)

In [None]:
dataset1.head()

In [None]:
dataset1.shape

In [None]:
# total number of question/intent pairs
len(all_template_questions)*dataset1.shape[0]

In [None]:
def generate_q_a(row, gene_list, mutation_list, template_index):
  # generate question
  gene_y = random.choice(gene_list)
  mutation_y = random.choice(mutation_list)

  try:
    question = all_template_questions[template_index]\
      .replace("[mutation_x]", row['mutation'])\
      .replace("[cancer]", row['cancer']) \
      .replace("[gene_x]", row['gene']) \
      .replace("[gene_y]", gene_y)\
      .replace("[mutation_y]", mutation_y)
    intent = all_template_answers[template_index]

  except Exception as e:
    print('unable to generate question {}'.format(str(e)))

  return question, intent


In [None]:
# for double gene/mutation questions, get the list to randomly sample second one from
gene_list = list(dataset1['gene'])
mutation_list = list(dataset1['mutation'])


In [None]:
dataset1.shape

In [None]:
questions_intent_df = pd.DataFrame(columns=['gene', 'mutation', 
                                            'cancer', 'questions', 
                                            'intent'])

In [None]:
# generate a question for every row in dataset1
dataset1_copy = dataset1.copy()

for template_index in range(len(all_template_questions)):
  dataset1_copy[['questions', 'intent']] = dataset1_copy.apply(
    lambda row: generate_q_a(row, gene_list, mutation_list, template_index),
    axis=1,result_type='expand'
    )
  questions_intent_df = pd.concat(
    [questions_intent_df, 
     dataset1_copy[['gene', 'mutation', 'cancer', 'questions', 'intent']]
    ]
  )

In [None]:
questions_intent_df.shape

In [None]:
questions_intent_df.head(n=6)

### `questions_intent_df` used to train BERT model

In [None]:
questions_intent_df.to_csv('../csvs/dataset1.query.intent.gdc.csv', sep=',', header=True)

**Train BERT model to identify intent from query**


In [None]:
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
import torch

In [None]:
# map intent to labels
intent_labels = { intent:float(idx) for idx, intent in enumerate(questions_intent_df['intent'].unique())}

In [None]:
intent_labels

In [None]:
questions_intent_df['label'] = questions_intent_df['intent'].map(intent_labels)

In [None]:
len(questions_intent_df['label'])

In [None]:
questions_intent_df.head(n=6)

**split into training and validation sets**

In [None]:
from sklearn.model_selection import train_test_split


In [None]:
# if you want to train on more rows, you need to move to
# our GPU cluster. It takes half a day or so to train on 50k rows in colab notebook
train_df, val_df = train_test_split(questions_intent_df, train_size=0.7, random_state=42)


In [None]:
train_df.shape, val_df.shape

In [None]:
train_df.to_csv('../csvs/train.csv', sep=',', header=True)
val_df.to_csv('../csvs/val.csv', sep=',', header=True)

In [None]:
train_texts = train_df['questions'].tolist()
train_labels = train_df['label'].tolist()
val_texts = val_df['questions'].tolist()
val_labels = val_df['label'].tolist()



In [None]:
print('train_texts {}'.format(len(train_texts)))
print('train_labels {}'.format(len(train_labels)))
print('val_texts {}'.format(len(val_texts)))
print('val_labels {}'.format(len(val_labels)))

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')


In [None]:
def convert_labels(labels):
  # convert labels to one-hot encoded labels
  one_hot_labels = torch.zeros(len(labels), len(np.unique(labels)))
  one_hot_labels = one_hot_labels.scatter_(1, torch.tensor(np.array(labels), dtype=torch.int64).unsqueeze(1), 1)

  return one_hot_labels

In [None]:
train_labels = convert_labels(train_labels)
val_labels = convert_labels(val_labels)

In [None]:
train_labels

In [None]:
# Tokenize the data
train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)


In [None]:
train_encodings.keys()

In [None]:
len(train_encodings['input_ids'])

In [None]:
class IntentDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        # item['labels'] = torch.tensor(self.labels[idx])
        item['labels'] = self.labels[idx].clone().detach()

        return item

    def __len__(self):
        return len(self.labels)

In [None]:
train_dataset = IntentDataset(train_encodings, train_labels)


In [None]:
train_dataset.__getitem__(1)

In [None]:
train_dataset.__len__()

In [None]:
val_dataset = IntentDataset(val_encodings, val_labels)


**BertSequenceClassification Model**


In [None]:
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=len(intent_labels))

In [None]:
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=2,
    weight_decay=0.01,
    logging_dir='./logs',
    report_to=['none']
)


In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)


In [None]:
trainer.train()


In [None]:
trainer.evaluate()

**Save the model and tokenizer**

In [None]:
torch.save(model, 'query_intent_model.pt')


**Load and test the model**

In [None]:
model_load = torch.load('query_intent_model.pt')

In [None]:
test_query_list = [
    'How frequently are JAK2 V617F mutations detected in Lymphoid Leukemia in the genomic data commons?',
    'What is the incidence of simple somatic mutations in JAK2 in the genomic data commons for Lymphoid Leukemia ?',
    'How frequently are MAP2K1 P124S mutations detected in Melanoma in the genomic data commons?',
    "what is the frequency of BRAF V600E in breast cancers?",
    "what is the prevalence of MSI H in colorectal cancers?",
    'what is the frequency of BRCA1 LOH in TCGA-BRCA project?',
    'how many low grade gliomas have mutations in IDH1 or IDH2?',
    'How many low grade gliomas have a mutation in IDH1 or IDH2 in the genomic data commons?',
    'how widespread are KRAS G12D mutations in pancreatic cancer?',
    'what percentage of cases have mutations in both KRAS G12D and BRAF V600E mutations in the TCGA-BRCA project?',
    'what is the frequency of cases with mutations in ATRX and EGFR amplification in the TCGA-LGG project?',
    'Could you provide the rate of BRCA1 mutations and BRCA2 deletions among breast cancer cases?',
    'what is the frequency of cases with somatic mutations in TNS1 and amplifications in STK19 in [cancer] in the genomic data commons?'
]

**predict labels for text**


In [None]:
for test_query in test_query_list:
  # set device and load both model and query on the same device
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  model.to(device)
  print('testing query: {}'.format(test_query))
  inputs = tokenizer(test_query, return_tensors="pt", truncation=True, padding=True)
  inputs = {k:v.to(device) for k, v in inputs.items()}
  # pass tokenized input through the model
  outputs = model(**inputs)
  print('output logits {}'.format(outputs))
  #outputs are logits, need to apply softmax to convert to probs
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
  print('probs: {}'.format(probs))
  predicted_label = torch.argmax(probs, dim=1).item()
  for k,v in intent_labels.items():
    if v == predicted_label:
      print('predicted label: {}\n'.format(k))