## multi LLM이 동시에 같은 작업、 잘 안 됨

In [None]:
from typing import List, Dict, Any, TypedDict, Annotated
from langgraph.graph import StateGraph
import asyncio
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

In [3]:
# Define your state structure
class State(TypedDict):
    prompt_index: int
    prompts: List[Dict[str, Any]]
    results: Dict[str, List[Any]]
    current_prompt: Dict[str, Any]

In [4]:
# Functions to invoke each model
async def invoke_gemma3(state: State):
    current_prompt = state["current_prompt"]
    system_prompt = current_prompt.get("system_prompt", "")
    user_input = current_prompt.get("input", "")
    
    # Set up and invoke Gemma 3 model (adjust parameters as needed)
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
    model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b", torch_dtype=torch.bfloat16)
    
    # Format input according to Gemma's requirements
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_input}
    ]
    formatted_input = tokenizer.apply_chat_template(messages, tokenize=False)
    
    # Generate response
    inputs = tokenizer(formatted_input, return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, max_new_tokens=512)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Add to results
    return {"results": {**state["results"], "gemma3": state["results"].get("gemma3", []) + [response]}}

async def invoke_llama3(state: State):
    current_prompt = state["current_prompt"]
    system_prompt = current_prompt.get("system_prompt", "")
    user_input = current_prompt.get("input", "")
    
    # Set up and invoke Llama 3.1 model (adjust parameters as needed)
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-8B")
    model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3-8B", torch_dtype=torch.bfloat16)
    
    # Format input according to Llama's requirements
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_input}
    ]
    formatted_input = tokenizer.apply_chat_template(messages, tokenize=False)
    
    # Generate response
    inputs = tokenizer(formatted_input, return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, max_new_tokens=512)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Add to results
    return {"results": {**state["results"], "llama3": state["results"].get("llama3", []) + [response]}}

async def invoke_deepseek(state: State):
    current_prompt = state["current_prompt"]
    system_prompt = current_prompt.get("system_prompt", "")
    user_input = current_prompt.get("input", "")
    
    # Set up and invoke DeepSeek-R1 model (adjust parameters as needed)
    tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/deepseek-coder-7b-base")
    model = AutoModelForCausalLM.from_pretrained("deepseek-ai/deepseek-coder-7b-base", torch_dtype=torch.bfloat16)
    
    # Format input according to DeepSeek's requirements
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_input}
    ]
    formatted_input = tokenizer.apply_chat_template(messages, tokenize=False)
    
    # Generate response
    inputs = tokenizer(formatted_input, return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, max_new_tokens=512)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Add to results
    return {"results": {**state["results"], "deepseek": state["results"].get("deepseek", []) + [response]}}


In [None]:
# Define a function to select the next prompt or end
def should_continue(state: State) -> str:
    if state["prompt_index"] >= len(state["prompts"]) - 1:
        return "end"
    else:
        return "continue"

# Define a function to prepare the next prompt
def prepare_next_prompt(state: State) -> State:
    next_index = state["prompt_index"] + 1
    next_prompt = state["prompts"][next_index]
    
    return {
        **state,
        "prompt_index": next_index,
        "current_prompt": next_prompt
    }

## 다시 생성한 코드

In [1]:
from langchain.chat_models import init_chat_model
from langgraph.graph import StateGraph, END
from typing import TypedDict, List, Dict, Any
import json


In [2]:
# Define the models

o3_mini = init_chat_model("openai:o3-mini")
claude_sonnet = init_chat_model("anthropic:claude-3-5-sonnet-latest", temperature=0)
gemma3 = init_chat_model("ollama:gemma3:12b", temperature=0)
llama3_1 = init_chat_model("ollama:llama3.1:latest", temperature=0)
deepseek_r1 = init_chat_model("ollama:deepseek-r1:8b", temperature=0)

In [3]:

file_path = "biggen_bench_instruction_idx0.json"

def load_benchmark_data(file_path: str) -> List[Dict]:
    """Load benchmark data from a JSON file."""
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data

def prepare_prompts(benchmark_data: List[Dict]) -> List[Dict]:
    """Prepare prompts by excluding reference_answer and score_rubric."""
    prompts = []
    for item in benchmark_data:
        prompt = {
            "id": item["id"],
            "capability": item["capability"],
            "task": item["task"],
            "instance_idx": item["instance_idx"],
            "system_prompt": item["system_prompt"],
            "input": item["input"],
            # Exclude reference_answer and score_rubric
        }
        prompts.append(prompt)
    return prompts

def prepare_rubric(benchmark_data: List[Dict]) -> List[Dict]:
    """Prepare rubric including reference_answer and score_rubric."""
    rubric = []
    for item in benchmark_data:
        prompt = {
            "id": item["id"],
            "reference_answer": item["reference_answer"],
            "score_rubric": item["score_rubric"]
        }
        rubric.append(prompt)
    return rubric

In [4]:
benchmark_data = load_benchmark_data(file_path)
prompts = prepare_prompts(benchmark_data)
rubric = prepare_rubric(benchmark_data)

