In [None]:
import json
import math
import re
from collections import defaultdict, Counter
import nltk
nltk.download('stopwords')
nltk.download('punkt')
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import spacy
from spacy.language import Language
from spacy.tokens import Doc
import tqdm
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore", message=".*beta.*will be renamed.*")
warnings.filterwarnings("ignore", message=".*gamma.*will be renamed.*")

# 1.加载数据集

In [None]:
## step1 加载数据集
with open("/local/home/sumyao/YSforGIT/dataset/Filtered2Added/picocorpus_nct_filtered_added_withnocluster.json",'r') as f:
    datasets=json.load(f)

In [None]:
id=8
print(datasets[id].get("content"))
pattern = r'[^.]*?clinicaltrials\.gov[^.]*?NCT\d{8}[^.]*?\.'
cleaned_abstract = re.sub(pattern, '', datasets[id].get("content"), flags=re.IGNORECASE).strip()
retrived_label = datasets[id].get("retrieved")
print(retrived_label)
participants, interventions, outcomes = retrived_pio_BM25_BERT(cleaned_abstract, retrived_label)
print(participants, interventions, outcomes)
label = datasets[id].get("label")
label

In [None]:

pattern = r'[^.]*?clinicaltrials\.gov[^.]*?NCT\d{8}[^.]*?\.'
cleaned_abstract = re.sub(pattern, '', datasets[id].get("content"), flags=re.IGNORECASE).strip()
cleaned_abstract

# 2.预处理 pubmed clinical

### pubmed预处理

In [None]:
def clean_sentence(sentence):
    sentence = re.sub(r"-DOCSTART-", "", sentence)
    sentence = sentence.strip()
    sentence = re.sub(r"[^a-zA-Z0-9.,;?!\s]", "", sentence)
    sentence = re.sub(r"\s+", " ", sentence)
    return sentence
    
def setup_custom_sentencizer(nlp):
    @Language.component("custom_sentencizer")
    def custom_sentencizer(doc: Doc) -> Doc:
        for sent in doc.sents:
            if sent.start < len(doc) and doc[sent.start].text[0].islower():
                doc[sent.start].is_sent_start = False
        return doc
    if "sentencizer" not in nlp.pipe_names:
        nlp.add_pipe("sentencizer", first=True)
    nlp.add_pipe("custom_sentencizer", before="parser")
    return nlp
    
cutter = spacy.load("en_core_web_sm")
cutter = setup_custom_sentencizer(cutter)

### clinical预处理

In [None]:
import re

def clean_value(value):
    if not value:
        return []
    # 如果包含 '|', 按 '|' 分割
    parts = value.split('|')
    cleaned_parts = []
    # 处理每个部分，只保留 ':' 右边的内容
    for part in parts:
        cleaned = re.sub(r'[()|]', '', part).strip()  # 去除多余字符
        if ':' in cleaned:
            cleaned = cleaned.split(':', 1)[1].strip()  # 保留冒号右边的内容
        cleaned_parts.append(cleaned)
    
    return cleaned_parts

def create_sentences(retrieved):
    # 获取数据
    age = retrieved.get("age", "").strip()
    gender = retrieved.get("gender", "").strip()
    conditions = clean_value(retrieved.get("conditions", ""))
    interventions = clean_value(retrieved.get("interventions", ""))
    primary_outcome = retrieved.get("primary outcome measures", "").strip()
    secondary_outcome = retrieved.get("secondary outcome measures", "").strip()

    # 合成句子
    sentences = []
    if gender or age or conditions:
        condition_text = ", ".join(conditions)
        sentences.append(f"Patient is a {gender} aged {age} with conditions {condition_text}.")

    if interventions:
        interventions_text = ", ".join(interventions)
        sentences.append(f"Interventions may include: {interventions_text}.")

    if primary_outcome:
        sentences.append(f"Primary outcome measures maybe: {primary_outcome}.")

    if secondary_outcome:
        sentences.append(f"Secondary outcome measures maybe: {secondary_outcome}.")

    return sentences

