![NVIDIA Logo](images/nvidia.png)

# PubMedQA Dataset

In this notebook we will familiarize ourselves with the PubMedQA data, in preparation for subsequent experiments.

---

## Learning Objectives

By the time you complete this notebook you will:
- Be familiar with the format of the PubMedQA dataset.

---

## Imports

In [None]:
import json
import random
from itertools import islice

---

## Clone PubMedQA

[PubMedQA](https://pubmedqa.github.io) is a dataset designed for question answering and biomedical natural language processing research. It's based on abstracts from PubMed, a free search engine accessing primarily the MEDLINE database of references and abstracts on life sciences and biomedical topics. For a deeper dive check out the [PubMedQA Paper](https://users.cs.duke.edu/~bdhingra/papers/pubmedqa.pdf).

We will be cloning the PubMedQA github repo to get access to the PubMedQA data we will be using for P-tuning.

In [None]:
!rm -rf pubmedqa
!git clone https://github.com/pubmedqa/pubmedqa

---

## Load Data

`ori_pqal.json` contains 1000 labeled samples.

In [None]:
pubmed_data = json.load(open('pubmedqa/data/ori_pqal.json'))

In [None]:
len(pubmed_data)

---

## View Raw sample

Below is a single raw sample from the dataset.

In [None]:
for id, info in islice(pubmed_data.items(), 1):
    print(info)

---

## Formatted Samples

For our purposes we are interested, for each sample, in:
- `'QUESTION'`, which is intended to be answered as either "yes", "no" or "maybe".
- `'CONTEXTS'` which provide information relevant to the question. Each sample may have several contexts, and each context has a corresponding context label, provided in `'LABELS`'.
- `'final_decision'` which is the correct answer.

Here we print one sample for each of the 3 possible answers, including the question, contexts and label.

In [None]:
possible_answers = {'yes', 'no', 'maybe'}
for sample in pubmed_data.values():
    if not len(possible_answers):
        break
    label = sample['final_decision']
    if label in possible_answers:
        possible_answers.remove(label)

        print('CONTEXTS\n--------\n')
        for context_label, context in zip(sample['LABELS'], sample['CONTEXTS']):
            print(f"{context_label}: {context}\n")
        print(f'QUESTION: {sample['QUESTION']}\n')
        print(f'LABEL: {label}')
        print('\n---\n')

---

## Split Data

We have provided the splits for the PubMedQA dataset for you, but if you're curious here you can view how we performed the splits.

```python
# Load the JSON data
with open('pubmedqa/data/ori_pqal.json', 'r') as file:
    data = json.load(file)

# Convert the dictionary to a list of items and shuffle
items = list(data.items())
random.shuffle(items)

# Split the data
train_split = int(0.7 * len(items))  # 70% for training
validate_split = int(0.85 * len(items))  # Additional 15% for validation

train_items = items[:train_split]
validate_items = items[train_split:validate_split]
test_items = items[validate_split:]

# Convert the lists back to dictionaries
train_data = dict(train_items)
validate_data = dict(validate_items)
test_data = dict(test_items)

print(len(train_data))
print(len(validate_data))
print(len(test_data))

# Save the splits to new JSON files
with open('data/pubmedqa_train.json', 'w') as file:
    json.dump(train_data, file, indent=4)

with open('data/pubmedqa_validate.json', 'w') as file:
    json.dump(validate_data, file, indent=4)

with open('data/pubmedqa_test.json', 'w') as file:
    json.dump(test_data, file, indent=4)
```