## Training

In [None]:
from sklearn.preprocessing import LabelEncoder, StandardScaler
from utils.train_utils import prepare_data_for_training, save_models_and_metrics, load_models_and_metrics
from sklearn.metrics import accuracy_score, classification_report
from joblib import dump, load
import os
from sklearn.model_selection import StratifiedKFold
import torch
import random
import numpy as np 
import lightgbm as lgb

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True

setup_seed(42)

def train(
    dataset_configs,
    mode = "cls",
    model_type = 'lgb',
    model_configs = {
        'boosting_type': 'gbdt',
        'num_leaves': 24,      
        'max_depth': 8,           
        'min_child_samples': 20,  
        'min_child_weight': 10,  

        'reg_alpha': 0.5,
        'reg_lambda': 2.0,
        'n_estimators': 2000,

        'subsample': 0.7,          
        'colsample_bytree': 0.7,   

        'learning_rate': 0.01,
        'scale_pos_weight': 0.3, # 1974/2628=0.75 (negative/positive), # 8B, 70B (cot, 0-cot, l2m, tot, mcts)

        'objective': 'multiclass',
        'metric': 'multi_logloss',
        'random_state': 42,
        'n_jobs': -1,
        'verbose': -1,
        'num_class': 2, # 2 for mode=='reg'; else 5
    },
    n_splits = 5,
    verbose = False,
    random_seed=42,
    start_idx=0, end_idx=20
):  

    print("==> Preprocessing data")
    X, y, x_scaler, y_scaler, acc_infos = prepare_data_for_training(
        dataset_configs,
        verbose=verbose,
        mode=mode
    )

    dump(
        [X, y, x_scaler, y_scaler, acc_infos],
        f"training_data/processed_train_data/randomForest/all_cfgs_start-0_end-20.pkl"
    )
    # X, y, x_scaler, y_scaler, acc_infos = load(f"training_data/processed_train_data/randomForest/{dataset_configs}_cfgs_start-0_end-20.pkl")
    
    print(X.shape)
    print(y.shape)
    val_scores = []
    models = []
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_seed)
    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        print(f"Training fold {fold + 1}/{n_splits}")
        
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]
        
        # Create dataset for LightGBM
        train_data = lgb.Dataset(X_train, label=y_train)
        val_data = lgb.Dataset(X_val, label=y_val, reference=train_data)
        
        # Train model
        model = lgb.train(
            model_configs,
            train_data,
            valid_sets=[train_data, val_data],
            callbacks=[
                lgb.early_stopping(stopping_rounds=50),
                lgb.log_evaluation(period=100)
            ]
        )
        
        # Predict and evaluate
        val_preds = model.predict(X_val)
        val_preds = np.argmax(val_preds, axis=1)
        val_score = np.mean(val_preds == y_val)
        
        print(f"Fold {fold + 1} Validation Accuracy: {val_score:.4f}")
        
        val_scores.append(val_score)
        models.append(model)
    
    print(f"\nMean CV Accuracy: {np.mean(val_scores):.4f} ± {np.std(val_scores):.4f}")
    
    return (
        models, 
        x_scaler, 
        y_scaler, 
        acc_infos['mean_voting_acc'], 
        acc_infos['mean_average_acc'],
        (np.mean(val_scores), np.std(val_scores))
    )

def generate_dataset_configs(
    start_idx=0,
    end_idx=5,
    specific_combinations=None
):
    """
    Generate dataset configurations with predefined options
    """
    # Predefined options
    DATASETS = ['aqua', 'mmlu', 'commonsenseqa', 'strategyqa'] 
    MODELS = [
        # 'Llama3.2-1B-Instruct',
        'Llama-3.2-3B-Instruct',
        'Llama3.1-8B-Instruct',
        'Meta-Llama-3.1-70B-Instruct-Turbo'
    ]
    METHODS = ['cot', 'zero_shot_cot', 'l2m']
    
    configs = []
    
    if specific_combinations:
        # Use specific combinations if provided
        for combo in specific_combinations:
            config = {
                'dataset': combo['dataset'],
                'model': combo['model'],
                'method': combo['method'],
                'start_idx': start_idx,
                'end_idx': end_idx
            }
            configs.append(config)
    else:
        # Generate all possible combinations
        for dataset in DATASETS:
            for model in MODELS:
                for method in METHODS:
                    config = {
                        'dataset': dataset,
                        'model': model,
                        'method': method,
                        'start_idx': start_idx,
                        'end_idx': end_idx
                    }
                    configs.append(config)
    
    return configs

# 1. Generate all combinations (48 combinations total: 4 datasets × 4 models × 3 methods)
dataset_configs = generate_dataset_configs(start_idx=0, end_idx=20)

print("\n ==> Training \n")
(
    randomforest_models, randomforest_X_scaler, 
    randomforest_y_scaler, llm_train_voting_acc, llm_train_avg_acc,
    randomforest_mean_std
) = train(
    dataset_configs=dataset_configs,
    mode="reg",
    model_type='lgb',
    start_idx=0, end_idx=20,
    verbose=True
)

save_models_and_metrics(
    randomforest_models, randomforest_X_scaler, 
    randomforest_y_scaler, llm_train_voting_acc, llm_train_avg_acc, randomforest_mean_std,
    ckpt_dir=f"ckpts/all_start_0_end_20"
)

## Evaluation

In [None]:
from utils.visual_utils import get_sample_distance_matrix
import numpy as np
from scipy.stats import pearsonr
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from itertools import combinations
import plotly.express as px
from joblib import dump, load
import os 

plt.style.use('default')

def split_array_by_weight_percentiles(data, n_segments=5):
    # Calculate percentile thresholds
    percentiles = np.linspace(0, 100, n_segments + 1)
    thresholds = np.percentile(data[:, 0], percentiles)
    
    segments = []
    for i in range(n_segments):
        if i == 0:
            mask = data[:, 0] <= thresholds[1]
        elif i == n_segments - 1:
            mask = data[:, 0] > thresholds[-2]
        else:
            mask = (data[:, 0] > thresholds[i]) & (data[:, 0] <= thresholds[i + 1])
        
        segments.append(data[mask])
    return segments

def get_consistency_data(model, dataset, method, root='exp-data-scale_full'):
    datas = get_sample_distance_matrix(model=model, dataset=dataset, method=method, root=root)
    all_sample_chains = []
    all_sample_answers = []
    all_sample_gts = []
    for _, plot_data in datas.items():
        distance_matrix, coordinates_2d, num_thoughts_each_chain, num_chains, all_answers, answer_gt_short = plot_data.values()
        # Collect points for each chain
        sample_chains = []
        for chain_idx in range(num_chains):
            start_idx = sum(num_thoughts_each_chain[:chain_idx])
            end_idx = sum(num_thoughts_each_chain[:chain_idx+1])
            if end_idx <= start_idx:
                continue
            if dataset == "strategyqa":
                if distance_matrix[start_idx:end_idx, :].shape[1] == 3:
                    chain_distance = distance_matrix[start_idx:end_idx, 1:] # remove the (s, T)
                else:
                    chain_distance = distance_matrix[start_idx:end_idx, :]
            elif dataset == "mmlu":
                if distance_matrix[start_idx:end_idx, :].shape[1] == 5:
                    chain_distance = distance_matrix[start_idx:end_idx, 1:] # remove the (s, T)
                else:
                    chain_distance = distance_matrix[start_idx:end_idx, :]
            else:
                if distance_matrix[start_idx:end_idx, :].shape[1] == 6:
                    chain_distance = distance_matrix[start_idx:end_idx, 1:] # remove the (s, T)
                else:
                    chain_distance = distance_matrix[start_idx:end_idx, :]
            
            sample_chains.append(chain_distance)
        all_sample_answers.append(all_answers)
        all_sample_gts.append(answer_gt_short)
        all_sample_chains.append(sample_chains)
    
    return all_sample_chains, all_sample_answers, all_sample_gts

