In [56]:
import pandas as pd
import numpy as np
import psycopg2
import os
import csv
from mlxtend.frequent_patterns import apriori
from mlxtend.frequent_patterns import association_rules
import warnings
warnings.filterwarnings('ignore')

In [57]:
DB_NAME = "project"
DB_USER = "postgres"  # 默认用户名
DB_PASSWORD = "postgres"  # 默认密码
DB_HOST = "localhost"
DB_PORT = "5433"

In [58]:
output_dir = "csv_debug_mining_results"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

In [59]:
# 连接到数据库
def connect_to_db():
    try:
        conn = psycopg2.connect(
            dbname=DB_NAME,
            user=DB_USER,
            password=DB_PASSWORD,
            host=DB_HOST,
            port=DB_PORT
        )
        print("成功连接到数据库")
        return conn
    except Exception as e:
        print(f"连接数据库时出错: {e}")
        return None

In [67]:
def extract_data_from_db(conn):
    query = """
    SELECT 
        p.road_user,
        p.gender,
        p.age_group,
        t.day_of_week,
        t.time_of_day,
        v.bus_involvement,
        v.heavy_rigid_truck_involvement,
        v.articulated_truck_involvement,
        r.national_road_type,
        a.crash_type,
        sp.christmas_period,
        sp.easter_period,
        l.state,
        l.national_remoteness_areas,
        s.speed_limit,
        f.number_fatalities,
        d.dwelling_records
    FROM 
        accident_facts af
    JOIN 
        personnel_dimension p ON af.personnelid = p.personnelid
    JOIN 
        time_dimension t ON af.timeid = t.timeid
    JOIN 
        vehicle_type_dimension v ON af.vehicleid = v.vehicleid
    JOIN 
        road_type_dimension r ON af.roadid = r.roadid
    JOIN 
        accident_info_dimension a ON af.crashtypeid = a.crashtypeid
    JOIN 
        special_period_dimension sp ON af.specialperiodid = sp.specialperiodid
    JOIN 
        location_dimension l ON af.locationid = l.locationid
    JOIN 
        speed_limit_dimension s ON af.speedlimitid = s.speedlimitid
    JOIN 
        fatality_dimension f ON af.fatalityid = f.fatalityid
    JOIN
        dwelling_dimension d ON af.dwellingid = d.dwellingid
    WHERE 
        p.road_user NOT IN ('Unknown', 'Undetermined', 'Other/-9');
    """
    return pd.read_sql_query(query, conn)

In [61]:
def preprocess_data(df):
    # 转换速度限制为分类
    df['speed_limit'] = pd.to_numeric(df['speed_limit'], errors='coerce')
    df['speed_limit_category'] = df['speed_limit'].apply(
        lambda x: 'Undetermined' if pd.isna(x) or x <= 0
        else 'Low (≤50)' if x <= 50
        else 'Medium (51-80)' if x <= 80
        else 'High (>80)'
    )
    df = df.drop('speed_limit', axis=1)
    
    # 转换死亡人数为分类
    df['number_fatalities'] = pd.to_numeric(df['number_fatalities'], errors='coerce')
    df['fatalities_category'] = df['number_fatalities'].apply(
        lambda x: 'Unknown' if pd.isna(x) or x < 0
        else 'Normal' if x == 1
        else 'Mid' if 2 <= x <= 3
        else 'High'
    )
    df = df.drop('number_fatalities', axis=1)

    # 转换dwelling records为分类
    df['dwelling_records'] = pd.to_numeric(df['dwelling_records'], errors='coerce')
    df['dwelling_category'] = df['dwelling_records'].apply(
        lambda x: 'Unknown' if pd.isna(x) or x < 0
        else 'Low' if x <= 10000
        else 'Mid' if 10000 < x <= 50000
        else 'High'
    )
    df = df.drop('dwelling_records', axis=1)
    
    # 确保所有列都是字符串类型
    for col in df.columns:
        if col != 'road_user':
            df[col] = df[col].astype(str)
    
    return df