In [8]:

import asyncio

# Define the state structure
class State(TypedDict):
    prompts: List[Dict[str, Any]]
    processed_count: int
    gemma3_results: List[Dict[str, Any]]
    llama3_1_results: List[Dict[str, Any]]
    deepseek_r1_results: List[Dict[str, Any]]

# Define model processing functions for each model
def process_gemma3(state: State) -> State:
    model = gemma3  # Your initialized gemma3 model
    model_name = "gemma3"
    
    if state["processed_count"] >= len(state["prompts"]):
        return state
    
    prompt = state["prompts"][state["processed_count"]]
    system_prompt = prompt.get('system_prompt', '')
    user_input = prompt.get('input', '')
    
    try:
        response = model.invoke(
            user_input,
            config={"system_prompt": system_prompt}
        )
        
        result = {
            "id": prompt.get('id', ''),
            "model_name": model_name,
            "response": response
        }
    except Exception as e:
        result = {
            "id": prompt.get('id', ''),
            "model_name": model_name,
            "error": str(e)
        }
    
    return {
        **state,
        "gemma3_results": state["gemma3_results"] + [result]
    }

def process_llama3_1(state: State) -> State:
    model = llama3_1  # Your initialized llama3_1 model
    model_name = "llama3_1"
    
    if state["processed_count"] >= len(state["prompts"]):
        return state
    
    prompt = state["prompts"][state["processed_count"]]
    system_prompt = prompt.get('system_prompt', '')
    user_input = prompt.get('input', '')
    
    try:
        response = model.invoke(
            user_input,
            config={"system_prompt": system_prompt}
        )
        
        result = {
            "id": prompt.get('id', ''),
            "model_name": model_name,
            "response": response
        }
    except Exception as e:
        result = {
            "id": prompt.get('id', ''),
            "model_name": model_name,
            "error": str(e)
        }
    
    return {
        **state,
        "llama3_1_results": state["llama3_1_results"] + [result]
    }

def process_deepseek_r1(state: State) -> State:
    model = deepseek_r1  # Your initialized deepseek_r1 model
    model_name = "deepseek_r1"
    
    if state["processed_count"] >= len(state["prompts"]):
        return state
    
    prompt = state["prompts"][state["processed_count"]]
    system_prompt = prompt.get('system_prompt', '')
    user_input = prompt.get('input', '')
    
    try:
        response = model.invoke(
            user_input,
            config={"system_prompt": system_prompt}
        )
        
        result = {
            "id": prompt.get('id', ''),
            "model_name": model_name,
            "response": response
        }
    except Exception as e:
        result = {
            "id": prompt.get('id', ''),
            "model_name": model_name,
            "error": str(e)
        }
    
    return {
        **state,
        "deepseek_r1_results": state["deepseek_r1_results"] + [result]
    }

# Define a function to merge results and increment counter
def merge_and_continue(states: List[Dict]) -> Dict:
    # Start with the first state
    base_state = states[0]
    
    # Update with results from other states
    for state in states[1:]:
        if "gemma3_results" in state and state["gemma3_results"]:
            base_state["gemma3_results"] = state["gemma3_results"]
        if "llama3_1_results" in state and state["llama3_1_results"]:
            base_state["llama3_1_results"] = state["llama3_1_results"]
        if "deepseek_r1_results" in state and state["deepseek_r1_results"]:
            base_state["deepseek_r1_results"] = state["deepseek_r1_results"]
    
    # Increment the processed count
    return {
        **base_state,
        "processed_count": base_state["processed_count"] + 1
    }

# Define a function to check if we're done processing
def check_completion(state: State) -> str:
    if state["processed_count"] >= len(state["prompts"]):
        return "end"
    else:
        return "continue"

# Set up the graph
def create_parallel_processing_graph():
    workflow = StateGraph(State)
    
    # Add the nodes
    workflow.add_node("process_gemma3", process_gemma3)
    workflow.add_node("process_llama3_1", process_llama3_1)
    workflow.add_node("process_deepseek_r1", process_deepseek_r1)
    
    # Set up branching
    workflow.add_edge("start", ("process_gemma3", "process_llama3_1", "process_deepseek_r1"))
    workflow.add_edge(("process_gemma3", "process_llama3_1", "process_deepseek_r1"), merge_and_continue)
    
    # Set up conditional continuation
    workflow.add_conditional_edges(
        merge_and_continue,
        check_completion,
        {
            "continue": ("process_gemma3", "process_llama3_1", "process_deepseek_r1"),
            "end": END
        }
    )
    
    # Compile the graph
    return workflow.compile()

# Execute the graph
def process_all_prompts_parallel(prompts):
    # Create initial state
    initial_state = {
        "prompts": prompts,
        "processed_count": 0,
        "gemma3_results": [],
        "llama3_1_results": [],
        "deepseek_r1_results": []
    }
    
    # Create and run the graph
    graph = create_parallel_processing_graph()
    final_state = graph.invoke(initial_state)
    
    # Return all results
    return {
        "gemma3_results": final_state["gemma3_results"],
        "llama3_1_results": final_state["llama3_1_results"],
        "deepseek_r1_results": final_state["deepseek_r1_results"]
    }

