In [None]:
import os
import json
import argparse

from tot.tasks import get_task
from tot.methods.bfs import solve, naive_solve
from tot.models import usage

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

from logger_config import logger

def run(args):
    task = get_task(args.task)
    logs, cnt_correct = [], 0
    if args.naive_run:
        file = f'./logs/{args.task}/{args.backend}_{args.temperature}_naive_sample_{args.prompt_sample}_{args.n_generate_sample}_apply_skills_{args.apply_skills}_decompose_problem_{args.decompose_problem}_start{args.task_start_index}_end{args.task_end_index}.json'
    else:
        file = f'./logs/{args.task}/{args.backend}_{args.temperature}_{args.n_generate_sample}_{args.n_evaluate_sample}_{args.method_select}_{args.n_select_sample}_apply_skills_{args.apply_skills}_decompose_problem_{args.decompose_problem}_start{args.task_start_index}_end{args.task_end_index}_retry.json'
    os.makedirs(os.path.dirname(file), exist_ok=True)

    global tokenizer, model
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,  # Set to True for 4-bit quantization
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type='nf4',
        bnb_4bit_compute_dtype=torch.float16  # You can also try torch.bfloat16
    )
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    name = args.backend
    if args.backend == "Llama-3.1-8B-Instruct":
        print("Loading Llama-3.1-8B-Instruct")
        tokenizer = AutoTokenizer.from_pretrained(f'/scratch/gpfs/jx0800/Llama-3.1-8B-Instruct')
        model = AutoModelForCausalLM.from_pretrained(f'/scratch/gpfs/jx0800/Llama-3.1-8B-Instruct', quantization_config=bnb_config)
    elif args.backend == "Llama-3.2-3B-Instruct":
        print("Loading Llama-3.2-3B-Instruct")
        tokenizer = AutoTokenizer.from_pretrained('/scratch/gpfs/jx0800/Llama-3.2-3B-Instruct')
        model = AutoModelForCausalLM.from_pretrained('/scratch/gpfs/jx0800/Llama-3.2-3B-Instruct').to(device)
    elif args.backend == "Qwen2.5-1.5B-Instruct":
        print("Loading Qwen2.5-1.5B-Instruct")
        tokenizer = AutoTokenizer.from_pretrained("/scratch/gpfs/jx0800/Qwen2.5-1.5B-Instruct")
        model = AutoModelForCausalLM.from_pretrained("/scratch/gpfs/jx0800/Qwen2.5-1.5B-Instruct").to(device)
    
    model_answers = []
    for i in range(args.task_start_index, args.task_end_index):
        # solve
        if args.naive_run:
            ys, info = naive_solve(model, tokenizer, name, args, task, i) 
        else:
            logger.info("solve start")
            logger.info(f"Task {i}")
            ys, info = solve(model, tokenizer, name, args, task, i)

        # log
        for y in ys:
            model_answer = task.extract_from_text(y, ['Answer:'])
            if not model_answer:
                model_answer = y
            model_answers.append(model_answer)
        logger.info(f"Task {i} done")

    with open(file, 'w') as f:
        json.dump(model_answers, f, indent=4)

    logger.info(f"Task {args.task_start_index} to {args.task_end_index} done")   
    #     infos = [task.test_output(i, y, args.backend) for y in ys]

    #     # log main metric
    #     accs = [_['r'] for _ in infos]
        
    #     # log main metric
    #     cnt_correct += sum(accs) / len(accs)
    #     cur_acc = cnt_correct / (i - args.task_start_index + 1)
    #     print('current accuracy: ', cur_acc)
    #     if args.backend == 'o1-mini' or args.backend == 'gpt-4o':
    #         info.update({'idx': i, 'ys': ys, 'infos': infos, 'usage_so_far': usage(args.backend), 'current accuracy': cur_acc})
    #     else:
    #         info.update({'idx': i, 'ys': ys, 'infos': infos, 'current accuracy': cur_acc})
    #     logs.append(info)
    #     with open(file, 'w') as f:
    #         json.dump(logs, f, indent=4)
    
    # n = args.task_end_index - args.task_start_index
    # print(cnt_correct / n)
    # if args.backend == 'o1-mini' or args.backend == 'gpt-4o':
    #     print('usage_so_far', usage(args.backend))