# 示例数据

retrieved = {
    "age": "",
    "gender": "FEMALE",
    "conditions": "Breast Cancer|Osteoporosis",
    "interventions": "DIETARY_SUPPLEMENT: calcium carbonate|DIETARY_SUPPLEMENT: calcium citrate|DIETARY_SUPPLEMENT: cholecalciferol|DRUG: alendronate sodium|DRUG: calcium gluconate|DRUG: risedronate sodium|OTHER: laboratory biomarker analysis|PROCEDURE: dual x-ray absorptometry",
    "primary outcome measures": "Percentage change of bone mineral density (BMD) measured at 2 years (from baseline) in the L1-L4 region of the spine and the hip, 5 years",
    "secondary outcome measures": "Percentage change in BMD at 5 years (from baseline), 5 years|Mean percentage change in BMD at 1, 3, and 5 years (from baseline), 5 years|Proportion of patients without osteopenia or osteoporosis (stratum I) who develop BMD below the absolute threshold for osteopenia (< -2.0 standard deviation below the mean), suffer any osteoporotic fracture, or have an asymptomatic fracture revealed ..., 5 years|Percentage of patients with osteopenia or osteoporosis (stratum II) who have ≥ 5% improvement of BMD at 2 years post randomization on protocol CAN-NCIC-MA27 and who have clinically apparent osteoporosis-related fracture of the long bones, 5 years|Pattern of change in bone biomarkers from baseline, 5 years|Clinical safety and tolerability of study medications, 5 years"
}

# 生成句子
sentences = create_sentences(retrieved)
for sentence in sentences:
    print(sentence)

# 3.检索与知识注入

## 3.1 声明guidelines

In [None]:
guidelines_picocorpus='''
For Participants, Eight entities are included: the total number of participants in the study, the number of participants in the intervention group, the number of participants in the control group, the condition being treated, eligibility criteria, age, ethnicity, and location. 
For Intervention and Control, the annotation focuses on specific interventions and control measures charactered with several words used in the study.
For Outcomes, the annotation emphasizes intervention-binary-absolute, intervention-binary-percentage, intervention-continous-mean, intervention-continous-standard deviation, intervention-continous-median, intervention-continous-first quartile, intervention-continous-third quartile
control-binary-absolute, control-binary-percentage, control-continous-mean, control-continous-standard deviation, control-continous-median, control-continous-first quartile, control-continous-third quartile '''

## 3.2 检索增强

In [None]:
## step4-1 BM25算法
import math
import numpy as np
from collections import defaultdict, Counter
from transformers import BertTokenizer, BertModel
import torch
from nltk.tokenize import sent_tokenize
import numpy as np
np.random.seed(42)  # 设置种子为42（你可以选择任何整数）


# 1. BM25 实现
class BM25:
    def __init__(self, documents, k1=1.5, b=0.75):
        self.documents = documents
        self.N = len(documents)  # 文档总数
        self.avgdl = sum(len(doc) for doc in documents) / self.N  # 文档平均长度
        self.k1 = k1
        self.b = b
        self.inverted_index = defaultdict(list)  # 倒排索引
        self.doc_lengths = []  # 记录每个文档的长度
        self.build_index()

    def build_index(self):
        """构建倒排索引并计算每个词项的文档频率和文档长度"""
        for idx, doc in enumerate(self.documents):
            self.doc_lengths.append(len(doc))
            term_counts = Counter(doc)
            for term, freq in term_counts.items():
                self.inverted_index[term].append((idx, freq))

    def idf(self, term):
        """计算词项的逆文档频率（IDF）"""
        df = len(self.inverted_index.get(term, []))  # 包含该词的文档数
        return math.log((self.N - df + 0.5) / (df + 0.5) + 1)

    def score(self, query, doc_idx):
        """计算查询与指定文档之间的BM25得分"""
        score = 0.0
        doc_length = self.doc_lengths[doc_idx]
        
        term_freqs = {term: freq for term, doc_freqs in self.inverted_index.items() 
                      for doc, freq in doc_freqs if doc == doc_idx}
        for term in query:
            if term in term_freqs:
                f = term_freqs[term]  # 词频
                idf = self.idf(term)  # 逆文档频率
                numerator = f * (self.k1 + 1)
                denominator = f + self.k1 * (1 - self.b + self.b * doc_length / self.avgdl)
                score += idf * (numerator / denominator)
        return score

    def search(self, query, top_n=1):
        """对查询进行BM25检索并返回相关性最高的前N个文档"""
        scores = {idx: self.score(query, idx) for idx in range(self.N)}
        return sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_n]