llm_consistency_records = {}
for model in [
    'Llama-3.2-1B-Instruct',
    'Llama-3.2-3B-Instruct',
    'Meta-Llama-3.1-8B-Instruct-Turbo',
    'Meta-Llama-3.1-70B-Instruct-Turbo'
]:
    llm_consistency_records[model] = {}
    for dataset in ['aqua', 'mmlu', 'commonsenseqa', 'strategyqa']: 
        llm_consistency_records[model][dataset] = {}
        methods = ['cot', 'l2m', 'tot', 'mcts']
        for method in methods:
            
            consistency_pkl_path = f"correlation_plot_data/{model}-{dataset}-{method}.pkl"
            if os.path.exists(consistency_pkl_path):
                all_sample_chains, all_sample_answers, all_sample_gts = load(consistency_pkl_path)
            else:
                print(f"==> Processing {model}-{dataset}-{method}")
                if method in ['tot', 'mcts']:
                    all_sample_chains, all_sample_answers, all_sample_gts = get_consistency_data(model, dataset, method, root='exp-data-searching')
                else:
                    all_sample_chains, all_sample_answers, all_sample_gts = get_consistency_data(model, dataset, method, root='exp-data-scale_full')
                dump((all_sample_chains, all_sample_answers, all_sample_gts), consistency_pkl_path)
            
            llm_consistency_records[model][dataset][method] = {}
            last_thought_chunk_recorder = {}

            ANSWER_LIST = np.array(["A", "B", "C", "D", "E"])
            sample_mean_accs = []
            sample_mean_confs = []
            for sample_idx, sample_chains in enumerate(all_sample_chains): # 50 samples
                last_thought_chunk_recorder[sample_idx] = {}

                chain_answers = all_sample_answers[sample_idx]
                sample_gt = all_sample_gts[sample_idx]
                # print(chain_answers) # 10
                # print(sample_gt) # 1
                confidence_list = []
                acc_list = []
                for chain_idx, chain in enumerate(sample_chains): # 10 CoT per sample
                    last_thought_chunk_recorder[sample_idx][chain_idx] = {}
                    model_pred = chain_answers[chain_idx]

                    # * consistency of all thought with last thought
                    if len(chain) <= 1:
                        continue
                    else:
                        chain_without_last_thought, last_thought = chain[:-1], chain[-1]
                        last_thought_answer = ANSWER_LIST[np.argmin(last_thought)]
                        # average
                        list_thought_consistency = []
                        list_thought_conf = []
                        for thought_idx, thought in enumerate(chain_without_last_thought):
                            # thouhgt: (5, ) distance matrix
                            thought_answer_idx = np.argmin(thought)
                            consistency = ANSWER_LIST[thought_answer_idx] == last_thought_answer
                            list_thought_consistency.append(consistency)
                            list_thought_conf.append(thought[thought_answer_idx])

                        last_thought_chunk_recorder[sample_idx][chain_idx] = {
                            'acc': np.mean(list_thought_consistency),
                            'conf': np.mean(list_thought_conf),
                            'correctness': model_pred == sample_gt
                        }

            llm_consistency_records[model][dataset][method] = last_thought_chunk_recorder

In [None]:
from joblib import load, dump

from utils.train_utils import *
import numpy as np 

dataset_configs = "all"
(
    randomforest_models, randomforest_X_scaler, 
    randomforest_y_scaler, llm_train_voting_acc, llm_train_avg_acc, randomforest_mean_std
) = load_models_and_metrics(
    f"ckpts/{dataset_configs}_start_0_end_20"
)

randomforest_records = {}
post_voting_accs = []
for model in [
    'Llama-3.2-1B-Instruct',
    'Llama-3.2-3B-Instruct',
    'Meta-Llama-3.1-8B-Instruct-Turbo',
    'Meta-Llama-3.1-70B-Instruct-Turbo'
]:
    randomforest_records[model] = {}
    for dataset in ['aqua', 'mmlu', 'commonsenseqa', 'strategyqa']: 
        randomforest_records[model][dataset] = {}
        methods = ['cot', 'l2m', 'tot', 'mcts']
        for method in methods:
            dataset_configs = {
                'model': model,
                'dataset': dataset,
                'method': method,
                'start_idx': 0,
                'end_idx': 50
            }

            if method in ['tot', 'mcts']:
                root = "exp-data-searching"
            else:
                root = "exp-data-scale_full"
            
            processed_data_path = f"./testing_data/randomForest/{dataset_configs['model']}-{dataset_configs['dataset']}-{dataset_configs['method']}_start-{dataset_configs['start_idx']}_end-{dataset_configs['end_idx']}.pkl"
            if os.path.exists(processed_data_path):
                print(f"==> Loading {model}--{method}--{dataset}...")
                (list_distance_matrix, list_num_chains, list_num_thoughts_each_chain,
                    list_coordinates_2d, list_answers, list_answer_gt_short, list_normed_A) = load(
                        f"./testing_data/randomForest/{dataset_configs['model']}-{dataset_configs['dataset']}-{dataset_configs['method']}_start-{dataset_configs['start_idx']}_end-{dataset_configs['end_idx']}.pkl"
                    )
            else:
                (
                    list_distance_matrix, list_num_chains, list_num_thoughts_each_chain,
                    list_coordinates_2d, list_answers, list_answer_gt_short, list_normed_A
                ) = load_sample_data(
                    model=dataset_configs['model'],
                    dataset=dataset_configs['dataset'],
                    method=dataset_configs['method'],
                    root=root,
                    start_sample_idx=dataset_configs['start_idx'],
                    end_sample_idx=dataset_configs['end_idx']
                )
                dump((list_distance_matrix, list_num_chains, list_num_thoughts_each_chain,
                    list_coordinates_2d, list_answers, list_answer_gt_short, list_normed_A), processed_data_path)
    
            # Calculate LLM accuracies
            voting_acc = vote_accuracy(list_answers, list_answer_gt_short)
            average_acc = row_wise_accuracy(list_answers, list_answer_gt_short)
            print(f"==> LLM Voting Acc: {voting_acc:.2f}")
            # print(f"==> LLM Average Acc: {average_acc:.2f}")

            # Initialize accumulators for this configuration
            config_voting_accs = []
            config_avg_accs = []
            config_predictions = []
            config_raw_predictions = []
            config_true_labels = []
            config_confidences = []

            # Process each sample
            for sample_idx in range(len(list_num_chains)):
                # Load chain data
                # we have K chain persample
                _, test_chain_matrix = load_chain_data(
                    sample_idx=sample_idx,
                    list_num_chains=list_num_chains,
                    list_num_thoughts_each_chain=list_num_thoughts_each_chain,
                    list_coordinates_2d=list_coordinates_2d,
                    list_distance_matrix=list_distance_matrix,
                    list_answers=list_answers,
                    list_answer_gt_short=list_answer_gt_short,
                    mode='reg',
                )

                # Evaluate using random forest
                results = eval_random_forest(
                    data=test_chain_matrix,
                    meta_info={
                        'dataset': dataset_configs['dataset'],
                        'model': dataset_configs['model'],
                        'method': dataset_configs['method'],
                    },
                    x_scaler=randomforest_X_scaler,
                    y_scaler=randomforest_y_scaler,
                    models=randomforest_models,
                    model_pred_mode='reg'
                )

                # Collect results
                config_voting_accs.append(results['voting_acc'])
                config_avg_accs.append(results['avg_acc'])
                config_true_labels.extend(results['y'])
                config_predictions.extend(results['predictions'])
                config_raw_predictions.append(results['raw_predicitons'])

                # using prediciton as weight/mask for voting with LLM prediction
                post_model_verifys = results['predictions']
                llm_predictions = list_answers[sample_idx]
                ground_truth = list_answer_gt_short[sample_idx]
                result = [x for x, m in zip(llm_predictions, post_model_verifys) if m]
                if len(result):
                    post_voting_accs.append(max(set(result), key=result.count) == ground_truth)
                else:
                    post_voting_accs.append(0)
            
            randomforest_records[model][dataset][method] = {
                'llm voting_acc': voting_acc,
                'llm avg_acc': average_acc,
                'post voting_acc': np.mean(post_voting_accs),
                'post verify_acc': np.mean(config_avg_accs),
                'predictions': config_predictions,
                'raw_predictions': config_raw_predictions,
                'llm_predictions': list_answers,
                'list_answer_gt_short': list_answer_gt_short
            }

            # Print configuration results
            print(f"\nConfiguration Results:")
            # print(f"Voting Accuracy: {np.mean(post_voting_accs):.4f}")
            print(f"RandomForest Binary Accuracy: {np.mean(config_avg_accs):.4f}")

## Process voting

$\text{vote}(A, W) = \underset{c \in C}{\arg\max} \text{ } \lambda \sum_{i=1}^n w_i \cdot \mathbb{1}[a_i = c] + \beta \cdot \text{consistency} + C$

where:

