In [1]:
import pandas as pd
import math
import csv
import sys

class SegmentTree:
    def __init__(self, file_path, chunk_size, list_name, load_path = None):
        if load_path: # 如果给了载入路径，那么载入
            self.chunk_size = chunk_size
            self.list_name = list_name
            self.data_path = file_path
            
            self.data = self.read_csv()

            self.n = math.ceil(len(self.data) / self.chunk_size)
            self.size = 2 ** (math.ceil(math.log2(self.n)) + 1)
            self.tree = [None] * self.size
            
            self.load_from_csv(load_path)
            
        else:
            self.chunk_size = chunk_size
            self.list_name = list_name
            self.data_path = file_path
            
            self.data = self.read_csv()
            
            self.n = math.ceil(len(self.data) / self.chunk_size)    # 向上取整
            self.size = 2 ** (math.ceil(math.log2(self.n)) + 1)
            self.tree = [None] * self.size  # 以数组的形式构建段树
            self.build(0, 0, self.n - 1)
        
        
    def read_csv(self):
        # 读取原始数据对应的列，用于构建段树或者进行硬扫描
        df = pd.read_csv(self.data_path)
        return df[self.list_name].tolist()
    
    def read_csv_part(self, start, end):
        # 读取原始数据某个范围内的行，用于硬扫描
        df = pd.read_csv(self.data_path, skiprows=range(1, start), nrows=end - start)
        return df[self.list_name].tolist()


    def build(self, idx, left, right): # 根节点存放在tree[0]，递归地到达叶节点，自底向上构建段树
        if left == right:
            start = left * self.chunk_size
            end = min(start + self.chunk_size, len(self.data)) # 防止最后一块超出范围，以及块范围（已经转化为0起始的）：0-49,50-99，100-149 …………
            # 计算当前节点的和以及平方和（有了平方和以及和就可以算方差，按照PPT上data canopy的recipe）
            self.tree[idx] = {"ordinary_sum":self.caculate_ordinary_sum(start, end), "square_sum":self.caculate_square_sum(start, end)} 
        else:
            mid = (left + right) // 2
            self.build(2 * idx + 1, left, mid)
            self.build(2 * idx + 2, mid + 1, right)
            # 儿子节点
            left_node = self.tree[2 * idx + 1]
            right_node = self.tree[2 * idx + 2]
            
            # 求解父节点的均值以及平方和
            ordinary_sum = left_node["ordinary_sum"]+ right_node["ordinary_sum"]
            square_sum = left_node["square_sum"] + right_node["square_sum"]
            self.tree[idx] = {
                "ordinary_sum": ordinary_sum, 
                "square_sum": square_sum
            }
            
            
    def caculate_ordinary_sum(self, start, end):
        return sum(self.data[start:end])
    
    def caculate_square_sum(self, start, end):
        return sum([x ** 2 for x in self.data[start:end]])
    
    def save_to_csv(self, save_path):
        with open(save_path, "w", newline='') as f:
            writer = csv.writer(f)
            writer.writerow(["idx", "ordinary_sum", "square_sum"])
            for idx, node in enumerate(self.tree):
                if node is not None:
                    writer.writerow([idx, node["ordinary_sum"], node["square_sum"]])
                else:
                    writer.writerow([idx,  None, None])
            
    def load_from_csv(self, load_path):
        with open(load_path, "r") as f:
            reader = csv.reader(f)
            next(reader)
            for row in reader:
                idx = int(row[0])
                if row[1] != '':
                    ordinary_sum = float(row[1]) 
                else:
                    ordinary_sum = None
                    
                if row[2] != '':
                    square_sum = float(row[2])
                else:
                    square_sum = None
                    
                if len(self.tree) <= idx:
                    self.tree.extend([None] * (idx - len(self.tree) + 1))
                self.tree[idx] = {"ordinary_sum":ordinary_sum, "square_sum":square_sum}
            print("段树"+ self.list_name +"读取完毕！")
            

    def query(self, data_start, data_end, query_type):
        # 将数据索引转换为对应块的索引
        start_chunk = data_start // self.chunk_size
        end_chunk = (data_end - 1) // self.chunk_size
        # 范围（start-end）： chunk编号（chunk_size）， 头部额外处理， 尾部额外处理
        # 3000-5000: 60-99 （50）, , 5000
        # 3000-5000: 3-4（1000）, , 5000
        # 2995-5005: 60-100 （50）, 2995-2999, 5000-5005
        # 15-35: 0-0 (50), 15-35
        # 15-50: 0-0 (50), 15-50
        # 50-100: 1-1 (50), , 100
        # 15-100: 0-1 (50), 15-49, 100
        # 120550-125555：2411-2511， ， 125550-125555

        total_sum = 0.0
        total_square_sum = 0.0
        total_size = 0
        
        if data_end - data_start < self.chunk_size:
            # 范围小于chunk_size，直接硬扫描
            # print("范围小于chunk_size，直接硬扫描。")
            data_part = self.read_csv_part(data_start - 1, data_end)
            total_sum = sum(data_part)
            total_square_sum = sum([x ** 2 for x in data_part])
        else:
            # 对于开始部分（如果存在）没有满的块，硬扫描
            if data_start % self.chunk_size != 1:
                # 从data_start到start_chunk结束
                # print("开始部分有不满的块，需要硬扫描。")
                if data_start % self.chunk_size != 0:
                    temp_end = (start_chunk + 1) * self.chunk_size
                    start_chunk += 1
                else:
                    temp_end = start_chunk * self.chunk_size
                part_data = self.read_csv_part(data_start - 1, temp_end)
                part_sum = sum(part_data)
                part_square_sum = sum(x**2 for x in part_data)
                part_size = len(part_data)
                
                # print("开始部分的大小：", data_start, "-", temp_end, "大小：", part_size)
                total_sum += part_sum
                total_square_sum += part_square_sum
            # else:
                # print("开始部分没有不满的块，不需要硬扫描。")
            
            # 结束部分同上处理
            if data_end % self.chunk_size != 0:
                # 从temp_start开始到data_end
                # print("结束部分有不满的块，需要硬扫描。")
                temp_start = end_chunk * self.chunk_size + 1
                end_chunk -= 1
                part_data = self.read_csv_part(temp_start - 1, data_end)
                part_sum = sum(part_data)
                part_square_sum = sum(x**2 for x in part_data)
                part_size = len(part_data)
                
                # print("结束部分的大小：", temp_start, "-", data_end, "大小：", part_size)
                total_sum += part_sum
                total_square_sum += part_square_sum
            # else:
                # print("结束部分没有不满的块，不需要硬扫描。")
            
            # 处理块
            if start_chunk <= end_chunk:
                # 查询段树，从 start_chunk 到 end_chunk
                # print("需要查询块。")
                # print("查询chunk的范围：", start_chunk, "-", end_chunk)
                sum_chunks, square_sum_chunks = self.range_query(start_chunk, end_chunk)
                total_sum += sum_chunks
                total_square_sum += square_sum_chunks
            # else:
                # print("不需要查询块，只执行硬扫描。")


        # 返回结果
        total_size = data_end - data_start + 1
        if query_type == 'mean':
            return total_sum / total_size
        elif query_type == 'variance':
            total_mean = total_sum / total_size
            variance = (total_square_sum / total_size) - (total_mean ** 2)
            return variance

    def query_util(self, node_idx, node_start, node_end, query_start, query_end):
        # 返回的格式：{sum, square_sum, size}，用于chunk的查询；
        # 不需要搜索到叶节点才停止，很多节点共有内部父（也可能祖父，曾祖父……）节点，搜到这个节点就够了
        if query_start > node_end or query_end < node_start:
            # 搜索到对应的节点区间，发现和要搜索的chunk完全没有交集
            # 那么这一部分不需要继续迭代下去，直接开始返回
            return 0.0, 0.0
        if query_start <= node_start and node_end <= query_end:
            node = self.tree[node_idx]
            if node is None:
                # print("节点不存在！") # 理论上不会执行到这个分支，调试用的
                return 0.0, 0.0
            return node['ordinary_sum'], \
                    node['square_sum']
        mid = (node_start + node_end) // 2
        left_sum, left_square_sum = self.query_util(2 * node_idx + 1, node_start, mid, query_start, query_end)
        right_sum, right_square_sum = self.query_util(2 * node_idx + 2, mid + 1, node_end, query_start, query_end)
        total_sum = left_sum + right_sum
        total_square_sum = left_square_sum + right_square_sum
        return total_sum, total_square_sum

    def range_query(self, chunk_start, chunk_end):
        return self.query_util(0, 0, self.n - 1, chunk_start, chunk_end)
    

