In [None]:
import os
import json
import random
from typing import Optional

import shutil
import pandas as pd
from tqdm import tqdm
from textwrap import dedent
from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain.schema import Document
from langchain_chroma import Chroma
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.schema import Document
from langchain_core.runnables import RunnableLambda
from langchain_huggingface.llms import HuggingFacePipeline
from pydantic import BaseModel, Field
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoTokenizer, pipeline


from FlagEmbedding import BGEM3FlagModel
from utils.pydantic import Class, RAG_Analyse_Class
from utils.tools import (
    get_vectordb_length, clear, md, green,
    grey, similarity_search, build_reference_examples_chain,
    docl_search, extract_json
)
from utils.prompt import docl_prompt_template, zs_prompt_template, fs_prompt_template, rag_prompt_template
from utils.langsmith import activate_langsmith, remove_langsmith
from utils.embedding import BGEM3Embeddings, BERTEmbeddings
from utils import api_config
from config.llm_setting import llm_config

# activate_langsmith()
# remove_langsmith()·


dataset = 'LatentHatred' #[Latent_Hatred]
method = 'cocl' #[zs, fs, rag, cocl]
llm_name = "llama3-70b-instruct" #[gpt-3.5-turbo, llama3-70b-instruct]
train_data_path = f'./data/{dataset}/IH_train.csv'
test_data_path = f'./data/{dataset}/IH_test.csv'
model_path = './model'
persist_directory = './Chroma'
emb_device = 'cuda:0'
llm_device = 'cuda:4'
encode_batch_size=256
# 'bert' or 'bgem3' or 'ftodel'
embedding_model = 'ftmodel'
output_path = f'./results/{dataset}_{method}_{llm_name}_{embedding_model}.csv'
finetune_data_path = f'./finetune/{method}_ft_data.json'
finetuned_model_path = './finetune/cocl_finetuned_model'

In [None]:
if llm_name.startswith('gpt'):
    llm = ChatOpenAI(
        temperature=0.95,
        model=llm_name,
        openai_api_key=api_config.openai_key,
        openai_api_base=api_config.openai_base
    )

elif llm_name == "llama3-70b-instruct":
    llm = ChatOpenAI(
        temperature=0.95,
        model=llm_name,
        openai_api_key=api_config.openai_key,
        openai_api_base=api_config.openai_base
    )

In [None]:
if method != 'zs' and method != 'fs':
    if embedding_model == 'bgem3':
        model = AutoModel.from_pretrained("BAAI/bge-m3", cache_dir=model_path, local_files_only=True).to(emb_device)
        tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3", cache_dir=model_path, local_files_only=True)
        embeddings = BGEM3Embeddings(model, tokenizer, encode_batch_size)
        
    elif embedding_model == 'bert':
        model = AutoModel.from_pretrained("./model/bert-base-uncased", cache_dir=model_path, local_files_only=True).to(emb_device)
        tokenizer = AutoTokenizer.from_pretrained("./model/bert-base-uncased", cache_dir=model_path, local_files_only=True)
        embeddings = BERTEmbeddings(model, tokenizer, encode_batch_size)
        
    elif embedding_model == 'ftmodel':
        model = AutoModel.from_pretrained(finetuned_model_path, cache_dir=model_path).to(emb_device)
        tokenizer = AutoTokenizer.from_pretrained(finetuned_model_path, cache_dir=model_path)
        embeddings = BGEM3Embeddings(model, tokenizer, encode_batch_size)


In [None]:
def build_chain(method):
    if method == 'zs':
        chain = (
            RunnableLambda(lambda x: x['input_example'])
            | zs_prompt_template
            | (llm.with_structured_output(Class) if llm_config[llm_name]['structured_output'] else llm)
        )
        print("zs_chain built")

    elif method == 'fs':
        chain = (
            RunnableLambda(lambda x: x['input_example'])
            | fs_prompt_template
            | (llm.with_structured_output(Class) if llm_config[llm_name]['structured_output'] else llm)
        )
        print("fs_chain built")

    elif method == "rag" or method == "cocl" or method == 'docl':
        df = pd.read_csv(train_data_path)
        documents = [
            Document(page_content=row['post'], metadata={'class': row['class']})
            for _, row in df.iterrows()
        ]
        # documents = documents[:1000]
        categories = set([doc.metadata['class'] for doc in documents])
        shutil.rmtree(persist_directory, ignore_errors=True)
        vectordb = Chroma.from_documents(documents, embeddings)
        print(get_vectordb_length(vectordb))
        
        if method == "rag" or "cocl":
            chain = (
                RunnableLambda(lambda x: {"sim_docs": similarity_search(vectordb, x["input_example"], x.get('k', 6)), "input_example": x["input_example"]})  # 相似度搜索
                | RunnableLambda(lambda x: {"reference_examples": build_reference_examples_chain(x["sim_docs"]), "input_example": x["input_example"]})  # 生成参考示例
                # | RunnableLambda(lambda x: md(docl_prompt_template.format(**x)) or x)  # Print formatted prompt
                # | RunnableLambda(lambda x: md(x['reference_examples'], f"{green('Input_post')}: {x['input_example']}") or x)  # Print the Input
                | rag_prompt_template  # 提示模板
                | (llm.with_structured_output(RAG_Analyse_Class) if llm_config[llm_name]['structured_output'] else llm)
            )
        if method == 'docl':
            chain = (
                RunnableLambda(lambda x: {"sim_docs": docl_search(vectordb, x["input_example"], x.get('k', 6)), "input_example": x["input_example"]})  # 相似度搜索
                | RunnableLambda(lambda x: {"reference_examples": build_reference_examples_chain(x["sim_docs"]), "input_example": x["input_example"]})  # 生成参考示例
                # | RunnableLambda(lambda x: md(docl_prompt_template.format(**x)) or x)  # Print formatted prompt
                # | RunnableLambda(lambda x: md(x['reference_examples'], f"{green('Input_post')}: {x['input_example']}") or x)  # Print the Input
                | rag_prompt_template  # 提示模板
                | (llm.with_structured_output(RAG_Analyse_Class) if llm_config[llm_name]['structured_output'] else llm)
            )
    return chain