In [None]:
from langgraph.graph import StateGraph, END
from typing import TypedDict, List, Dict, Any

def create_model_processor(model, model_name):
    """Create a processing function for the given model."""
    def process_model(state: State) -> State:
        if state["processed_count"] >= len(state["prompts"]):
            return state
        
        prompt = state["prompts"][state["processed_count"]]
        system_prompt = prompt.get('system_prompt', '')
        user_input = prompt.get('input', '')
        
        try:
            response = model.invoke(
                user_input,
                config={"system_prompt": system_prompt}
            )
            
            result = {
                "id": prompt.get('id', ''),
                "model_name": model_name,
                "response": response
            }
        except Exception as e:
            result = {
                "id": prompt.get('id', ''),
                "model_name": model_name,
                "error": str(e)
            }
        
        # Update the results for this specific model
        results_key = f"{model_name}_results"
        return {
            **state,
            results_key: state.get(results_key, []) + [result]
        }
    
    return process_model

# Then use it like this:
process_gemma3 = create_model_processor(gemma3, "gemma3")
process_llama3_1 = create_model_processor(llama3_1, "llama3_1")
process_deepseek_r1 = create_model_processor(deepseek_r1, "deepseek_r1")

In [9]:
import asyncio

# Define the state structure
class State(TypedDict):
    prompts: List[Dict[str, Any]]
    processed_count: int
    gemma3_results: List[Dict[str, Any]]
    llama3_1_results: List[Dict[str, Any]]
    deepseek_r1_results: List[Dict[str, Any]]

In [10]:

def create_model_processor(model, model_name):
    def process_model(state: State) -> State:
        results = []
        
        for prompt in state["prompts"]:
            system_prompt = prompt.get('system_prompt', '')
            user_input = prompt.get('input', '')
            
            try:
                response = model.invoke(
                    user_input,
                    config={"system_prompt": system_prompt}
                )
                
                result = {
                    "id": prompt.get('id', ''),
                    "model_name": model_name,
                    "response": response
                }
            except Exception as e:
                result = {
                    "id": prompt.get('id', ''),
                    "model_name": model_name,
                    "error": str(e)
                }
                
            results.append(result)
        
        results_key = f"{model_name}_results"
        return {
            **state,
            results_key: results
        }
    
    return process_model

In [12]:
process_gemma3 = create_model_processor(gemma3, "gemma3")
process_llama3_1 = create_model_processor(llama3_1, "llama3_1")
process_deepseek_r1 = create_model_processor(deepseek_r1, "deepseek_r1")

In [13]:
process_gemma3

<function __main__.create_model_processor.<locals>.process_model(state: __main__.State) -> __main__.State>

In [14]:

# 초기 상태 설정
initial_state = {
    "prompts": prompts,
    "processed_count": 0,
    "gemma3_results": [],
    "llama3_1_results": [],
    "deepseek_r1_results": []
}

# 그래프 생성
workflow = StateGraph(State)
workflow.add_node("process_gemma3", process_gemma3)
workflow.add_node("process_llama3", process_llama3_1)
workflow.add_node("process_deepseek", process_deepseek_r1)

# 노드 연결
workflow.set_entry_point("process_gemma3")
workflow.add_edge("process_gemma3", "process_llama3")
workflow.add_edge("process_llama3", "process_deepseek")

# 그래프 컴파일
app = workflow.compile()

# 실행
result = app.invoke(initial_state)

In [23]:
type(result)

str

In [25]:
with open("results_branching.txt", "w", encoding="utf-8") as f:
    f.write(result)


In [22]:
# result가 문자열인 경우 처리
serializable_results = []

# result가 단일 문자열인 경우
if isinstance(result, str):
    serializable_results.append({
        "id": "",
        "model_name": "",
        "response_content": result
    })
# result가 리스트인 경우
elif isinstance(result, list):
    for item in result:
        if isinstance(item, dict):
            # 딕셔너리인 경우
            serialized_result = {
                "id": item.get("id", ""),
                "model_name": item.get("model_name", ""),
                "response_content": item["response"].content if hasattr(item.get("response", ""), "content") else str(item.get("response", ""))
            }
        else:
            # 딕셔너리가 아닌 경우
            serialized_result = {
                "id": "",
                "model_name": "",
                "response_content": str(item)
            }
        serializable_results.append(serialized_result)
# result가 딕셔너리인 경우
elif isinstance(result, dict):
    for key, value in result.items():
        serializable_results.append({
            "id": "",
            "model_name": key,
            "response_content": str(value)
        })

# JSON으로 저장
with open("results_branching.json", "w", encoding="utf-8") as f:
    json.dump(serializable_results, f, ensure_ascii=False, indent=2)