# Doctor RL - Chatbot Interface
## Based on SFT and GRPO trained Qwen3-1.7B

# Environment Setup

In [None]:
try: import numpy, PIL; get_numpy = f'numpy=={numpy.__version__}'; get_pil = f'pillow=={PIL.__version__}'
except: get_numpy = 'numpy'; get_pil = 'pillow'
try: import subprocess; is_t4 = 'Tesla T4' in str(subprocess.check_output(['nvidia-smi']))
except: is_t4 = False
get_vllm, get_triton = ('vllm==0.9.2', 'triton==3.2.0') if is_t4 else ('vllm==0.10.2', 'triton')
!uv pip install -qqq --upgrade unsloth {get_vllm} {get_numpy} {get_pil} torchvision bitsandbytes xformers
!uv pip install -qqq {get_triton}
!uv pip install transformers==4.56.2

[2mUsing Python 3.12.12 environment at: /usr[0m
[37m⠋[0m [2mResolving dependencies...                                                     [0m[2K[37m⠙[0m [2mResolving dependencies...                                                     [0m[2K[37m⠋[0m [2mResolving dependencies...                                                     [0m[2K[37m⠙[0m [2mResolving dependencies...                                                     [0m[2K[37m⠙[0m [2mtransformers==4.56.2                                                          [0m[2K[37m⠙[0m [2mfilelock==3.20.0                                                              [0m[2K[37m⠙[0m [2mhuggingface-hub==0.36.0                                                       [0m[2K[37m⠙[0m [2mnumpy==2.0.2                                                                  [0m[2K[37m⠙[0m [2mpackaging==25.0                                                               [0m[2K[37m⠙[0m [2mpyyaml==6.0.3    

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

%cd /content/drive/MyDrive/RL/multi-reward-medical-reasoning
!ls

#%cd /content/drive/MyDrive/Reinforcement_Learning/RL_Project/multi-reward-medical-reasoning
#!ls

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).
/content/drive/.shortcut-targets-by-id/1OhvtbvB42jF4ERuZvuDpJ8-HFw4kMXsD/multi-reward-medical-reasoning
backup.ipynb		      qwen3-1.7b-base_sft
base.ipynb		      qwen3-1.7b_grpo
grpo_trainer_lora_model       qwen3-1.7b_sft
huggingface_tokenizers_cache  rag_demo.ipynb
instruct.ipynb		      unsloth_compiled_cache
ppo_baselines		      unsloth_training_checkpoints
qwen3-1.7b-base_grpo	      wandb


In [None]:
from unsloth import FastLanguageModel
from vllm import SamplingParams

import gc
import re
import time
import threading
import numpy as np
import pandas as pd
import gradio as gr

import torch
import torch.nn.functional as F
from safetensors import safe_open
from datasets import load_dataset
from transformers import TextIteratorStreamer
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

# Model Setup

In [None]:
model_id = 'unsloth/Qwen3-1.7B-Base'          # unsloth/Qwen3-1.7B
model_name = model_id.split('/')[-1].lower()  # Extract model name from ID
max_seq_length = 2048                         # Can increase for longer reasoning traces
lora_rank = 32                                # Larger rank = smarter, but slower
lora_path = f'./{model_name}_grpo'            # Path to saved GRPO LoRA

In [None]:
print(lora_path)

./qwen3-1.7b-base_grpo


In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=lora_path,       # Reload GRPO weights
    max_seq_length=max_seq_length,
    load_in_4bit=False,         # False for LoRA 16bit
    fast_inference=True,        # Enable vLLM fast inference
    max_lora_rank=lora_rank,
    gpu_memory_utilization=0.8, # Reduce if out of memory
)
lora_request = model.load_lora(lora_path)
FastLanguageModel.for_inference(model)
embed_model = SentenceTransformer('all-MiniLM-L6-v2')

INFO 11-03 01:16:16 [vllm_utils.py:694] Unsloth: Patching vLLM v1 graph capture
INFO 11-03 01:16:16 [vllm_utils.py:722] Unsloth: Patching vLLM v0 graph capture
==((====))==  Unsloth 2025.10.12: Fast Qwen3 patching. Transformers: 4.56.2. vLLM: 0.10.2.
   \\   /|    NVIDIA A100-SXM4-80GB. Num GPUs = 1. Max memory: 79.318 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu126. CUDA: 8.0. CUDA Toolkit: 12.6. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.32.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/Qwen3-1.7B-Base with actual GPU utilization = 13.83%
Unsloth: Your GPU has CUDA compute capability 8.0 with VRAM = 79.32 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 2048. Num Sequences = 192.
Unsloth: vLLM's KV Cache can use up to 7.7 GB. Also swap space = 6 GB.
Unsloth: Not an error, but `device` is 

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 11-03 01:16:39 [default_loader.py:268] Loading weights took 1.19 seconds
INFO 11-03 01:16:41 [gpu_model_runner.py:2392] Model loading took 3.2841 GiB and 1.968943 seconds
INFO 11-03 01:16:52 [backends.py:539] Using cache directory: /root/.cache/vllm/torch_compile_cache/1d9678ad38/rank_0_0/backbone for vLLM's torch.compile
INFO 11-03 01:16:52 [backends.py:550] Dynamo bytecode transform time: 9.57 s
INFO 11-03 01:16:57 [backends.py:161] Directly load the compiled graph(s) for dynamic shape from the cache, took 4.169 s
INFO 11-03 01:16:59 [monitor.py:34] torch.compile takes 9.57 s in total
INFO 11-03 01:17:02 [gpu_worker.py:298] Available KV cache memory: 6.64 GiB
INFO 11-03 01:17:03 [kv_cache_utils.py:864] GPU KV cache size: 62,128 tokens
INFO 11-03 01:17:03 [kv_cache_utils.py:868] Maximum concurrency for 2,048 tokens per request: 30.34x
INFO 11-03 01:17:03 [vllm_utils.py:699] Unsloth: Running patched vLLM v1 `capture_model`.
INFO 11-03 01:17:03 [vllm_utils.py:699] Unsloth: Running 

Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|██████████| 51/51 [00:07<00:00,  6.86it/s]
Capturing CUDA graphs (decode, FULL): 100%|██████████| 27/27 [00:03<00:00,  7.00it/s]

INFO 11-03 01:17:14 [gpu_model_runner.py:3118] Graph capturing finished in 11 secs, took 0.67 GiB
INFO 11-03 01:17:14 [vllm_utils.py:706] Unsloth: Patched vLLM v1 graph capture finished in 11 secs.
INFO 11-03 01:17:14 [vllm_utils.py:706] Unsloth: Patched vLLM v1 graph capture finished in 11 secs.





INFO 11-03 01:17:16 [gpu_worker.py:391] Free memory on device (14.42/79.32 GiB) on startup. Desired GPU memory utilization is (0.13826704495174866, 10.97 GiB). Actual usage is 3.28 GiB for weight, 1.04 GiB for peak activation, 0.0 GiB for non-torch memory, and 0.67 GiB for CUDAGraph memory. Replace gpu_memory_utilization config with `--kv-cache-memory=6247239372` to fit into requested memory, or `--kv-cache-memory=9958752768` to fully utilize gpu memory. Current kv cache memory in use is 7125946060 bytes.
INFO 11-03 01:17:17 [core.py:218] init engine (profile, create kv cache, warmup model) took 35.43 seconds
INFO 11-03 01:17:18 [llm.py:295] Supported_tasks: ('generate',)
INFO 11-03 01:17:18 [__init__.py:36] No IOProcessor plugins requested by the model
Unsloth: Just some info: will skip parsing ['post_layernorm', 'q_norm', 'post_attention_layernorm', 'attention_norm', 'pre_feedforward_layernorm', 'layer_norm1', 'norm2', 'post_feedforward_layernorm', 'norm1', 'layer_norm2', 'k_norm', '

# Prepare the knowledge base


In [None]:
# Create a knowledge base by generating embeddings for the combined question-answer pairs from the dataset.
dataset = load_dataset('lavita/ChatDoctor-HealthCareMagic-100k', split='train').to_pandas()
dataset['combined'] = dataset['input'] + ' ' + dataset['output'] # Combine question and answer for context
embeddings = embed_model.encode(dataset['combined'].tolist(), show_progress_bar=True, batch_size=128)
dataset

Batches:   0%|          | 0/877 [00:00<?, ?it/s]

Unnamed: 0,instruction,input,output,combined
0,"If you are a doctor, please answer the medical...",I woke up this morning feeling the whole room ...,"Hi, Thank you for posting your query. The most...",I woke up this morning feeling the whole room ...
1,"If you are a doctor, please answer the medical...",My baby has been pooing 5-6 times a day for a ...,Hi... Thank you for consulting in Chat Doctor....,My baby has been pooing 5-6 times a day for a ...
2,"If you are a doctor, please answer the medical...","Hello, My husband is taking Oxycodone due to a...","Hello, and I hope I can help you today.First, ...","Hello, My husband is taking Oxycodone due to a..."
3,"If you are a doctor, please answer the medical...",lump under left nipple and stomach pain (male)...,HI. You have two different problems. The lump ...,lump under left nipple and stomach pain (male)...
4,"If you are a doctor, please answer the medical...",I have a 5 month old baby who is very congeste...,Thank you for using Chat Doctor. I would sugge...,I have a 5 month old baby who is very congeste...
...,...,...,...,...
112160,"If you are a doctor, please answer the medical...",im 25 years old ..i started using mtp kit on 5...,"Hello, Thanks for letting us know your health ...",im 25 years old ..i started using mtp kit on 5...
112161,"If you are a doctor, please answer the medical...",My 5 year old son has been coughing for a mont...,As you have mentioned in your history that you...,My 5 year old son has been coughing for a mont...
112162,"If you are a doctor, please answer the medical...",My toes on right foot more than left are numb ...,Hi. The numbness and blue discoloration could ...,My toes on right foot more than left are numb ...
112163,"If you are a doctor, please answer the medical...","I was diagnosis with pleurisy last Tuesday, an...",Thanks for your question on Chat Doctor. Treat...,"I was diagnosis with pleurisy last Tuesday, an..."


# Retrieval and Response Generation


In [None]:
def retrieve_relevant_contexts(query: str, k: int = 3) -> list:
    ''' Retrieves the k most relevant contexts to a given query using cosine similarity.

    Args:
        query (str): The user's medical query.
        k (int): The number of relevant contexts to retrieve.

    Returns:
        list: A list of dictionaries, each containing a relevant context.
    '''
    query_embedding = embed_model.encode([query])[0] # Generate query embedding
    similarities = cosine_similarity([query_embedding], embeddings)[0] # Calculate similarities
    top_k_indices = np.argsort(similarities)[-k:][::-1] # Get top k similar contexts
    return [{
        'question': dataset.iloc[idx]['input'],
        'answer': dataset.iloc[idx]['output'],
        'similarity': similarities[idx],
    } for idx in top_k_indices]

In [None]:
def generate_structured_response(query: str, contexts: list, max_completion_length=1024) -> str:
    ''' Generates a detailed response using the retrieved contexts.

    Args:
        query (str): The user's medical query.
        contexts (list): A list of relevant contexts.

    Returns:
        str: The generated response.
    '''
    context_prompt = '\n'.join([ # Prepare prompt with retrieved contexts
        f"Reference {i + 1}:\nQuestion: {ctx['question']}\nAnswer: {ctx['answer']}"
        for i, ctx in enumerate(contexts)
    ])
    prompt = f'''Based on the following references and your medical knowledge, provide a detailed response:

References:
{context_prompt}

Question: {query}

By considering:
1. The key medical concepts in the question.
2. How the reference cases relate to this question.
3. What medical principles should be applied.
4. Any potential complications or considerations.

Give the final response:
'''
    messages = [{'role': 'user', 'content': prompt}]
    inputs = tokenizer(
        tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False),
        return_tensors='pt',
    ).to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_completion_length,
            temperature=0.6,                # Balance creativity and consistency
            top_p=0.95,                     # Nucleus sampling for quality
            top_k=20,
            do_sample=True,                 # Enable sampling for varied reasoning paths
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1,         # Reduce repetitive reasoning steps
            length_penalty=1.0,             # Neutral preference for response length
            early_stopping=True,            # Stop at natural completion
            # streamer=TextStreamer(tokenizer, skip_prompt=True), # Remove streamer for direct output
        )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response.split('Give the final response:\n')[-1] # Extract the final response portion

# Putting It All Together

In [None]:
# query = "I've been experiencing persistent headaches and dizziness for the past week. What could be the cause?"
'''
query = "I'm 14 and have really bad acne. A doctor put me on antibiotics for the past 3 weeks but it doesn't seem to help.  What is causing it?"
print('Query:', query, '\n\n===== Relevant Contexts ===== ')

contexts = retrieve_relevant_contexts(query)
for i, ctx in enumerate(contexts, 1):
    print(f"Reference {i} (Similarity: {ctx['similarity']:.3f}):")
    print(f"Q: {ctx['question']}")
    print(f"A: {ctx['answer']}\n")

print('===== Generated Response =====')
response = generate_structured_response(query, contexts)
print(response)
'''

'\nquery = "I\'m 14 and have really bad acne. A doctor put me on antibiotics for the past 3 weeks but it doesn\'t seem to help.  What is causing it?"\nprint(\'Query:\', query, \'\n\n===== Relevant Contexts ===== \')\n\ncontexts = retrieve_relevant_contexts(query)\nfor i, ctx in enumerate(contexts, 1):\n    print(f"Reference {i} (Similarity: {ctx[\'similarity\']:.3f}):")\n    print(f"Q: {ctx[\'question\']}")\n    print(f"A: {ctx[\'answer\']}\n")\n\nprint(\'===== Generated Response =====\')\nresponse = generate_structured_response(query, contexts)\nprint(response)\n'

# Interface

In [None]:
# Define structured output format for mathematical reasoning
REASONING_START = '<THINK>' # Begin reasoning section
REASONING_END = '</THINK>'  # End reasoning section
ANSWER_START = '<ANSWER>'   # Begin final answer
ANSWER_END = '</ANSWER>'    # End final answer

# System prompt adapted for RAG + medical reasoning
SYSTEM_PROMPT = f'''You are a medical reasoning assistant. When given a medical problem and relevant references:
1. Show your step-by-step complex reasoning (including reflection, backtracking, alternative paths, and how references relate) between {REASONING_START} and {REASONING_END}.
2. Provide your final answer between {ANSWER_START} and {ANSWER_END}.
3. Be precise and show all deliberation steps clearly, considering the following:
- Medical aliases/synonyms.
- The key medical concepts in the question.
- How the reference cases relate to this question.
- What medical principles should be applied.
- Any potential complications or considerations.'''
print(SYSTEM_PROMPT)

You are a medical reasoning assistant. When given a medical problem and relevant references:
1. Show your step-by-step complex reasoning (including reflection, backtracking, alternative paths, and how references relate) between <THINK> and </THINK>.
2. Provide your final answer between <ANSWER> and </ANSWER>.
3. Be precise and show all deliberation steps clearly, considering the following:
- Medical aliases/synonyms.
- The key medical concepts in the question.
- How the reference cases relate to this question.
- What medical principles should be applied.
- Any potential complications or considerations.


In [None]:
class DrChat_Manager:
    def __init__(self, mode='stream'):
        self.query_history = []
        self.mode = mode

    def parse_response(self, full_response):
        # Parse full response for formatting
        think_match = re.search(r'\s*(.+?)\s*</THINK>', full_response, re.DOTALL | re.IGNORECASE)
        answer_match = re.search(r'<ANSWER>\s*(.+?)\s*</ANSWER>', full_response, re.DOTALL | re.IGNORECASE)
        think_text = think_match.group(1).strip() if think_match else "No explicit thinking section generated."
        answer_text = answer_match.group(1).strip() if answer_match else full_response.strip()
        return think_text, answer_text

    def chat_manager(self, message, history):
        pos_dir = ['y', 'yes', 'yep', 'ok']
        neg_dir = ['n', 'no', 'nope']

        if message.lower() in pos_dir:
            if self.query_history[-1]['role'] == 'assistant':
                think_text, _ = self.parse_response(self.query_history[-1]['content'])
                yield (think_text)
        elif message.lower() in neg_dir:
            yield "Ok, anything else you need help with?"
        else:
            # Process query
            self.query_history.append({'role': 'user', 'content': message})

            conversation = [{'role': 'system', 'content': SYSTEM_PROMPT}]
            conversation.append({'role': 'user', 'content': message})

            contexts = retrieve_relevant_contexts(message, k=3) # Retrieve contexts based on current message
            context_prompt = '\n'.join([ # Prepare prompt with retrieved contexts
                f"Reference {i + 1}:\nQuestion: {ctx['question']}\nAnswer: {ctx['answer']}\n"
                for i, ctx in enumerate(contexts)
            ])

            augmented_message = f'''\n\nBased on the following references and your medical knowledge, provide a detailed response:

                References:
                ```
                {context_prompt}
                ```
                Question: {message}
                '''

            conversation[-1]['content'] = augmented_message
            text = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) # Render into a single string and append <THINK> for generation
            inputs = tokenizer(text, return_tensors='pt').to(model.device)

            start_time = time.time()
            streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
            generation_kwargs = dict(
                inputs,
                max_new_tokens=1024,
                temperature=0.6,                # Balance creativity and consistency
                top_p=0.95,                     # Nucleus sampling for quality
                top_k=20,
                do_sample=True,                 # Enable sampling for varied reasoning paths
                pad_token_id=tokenizer.eos_token_id,
                repetition_penalty=1.1,         # Reduce repetitive reasoning steps
                length_penalty=1.0,             # Neutral preference for response length
                early_stopping=True,            # Stop at natural completion
                streamer=streamer,
            )

            start_time = time.time()
            thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
            thread.start()

            full_response = ""
            for new_text in streamer:
                full_response += new_text
                if self.mode == 'stream':
                    yield full_response # Stream partial response in UI

            end_time = time.time()
            thinking_time = end_time - start_time

            self.query_history.append({'role': 'assistant', 'content': full_response})

            think_text, answer_text = self.parse_response(full_response)

            if self.mode == 'stream':
                # Format with collapsible HTML
                formatted_response = (
                    f'<details><summary>Contexts</summary>'
                    f'<p>{context_prompt}</p>'
                    f'</details>\n\n'
                    f'<details><summary>Thinking (took {thinking_time:.2f} seconds)</summary>'
                    f'<p>{think_text}</p>'
                    f'</details>\n\n'
                    f'Final diagnosis\n{answer_text}'
                )

                yield formatted_response
            else:
                yield (answer_text + '\nWould you like the rational for the diagnosis?')



In [None]:
## DoctorRL chatbot interface

drRL = DrChat_Manager(mode='stream')

example_qs = ['I have pain in my chest and left arm and I feel dizzy.',
              'I have blocked nose, cough, slightly high temperature and am finding it difficult to eat.',
              'I have a fever, stiff neck and sore eyes from bright lights.',
              'I have a buring pain in the back of my right leg and it feels weak.'
              ]

chat_ui = gr.ChatInterface(
    fn=drRL.chat_manager,
    type="messages",
    textbox=gr.Textbox(placeholder='What seems to be the problem, please describe your symptoms', container=False, scale=7),
    title='Doctor RL',
    description='Doctor RL is for academic purposes only and may make errors.  See a qualified medical professional for health advice.',
    examples=example_qs,
    chatbot=gr.Chatbot(placeholder="<strong>Welcome to Doctor RL.  What are your symptoms?", height=650),
    flagging_mode="manual",
    flagging_options=["Like", "Wrong", "Inappropriate"],
    save_history=True
)

chat_ui.launch(share=True, debug=True)



  chatbot=gr.Chatbot(placeholder="<strong>Welcome to Doctor RL.  What are your symptoms?", height=650),


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://73f8f3023ef7d1e79f.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://73f8f3023ef7d1e79f.gradio.live