# 2. BERT 相似度计算
class BERT:
    def __init__(self, documents, bert_model="bert-base-uncased"):
        self.documents = documents
        self.tokenizer = BertTokenizer.from_pretrained(bert_model)
        self.model = BertModel.from_pretrained(bert_model)
        self.document_embeddings = self.encode_documents()

    def encode_documents(self):
        """使用 BERT 对文档进行编码"""
        embeddings = []
        for doc in self.documents:
            sentence = " ".join(doc)
            inputs = self.tokenizer(sentence, return_tensors="pt", truncation=True, padding=True, max_length=512)
            with torch.no_grad():
                outputs = self.model(**inputs)
            embeddings.append(outputs.last_hidden_state.mean(dim=1).squeeze().numpy())
        return np.array(embeddings)

    def encode_query(self, query):
        """使用 BERT 对查询进行编码"""
        sentence = " ".join(query)
        inputs = self.tokenizer(sentence, return_tensors="pt", truncation=True, padding=True, max_length=512)
        with torch.no_grad():
            outputs = self.model(**inputs)
        return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()

    def bert_similarity(self, query_embedding, doc_idx):
        """Calculate BERT similarity (cosine similarity)"""
        doc_embedding = self.document_embeddings[doc_idx]
        
        # Ensure embeddings are numpy arrays
        query_embedding = np.array(query_embedding)
        doc_embedding = np.array(doc_embedding)
        
        # Compute cosine similarity using np.matmul (alternative to np.dot)
        cos_sim = np.matmul(query_embedding, doc_embedding) / (np.linalg.norm(query_embedding) * np.linalg.norm(doc_embedding))
        #print(cos_sim)
        return cos_sim


# 3. 集成检索（BM25 + BERT）
class BM25_BERT:
    def __init__(self, documents, k1=1.5, b=0.75, bert_model="bert-base-uncased", alpha=0.5):
        self.bm25 = BM25(documents, k1, b)  # 初始化 BM25
        self.bert = BERT(documents, bert_model)  # 初始化 BERT
        self.alpha = alpha  # 集成的加权系数

    def search(self, query, top_n=1):
        """BM25 和 BERT 集成检索"""
        bm25_scores = {idx: self.bm25.score(query, idx) for idx in range(self.bm25.N)}
        query_embedding = self.bert.encode_query(query)
        
        min_bm25 = min(bm25_scores.values())
        max_bm25 = max(bm25_scores.values())
        
        final_scores = {}
        for idx, bm25_score in bm25_scores.items():
            # 对 BM25 评分进行归一化
            if max_bm25 != min_bm25:  # 避免除以零
                bm25_score_norm = (bm25_score - min_bm25) / (max_bm25 - min_bm25)
            else:
                bm25_score_norm = 0  # 当所有 BM25 分数相同时，归一化后设为 0

            # 对 BERT 评分进行归一化
            bert_score = self.bert.bert_similarity(query_embedding, idx)
            bert_score_norm = (bert_score + 1) / 2  # 将 [-1,1] 映射到 [0,1]
            
            # 计算最终得分
            final_scores[idx] = self.alpha * bm25_score_norm + (1 - self.alpha) * bert_score_norm
        #print(f"BM25 Score for {idx}: {bm25_score}, BERT Score for {idx}: {bert_score}")

        # 返回按分数排序的前 N 个文档
        return sorted(final_scores.items(), key=lambda x: x[1], reverse=True)[:top_n]