- $A = {a_1, ..., a_n}$ is the set of answers
- $W = {w_1, ..., w_n}$ is the set of Neural Verifier prediction confidence
- $C$ is the set of unique choices in $A$
- $\mathbb{1}[a_i = c]$ is the indicator function that equals 1 if $a_i = c$ and 0 otherwise

In [None]:
'''
randomforest_records
    ├── Llama-3.2-1B-Instruct
    ├── aqua
    │   ├── cot
    │   │   ├── llm voting_acc
    │   │   ├── llm avg_acc
    │   │   ├── post voting_acc
    │   │   ├── post verify_acc
    │   │   ├── predictions
    │   │   ├── raw_predictions <== (sample, chain, n_anchors, 2 (T/F))
    │   │   ├── llm_predictions 
    │   │   └── list_answer_gt_short
    │   ├── l2m

llm_consistency_records
├── Llama-3.2-1B-Instruct
│   ├── aqua
│   │   ├── cot
│   │   │   ├── 0 sample
│   │   │   │   ├── 0 chain
│   │   │   │   │   ├── acc
│   │   │   │   │   ├── conf
│   │   │   │   │   └── correctness
│   │   │   │   ├── 1
│   │   │   │   │   ├── acc
│   │   │   │   │   ├── conf
│   │   │   │   │   └── correctness
'''


for sample_idx in range(50):
    list_chain_consistency = []
    for chain_idx in range(10):
        list_chain_consistency.append(llm_consistency_records['Llama-3.2-1B-Instruct']['aqua']['cot'][sample_idx].get(chain_idx, {}).get('conf', 0))


In [None]:
from einops import rearrange, reduce

def weighted_vote(answers, weights):
    # processed_answers = ['A' if x == '' else x for x in answers]
    weighted_votes = {}
    for ans, w in zip(answers, weights):
        weighted_votes[ans] = weighted_votes.get(ans, 0) + w
    return max(weighted_votes.items(), key=lambda x: x[1])[0]

def voting_with_consistency(dataset_name, method, randomforest_records, llm_consistency_records, LAMBDA, BETA, C):
    methods = [method]
    for model in sorted(randomforest_records.keys()):
        for method_idx, method in enumerate(methods):
            if method in randomforest_records[model][dataset_name]:
                print(f'==> {dataset_name}-{model}-{method}')
                method_preds = randomforest_records[model][dataset_name][method]['raw_predictions']

                mean_preds = []
                for method_pred in method_preds:
                    mean_pred = reduce(method_pred, 'n_samples n_predictors n_classes -> n_samples n_classes', 'mean')
                    reweighting_pred = mean_pred[:, 1] / (mean_pred[:, 0] + mean_pred[:, 1]) # NOTE: True / (Flase + True)
                    mean_preds.append(reweighting_pred)

                accs = []
                
                for sample_idx, verifier_weights in enumerate(mean_preds):
                    llm_answers = randomforest_records[model][dataset_name][method]['llm_predictions'][sample_idx]
                    # NOTE: here is the voting mechanism

                    chain_consistency = []
                    for chain_idx in range(len(verifier_weights)):
                        chain_consistency.append(llm_consistency_records[model][dataset_name][method][sample_idx].get(chain_idx, {}).get('conf', 0))
                    # print(len(verifier_weights), len(chain_consistency))
                    chain_weight = LAMBDA * np.array(verifier_weights) + BETA * np.array(chain_consistency) + C
                    # print(chain_weight)
                    gt_answer = randomforest_records[model][dataset_name][method]['list_answer_gt_short'][sample_idx]
                    accs.append(weighted_vote(llm_answers, chain_weight) == gt_answer)

                print(np.mean(accs))


dataset_name = 'mmlu'
method = 'tot'
LAMBDA = 0.5
BETA = 0.5
C = 1

voting_with_consistency(dataset_name, method, randomforest_records, llm_consistency_records, LAMBDA, BETA, C)

## Inference scaling

In [None]:
from joblib import load

from utils.train_utils import *
import numpy as np 

def vote_accuracy_K_chain(list_answers, list_answer_gt_short, K):
    truncated_answers = [row[:K] for row in list_answers]
    voted_answers = [Counter(row).most_common(1)[0][0] if row else '' 
                    for row in truncated_answers]
    return sum(v == g for v, g in zip(voted_answers, list_answer_gt_short)) / len(list_answer_gt_short)

def row_wise_accuracy_K_chain(list_answers, list_answer_gt_short, K):
    truncated_answers = [row[:K] for row in list_answers]
    row_accs = [sum(ans == gt for ans in row if ans != '') / len([a for a in row if a != ''])
                if any(a != '' for a in row) else 0.0
                for row, gt in zip(truncated_answers, list_answer_gt_short)]
    return sum(row_accs) / len(row_accs)

dataset_configs = "all"
(
    randomforest_models, randomforest_X_scaler, 
    randomforest_y_scaler, llm_train_voting_acc, llm_train_avg_acc, randomforest_mean_std
) = load_models_and_metrics(
    f"ckpts/{dataset_configs}_start_0_end_20"
)

# ! inference scaling factor
NUM_THOUGHTS = 35

records = {}
post_voting_accs = []
for model in [
    'Llama-3.2-3B-Instruct-Turbo',
    'Meta-Llama-3.1-8B-Instruct-Turbo',
]:
    records[model] = {}
    for dataset in ['strategyqa']: 
        records[model][dataset] = {}
        methods = ['cot']
        for method in methods:
            dataset_configs = {
                'model': model,
                'dataset': dataset,
                'method': method,
                'start_idx': 0,
                'end_idx': 50
            }

            root = "./exp-data-inference-scaling"
            print(f"==> Loading {model}--{method}--{dataset}...")
            # (
            #     list_distance_matrix, list_num_chains, list_num_thoughts_each_chain,
            #     list_coordinates_2d, list_answers, list_answer_gt_short, list_normed_A
            # ) = load_sample_data(
            #     model=dataset_configs['model'],
            #     dataset=dataset_configs['dataset'],
            #     method=dataset_configs['method'],
            #     root=root,
            #     start_sample_idx=dataset_configs['start_idx'],
            #     end_sample_idx=dataset_configs['end_idx']
            # )
            (
                list_distance_matrix, list_num_chains, list_num_thoughts_each_chain,
                list_coordinates_2d, list_answers, list_answer_gt_short, list_normed_A
            ) = load(f'exp-data-inference-scaling/strategyqa/preprocess_data/{model}_{method}_{dataset}.pkl')
            # dump([list_distance_matrix, list_num_chains, list_num_thoughts_each_chain,
            #     list_coordinates_2d, list_answers, list_answer_gt_short, list_normed_A],
            #     f'exp-data-inference-scaling/strategyqa/preprocess_data/{model}_{method}_{dataset}.pkl')

            print(f'==> Inference scaling for {NUM_THOUGHTS} thoughts')
            list_num_chains = [NUM_THOUGHTS] * len(list_num_chains)
            # Calculate LLM accuracies
            voting_acc = vote_accuracy_K_chain(list_answers, list_answer_gt_short, NUM_THOUGHTS)
            average_acc = row_wise_accuracy_K_chain(list_answers, list_answer_gt_short, NUM_THOUGHTS)
            print(f"==> LLM Voting Acc: {voting_acc:.2f}")
            print(f"==> LLM Average Acc: {average_acc:.2f}")

            # Initialize accumulators for this configuration
            config_voting_accs = []
            config_avg_accs = []
            config_predictions = []
            config_raw_predictions = []
            config_true_labels = []
            config_confidences = []

            # Process each sample
            for sample_idx in range(len(list_num_chains)):
                # Load chain data
                # we have K chain persample
                _, test_chain_matrix = load_chain_data(
                    sample_idx=sample_idx,
                    list_num_chains=list_num_chains,
                    list_num_thoughts_each_chain=list_num_thoughts_each_chain,
                    list_coordinates_2d=list_coordinates_2d,
                    list_distance_matrix=list_distance_matrix,
                    list_answers=list_answers,
                    list_answer_gt_short=list_answer_gt_short,
                    mode='reg',
                )

                # Evaluate using random forest
                results = eval_random_forest(
                    data=test_chain_matrix,
                    meta_info={
                        'dataset': dataset_configs['dataset'],
                        'model': dataset_configs['model'],
                        'method': dataset_configs['method'],
                    },
                    x_scaler=randomforest_X_scaler,
                    y_scaler=randomforest_y_scaler,
                    models=randomforest_models,
                    model_pred_mode='reg'
                )

                # Collect results
                config_voting_accs.append(results['voting_acc'])
                config_avg_accs.append(results['avg_acc'])
                config_true_labels.extend(results['y'])
                config_predictions.extend(results['predictions'])
                config_raw_predictions.append(results['raw_predicitons'])

                # using prediciton as weight/mask for voting with LLM prediction
                post_model_verifys = results['predictions']
                llm_predictions = list_answers[sample_idx]
                ground_truth = list_answer_gt_short[sample_idx]
                result = [x for x, m in zip(llm_predictions, post_model_verifys) if m]
                if len(result):
                    post_voting_accs.append(max(set(result), key=result.count) == ground_truth)
                else:
                    post_voting_accs.append(0)
            
            records[model][dataset][method] = {
                'llm voting_acc': voting_acc,
                'llm avg_acc': average_acc,
                'post voting_acc': np.mean(post_voting_accs),
                'post verify_acc': np.mean(config_avg_accs),
                'predictions': config_predictions,
                'raw_predictions': config_raw_predictions,
                'llm_predictions': list_answers,
                'list_answer_gt_short': list_answer_gt_short
            }

            # Print configuration results
            print(f"\nConfiguration Results:")
            print(f"Voting Accuracy: {np.mean(post_voting_accs):.4f}")
            print(f"Average Accuracy: {np.mean(config_avg_accs):.4f}")