In [62]:
def mine_1to1_rules(df, target_column='road_user', min_support=0.1, min_confidence=0.6):
    rules_list = []
    
    # 获取目标列的唯一值
    target_values = df[target_column].unique()
    
    for target in target_values:
        for feature in df.columns:
            if feature == target_column:
                continue
                
            # 获取特征的唯一值
            feature_values = df[feature].unique()
            
            for value in feature_values:
                # 创建临时DataFrame进行分析
                temp_df = pd.DataFrame()
                temp_df[f"{feature}_{value}"] = (df[feature] == value).astype(int)
                temp_df[f"{target_column}_{target}"] = (df[target_column] == target).astype(int)
                
                try:
                    # 应用Apriori算法
                    frequent_itemsets = apriori(temp_df, 
                                             min_support=min_support, 
                                             use_colnames=True)
                    
                    if frequent_itemsets.empty:
                        continue
                        
                    # 生成规则
                    rules = association_rules(frequent_itemsets, 
                                          metric="confidence",
                                          min_threshold=min_confidence)
                    
                    if rules.empty:
                        continue
                        
                    # 筛选1->1规则
                    for _, rule in rules.iterrows():
                        ant = list(rule['antecedents'])[0]
                        con = list(rule['consequents'])[0]
                        
                        if f"{target_column}_{target}" in [ant, con]:
                            rules_list.append({
                                'feature': feature,
                                'feature_value': value,
                                'target': target,
                                'support': rule['support'],
                                'confidence': rule['confidence'],
                                'lift': rule['lift']
                            })
                            
                except Exception as e:
                    continue
    
    return pd.DataFrame(rules_list)


In [63]:
def save_rules_to_csv(rules_df):
    if not rules_df.empty:
        # 应用筛选条件
        filtered_rules = rules_df[
            (rules_df['support'] >= 0.05) &      # 最小支持度 0.2
            (rules_df['confidence'] >= 0.6) &    # 最小置信度 0.6
            (rules_df['lift'] > 1.0)            # 最小提升度 1.0
        ]
        
        # 按lift降序排序
        filtered_rules = filtered_rules.sort_values('lift', ascending=False)
        
        # 保存到CSV
        output_file = os.path.join(output_dir, "one_to_one_rules.csv")
        filtered_rules.to_csv(output_file, index=False)
        
        # 打印统计信息
        print(f"\n规则筛选统计:")
        print(f"筛选前规则数量: {len(rules_df)}")
        print(f"筛选后规则数量: {len(filtered_rules)}")
        print(f"规则已保存到: {output_file}")
        
        # 打印一些基本统计
        if not filtered_rules.empty:
            print("\n筛选后规则统计:")
            print(f"平均支持度: {filtered_rules['support'].mean():.4f}")
            print(f"平均置信度: {filtered_rules['confidence'].mean():.4f}")
            print(f"平均提升度: {filtered_rules['lift'].mean():.4f}")
    else:
        print("未找到符合条件的规则")

In [68]:
def main():
    conn = connect_to_db()
    if conn is None:
        return
        
    try:
        # 提取和预处理数据
        df = extract_data_from_db(conn)
        df = preprocess_data(df)
        
        # 挖掘规则
        rules_df = mine_1to1_rules(
            df,
            min_support=0.005,
            min_confidence=0.1
        )
        
        # 保存规则
        save_rules_to_csv(rules_df)
        
    except Exception as e:
        print(f"错误: {e}")
    finally:
        conn.close()

if __name__ == "__main__":
    main()

成功连接到数据库

规则筛选统计:
筛选前规则数量: 309
筛选后规则数量: 30
规则已保存到: csv_debug_mining_results/one_to_one_rules.csv

筛选后规则统计:
平均支持度: 0.1977
平均置信度: 0.8505
平均提升度: 1.1101