def parse_args():
    args = argparse.ArgumentParser()
    args.add_argument('--backend', type=str, choices=['o1-mini', 'gpt-4o', 'Llama-3.1-8B-Instruct', 'Llama-3.2-3B-Instruct', 'Qwen2.5-1.5B-Instruct', 'gpt-4o-mini'], default='gpt-4o-mini')
    args.add_argument('--temperature', type=float, default=0.7) # only used for proposal; for value prompt, temperature is set as 0.1

    args.add_argument('--task', type=str, required=True, choices=['MATH', "MATH2"])
    args.add_argument('--task_start_index', type=int, default=0)
    args.add_argument('--task_end_index', type=int, default=100)

    args.add_argument('--naive_run', action='store_true')
    args.add_argument('--prompt_sample', type=str, choices=['standard', 'cot'])  # only used when method_generate = sample, or naive_run

    args.add_argument('--method_select', type=str, choices=['sample', 'greedy'], default='greedy')
    args.add_argument('--apply_skills', action='store_true')
    args.add_argument('--decompose_problem', action='store_true') # haven't implemented the case where apply_skills and decompose_problem are simultaneously true. only used for math^2
    args.add_argument('--n_generate_sample', type=int, default=1) 
    args.add_argument('--n_evaluate_sample', type=int, default=1)
    args.add_argument('--n_select_sample', type=int, default=1)

    args = args.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    print(args)
    run(args)

## Semmed graph

In [64]:
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd
import csv

# Read the CSV file
graph_path = './semmed/graph_data_whole.csv'

# Initialize an empty graph
G = nx.MultiDiGraph()  # Use DiGraph for a directed graph. Use Graph() for an undirected graph.

def get_all_rel(G):
    rel_set = set()
    
    # Check if the graph is a MultiGraph (supports multiple edges between nodes)
    if isinstance(G, (nx.MultiGraph, nx.MultiDiGraph)):
        for u, v, key, data in G.edges(data=True, keys=True):
            if 'rel' in data:
                rel_set.add(data['rel'])
    else:
        for u, v, data in G.edges(data=True):
            if 'rel' in data:
                rel_set.add(data['rel'])
                
    return rel_set

with open(graph_path, 'r') as file:
    reader = csv.reader(file)
    for line in reader:
        head = line[0]
        # Iterate over the subsequent columns in pairs (relation, tail)
        for i in range(1, len(line), 2):
            if pd.isna(line[i]) or pd.isna(line[i+1]):
                break
            relation = line[i]
            tail = line[i+1]
            # Add an edge to the graph with the relation as an edge attribute
            G.add_edge(head, tail, rel=relation)


In [42]:
print(f"Number of nodes: {G.number_of_nodes()}")
print(f"Number of edges: {G.number_of_edges()}")
print(f"number of relations: {len(get_all_rel(G))}")

Number of nodes: 311707
Number of edges: 21193134
number of relations: 32


In [65]:
import re
nodes_to_remove = []
for node in G.nodes():
    if '"' in node:#len(node)==2 and not re.search(r'\d', node):
        nodes_to_remove.append(node)
        print(node)
        # print(f"****************{node}*****************")
        # # for tail in G[node]:
        # #     for key, data in G[node][tail].items():
        # #         print(node, data['relation'], tail)
        # for root in G.predecessors(node):
        #     for key, data in G[root][node].items():
        #         print(root, data['relation'], node)
        # print('****************************************')
        # for tail in G[node]:
        #     for key, data in G[node][tail].items():
        #         print(node, data['relation'], tail)
        # print()
for node in nodes_to_remove:
    G.remove_node(node)

#print(f"Number of nodes: {G.number_of_nodes()}")
#print(f"Number of edges: {G.number_of_edges()}")
#print(f"number of relations: {len(get_all_rel(G))}")