- 1 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 1 thoughts
==> LLM Voting Acc: 0.32
==> LLM Average Acc: 0.32

Configuration Results:
Voting Accuracy: 0.1200
Average Accuracy: 0.7000
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 1 thoughts
==> LLM Voting Acc: 0.24
==> LLM Average Acc: 0.24

Configuration Results:
Voting Accuracy: 0.1300
Average Accuracy: 0.8000

- 5 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 5 thoughts
==> LLM Voting Acc: 0.36
==> LLM Average Acc: 0.37

Configuration Results:
Voting Accuracy: 0.2200
Average Accuracy: 0.6360
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 5 thoughts
==> LLM Voting Acc: 0.26
==> LLM Average Acc: 0.32

Configuration Results:
Voting Accuracy: 0.3600
Average Accuracy: 0.8040

- 10 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 10 thoughts
==> LLM Voting Acc: 0.30
==> LLM Average Acc: 0.37

Configuration Results:
Voting Accuracy: 0.3400
Average Accuracy: 0.6440
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 10 thoughts
==> LLM Voting Acc: 0.30
==> LLM Average Acc: 0.33

Configuration Results:
Voting Accuracy: 0.4300
Average Accuracy: 0.7880

- 15 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 15 thoughts
==> LLM Voting Acc: 0.40
==> LLM Average Acc: 0.39

Configuration Results:
Voting Accuracy: 0.4400
Average Accuracy: 0.6413
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 15 thoughts
==> LLM Voting Acc: 0.22
==> LLM Average Acc: 0.34

Configuration Results:
Voting Accuracy: 0.4900
Average Accuracy: 0.7907

- 20 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 20 thoughts
==> LLM Voting Acc: 0.30
==> LLM Average Acc: 0.38

Configuration Results:
Voting Accuracy: 0.4600
Average Accuracy: 0.6410
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 20 thoughts
==> LLM Voting Acc: 0.22
==> LLM Average Acc: 0.33

Configuration Results:
Voting Accuracy: 0.5200
Average Accuracy: 0.7970

- 25 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 25 thoughts
==> LLM Voting Acc: 0.28
==> LLM Average Acc: 0.37

Configuration Results:
Voting Accuracy: 0.5200
Average Accuracy: 0.6528
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 25 thoughts
==> LLM Voting Acc: 0.26
==> LLM Average Acc: 0.33

Configuration Results:
Voting Accuracy: 0.5500
Average Accuracy: 0.8032

- 30 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 30 thoughts
==> LLM Voting Acc: 0.26
==> LLM Average Acc: 0.37

Configuration Results:
Voting Accuracy: 0.5600
Average Accuracy: 0.6613
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 30 thoughts
==> LLM Voting Acc: 0.24
==> LLM Average Acc: 0.33

Configuration Results:
Voting Accuracy: 0.5800
Average Accuracy: 0.8047

- 35 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 35 thoughts
==> LLM Voting Acc: 0.24
==> LLM Average Acc: 0.37

Configuration Results:
Voting Accuracy: 0.6000
Average Accuracy: 0.6651
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 35 thoughts
==> LLM Voting Acc: 0.28
==> LLM Average Acc: 0.33

Configuration Results:
Voting Accuracy: 0.6200
Average Accuracy: 0.8017

- 40 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 40 thoughts
==> LLM Voting Acc: 0.24
==> LLM Average Acc: 0.37

Configuration Results:
Voting Accuracy: 0.6800
Average Accuracy: 0.6685
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 40 thoughts
==> LLM Voting Acc: 0.26
==> LLM Average Acc: 0.33

Configuration Results:
Voting Accuracy: 0.6600
Average Accuracy: 0.8010

- 45 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 45 thoughts
==> LLM Voting Acc: 0.24
==> LLM Average Acc: 0.37

Configuration Results:
Voting Accuracy: 0.6600
Average Accuracy: 0.6698
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 45 thoughts
==> LLM Voting Acc: 0.26
==> LLM Average Acc: 0.33

Configuration Results:
Voting Accuracy: 0.6600
Average Accuracy: 0.8022

- 50 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 45 thoughts
==> LLM Voting Acc: 0.24
==> LLM Average Acc: 0.37

Configuration Results:
Voting Accuracy: 0.6600
Average Accuracy: 0.6698
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 45 thoughts
==> LLM Voting Acc: 0.26
==> LLM Average Acc: 0.33

Configuration Results:
Voting Accuracy: 0.6600
Average Accuracy: 0.8022

In [19]:
# 输入数据
data = """
- 1 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 1 thoughts
==> LLM Voting Acc: 0.32
==> LLM Average Acc: 0.32

Configuration Results:
Voting Accuracy: 0.1200
Average Accuracy: 0.7000
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 1 thoughts
==> LLM Voting Acc: 0.24
==> LLM Average Acc: 0.24

Configuration Results:
Voting Accuracy: 0.1300
Average Accuracy: 0.8000

- 5 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 5 thoughts
==> LLM Voting Acc: 0.36
==> LLM Average Acc: 0.37

Configuration Results:
Voting Accuracy: 0.2200
Average Accuracy: 0.6360
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 5 thoughts
==> LLM Voting Acc: 0.26
==> LLM Average Acc: 0.32

Configuration Results:
Voting Accuracy: 0.3600
Average Accuracy: 0.8040

- 10 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 10 thoughts
==> LLM Voting Acc: 0.30
==> LLM Average Acc: 0.37

Configuration Results:
Voting Accuracy: 0.3400
Average Accuracy: 0.6440
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 10 thoughts
==> LLM Voting Acc: 0.30
==> LLM Average Acc: 0.33

Configuration Results:
Voting Accuracy: 0.4300
Average Accuracy: 0.7880

- 15 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 15 thoughts
==> LLM Voting Acc: 0.40
==> LLM Average Acc: 0.39

Configuration Results:
Voting Accuracy: 0.4400
Average Accuracy: 0.6413
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 15 thoughts
==> LLM Voting Acc: 0.22
==> LLM Average Acc: 0.34

Configuration Results:
Voting Accuracy: 0.4900
Average Accuracy: 0.7907

- 20 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 20 thoughts
==> LLM Voting Acc: 0.30
==> LLM Average Acc: 0.38

Configuration Results:
Voting Accuracy: 0.4600
Average Accuracy: 0.6410
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 20 thoughts
==> LLM Voting Acc: 0.22
==> LLM Average Acc: 0.33

Configuration Results:
Voting Accuracy: 0.5200
Average Accuracy: 0.7970

- 25 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 25 thoughts
==> LLM Voting Acc: 0.28
==> LLM Average Acc: 0.37

Configuration Results:
Voting Accuracy: 0.5200
Average Accuracy: 0.6528
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 25 thoughts
==> LLM Voting Acc: 0.26
==> LLM Average Acc: 0.33

Configuration Results:
Voting Accuracy: 0.5500
Average Accuracy: 0.8032

- 30 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 30 thoughts
==> LLM Voting Acc: 0.26
==> LLM Average Acc: 0.37

Configuration Results:
Voting Accuracy: 0.5600
Average Accuracy: 0.6613
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 30 thoughts
==> LLM Voting Acc: 0.24
==> LLM Average Acc: 0.33

Configuration Results:
Voting Accuracy: 0.5800
Average Accuracy: 0.8047

- 35 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 35 thoughts
==> LLM Voting Acc: 0.24
==> LLM Average Acc: 0.37

Configuration Results:
Voting Accuracy: 0.6000
Average Accuracy: 0.6651
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 35 thoughts
==> LLM Voting Acc: 0.28
==> LLM Average Acc: 0.33

Configuration Results:
Voting Accuracy: 0.6200
Average Accuracy: 0.8017

- 40 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 40 thoughts
==> LLM Voting Acc: 0.24
==> LLM Average Acc: 0.37

Configuration Results:
Voting Accuracy: 0.6800
Average Accuracy: 0.6685
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 40 thoughts
==> LLM Voting Acc: 0.26
==> LLM Average Acc: 0.33

Configuration Results:
Voting Accuracy: 0.6600
Average Accuracy: 0.8010

- 45 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 45 thoughts
==> LLM Voting Acc: 0.24
==> LLM Average Acc: 0.37

Configuration Results:
Voting Accuracy: 0.6600
Average Accuracy: 0.6698
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 45 thoughts
==> LLM Voting Acc: 0.26
==> LLM Average Acc: 0.33

Configuration Results:
Voting Accuracy: 0.6600
Average Accuracy: 0.8022

- 50 thoughts
==> Loading Llama-3.2-3B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 45 thoughts
==> LLM Voting Acc: 0.24
==> LLM Average Acc: 0.37

Configuration Results:
Voting Accuracy: 0.6600
Average Accuracy: 0.6698
==> Loading Meta-Llama-3.1-8B-Instruct-Turbo--cot--strategyqa...
==> Inference scaling for 45 thoughts
==> LLM Voting Acc: 0.26
==> LLM Average Acc: 0.33

Configuration Results:
Voting Accuracy: 0.6600
Average Accuracy: 0.8022
"""