chain = build_chain(method)


In [None]:
def test(test_data_path, output_path, chain, num=None, k=None):
    test_df = pd.read_csv(test_data_path)
    test_df['pred'] = None
    if num is not None:
        test_df = test_df[:num]

    if dataset == 'LatentHatred':
        def process_row(idx, post):
            # 调用预测模型
            try:
                input_data = {
                    "input_example": post
                }
                if k is not None:
                    input_data["k"] = k
                result = chain.invoke(input_data)
                if llm_config[llm_name]['structured_output']:
                    predict_label = result.dict()['result']
                else:
                    content = result.dict()['content']
                    content_json = extract_json(content)
                    predict_label = json.loads(content_json)['result']
                return idx, predict_label
            except Exception as e:
                # print(f"Exception caught in chain: {content}")
                return idx, "error"  # Merge with original input if needed
        
        with ThreadPoolExecutor(max_workers=20) as executor:
            # 提交任务
            futures = {executor.submit(process_row, idx, row['post']): idx for idx, row in test_df.iterrows()}
            
            for future in tqdm(as_completed(futures), total=len(futures), mininterval=2.0):
                idx, predict_label = future.result()
                if idx % 100 == 0:
                    test_df.to_csv(output_path, index=False)
                test_df.at[idx, 'pred'] = predict_label

        # test_df['pred'] = test_df['pred'].apply(lambda x: 'not_hate' if 'B' in x.upper() else 'hate')
    return test_df

def calculate_metrics(test_result):
    label = [0 if 'not' in cls.lower() else 1 for cls in test_result['class'].tolist()]
    pred = [0 if 'not' in cls.lower() else 1 for cls in test_result['pred'].tolist()]

    precision = precision_score(label, pred)
    recall = recall_score(label, pred)
    f1 = f1_score(label, pred)
    accuracy = accuracy_score(label, pred)
    conf_matrix = confusion_matrix(label, pred)
    
    print(f"Precision: {precision:.3f}")
    print(f"Recall: {recall:.3f}")
    print(f"F1 Score: {f1:.3f}")
    print(f"Accuracy: {accuracy:.3f}")
    print(f"Confusion Matrix: \n{conf_matrix}")
    
    return precision, recall, f1, accuracy, conf_matrix

In [None]:
test_result = test(test_data_path, output_path, chain, num=100, k=20)
p, r, f1, acc, cm = calculate_metrics(test_result[test_result.loc[:, 'pred'] != "error"])
test_result.to_csv(output_path, index=False)

In [None]:
temp_result = pd.read_csv("/home/liyuanyong2022/DOCL/results/LatentHatred_cocl_llama3-70b-instruct_ftmodel.csv")
temp_result = temp_result[temp_result.loc[:, 'pred'].notnull()]
temp_result = temp_result[:1000]
p, r, f1, acc, cm = calculate_metrics(temp_result[temp_result.loc[:, 'pred'] != "error"])

In [None]:
temp_result = pd.read_csv("/home/liyuanyong2022/DOCL/results/LatentHatred_zs_llama3-70b-instruct_bgem3.csv")
temp_result = temp_result[temp_result.loc[:, 'pred'].notnull()]
temp_result = temp_result[:1000]
p, r, f1, acc, cm = calculate_metrics(temp_result[temp_result.loc[:, 'pred'] != "error"])

In [None]:
test_result[test_result.loc[:, 'pred'] != "error"]

In [None]:
from langchain.globals import set_verbose, set_debug
set_verbose(True)
set_debug(True)

result = chain.invoke({
    "input_example": "white americans are committing suicide in reco...	",
})

## Build Train Data(COCL)