In [66]:
filename = './semmed/overlap.csv'
with open(filename, 'w') as f:
    f.write(f"head,relation,tail\n")
    for node1, node2, key, data in list(G.edges(keys=True, data=True)):
        # Check if node1 overlaps with node2 (for example, if they are the same string)
        if set(node1.split()) & set(node2.split()):
            #print(f"{node1}, {data['relation']}, {node2}")
            f.write(f"{node1},{data['rel']},{node2}\n")
            # Remove all edges between the two nodes
            G.remove_edge(node1, node2, key)
            
nodes_to_remove = [node for node in G.nodes() if G.degree(node) == 0]
G.remove_nodes_from(nodes_to_remove)

print(f"Number of nodes: {G.number_of_nodes()}")
print(f"Number of edges: {G.number_of_edges()}")
print(f"number of relations: {len(get_all_rel(G))}")

Number of nodes: 309045
Number of edges: 20610169
number of relations: 32


In [39]:
print(f"number of relations: {get_all_rel(G)}")

number of relations: {'process of', 'part of', 'conceptually related to', 'associated with', 'measurement of', 'conceptual part of', 'property of', 'analyzes', 'interacts with', 'ingredient of', 'disrupts', 'brings about', 'location of', 'temporally related to', 'method of', 'uses', 'physically related to', 'carries out', 'occurs in', 'treats', 'causes', 'measures', 'isa', 'functionally related to', 'precedes', 'affects'}