# 解析数据
import re

# 正则表达式匹配关键信息
pattern = re.compile(
    r"==> Loading (.*?)\n"
    r".*?"
    r"==> LLM Voting Acc: (.*?)\n"
    r"==> LLM Average Acc: (.*?)\n"
    r".*?"
    r"Voting Accuracy: (.*?)\n"
    r"Average Accuracy: (.*?)\n",
    re.DOTALL
)

# 提取数据
results = []
for match in pattern.finditer(data):
    config = match.group(1)  # 配置名称
    llm_voting_acc = float(match.group(2))  # LLM Voting Accuracy
    llm_avg_acc = float(match.group(3))  # LLM Average Accuracy
    voting_acc = float(match.group(4))  # Voting Accuracy
    avg_acc = float(match.group(5))  # Average Accuracy

    results.append({
        "Config": config,
        "LLM Voting Acc": llm_voting_acc,
        "LLM Average Acc": llm_avg_acc,
        "Voting Accuracy": max(voting_acc, llm_voting_acc),
        "Average Accuracy": avg_acc
    })

# 打印整理后的数据
# for result in results:
#     print(f"Config: {result['Config']}")
#     print(f"  LLM Voting Acc: {result['LLM Voting Acc']}")
#     # print(f"  LLM Average Acc: {result['LLM Average Acc']}")
#     print(f"  Voting Accuracy: {result['Voting Accuracy']}")
#     # print(f"  Average Accuracy: {result['Average Accuracy']}")

In [23]:
import plotly.graph_objects as go

# 提取数据
number_of_thoughts = [1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50]  # X 轴数据
# 提取 Llama-3.2-3B-Instruct-Turbo 的数据
llama_voting_acc = [result['Voting Accuracy'] for result in results if 'Llama-3.2' in result['Config']]
llama_llm_voting_acc = [result['LLM Voting Acc'] for result in results if 'Llama-3.2' in result['Config']]

# 提取 Meta-Llama-3.1-8B-Instruct-Turbo 的数据
meta_llama_voting_acc = [result['Voting Accuracy'] for result in results if 'Meta-Llama' in result['Config']]
meta_llama_llm_voting_acc = [result['LLM Voting Acc'] for result in results if 'Meta-Llama' in result['Config']]

# 创建 Plotly 图表
fig = go.Figure()

# 添加 Llama-3.2-3B-Instruct-Turbo 的 Voting Accuracy 曲线
fig.add_trace(go.Scatter(
    x=number_of_thoughts,
    y=llama_voting_acc,
    mode='lines+markers',
    name='3B (with verifier)  ',
    line=dict(color='#7DAEE0', width=2),
    marker=dict(size=10, symbol='circle')
))

# 添加 Llama-3.2-3B-Instruct-Turbo 的 LLM Voting Acc 曲线
fig.add_trace(go.Scatter(
    x=number_of_thoughts,
    y=llama_llm_voting_acc,
    mode='lines+markers',
    name='3B',
    line=dict(color='#7DAEE0', width=2, dash='dash'),
    marker=dict(size=10, symbol='circle')
))

# 添加 Meta-Llama-3.1-8B-Instruct-Turbo 的 Voting Accuracy 曲线
fig.add_trace(go.Scatter(
    x=number_of_thoughts,
    y=meta_llama_voting_acc,
    mode='lines+markers',
    name='8B (with verifier)  ',
    line=dict(color='#EA8379', width=2),
    marker=dict(size=10, symbol='square')
))

# 添加 Meta-Llama-3.1-8B-Instruct-Turbo 的 LLM Voting Acc 曲线
fig.add_trace(go.Scatter(
    x=number_of_thoughts,
    y=meta_llama_llm_voting_acc,
    mode='lines+markers',
    name='8B',
    line=dict(color='#EA8379', width=2, dash='dash'),
    marker=dict(size=10, symbol='square')
))

# 更新布局
fig.update_layout(
    xaxis_title='Number of Reasoning Times',
    yaxis_title='Accuracy',
    xaxis=dict(tickfont=dict(size=35), title_font=dict(size=40), tickvals=number_of_thoughts, ticktext=[str(x) for x in number_of_thoughts]),
    yaxis=dict(tickfont=dict(size=35), title_font=dict(size=40), range=[0, 0.75], tickformat=".0%", dtick=0.15),  # 设置 y 轴范围
    # legend=dict(x=0.02, y=0.98),  # 图例位置
    legend=dict(
        font=dict(size=35),
        orientation="h",  # Make legend horizontal
        yanchor="bottom",
        y=1.03,  # Position above the plot
        xanchor="center",  # Center horizontally
        x=0.5,
        itemsizing='constant',
        itemwidth=80  # Adjust spacing between legend items
    ),
    margin=dict(l=10, r=20, t=5, b=5),  # Remove margins
    width=1000,  # 图表宽度
    height=500  # 图表高度
)
# 模板配置 - 强制所有颜色为黑色
template = dict(
    layout=dict(
        font_color="black",
        paper_bgcolor="white",
        plot_bgcolor="white",
        title_font_color="black",
        legend_font_color="black",
        
        xaxis=dict(
            title_font_color="black",
            tickfont_color="black",
            linecolor="black",
            gridcolor="lightgray",
            zerolinecolor="black",
        ),
        
        yaxis=dict(
            title_font_color="black", 
            tickfont_color="black",
            linecolor="black",
            gridcolor="lightgray",
            zerolinecolor="black",
        ),
        
        hoverlabel=dict(
            font_color="black",
            bgcolor="white"
        ),
        
        annotations=[dict(font_color="black")],
        shapes=[dict(line_color="black")],
        
        coloraxis=dict(
            colorbar_tickfont_color="black",
            colorbar_title_font_color="black"
        ),
    )
)

# 应用配置
fig.update_layout(template=template)



# 显示图表
fig.show()
import plotly.io as pio
pio.write_image(fig, f'figures/abls/inference_scaling.pdf', scale=6)

In [None]:
import plotly.graph_objects as go

# 数据
number_of_thoughts = [1, 5, 10, 15, 20, 25]

# Llama-3.2-3B-Instruct-Turbo 的 Voting Accuracy
llama_voting_acc = [0.12, 0.22, 0.34, 0.44, 0.46, 0.52]