# 4. 检索计算主函数（处理 PIO 类别）
def retrived_pio_BM25_BERT(content, retrived_label):
    #content_sentences = sent_tokenize(content.replace('The trial is registered at ClinicalTrials',"").replace("-DOCSTART-",""))[:-1]  # 段落分句子
    doc = cutter(content)
    content_sentences=[clean_sentence(str(sent)) for sent in doc.sents]
    documents = [item.split() for item in content_sentences]  # 句子转词汇列表, [[]]
    #print(documents)
    rrf = BM25_BERT(documents)  # 初始化 BM25 + BERT

    participants = []
    outcomes = []
    interventions = []

    # 创建检索的句子
    sentences = create_sentences(retrived_label)
    #print(sentences)
    for sentence in sentences:
        query = sentence.split()
        
        # 检索 BM25 和 BERT 集成的结果
        results = rrf.search(query, top_n=1)
        
        # clinical transfer and match（根据查询的第一个词判断类别）
        for doc_idx, score in results:
            sim_document_content = " ".join(documents[doc_idx])
            first_word = query[0].lower()  # 将第一个词汇转为小写以便匹配
            #print(score)
            if first_word in ["age", "gender", "conditions"]:
                participants.append(sim_document_content)
            elif first_word.startswith("primary") or first_word.startswith("secondary"):
                outcomes.append(sim_document_content)
            else:
                interventions.append(sim_document_content)
    
    return participants, interventions, outcomes



# demo示例：如何调用
id=11
content = datasets[id].get("content")
pattern = r'\b.*?ClinicalTrials\.gov.*?NCT\d{8}.*?\b'
cleaned_abstract = re.sub(pattern, '', content)
content = re.sub(r'\n\s*\n', '\n', cleaned_abstract).strip()



retrived_label = datasets[id].get("retrieved")

participants, interventions, outcomes = retrived_pio_BM25_BERT(content, retrived_label)
participants, interventions, outcomes

print("Participants--------------", )
for p1 in list(set(participants)):
    print(p1)
print("Interventions----------")
for p2 in list(set(interventions))[:2]:
    print(p2)
print("Outcomes----------")
for p3 in list(set(outcomes))[:3]:
    print(p3)

datasets[id].get("labels")

# 4.生成模型

In [None]:
from openai import OpenAI
import random


# arg1: instruction
instruction='''Your task is to accurately complete the JSON object with a short text phrases and sentences that describe the keys based on provided input, which is the title and abstract of a publication. 
            The keys are Population, Interventions,  and  Outcomes. 
            Your response should be like :"{"Population": "","Interventions": "","Outcomes": ""}  
            You will be punished if your Response is not a JSON file filled with phrases from input.'''

# arg2: input
#input=datasets[id].get("content")

# arg3: guidelines
#guidelines=guidelines_picocorpus

# arg4: rag
# 确保列表有值，否则设置一个默认值
def rag_func(participants, interventions, outcomes):
    participant_text = list(set(participants))[0] if participants else ""
    intervention_text = list(set(interventions))[0] if interventions else ""
    outcome_text = list(set(outcomes))[0] if outcomes else ""
    rag = '''Here is the category and corresponding sentence numbering (id) based on SpaCy and initial capitalization:
    population :  {},
    intervention: {},
    outcome: {}
    You will get punished if the phrases and values are not extracted from the INPUT.
    '''.format(participant_text, intervention_text, outcome_text)
    return rag


In [None]:
import os
import openai


openai.api_key = "xxx"
openai.base_url = "xxx"
openai.default_headers = {"x-foo": "true"}