In [48]:
def save_to_csv(G, save_path):
    data_for_csv = []
    for node in G.nodes():
        if G[node]:
            row = [node]  # Start row with the node (Head)
            for neighbor in sorted(map(str, G[node])):  # Sort neighbors to maintain consistent order
                if isinstance(G, nx.MultiDiGraph):
                    for key in G[node][neighbor]:   
                        rel = G[node][neighbor][key]['rel']
                        row.extend([rel, neighbor])
                elif isinstance(G, nx.DiGraph):
                    rel = G[node][neighbor]['rel']
                    row.extend([rel, neighbor])
                elif isinstance(G, nx.Graph) or isinstance(G, nx.MultiGraph):
                    if node < neighbor:
                        rel = G[node][neighbor]['rel']
                        row.extend([rel, neighbor])
            
            data_for_csv.append(row)

    import csv
    # Writing to the CSV file
    with open(save_path, 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerows(data_for_csv)
    print(f'{save_path} saved')

save_to_csv(G, './semmed/graph_data_whole_cleaned.csv')

./semmed/graph_data_whole_cleaned.csv saved


## PubMed graph

In [67]:
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd
import csv

# Read the CSV file
graph_path = 'graph_data_umls.csv'

# Initialize an empty graph
G = nx.MultiDiGraph()  # Use DiGraph for a directed graph. Use Graph() for an undirected graph.

def get_all_rel(G):
    rel_set = set()
    
    # Check if the graph is a MultiGraph (supports multiple edges between nodes)
    if isinstance(G, (nx.MultiGraph, nx.MultiDiGraph)):
        for u, v, key, data in G.edges(data=True, keys=True):
            if 'rel' in data:
                rel_set.add(data['rel'])
    else:
        for u, v, data in G.edges(data=True):
            if 'rel' in data:
                rel_set.add(data['rel'])
                
    return rel_set

with open(graph_path, 'r') as file:
    reader = csv.reader(file)
    for line in reader:
        head = line[0]
        # Iterate over the subsequent columns in pairs (relation, tail)
        for i in range(1, len(line), 2):
            if pd.isna(line[i]) or pd.isna(line[i+1]):
                break
            relation = line[i]
            tail = line[i+1]
            # Add an edge to the graph with the relation as an edge attribute
            G.add_edge(head, tail, rel=relation)


In [68]:
print(f"Number of nodes: {G.number_of_nodes()}")
print(f"Number of edges: {G.number_of_edges()}")
print(f"number of relations: {len(get_all_rel(G))}")

Number of nodes: 220174
Number of edges: 801265
number of relations: 27


In [69]:
import re
nodes_to_remove = []
for node in G.nodes():
    if ',' in node:#len(node)==2 and not re.search(r'\d', node):
        nodes_to_remove.append(node)
        print(node)
        # print(f"****************{node}*****************")
        # # for tail in G[node]:
        # #     for key, data in G[node][tail].items():
        # #         print(node, data['relation'], tail)
        # for root in G.predecessors(node):
        #     for key, data in G[root][node].items():
        #         print(root, data['relation'], node)
        # print('****************************************')
        # for tail in G[node]:
        #     for key, data in G[node][tail].items():
        #         print(node, data['relation'], tail)
        # print()
for node in nodes_to_remove:
    G.remove_node(node)

#print(f"Number of nodes: {G.number_of_nodes()}")
#print(f"Number of edges: {G.number_of_edges()}")
#print(f"number of relations: {len(get_all_rel(G))}")

aspiration of trachea, percutaneous
pulmonary valve prosthesis, device
education, guidance and counseling
posterior vestibuloplasty, bilateral
diagnostic radiography with contrast media, bilateral
roentgenography, negative contrast
positive contrast bronchography, bilateral
venography of adrenal, bilateral
radionuclide venogram, unilateral
adult respiratory distress syndrome, ctcae
anaphylaxis, ctcae
bronchospasm, ctcae
disseminated intravascular coagulation, ctcae
fever, ctcae
heart failure, ctcae
herpesvirus 5, human
hyperkalemia, ctcae
hypomagnesemia, ctcae
hypotension, ctcae
myalgia, ctcae 5
phlebitis, ctcae
pruritus, ctcae
urticaria, ctcae
millard operation, cleft lip repair
thompson operation, cleft lip repair
illegal termination of pregnancy, incomplete
personal periodontal care, plaque control education
functional endoscopic sinus surgery, limited
functional endoscopic sinus surgery, total
injection of neurolytic substance, subarachnoid
nerve blocks, infiltrations and injection

In [70]:
print(len(nodes_to_remove))

4976


In [71]:
filename = 'overlap.csv'
with open(filename, 'w') as f:
    f.write(f"head,relation,tail\n")
    for node1, node2, key, data in list(G.edges(keys=True, data=True)):
        # Check if node1 overlaps with node2 (for example, if they are the same string)
        if set(node1.split()) & set(node2.split()):
            #print(f"{node1}, {data['relation']}, {node2}")
            f.write(f"{node1},{data['rel']},{node2}\n")
            # Remove all edges between the two nodes
            G.remove_edge(node1, node2, key)
            
nodes_to_remove = [node for node in G.nodes() if G.degree(node) == 0]
G.remove_nodes_from(nodes_to_remove)

print(f"Number of nodes: {G.number_of_nodes()}")
print(f"Number of edges: {G.number_of_edges()}")
print(f"number of relations: {len(get_all_rel(G))}")

Number of nodes: 132962
Number of edges: 361576
number of relations: 26


In [57]:
def save_to_csv(G, save_path):
    data_for_csv = []
    for node in G.nodes():
        if G[node]:
            row = [node]  # Start row with the node (Head)
            for neighbor in sorted(map(str, G[node])):  # Sort neighbors to maintain consistent order
                if isinstance(G, nx.MultiDiGraph):
                    for key in G[node][neighbor]:   
                        rel = G[node][neighbor][key]['rel']
                        row.extend([rel, neighbor])
                elif isinstance(G, nx.DiGraph):
                    rel = G[node][neighbor]['rel']
                    row.extend([rel, neighbor])
                elif isinstance(G, nx.Graph) or isinstance(G, nx.MultiGraph):
                    if node < neighbor:
                        rel = G[node][neighbor]['rel']
                        row.extend([rel, neighbor])
            
            data_for_csv.append(row)

    import csv
    # Writing to the CSV file
    with open(save_path, 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerows(data_for_csv)
    print(f'{save_path} saved')

save_to_csv(G, 'graph_data_umls_cleaned.csv')

graph_data_umls_cleaned.csv saved