# Meta-Llama-3.1-8B-Instruct-Turbo 的 Voting Accuracy
meta_llama_voting_acc = [0.13, 0.36, 0.43, 0.49, 0.52, 0.55]

# 创建 Plotly 图表
fig = go.Figure()

# 添加 Llama-3.2-3B-Instruct-Turbo 的曲线
fig.add_trace(go.Scatter(
    x=number_of_thoughts,
    y=llama_voting_acc,
    mode='lines+markers',
    name='Llama-3.2-3B-Instruct-Turbo',
    line=dict(color='blue', width=2),
    marker=dict(size=10, symbol='circle')
))

# 添加 Meta-Llama-3.1-8B-Instruct-Turbo 的曲线
fig.add_trace(go.Scatter(
    x=number_of_thoughts,
    y=meta_llama_voting_acc,
    mode='lines+markers',
    name='Meta-Llama-3.1-8B-Instruct-Turbo',
    line=dict(color='red', width=2),
    marker=dict(size=10, symbol='square')
))

# 更新布局
fig.update_layout(
    title='Voting Accuracy vs Number of Thoughts',
    xaxis_title='Number of Thoughts',
    yaxis_title='Voting Accuracy',
    xaxis=dict(tickvals=number_of_thoughts, ticktext=[str(x) for x in number_of_thoughts]),
    yaxis=dict(range=[0, 0.6]),  # 设置 y 轴范围
    legend=dict(x=0.02, y=0.98),  # 图例位置
    template='plotly_white',  # 使用白色背景模板
    width=800,  # 图表宽度
    height=500  # 图表高度
)

# 显示图表
fig.show()

## Analysis

In [None]:
import plotly.subplots as sp
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
from collections import Counter
def plot_prediction_distribution(records, dataset_name, subplot_titles=None, height=800, width=1200):
    """
    Plot prediction distribution for a specific dataset across different models and methods.
    
    Parameters:
    -----------
    records : dict
        Nested dictionary with structure:
        {model: {dataset: {method: {'predictions': [...]}}}}
    dataset_name : str
        Name of the dataset to plot
    subplot_titles : list, optional
        List of titles for subplots. If None, uses model names
    height : int, optional
        Height of the plot in pixels
    width : int, optional
        Width of the plot in pixels
        
    Returns:
    --------
    plotly.graph_objects.Figure
    """
    
    # Create 2x2 subplots
    if subplot_titles is None:
        subplot_titles = sorted(records.keys())
        
    fig = sp.make_subplots(
        rows=2, cols=2,
        subplot_titles=subplot_titles
    )

    positions = [(1,1), (1,2), (2,1), (2,2)]
    
    # List of all possible methods
    methods = ['cot', 'zero_shot_cot', 'l2m', 'tot', 'mcts']
    
    # Get plotly express qualitative color sequence
    colors = px.colors.qualitative.Set1

    for model, pos in zip(sorted(records.keys()), positions):
        if dataset_name in records[model]:
            for method_idx, method in enumerate(methods):
                if method in records[model][dataset_name]:
                    method_preds = records[model][dataset_name][method]['predictions']
                    
                    print(f'==> {model}-{dataset_name}-{method}')
                    print(Counter(method_preds))
                    if method_preds:
                        hist = np.histogram(method_preds, bins=[0, 0.5, 1])[0]
                        percentages = (hist / len(method_preds)) * 100
                        
                        fig.add_trace(
                            go.Bar(
                                x=['0', '1'],
                                y=percentages,
                                name=method,
                                text=[f'{p:.1f}%' for p in percentages],
                                textposition='auto',
                                marker_color=colors[method_idx],
                                showlegend=(pos == (1,1))  # Only show legend for first subplot
                            ),
                            row=pos[0], col=pos[1]
                        )

    # Update layout
    fig.update_layout(
        height=height,
        width=width,
        barmode='group',
        title_text=f"Distribution of Predictions by Model and Method for {dataset_name}",
        template='simple_white',
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=1.0
        )
    )

    # Update axes labels and ranges
    for i in range(1,3):
        for j in range(1,3):
            fig.update_xaxes(title_text="Prediction", row=i, col=j)
            fig.update_yaxes(title_text="Percentage (%)", row=i, col=j, range=[0, 100])

    return fig

# Example usage:
fig = plot_prediction_distribution(records, dataset_name='mmlu')
fig.show()

# Or with custom titles:
# custom_titles = ['Llama-3.2-1B', 'Llama-3.2-3B', 'Llama-3.1-8B', 'Llama-3.1-70B']
# fig = plot_prediction_distribution(records, dataset_name='gsm8k', subplot_titles=custom_titles)
# fig.show()

In [None]:
def create_markdown_table(records):
    # 创建表头
    markdown = "# Performance Results\n\n"
    
    # 创建主表格
    markdown += "| Model | Dataset | Method | LLM Voting Acc | LLM Avg Acc | Post Voting Acc | Post Verify Acc |\n"
    markdown += "|-------|---------|---------|---------------|-------------|----------------|----------------|\n"
    
    # 添加数据行
    for model in sorted(records.keys()):
        for dataset in sorted(records[model].keys()):
            for method in sorted(records[model][dataset].keys()):
                results = records[model][dataset][method]
                markdown += f"| {model} | {dataset} | {method} | "
                markdown += f"{results['llm voting_acc']:.3f} | "
                markdown += f"{results['llm avg_acc']:.3f} | "
                markdown += f"{results['post voting_acc']:.3f} | "
                markdown += f"{results['post verify_acc']:.3f} |\n"
    
    # 添加平均性能
    markdown += "\n## Average Performance\n\n"
    avg_performance = calculate_average_performance(records)
    markdown += "| Metric | Average Value |\n"
    markdown += "|--------|---------------|\n"
    for metric, value in avg_performance.items():
        markdown += f"| {metric} | {value:.3f} |\n"
    
    return markdown

def calculate_average_performance(records):
    metrics = {
        'llm voting_acc': [],
        'llm avg_acc': [],
        'post voting_acc': [],
        'post verify_acc': []
    }
    
    for model in records:
        for dataset in records[model]:
            for method in records[model][dataset]:
                results = records[model][dataset][method]
                for metric in metrics:
                    metrics[metric].append(results[metric])
    
    return {k: np.mean(v) for k, v in metrics.items()}

# 使用方法：
markdown_table = create_markdown_table(records)
print(markdown_table)

In [None]:
import pandas as pd

# Read the markdown table
# Skip the first 2 rows (header and separator)
df = pd.read_table('randomForest_results.md', sep='|', skiprows=1)

# Clean up the DataFrame
# Remove empty columns (created by leading/trailing |)
df = df.iloc[:, 1:-1]
# Set column names
df.columns = ['Model', 'Dataset', 'Method', 'LLM Voting Acc', 'LLM Avg Acc', 'Post Voting Acc', 'Post Verify Acc']
# Strip whitespace from all string columns
df = df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)

dataset_order = ['aqua', 'mmlu', 'strategyqa', 'commonsenseqa']

# Convert accuracy columns to float
accuracy_cols = ['LLM Voting Acc', 'LLM Avg Acc', 'Post Voting Acc', 'Post Verify Acc']
for col in accuracy_cols:
    df[col] = pd.to_numeric(df[col])

# Function to filter data based on dataset and method
def filter_results(data, datasets=None, methods=None):
    """
    Filter results based on dataset and/or method
    
    Args:
        data: pandas DataFrame containing the results
        datasets: str or list of str, datasets to filter
        methods: str or list of str, methods to filter
    
    Returns:
        Filtered DataFrame
    """
    filtered_df = data.copy()
    
    if datasets:
        if isinstance(datasets, str):
            datasets = [datasets]
        filtered_df = filtered_df[filtered_df['Dataset'].isin(datasets)]
    
    if methods:
        if isinstance(methods, str):
            methods = [methods]
        filtered_df = filtered_df[filtered_df['Method'].isin(methods)]
    
    return filtered_df

# Example usage:
# Filter for specific dataset
# aqua_results = filter_results(df, datasets='aqua')
# print("\nResults for AQUA dataset:")
# print(aqua_results)

# Get the accuracy columns
# acc_columns = ['LLM Voting Acc', 'LLM Avg Acc', 'Post Voting Acc', 'Post Verify Acc']

# # Calculate means
# means = df[acc_columns].mean()

# print("Average scores:")
# for col, mean in means.items():
#     print(f"{col}: {mean:.3f}")


# Filter for specific method
cot_results = filter_results(df, methods='mcts', datasets='commonsenseqa')
cot_results

