In [1]:
import json
import re
import torch

from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer

In [2]:
print("CUDA GPU backend", torch.cuda.is_available())


CUDA GPU backend True


In [3]:

ds = load_dataset("qiaojin/PubMedQA", "pqa_unlabeled")

print(ds)

DatasetDict({
    train: Dataset({
        features: ['pubid', 'question', 'context', 'long_answer'],
        num_rows: 61249
    })
})


In [4]:
SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and given a context, the Assistant solves it. The assistant "
    "first thinks about the context, then runs a reasoning process in the mind and finally provides the user with the answer. The reasoning "
    "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
    "<think> reasoning process here </think><answer> answer here </answer>"
)

def make_conversation(example):
    return {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": "".join(example["context"]['contexts'])},
            {"role": "user", "content": example["question"]},
        ],
    }

ds = ds.map(make_conversation)
train_dataset = ds.remove_columns(['pubid', 'question', 'context'])

print(train_dataset)

DatasetDict({
    train: Dataset({
        features: ['long_answer', 'prompt'],
        num_rows: 61249
    })
})


In [5]:
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")

In [6]:
def extract_reasoning_and_graph(example):
    prompt_text = example["prompt"]
    input_text = "\n".join([entry["content"] for entry in prompt_text])
    print("input", input_text)
    inputs = tokenizer(input_text, return_tensors="pt")
    print("input tokens", inputs)
    outputs = model.generate(**inputs, max_new_tokens=300)
    print("outputs", outputs)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print("response", response)
    
    # Extract reasoning using regex
    think_match = re.search(r"<think>(.*?)</think>", response, re.DOTALL)
    reasoning_process = think_match.group(1) if think_match else ""
    
    # Convert reasoning into a simple graph (JSON format)
    graph_structure = {"nodes": [], "edges": []}
    steps = reasoning_process.split(". ")
    for i, step in enumerate(steps):
        graph_structure["nodes"].append({"id": i, "text": step})
        if i > 0:
            graph_structure["edges"].append({"source": i - 1, "target": i})
    
    return {"reasoning_graph": json.dumps(graph_structure), "model_response": response}

In [7]:
dso = ds['train'].select(range(2)).map(extract_reasoning_and_graph)


Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


input A conversation between User and Assistant. The user asks a question, and given a context, the Assistant solves it. The assistant first thinks about the context, then runs a reasoning process in the mind and finally provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>
Although the use of alternative medicine in the United States is increasing, no published studies have documented the effectiveness of naturopathy for treatment of menopausal symptoms compared to women receiving conventional therapy in the clinical setting.To compare naturopathic therapy with conventional medical therapy for treatment of selected menopausal symptoms.A retrospective cohort study, using abstracted data from medical charts.One natural medicine and six conventional medical clinics at Community Health Centers of King County, Washington, 

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


outputs tensor([[151646,     32,  10435,   1948,   2657,    323,  21388,     13,    576,
           1196,  17064,    264,   3405,     11,    323,   2661,    264,   2266,
             11,    279,  21388,  67477,    432,     13,    576,  17847,   1156,
          15482,    911,    279,   2266,     11,   1221,   8473,    264,  32711,
           1882,    304,    279,   3971,    323,   5499,   5707,    279,   1196,
            448,    279,   4226,     13,    576,  32711,   1882,    323,   4226,
            525,  43810,   2878,    220, 151648,    220, 151649,    323,    366,
           9217,     29,    690,   9217,     29,   9492,     11,  15576,     11,
            600,   1734,   2572,    220, 151648,  32711,   1882,   1588,    220,
         151649,     27,   9217,     29,   4226,   1588,    690,   9217,    397,
          15802,    279,    990,    315,  10555,  15712,    304,    279,   3639,
           4180,    374,   7703,     11,    902,   4652,   7822,    614,  26372,
            279,  26

In [8]:
print(dso)

Dataset({
    features: ['pubid', 'question', 'context', 'long_answer', 'prompt', 'reasoning_graph', 'model_response'],
    num_rows: 2
})


In [9]:
dso['reasoning_graph']

['{"nodes": [{"id": 0, "text": " "}], "edges": []}',
 '{"nodes": [{"id": 0, "text": " "}], "edges": []}']

In [None]:
dso['model_response']