In [None]:
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import os
import shutil
import zipfile

INPUT_DIR = 'input'
OUTPUT_DIR = 'output'

def clear_folder(folder):
    if os.path.exists(folder):
        shutil.rmtree(folder)
    os.makedirs(folder)

def save_uploaded_file(uploaded_file, folder):
    file_path = os.path.join(folder, uploaded_file.name)
    with open(file_path, 'wb') as f:
        f.write(uploaded_file.getbuffer())
    return file_path

def zip_files(file_list, zip_path):
    with zipfile.ZipFile(zip_path, 'w') as zipf:
        for file in file_list:
            zipf.write(file, os.path.basename(file))
    return zip_path

st.title("Excel数据分析工具")

# 初始化 session state
if 'last_file_name' not in st.session_state:
    st.session_state['last_file_name'] = None
if 'last_sheet_name' not in st.session_state:
    st.session_state['last_sheet_name'] = None
if 'should_clear_output' not in st.session_state:
    st.session_state['should_clear_output'] = True

# 只在需要时清空 output 文件夹
if st.session_state['should_clear_output']:
    clear_folder(OUTPUT_DIR)
    st.session_state['should_clear_output'] = False

uploaded_file = st.file_uploader("上传Excel文件", type=["xlsx"])

if uploaded_file:
    # 上传新文件时清空 input 和 output 文件夹
    if st.session_state['last_file_name'] != uploaded_file.name:
        clear_folder(INPUT_DIR)
        clear_folder(OUTPUT_DIR)
        st.session_state['last_file_name'] = uploaded_file.name
        st.session_state['last_sheet_name'] = None

    temp_path = os.path.join(INPUT_DIR, uploaded_file.name)
    if not os.path.exists(INPUT_DIR):
        os.makedirs(INPUT_DIR)
    with open(temp_path, 'wb') as f:
        f.write(uploaded_file.getbuffer())
    xls = pd.ExcelFile(temp_path)
    sheet_names = xls.sheet_names
    xls.close()
    selected_sheet = st.selectbox("请选择要分析的Sheet", sheet_names)

    # 切换sheet时清空output文件夹
    if selected_sheet and st.session_state['last_sheet_name'] != selected_sheet:
        clear_folder(OUTPUT_DIR)
        st.session_state['last_sheet_name'] = selected_sheet

    if selected_sheet:
        if st.button("开始分析"):
            file_path = save_uploaded_file(uploaded_file, INPUT_DIR)
            st.success(f"文件上传成功，选择Sheet：{selected_sheet}，开始分析...")

            def analyze_data(file_path, sheet_name):
                df = pd.read_excel(file_path, sheet_name=sheet_name)
                csv_files = []
                for i in range(1, 7):
                    out_path = os.path.join(OUTPUT_DIR, f'result_{i}.csv')
                    df.head(10 * i).to_csv(out_path, index=False, encoding='utf-8-sig')
                    csv_files.append(out_path)
                img_paths = []
                num_cols = df.select_dtypes(include='number').columns
                if len(num_cols) > 0:
                    plt.figure()
                    df[num_cols[0]].hist()
                    plt.title('数据分布直方图')  # 添加中文标题
                    img1 = os.path.join(OUTPUT_DIR, 'plot1.png')
                    plt.savefig(img1)
                    img_paths.append(img1)
                    plt.close()
                if len(num_cols) > 1:
                    plt.figure()
                    df[num_cols[1]].plot(kind='line')
                    plt.title('数据趋势图')  # 添加中文标题
                    img2 = os.path.join(OUTPUT_DIR, 'plot2.png')
                    plt.savefig(img2)
                    img_paths.append(img2)
                    plt.close()
                return csv_files, img_paths

            excel_files, img_paths = analyze_data(file_path, selected_sheet)
            zip_name = f"{os.path.splitext(uploaded_file.name)[0]}_results.zip"
            zip_path = os.path.join(OUTPUT_DIR, zip_name)
            zip_files(excel_files + img_paths, zip_path)

# 在新的浏览器标签页访问时清空文件夹
elif not uploaded_file:
    st.session_state['should_clear_output'] = True

# 分析结果和图片一直可以下载
if 'last_file_name' in st.session_state and st.session_state['last_file_name']:
    zip_name = f"{os.path.splitext(st.session_state['last_file_name'])[0]}_results.zip"
    zip_path = os.path.join(OUTPUT_DIR, zip_name)
    if os.path.exists(zip_path):
        with open(zip_path, 'rb') as f:
            st.download_button("下载分析结果（压缩包）", f, file_name=zip_name)
if os.path.exists(OUTPUT_DIR):
    img_files = [f for f in os.listdir(OUTPUT_DIR) if f.endswith('.png')]
    if img_files:
        st.write("分析图片：")
        for img in img_files:
            img_path = os.path.join(OUTPUT_DIR, img)
            st.image(img_path)
            with open(img_path, 'rb') as f:
                st.download_button(f"下载 {img}", f, file_name=img)