def generate(instruction, input, rag, guidelines,modelname):
    prompt='''### Instruction: {}. 
            ### Input:{}. 
            ### RAG:{}
            ### Guidelines: {}. 
            ### Response '''.format(instruction,input,rag,guidelines,"")
    completion = openai.chat.completions.create(
                                        model=modelname,
                                        messages=[
                                            {
                                                "role": "user",
                                                "content": prompt,
                                            },
                                        ],
                                    )
    response=completion.choices[0].message.content
    json_string = response.strip("```json\n").strip("```").strip()
    try:
        json_data = json.loads(json_string)
        json_response = json.dumps(json_data, ensure_ascii=False, indent=4)
        print(json_response)
    except json.JSONDecodeError as e:
        print(f"Error decoding JSON: {e}")
    return json_response

## demo4o

In [None]:
from openai import OpenAI
import random
client = OpenAI(
    # defaults to os.environ.get("OPENAI_API_KEY")
    api_key="sk-hZuBOOA4Ohv18QPNxV4OAhx8VE1A32m1LvWeUKpGWl2dMez4",
    base_url="https://api.chatanywhere.tech/v1"
)
def generate(instruction, input, rag, guidelines,modelname):
    prompt='''### Instruction: {}. 
            ### Input:{}. 
            ### RAG:{}
            ### Guidelines: {}. 
            ### Response '''.format(instruction,input,rag,guidelines,"")
    completion = client.chat.completions.create(
                                        model=modelname,
                                        messages=[
                                            {
                                                "role": "user",
                                                "content": prompt,
                                            },
                                        ],
                                    )
    response=completion.choices[0].message.content
    json_string = response.strip("```json\n").strip("```").strip()
    try:
        json_data = json.loads(json_string)
        json_response = json.dumps(json_data, ensure_ascii=False, indent=4)
        print(json_response)
    except json.JSONDecodeError as e:
        print(f"Error decoding JSON: {e}")
    return json_response

In [None]:

results=[]
for id in tqdm(range(20),desc="Processing examples"):
        content = datasets[id].get("content")
        pattern = r'[^.]*?clinicaltrials\.gov[^.]*?NCT\d{8}[^.]*?\.'
        content= re.sub(pattern, '', datasets[id].get("content"), flags=re.IGNORECASE).strip()

        retrived_label = datasets[id].get("retrieved")
        participants, interventions, outcomes = retrived_pio_BM25_BERT(content, retrived_label)
        rag=rag_func(participants, interventions, outcomes)

        json_onepiece=generate(instruction=instruction,
                                input=clean_sentence(content), 
                                rag=rag, 
                                guidelines="",
                                modelname= "gpt-4o-mini")
        results.append(json_onepiece)
with open('/local/home/sumyao/YSforGIT/output/4o_picocorpus_rag3.json', 'w') as json_file:
    json.dump(results, json_file)

results=[]
for id in tqdm(range(20),desc="Processing examples"):
        content = datasets[id].get("content")
        ppattern = r'[^.]*?clinicaltrials\.gov[^.]*?NCT\d{8}[^.]*?\.'
        content= re.sub(pattern, '', datasets[id].get("content"), flags=re.IGNORECASE).strip()

        json_onepiece=generate(instruction=instruction,
                                input=clean_sentence(content), 
                                rag="", 
                                guidelines=guidelines_picocorpus,
                                modelname= "gpt-4o-mini")
        results.append(json_onepiece)
        
with open('/local/home/sumyao/YSforGIT/output/4o_picocorpus_guidelines3.json', 'w') as json_file:
    json.dump(results, json_file)

results=[]
for id in tqdm(range(20),desc="Processing examples"):
        content = datasets[id].get("content")
        pattern = r'[^.]*?clinicaltrials\.gov[^.]*?NCT\d{8}[^.]*?\.'
        content= re.sub(pattern, '', datasets[id].get("content"), flags=re.IGNORECASE).strip()
        
        retrived_label = datasets[id].get("retrieved")
        participants, interventions, outcomes = retrived_pio_BM25_BERT(content, retrived_label)
        rag=rag_func(participants, interventions, outcomes)

        json_onepiece=generate(instruction=instruction,
                                input=clean_sentence(content), 
                                rag=rag, 
                                guidelines=guidelines_picocorpus,
                                modelname= "gpt-4o-mini")
        results.append(json_onepiece)
