In [1]:
pip install torch transformers

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
from typing import List, Dict
import random

# Initialize the Llama-3.2-3B-Instruct model
model_name = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Simple medical knowledge base for simulation
MEDICAL_KNOWLEDGE = {
    "fever": {
        "symptoms": ["high temperature", "sweating", "headache"],
        "tests": ["temperature", "blood test"],
        "diagnosis": "Possible viral infection"
    },
    "chest_pain": {
        "symptoms": ["chest discomfort", "shortness of breath"],
        "tests": ["ECG", "chest X-ray"],
        "diagnosis": "Possible cardiac issue"
    }
}

class PatientAgent:
    def __init__(self):
        self.condition = random.choice(list(MEDICAL_KNOWLEDGE.keys()))
        self.symptoms = MEDICAL_KNOWLEDGE[self.condition]["symptoms"]

    def respond(self, question: str) -> str:
        prompt = f"Patient with {self.condition}. Question from doctor: {question}\nRespond as the patient would:"
        inputs = tokenizer(prompt, return_tensors="pt")
        outputs = model.generate(**inputs, max_length=100, temperature=0.7)
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response

class MeasurementAgent:
    def perform_test(self, test_name: str, condition: str) -> str:
        if test_name in MEDICAL_KNOWLEDGE[condition]["tests"]:
            return f"{test_name} result: Abnormal findings consistent with {condition}"
        return f"{test_name} result: Normal"

class DoctorAgent:
    def __init__(self):
        self.memory = []  # Experience Records Buffer
        self.medical_records = []  # Medical Records Buffer

    def ask_question(self, patient_response: str) -> str:
        prompt = f"""Doctor in a clinical setting. Patient response: "{patient_response}"
        Based on this, ask a relevant follow-up question to gather more information:"""
        inputs = tokenizer(prompt, return_tensors="pt")
        outputs = model.generate(**inputs, max_length=500, temperature=0.7)
        return tokenizer.decode(outputs[0], skip_special_tokens=True)

    def request_test(self, symptoms: str) -> str:
        prompt = f"""Doctor analyzing symptoms: "{symptoms}"
        Decide which medical test to request (e.g., temperature, ECG, X-ray):"""
        inputs = tokenizer(prompt, return_tensors="pt")
        outputs = model.generate(**inputs, max_length=30, temperature=0.7)
        return tokenizer.decode(outputs[0], skip_special_tokens=True)

    def make_diagnosis(self, conversation: List[Dict], test_results: List[str]) -> str:
        # Simple chain-of-thought reasoning
        context = "\n".join([f"{turn['role']}: {turn['content']}" for turn in conversation])
        prompt = f"""Doctor analyzing conversation:
        {context}
        Test results: {', '.join(test_results)}

        Step 1: Review symptoms and test results
        Step 2: Consider possible conditions
        Step 3: Make a diagnosis

        Provide the diagnosis:"""
        inputs = tokenizer(prompt, return_tensors="pt")
        outputs = model.generate(**inputs, max_length=100, temperature=0.7)
        diagnosis = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Store in memory
        self.memory.append({
            "conversation": conversation,
            "test_results": test_results,
            "diagnosis": diagnosis
        })
        return diagnosis

class MedAgentSim:
    def __init__(self):
        self.doctor = DoctorAgent()
        self.patient = PatientAgent()
        self.measurement = MeasurementAgent()
        self.conversation = []

    def run_simulation(self, num_turns: int = 3):
        print(f"Starting simulation. Patient condition: {self.patient.condition}")

        # Initial patient statement
        initial_response = self.patient.respond("What seems to be the problem?")
        self.conversation.append({"role": "patient", "content": initial_response})
        print(f"Patient: {initial_response}")

        # Conversation phase
        for _ in range(num_turns):
            doctor_question = self.doctor.ask_question(initial_response)
            self.conversation.append({"role": "doctor", "content": doctor_question})
            print(f"Doctor: {doctor_question}")

            patient_response = self.patient.respond(doctor_question)
            self.conversation.append({"role": "patient", "content": patient_response})
            print(f"Patient: {patient_response}")

            initial_response = patient_response  # Update for next turn

        # Test request phase
        test_name = self.doctor.request_test(" ".join(self.patient.symptoms))
        print(f"Doctor requests test: {test_name}")
        test_result = self.measurement.perform_test(test_name, self.patient.condition)
        print(f"Measurement Agent: {test_result}")

        # Diagnosis phase
        diagnosis = self.doctor.make_diagnosis(self.conversation, [test_result])
        print(f"Doctor's Diagnosis: {diagnosis}")

# Run the simulation
if __name__ == "__main__":
    sim = MedAgentSim()
    sim.run_simulation()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Starting simulation. Patient condition: chest_pain


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Patient: Patient with chest_pain. Question from doctor: What seems to be the problem?
Respond as the patient would: I've been having chest pain for a few days now, it's been getting worse. At first, I thought it was just indigestion, but then it started to feel like a squeezing sensation in my chest, and now it's more of a sharp pain. It's usually when I exert myself, like climbing stairs or doing yard work, but it's also been happening