In [None]:
if method == 'cocl':
    def build_data(x):
        cla = 'hate' if x['class'] in ['implicit_hate', 'explicit_hate'] else x['class']
        
        reference_examples = similarity_search_chain(vectordb, x["input_example"], k=50)
        reference_examples = [
            {'class': 'hate' if e['class'] in ['implicit_hate', 'explicit_hate'] else e['class'], 'post': e['post']}
            for e in reference_examples
        ]
        query = x['input_example']
        pos_list = []
        neg_list = []
        for example in reference_examples:
            if example['class'] == cla:
                pos_list.append(example['post'])
            else:
                neg_list.append(example['post'])
        
        if len(pos_list) == 0 or len(neg_list) == 0:
            return None
        return {
            "query": query,
            "pos": pos_list,
            "neg": neg_list
        }
    
    test_data = pd.read_csv(test_data_path)
    output_data = []
    for idx, data in tqdm(test_data.iterrows(), total=len(test_data), mininterval=2.0):
        result = build_data({"input_example": data['post'], "class": data['class']})
        if result:
            output_data.append(result)
    with open(finetune_data_path, 'w', encoding='utf-8') as f:
        json.dump(output_data, f, ensure_ascii=False, indent=4)



## Test

In [None]:
df = pd.read_csv(train_data_path)
documents = [
    Document(page_content=row['post'], metadata={'class': row['class']})
    for _, row in df.iterrows()
]
# documents = documents[:1000]
categories = set([doc.metadata['class'] for doc in documents])
shutil.rmtree(persist_directory, ignore_errors=True)
vectordb = Chroma.from_documents(documents, embeddings)
print(get_vectordb_length(vectordb))

if method == "rag" or "cocl":
    chain = (
        RunnableLambda(lambda x: {"sim_docs": similarity_search(vectordb, x["input_example"], x.get('k', 6)), "input_example": x["input_example"]})  # 相似度搜索
        | RunnableLambda(lambda x: {"reference_examples": build_reference_examples_chain(x["sim_docs"]), "input_example": x["input_example"]})  # 生成参考示例
        # | RunnableLambda(lambda x: md(docl_prompt_template.format(**x)) or x)  # Print formatted prompt
        # | RunnableLambda(lambda x: md(x['reference_examples'], f"{green('Input_post')}: {x['input_example']}") or x)  # Print the Input
        | rag_prompt_template  # 提示模板
        | llm.with_structured_output(RAG_Analyse_Class)  # Pass the prompt string to llm
    )
if method == 'docl':
    chain = (
        RunnableLambda(lambda x: {"sim_docs": docl_search(vectordb, x["input_example"], x.get('k', 6)), "input_example": x["input_example"]})  # 相似度搜索
        | RunnableLambda(lambda x: {"reference_examples": build_reference_examples_chain(x["sim_docs"]), "input_example": x["input_example"]})  # 生成参考示例
        # | RunnableLambda(lambda x: md(docl_prompt_template.format(**x)) or x)  # Print formatted prompt
        # | RunnableLambda(lambda x: md(x['reference_examples'], f"{green('Input_post')}: {x['input_example']}") or x)  # Print the Input
        | rag_prompt_template  # 提示模板
        | llm.with_structured_output(RAG_Analyse_Class)  # Pass the prompt string to llm
    )

In [None]:
docl_chain = (
    RunnableLambda(lambda x: {"sim_docs": docl_search(vectordb, x["input_example"], x.get('k', 6)), "input_example": x["input_example"]})  # 相似度搜索
    | RunnableLambda(lambda x: {"reference_examples": build_reference_examples_chain(x["sim_docs"]), "input_example": x["input_example"]})  # 生成参考示例
    # | RunnableLambda(lambda x: md(docl_prompt_template.format(**x)) or x)  # Print formatted prompt
    | RunnableLambda(lambda x: md(x['reference_examples'], f"{green('Input_post')}: {x['input_example']}") or x)  # Print the Input
    | rag_prompt_template  # 提示模板
    # | llm.with_structured_output(RAG_Analyse_Class)  # Pass the prompt string to llm
)
cocl_chain = (
    RunnableLambda(lambda x: {"sim_docs": similarity_search(vectordb, x["input_example"], x.get('k', 6)), "input_example": x["input_example"]})  # 相似度搜索
    | RunnableLambda(lambda x: {"reference_examples": build_reference_examples_chain(x["sim_docs"]), "input_example": x["input_example"]})  # 生成参考示例
    # | RunnableLambda(lambda x: md(docl_prompt_template.format(**x)) or x)  # Print formatted prompt
    | RunnableLambda(lambda x: md(x['reference_examples'], f"{green('Input_post')}: {x['input_example']}") or x)  # Print the Input
    | rag_prompt_template  # 提示模板
    # | llm.with_structured_output(RAG_Analyse_Class)  # Pass the prompt string to llm
)

In [None]:
temp_result = pd.read_csv("/home/liyuanyong2022/DOCL/results/LatentHatred_docl_gpt-3.5-turbo_ftmodel.csv")
temp_result = temp_result[temp_result.loc[:, 'pred'].notnull()]

for idx, data in temp_result.iterrows():
    if data['class'] != data['pred']:
        result = docl_chain.invoke({
            "input_example": data['post'],
            "k": 12
        })
        md(f"{grey('label: ')} {data['class']}")
        result = cocl_chain.invoke({
            "input_example": data['post'],
            "k": 12
        })
        break