# Filter for multiple datasets and methods
# specific_results = filter_results(df, 
#                                 datasets=['aqua', 'mmlu'],
#                                 methods=['cot', 'zero_shot_cot'])
# print("\nResults for AQUA and MMLU datasets with CoT and Zero-shot CoT methods:")
# print(specific_results)

## Main Table

In [None]:
# Data structure with all values
data = {
    "CoT": [
        [[12.0, 20.0, 50.0, 20.0], [16.0, 20.0, 52.0, 26.0]],
        [[60.0, 40.0, 30.0, 56.0], [68.0, 46.0, 34.0, 56.0]],
        [[74.0, 66.0, 30.0, 68.0], [76.0, 66.0, 43.0, 68.0]],
        [[94.0, 86.0, 40.0, 68.0], [96.0, 86.0, 42.0, 68.0]]
    ],
    "LeastToMost": [
        [[18.0, 22.0, 32.0, 16.0], [22.0, 22.0, 32.0, 22.0]],
        [[34.0, 14.0, 36.0, 32.0], [38.0, 16.0, 36.0, 34.0]],
        [[76.0, 66.0, 24.0, 66.0], [78.0, 66.0, 38.0, 66.0]],
        [[96.0, 90.0, 32.0, 70.0], [96.0, 90.0, 32.0, 70.0]]
    ],
    "ToT": [
        [[34.0, 30.0, 50.0, 26.0], [34.0, 32.0, 58.0, 26.0]],
        [[38.0, 32.0, 14.0, 47.0], [38.0, 36.6, 14.0, 51.0]],
        [[64.0, 62.0, 12.0, 64.0], [66.0, 68.0, 14.0, 64.0]],
        [[98.0, 88.0, 42.0, 66.0], [98.0, 90.0, 44.0, 66.0]]
    ],
    "MCTS": [
        [[18.0, 22.0, 46.0, 30.0], [24.0, 22.0, 50.0, 30.0]],
        [[44.0, 28.0, 14.0, 54.0], [44.0, 30.0, 18.0, 54.0]],
        [[72.0, 60.0, 14.0, 56.0], [72.0, 60.0, 16.0, 56.0]],
        [[96.0, 84.0, 40.0, 68.0], [96.0, 84.0, 44.0, 68.0]]
    ]
}

def calculate_improvements(section_data):
    improvements = []
    for pair in section_data:
        base, verifier = pair
        # Calculate differences for each metric
        diffs = [v - b for b, v in zip(base, verifier)]
        improvements.extend(diffs)
    return improvements

# Calculate improvements for each section
all_improvements = []
for section, section_data in data.items():
    improvements = calculate_improvements(section_data)
    section_avg = sum(improvements) / len(improvements)
    print(f"{section} average improvement: {section_avg:.2f}%")
    all_improvements.extend(improvements)

# Calculate overall average improvement
overall_avg = sum(all_improvements) / len(all_improvements)
print(f"\nOverall average improvement: {overall_avg:.2f}%")

In [None]:
data = {
    "CoT": [
        [[12.0, 20.0, 50.0, 20.0], [16.0, 20.0, 52.0, 26.0]],
        [[60.0, 40.0, 30.0, 56.0], [68.0, 46.0, 34.0, 56.0]],
        [[74.0, 66.0, 30.0, 68.0], [76.0, 66.0, 43.0, 68.0]],
        [[94.0, 86.0, 40.0, 68.0], [96.0, 86.0, 42.0, 68.0]]
    ],
    "LeastToMost": [
        [[18.0, 22.0, 32.0, 16.0], [22.0, 22.0, 32.0, 22.0]],
        [[34.0, 14.0, 36.0, 32.0], [38.0, 16.0, 36.0, 34.0]],
        [[76.0, 66.0, 24.0, 66.0], [78.0, 66.0, 38.0, 66.0]],
        [[96.0, 90.0, 32.0, 70.0], [96.0, 90.0, 32.0, 70.0]]
    ],
    "ToT": [
        [[34.0, 30.0, 50.0, 26.0], [34.0, 32.0, 58.0, 26.0]],
        [[38.0, 32.0, 14.0, 47.0], [38.0, 36.6, 14.0, 51.0]],
        [[64.0, 62.0, 12.0, 64.0], [66.0, 68.0, 14.0, 64.0]],
        [[98.0, 88.0, 42.0, 66.0], [98.0, 90.0, 44.0, 66.0]]
    ],
    "MCTS": [
        [[18.0, 22.0, 46.0, 30.0], [24.0, 22.0, 50.0, 30.0]],
        [[44.0, 28.0, 14.0, 54.0], [44.0, 30.0, 18.0, 54.0]],
        [[72.0, 60.0, 14.0, 56.0], [72.0, 60.0, 16.0, 56.0]],
        [[96.0, 84.0, 40.0, 68.0], [96.0, 84.0, 44.0, 68.0]]
    ]
}

import plotly.graph_objects as go
import plotly.express as px

def create_performance_plot(data, methods, models):
    fig = go.Figure()
    
    # Use a nice color palette
    colors = px.colors.sequential.Teal[2:]

    
    # Calculate bar positions
    total_bars = len(models) * 2
    bar_width = 1 / (total_bars + 1)
    
    # For each model size
    for model_idx, model in enumerate(models):
        # Calculate offsets for original and verifier bars
        orig_offset = -0.5 + bar_width * (model_idx * 2 + 0.5)
        ver_offset = -0.5 + bar_width * (model_idx * 2 + 1.5)
        
        # Get data for this model size across all methods
        original_data = []
        verifier_data = []
        for method in methods:
            original_data.append(data[method][model_idx][0][0])
            verifier_data.append(data[method][model_idx][1][0])
        
        # Original performance
        fig.add_trace(go.Bar(
            name=model,
            x=methods,
            y=original_data,
            marker_color=colors[model_idx],
            opacity=0.7,
            width=bar_width,
            offset=orig_offset,
            showlegend=True,
            text=[f"{x:.0f}%" for x in original_data],  # Add percentage labels
            textposition='outside',
            textfont=dict(size=16)
        ))
        
        # Verifier performance
        fig.add_trace(go.Bar(
            name=f"{model} Verifier",
            x=methods,
            y=verifier_data,
            marker=dict(
                pattern=dict(
                    shape='/',
                    solidity=0.7,
                    size=10,
                    fgcolor='white',
                    bgcolor=colors[model_idx],
                ),
                color=colors[model_idx],
            ),
            width=bar_width,
            offset=ver_offset,
            showlegend=False,
            text=[f"{x:.0f}%" for x in verifier_data],  # Add percentage labels
            textposition='outside',
            textfont=dict(size=16)
        ))
    
    # Add legend items for LLM and Verifier
    fig.add_trace(go.Bar(
        name="LLM",
        x=[None],
        y=[None],
        marker_color='black',
        showlegend=True,
        visible='legendonly'
    ))
    
    fig.add_trace(go.Bar(
        name="Verifier",
        x=[None],
        y=[None],
        marker=dict(
            pattern=dict(
                shape='/',
                solidity=0.7,
                size=10,
                fgcolor='white',
                bgcolor='black',
            ),
            color='black',
        ),
        showlegend=True,
        visible='legendonly'
    ))

    # Update layout
    fig.update_layout(
        width=1000,
        height=400,
        template='plotly_white',
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            font=dict(
                color='black',
                size=25,
            ),
            xanchor="right",
            x=1
        ),
        margin=dict(l=5, r=5, t=35, b=5),  # Remove margins
        barmode='overlay',
        yaxis=dict(
            tickformat='%',  # Format y-axis ticks as percentages
            ticktext=[f'{x}%' for x in range(0, 101, 20)],  # Custom tick labels
            tickvals=[x for x in range(0, 101, 20)],  # Custom tick positions
        )
    )
    
    # Update axes
    fig.update_yaxes(
        range=[0, 120],  # Set range from 0% to 100%
        zeroline=True,
        zerolinewidth=1,
        tickfont=dict(size=25),
        zerolinecolor='black',
        gridcolor='lightgray',
        dtick=40,  # Show ticks every 20%
    )
    
    fig.update_xaxes(
        tickfont=dict(size=25),
        gridcolor='lightgray'
    )

    return fig

# Create and show the plot
methods = ["CoT", "LeastToMost", "ToT", "MCTS"]
models = ["1B", "3B", "8B", "70B"]