In [2]:
import os

file_path = r"E:/Codes/IDEA codes/BigDataAnalysis/LAB4/data.csv"


print("==============================")
chunk_size = int(input("指定段树的chunk_size："))
print("指定的chunk_size为：" + str(chunk_size))
print("以离线模式构建段树……")
st_TS = SegmentTree(file_path, chunk_size, list_name = "TS")
st_PRECTOT = SegmentTree(file_path, chunk_size, list_name = "PRECTOT")
print("段树构建完毕，正在保存……")
folder_path = r"E:/Codes/IDEA codes/BigDataAnalysis/LAB4/chunk_size_" + str(chunk_size)
if not os.path.exists(folder_path):
        os.makedirs(folder_path)
save_path_TS = r"E:/Codes/IDEA codes/BigDataAnalysis/LAB4/chunk_size_" + str(chunk_size) + r"/segment_tree_TS.csv"
save_path_PRECTOT = r"E:/Codes/IDEA codes/BigDataAnalysis/LAB4/chunk_size_" + str(chunk_size) + r"/segment_tree_PRECTOT.csv"
st_TS.save_to_csv(save_path_TS)
st_PRECTOT.save_to_csv(save_path_PRECTOT)
print("保存完毕！")
print("==============================")


指定的chunk_size为：500000
以离线模式构建段树……
段树构建完毕，正在保存……
保存完毕！