with open('/local/home/sumyao/YSforGIT/output/4o_picocorpus_rag_guidelines3.json', 'w') as json_file:
    json.dump(results, json_file)

# 5.评价函数

In [None]:
import os
# annotatiopn是列表形式
def preprocess_annotations(annotations):
    processed_annotations = []
    for annotation in annotations:
        processed_annotation = {}
        if 'participants' in annotation:
            processed_annotation['Population'] = ' '.join(annotation['participants'])  # Ensure it becomes a single string
        if 'interventions' in annotation:
            interventions = ' '.join(annotation['interventions'])  # Ensure it becomes a single string
            if 'comparator' in annotation:
                interventions += ' ' + ' '.join(annotation['comparator'])  # 合并comparator到interventions
            processed_annotation['Interventions'] = interventions
        if 'outcomes' in annotation:
            processed_annotation['Outcomes'] = ' '.join(annotation['outcomes'])  # Ensure it's a single string
        processed_annotations.append(processed_annotation)
    return processed_annotations

def preprocess_text(text):
    """---input:sentence
    ---output: cleaned sentence"""
    text = text.lower()
    text = re.sub(r'[-/]', ' ', text)
    text = re.sub(r'[(),;:-]', '', text)
    text = re.sub(r'Not specified', '', text)
    text = re.sub(r'not specified', '', text)
    stop_words = set(stopwords.words('english'))
    words = word_tokenize(text)
    filtered_words = [word for word in words if word not in stop_words]
    return ' '.join(filtered_words)

In [None]:
def calculate_metrics(annotations, predictions):
    p_annotation_text = preprocess_text(''.join(annotations.get('Population', [])))
    p_prediction_text = preprocess_text(''.join(predictions.get('Population', [])))
    i_annotation_text = preprocess_text(''.join(annotations.get('Interventions', [])))
    i_prediction_text = preprocess_text(''.join(predictions.get('Interventions', [])))
    o_annotation_text = preprocess_text(''.join(annotations.get('Outcomes', [])))
    o_prediction_text = preprocess_text(''.join(predictions.get('Outcomes', [])))

    #print(p_annotation_text)
    #print(o_prediction_text)

    p_precision, p_recall, p_f1, p_correct, p_annotation_len, p_prediction_len = calculate_element_metrics(p_annotation_text, p_prediction_text)
    i_precision, i_recall, i_f1, i_correct, i_annotation_len, i_prediction_len = calculate_element_metrics(i_annotation_text, i_prediction_text)
    o_precision, o_recall, o_f1, o_correct, o_annotation_len, o_prediction_len = calculate_element_metrics(o_annotation_text, o_prediction_text)

    return {
        'P': {'precision': p_precision, 'recall': p_recall, 'f1': p_f1},
        'I': {'precision': i_precision, 'recall': i_recall, 'f1': i_f1},
        'O': {'precision': o_precision, 'recall': o_recall, 'f1': o_f1},
    }

def calculate_element_metrics(annotation_text, prediction_text):
    annotation_words = set(annotation_text.split())
    prediction_words = set(prediction_text.split())

    intersection = annotation_words.intersection(prediction_words)
    union = annotation_words.union(prediction_words)

    precision = len(intersection) / len(prediction_words) if len(prediction_words) > 0 else 0
    recall = len(intersection) / len(annotation_words) if len(annotation_words) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    return precision, recall, f1, len(intersection), len(annotation_words), len(prediction_words)