fig = create_performance_plot(data, methods, models)
fig.show()
fig.write_image("figures/performance_plot.pdf", scale=6)

In [17]:
import plotly.graph_objects as go
import plotly.express as px
import numpy as np

def calculate_averages(data):
    avg_data = {}
    for method in data:
        avg_data[method] = []
        for model_idx in range(len(data[method])):
            # Calculate average for original and verifier
            orig_avg = np.mean(data[method][model_idx][0])
            ver_avg = np.mean(data[method][model_idx][1])
            avg_data[method].append([[orig_avg], [ver_avg]])
    return avg_data

def create_performance_plot(data, methods, models, dataset_idx=None):
    fig = go.Figure()
    colors = ['#AAD09D', '#AFC7E8', '#FF9896', '#D2BCDE'] # px.colors.sequential.Teal[2:]
    
    # Calculate bar positions
    total_bars = len(models) * 2
    bar_width = 1 / (total_bars + 1) # + 0.005
    
    for model_idx, model in enumerate(models):
        orig_offset = -0.5 + bar_width * (model_idx * 2 + 0.5)
        ver_offset = -0.5 + bar_width * (model_idx * 2 + 1.5)
        
        original_data = []
        verifier_data = []
        for method in methods:
            if dataset_idx is None:
                # Use average data
                original_data.append(data[method][model_idx][0][0])
                verifier_data.append(data[method][model_idx][1][0])
                MAX = 80
            else:
                # Use specific dataset data
                original_data.append(data[method][model_idx][0][dataset_idx])
                verifier_data.append(data[method][model_idx][1][dataset_idx])
                LIST_MAX = [105, 98, 68, 75]
                MAX = LIST_MAX[dataset_idx]

        # Original performance
        fig.add_trace(go.Bar(
            name=model,
            x=methods,
            y=original_data,
            marker_color=colors[model_idx],
            opacity=0.7,
            width=bar_width,
            offset=orig_offset,
            showlegend=True,
            text=[f"{x:.0f}%" for x in original_data],
            textposition='outside',
            textfont=dict(size=16)
        ))
        
        # Verifier performance
        fig.add_trace(go.Bar(
            name=f"{model} Verifier",
            x=methods,
            y=verifier_data,
            marker=dict(
                pattern=dict(
                    shape='/',
                    solidity=0.7,
                    size=10,
                    fgcolor='white',
                    bgcolor=colors[model_idx],
                ),
                color=colors[model_idx],
            ),
            width=bar_width,
            offset=ver_offset,
            showlegend=False,
            text=[f"{x:.0f}%" for x in verifier_data],
            textposition='outside',
            textfont=dict(size=14)
        ))

    # Legend items for LLM and Verifier
    fig.add_trace(go.Bar(
        name="Unweighted Voting",
        x=[None],
        y=[None],
        marker_color='black',
        showlegend=True,
        marker=dict(
                pattern=dict(
                    size=15,
                ),
            ),
        # visible='legendonly'
    ))
    
    fig.add_trace(go.Bar(
        name="Verifier Voting",
        x=[None],
        y=[None],
        marker=dict(
            pattern=dict(
                shape='/',
                solidity=0.7,
                size=15,
                fgcolor='white',
                bgcolor='black',
            ),
            color='black',
        ),
        showlegend=True,
        # visible='legendonly'
    ))

    fig.update_layout(
        # title=dict(
        #     text=title,
        #     font=dict(size=30),
        #     x=0.5,
        #     y=0.95
        # ),
        width=1200,
        height=340,
        # template='plotly_white',
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.05,
            font=dict(
                color='black',
                size=20,
            ),
            xanchor="right",
            x=1
        ),
        margin=dict(l=5, r=5, t=5, b=5),
        barmode='overlay',
        yaxis=dict(
            title='Accuracy',
            tickformat='%',
            ticktext=[f'{x}%' for x in range(0, 101, 20)],
            tickvals=[x for x in range(0, 101, 20)],
            title_font=dict(size=23),
            tickfont=dict(size=20)
        ),
        xaxis=dict(
            title='Method',
            title_font=dict(size=23),
            tickfont=dict(size=20)
        ),
    )

    # 模板配置 - 强制所有颜色为黑色
    template = dict(
        layout=dict(
            font_color="black",
            paper_bgcolor="white",
            plot_bgcolor="white",
            title_font_color="black",
            legend_font_color="black",
            
            xaxis=dict(
                title_font_color="black",
                tickfont_color="black",
                linecolor="black",
                gridcolor="black",
                zerolinecolor="black",
            ),
            
            yaxis=dict(
                title_font_color="black", 
                tickfont_color="black",
                linecolor="black",
                gridcolor="black",
                zerolinecolor="black",
            ),
            
            hoverlabel=dict(
                font_color="black",
                bgcolor="white"
            ),
            
            annotations=[dict(font_color="black")],
            shapes=[dict(line_color="black")],
            
            coloraxis=dict(
                colorbar_tickfont_color="black",
                colorbar_title_font_color="black"
            ),
        )
    )

    fig.update_yaxes(
        range=[0, MAX],
        zeroline=True,
        zerolinewidth=1,
        zerolinecolor='black',
        gridcolor='lightgray',
        dtick=40,
    )
    
    fig.update_xaxes(
        gridcolor='lightgray',
        title_standoff=0,
    )

    # 应用配置
    fig.update_layout(template=template)

    
    # fig.update_layout(
    # xaxis_title='Number of Chains',
    # yaxis_title='Accuracy',
    # xaxis=dict(tickfont=dict(size=20), title_font=dict(size=30), tickvals=number_of_thoughts, ticktext=[str(x) for x in number_of_thoughts]),
    # yaxis=dict(tickfont=dict(size=20), title_font=dict(size=30), range=[0, 0.75], tickformat=".0%", dtick=0.15),  # 设置 y 轴范围


    return fig

# Create plots
methods = ["CoT", "LtM", "ToT", "MCTS"]
models = ["1B", "3B", "8B", "70B"]
dataset_names = ["AQuA", "MMLU", "StrategyQA", "CommonSenseQA"]

data = {
    "CoT": [
        [[12.0, 20.0, 50.0, 20.0], [16.0, 20.0, 52.0, 26.0]],
        [[60.0, 40.0, 30.0, 56.0], [68.0, 46.0, 34.0, 56.0]],
        [[74.0, 66.0, 30.0, 68.0], [76.0, 66.0, 43.0, 68.0]],
        [[94.0, 86.0, 40.0, 68.0], [96.0, 86.0, 42.0, 68.0]]
    ],
    "LtM": [
        [[18.0, 22.0, 32.0, 16.0], [22.0, 22.0, 32.0, 22.0]],
        [[34.0, 14.0, 36.0, 32.0], [38.0, 16.0, 36.0, 34.0]],
        [[76.0, 66.0, 24.0, 66.0], [78.0, 66.0, 38.0, 66.0]],
        [[96.0, 90.0, 32.0, 70.0], [96.0, 90.0, 32.0, 70.0]]
    ],
    "ToT": [
        [[34.0, 30.0, 50.0, 26.0], [34.0, 32.0, 58.0, 26.0]],
        [[38.0, 32.0, 14.0, 47.0], [38.0, 36.6, 14.0, 51.0]],
        [[64.0, 62.0, 12.0, 64.0], [66.0, 68.0, 14.0, 64.0]],
        [[98.0, 88.0, 42.0, 66.0], [98.0, 90.0, 44.0, 66.0]]
    ],
    "MCTS": [
        [[18.0, 22.0, 46.0, 30.0], [24.0, 22.0, 50.0, 30.0]],
        [[44.0, 28.0, 14.0, 54.0], [44.0, 30.0, 18.0, 54.0]],
        [[72.0, 60.0, 14.0, 56.0], [72.0, 60.0, 16.0, 56.0]],
        [[96.0, 84.0, 40.0, 68.0], [96.0, 84.0, 44.0, 68.0]]
    ]
}

# Create average performance plot
avg_data = calculate_averages(data)
avg_fig = create_performance_plot(avg_data, methods, models)
# avg_fig.show()
avg_fig.write_image("figures/average_performance.pdf", scale=6)

# Create individual dataset plots
for i in range(4):
    dataset_fig = create_performance_plot(data, methods, models, dataset_idx=i)
    # dataset_fig.show()
    dataset_fig.write_image(f"figures/{dataset_names[i]}_performance.pdf", scale=6)