def calculate_macro_avg(metrics):
    # 初始化三个要素的累计精度、召回率、F1分数
    p_precision_total, p_recall_total, p_f1_total = 0, 0, 0
    i_precision_total, i_recall_total, i_f1_total = 0, 0, 0
    o_precision_total, o_recall_total, o_f1_total = 0, 0, 0
    count = len(metrics)

    # 累加每个要素的指标
    for metric in metrics:
        p_precision_total += metric['P']['precision']
        p_recall_total += metric['P']['recall']
        p_f1_total += metric['P']['f1']

        i_precision_total += metric['I']['precision']
        i_recall_total += metric['I']['recall']
        i_f1_total += metric['I']['f1']

        o_precision_total += metric['O']['precision']
        o_recall_total += metric['O']['recall']
        o_f1_total += metric['O']['f1']

    # 计算宏平均
    p_precision_avg = p_precision_total / count if count > 0 else 0
    p_recall_avg = p_recall_total / count if count > 0 else 0
    p_f1_avg = p_f1_total / count if count > 0 else 0

    i_precision_avg = i_precision_total / count if count > 0 else 0
    i_recall_avg = i_recall_total / count if count > 0 else 0
    i_f1_avg = i_f1_total / count if count > 0 else 0

    o_precision_avg = o_precision_total / count if count > 0 else 0
    o_recall_avg = o_recall_total / count if count > 0 else 0
    o_f1_avg = o_f1_total / count if count > 0 else 0

    # 返回各个要素的平均结果
    return {
        'P': {'precision': p_precision_avg, 'recall': p_recall_avg, 'f1': p_f1_avg},
        'I': {'precision': i_precision_avg, 'recall': i_recall_avg, 'f1': i_f1_avg},
        'O': {'precision': o_precision_avg, 'recall': o_recall_avg, 'f1': o_f1_avg}
    }



def calculate_micro_avg(metrics):
    # 初始化三个要素的累计统计
    p_correct, p_annotation_total, p_prediction_total = 0, 0, 0
    i_correct, i_annotation_total, i_prediction_total = 0, 0, 0
    o_correct, o_annotation_total, o_prediction_total = 0, 0, 0

    # 累计每个要素的统计值
    for metric in metrics:
        # P
        p_correct += metric['P']['precision'] * metric['P']['f1']  # 假设使用交集数量作为正确预测
        p_annotation_total += metric['P']['recall']  # 注释的单词数
        p_prediction_total += metric['P']['f1']  # 预测的单词数

        # I
        i_correct += metric['I']['precision'] * metric['I']['f1']
        i_annotation_total += metric['I']['recall']
        i_prediction_total += metric['I']['f1']

        # O
        o_correct += metric['O']['precision'] * metric['O']['f1']
        o_annotation_total += metric['O']['recall']
        o_prediction_total += metric['O']['f1']

    # 分别计算三个要素的微观平均
    def calculate_avg(correct, annotation_total, prediction_total):
        precision = correct / prediction_total if prediction_total > 0 else 0
        recall = correct / annotation_total if annotation_total > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        return precision, recall, f1

    p_precision, p_recall, p_f1 = calculate_avg(p_correct, p_annotation_total, p_prediction_total)
    i_precision, i_recall, i_f1 = calculate_avg(i_correct, i_annotation_total, i_prediction_total)
    o_precision, o_recall, o_f1 = calculate_avg(o_correct, o_annotation_total, o_prediction_total)

    # 返回结果
    return {
        'P': {'precision': p_precision, 'recall': p_recall, 'f1': p_f1},
        'I': {'precision': i_precision, 'recall': i_recall, 'f1': i_f1},
        'O': {'precision': o_precision, 'recall': o_recall, 'f1': o_f1}
    }

def calculate_total_macro_avg(macro_avg):
    # 总宏平均：取 P、I、O 的 Precision、Recall 和 F1 的均值
    total_macro_precision = (macro_avg['P']['precision'] + macro_avg['I']['precision'] + macro_avg['O']['precision']) / 3
    total_macro_recall = (macro_avg['P']['recall'] + macro_avg['I']['recall'] + macro_avg['O']['recall']) / 3
    total_macro_f1 = (macro_avg['P']['f1'] + macro_avg['I']['f1'] + macro_avg['O']['f1']) / 3
    return total_macro_precision, total_macro_recall, total_macro_f1

def calculate_total_micro_avg(micro_avg):
    # 总微观平均：取 P、I、O 的 Precision、Recall 和 F1 的均值
    total_micro_precision = (micro_avg['P']['precision'] + micro_avg['I']['precision'] + micro_avg['O']['precision']) / 3
    total_micro_recall = (micro_avg['P']['recall'] + micro_avg['I']['recall'] + micro_avg['O']['recall']) / 3
    total_micro_f1 = (micro_avg['P']['f1'] + micro_avg['I']['f1'] + micro_avg['O']['f1']) / 3
    return total_micro_precision, total_micro_recall, total_micro_f1


individual_metrics = []
individual_metrics = []
annotations=[item['label'] for item in datasets][:20]
processed_annotations = preprocess_annotations(annotations)

# 假设 annotations_list 和 predictions_list 已经正确填充
'''
for path in ['/local/home/sumyao/YSforGIT/output/final/4o_picocorpus3.json',
             '/local/home/sumyao/YSforGIT/output/final/4o_picocorpus_rag3.json',
             '/local/home/sumyao/YSforGIT/output/final/4o_picocorpus_guidelines3.json',
             '/local/home/sumyao/YSforGIT/output/final/4o_picocorpus_rag_guidelines3.json']:'''
for path in [#'/local/home/sumyao/YSforGIT/output/4o_picocorpus.json',
             '/local/home/sumyao/YSforGIT/output/4o_picocorpus_rag3.json',
             '/local/home/sumyao/YSforGIT/output/4o_picocorpus_guidelines3.json',
             '/local/home/sumyao/YSforGIT/output/4o_picocorpus_rag_guidelines3.json']:
    predictions = json.load(open(path))
    predictions = [json.loads(item) for item in predictions]

    # 在这里我们假设 annotations_list 已经是一个经过处理的字典列表
    annotations_list = processed_annotations  # 这个列表包含处理后的注释字典
    predictions_list = predictions
    print(len(annotations_list),len(predictions_list))

    for i in range(len(annotations_list)):
        metrics = calculate_metrics(annotations_list[i], predictions_list[i])
        
        '''# 输出每个要素的评估结果
        print(f"Element {i + 1}:")
        for key in metrics:
            print(f"{key} - Precision: {metrics[key]['precision']:.4f}, Recall: {metrics[key]['recall']:.4f}, F1: {metrics[key]['f1']:.4f}")
'''
        individual_metrics.append(metrics)

    # 计算宏平均
    macro_avg = calculate_macro_avg(individual_metrics)
    print("\nMacro Average Results:")
    for key in macro_avg:
        print(f"{key} - Precision: {macro_avg[key]['precision']:.4f}, Recall: {macro_avg[key]['recall']:.4f}, F1: {macro_avg[key]['f1']:.4f}")

    # 计算微观平均
    micro_avg = calculate_micro_avg(individual_metrics)
    print("\nMicro Average Results:")
    for key in micro_avg:
        print(f"{key} - Precision: {micro_avg[key]['precision']:.4f}, Recall: {micro_avg[key]['recall']:.4f}, F1: {micro_avg[key]['f1']:.4f}")

    # 计算总宏平均
    total_macro_precision, total_macro_recall, total_macro_f1 = calculate_total_macro_avg(macro_avg)
    print("\nTotal Macro Average:")
    print(f"Precision: {total_macro_precision:.4f}, Recall: {total_macro_recall:.4f}, F1: {total_macro_f1:.4f}")

    # 计算总微观平均
    total_micro_precision, total_micro_recall, total_micro_f1 = calculate_total_micro_avg(micro_avg)
    print("\nTotal Micro Average:")
    print(f"Precision: {total_micro_precision:.4f}, Recall: {total_micro_recall:.4f}, F1: {total_micro_f1:.4f}")