In [1]:
import logging
import copy
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

# 导入SecretFlow相关库
from secretflow import PYUObject, proxy
import secretflow as sf


# 定义一个简单的CNN模型，可以按块分解
class SimpleBlockModel(nn.Module):
    def __init__(self, num_blocks=3, in_channels=1, num_classes=10, client_num=5):
        super(SimpleBlockModel, self).__init__()
        self.num_blocks = num_blocks

        # 按照客户端数量的平方根缩放隐藏维度
        scale_factor = 1.0 / math.sqrt(client_num)
        logging.info(f"模型初始化: 隐藏维度缩放系数 = {scale_factor:.4f}")

         # 计算各层通道数
        block1_channels = int(32 * scale_factor)
        block2_channels = int(64 * scale_factor)
        block3_output = int(128 * scale_factor)

         # Block 1: 第一个卷积块
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels, block1_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(block1_channels),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        logging.info(f"模型初始化: Block 1 输出通道数 = {block1_channels}")
        
        # Block 2: 第二个卷积块
        self.block2 = nn.Sequential(
            nn.Conv2d(block1_channels, block2_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(block2_channels),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        logging.info(f"模型初始化: Block 2 输出通道数 = {block2_channels}")
        
        # 计算展平后的特征维度 (MNIST 图像经过两次池化后为 7x7)
        flattened_dim = block2_channels * 7 * 7

        # Block 3: 全连接层块
        self.block3 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flattened_dim, block3_output),
            nn.ReLU()
        )
        logging.info(f"模型初始化: Block 3 输入特征维度 = {flattened_dim}")
        logging.info(f"模型初始化: Block 3 输出特征维度 = {block3_output}")

        # 分类器
        self.classifier = nn.Linear(block3_output, num_classes)
        logging.info(f"模型初始化: 分类器输出类别数 = {num_classes}")

    def forward(self, x, block_idx=None):
        if block_idx is None:  # 前向传播所有块
            x = self.block1(x)
            x = self.block2(x)
            x = self.block3(x)
            x = self.classifier(x)
            return x

        # 只前向传播指定块
        if block_idx == 1:
            return self.block1(x)
        elif block_idx == 2:
            return self.block2(x)
        elif block_idx == 3:
            return self.block3(x)
        elif block_idx == 4:  # 分类器
            return self.classifier(x)

    def get_feature(self, x):
        """获取特征向量"""
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        return x


# 特征适配器
class FeatureAdapter(nn.Module):
    def __init__(self, in_channels, out_channels, adapter_type="conv1x1"):
        super(FeatureAdapter, self).__init__()
        self.adapter_type = adapter_type
        logging.info(f"创建特征适配器: 类型={adapter_type}, 输入通道={in_channels}, 输出通道={out_channels}")

        if adapter_type == "conv1x1":
            self.adapter = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        elif adapter_type == "linear":
            self.adapter = nn.Linear(in_channels, out_channels)

    def forward(self, x):
        return self.adapter(x)


# 在SecretFlow中定义客户端
@proxy(PYUObject)
class CauseClient:
    def __init__(self, client_id, client_num, num_blocks=3, num_classes=10, in_channels=1):
        self.client_id = client_id
        self.client_num = client_num
        self.num_blocks = num_blocks
        self.device = "cpu"  # 为简化使用CPU

        logging.info(f"初始化客户端 {client_id}/{client_num}, 模型块数={num_blocks}, 类别数={num_classes}")

        # 创建缩小版模型
        self.model = SimpleBlockModel(num_blocks=num_blocks, in_channels=in_channels,
                                      num_classes=num_classes, client_num=client_num)
        self.optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9)
        logging.info(f"客户端 {client_id} 创建了SGD优化器，学习率=0.01, 动量=0.9")

        # 存储融合后的块
        self.fused_blocks = {}

        # 为每个块创建特征适配器
        self.adapters = {}
        logging.info(f"客户端 {client_id} 初始化完成")

    def _debug_shape(self, x, name):
        """用于调试张量形状的辅助函数"""
        if isinstance(x, torch.Tensor):
            logging.info(f"客户端 {self.client_id} - DEBUG: {name} 形状: {x.shape}")
        else:
            logging.info(f"客户端 {self.client_id} - DEBUG: {name} 不是张量")
            return x
    

    def load_dataset(self, dataset_name="mnist", alpha=0.5):
        """加载数据集，创建非独立同分布数据"""
        logging.info(f"客户端 {self.client_id} 开始加载数据集: {dataset_name}, 非IID系数alpha={alpha}")

        if dataset_name == "mnist":
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])
            
            # 加载数据集
            logging.info(f"客户端 {self.client_id} 下载MNIST数据集...")
            dataset = datasets.MNIST(f'../data/client_{self.client_id}',
                                     train=True,
                                     download=True,
                                     transform=transform)
            
            # 创建非IID数据分布(简化版)
            # 这里为了简单，每个客户端只使用特定的标签
            num_classes = 10
            labels_per_client = 2
            start_label = (self.client_id * labels_per_client) % num_classes
            logging.info(f"客户端 {self.client_id} 使用标签范围: {start_label}~{start_label+labels_per_client-1}")

            indices = []
            for idx, (_, label) in enumerate(dataset):
                if start_label <= label < start_label + labels_per_client:
                    indices.append(idx)

            self.train_dataset = torch.utils.data.Subset(dataset, indices)
            self.train_loader = torch.utils.data.DataLoader(
                self.train_dataset, batch_size=64, shuffle=True
            )
            logging.info(f"客户端 {self.client_id} 数据集加载完成，共有{len(self.train_dataset)}个样本，批大小=64")
        else:
            raise NotImplementedError(f"数据集 {dataset_name} 尚未实现")

    def train_block(self, block_idx, epochs=5):
        """训练特定块"""
        logging.info(f"客户端 {self.client_id} 开始训练块 {block_idx}，训练轮次={epochs}")

        self.model.train()

        # 冻结其他块，只训练当前块
        for param in self.model.parameters():
            param.requires_grad = False
        logging.info(f"客户端 {self.client_id} 已冻结所有参数")

        if block_idx == 1:
            for param in self.model.block1.parameters():
                param.requires_grad = True
            logging.info(f"客户端 {self.client_id} 解冻块 1 参数")
        elif block_idx == 2:
            for param in self.model.block2.parameters():
                param.requires_grad = True
            logging.info(f"客户端 {self.client_id} 解冻块 2 参数")
        elif block_idx == 3:
            for param in self.model.block3.parameters():
                param.requires_grad = True
            logging.info(f"客户端 {self.client_id} 解冻块 3 参数")
        elif block_idx == 4:  # 分类器
            for param in self.model.classifier.parameters():
                param.requires_grad = True
            logging.info(f"客户端 {self.client_id} 解冻分类器参数")

        # 训练块
        for epoch in range(epochs):
            running_loss = 0.0
            correct = 0
            total = 0
            for batch_idx, (data, target) in enumerate(self.train_loader):
                data, target = data.to(self.device), target.to(self.device)
                self.optimizer.zero_grad()

                # 前向传播
                if block_idx == 1:  # 第一个块
                    # 使用原始输入
                    h1 = self.model.block1(data)
                    h2 = self.model.block2(h1)
                    h3 = self.model.block3(h2)
                    output = self.model.classifier(h3)

                elif block_idx == 2:  # 第二个块
                    # 使用融合的block1(如果有)
                    if 1 in self.fused_blocks:
                        h1 = self.fused_blocks[1](data)
                        logging.info(f"客户端 {self.client_id} 使用融合后的块 1") if batch_idx == 0 else None
                    else:
                        h1 = self.model.block1(data)
                        logging.info(f"客户端 {self.client_id} 使用本地块 1") if batch_idx == 0 else None

                    h2 = self.model.block2(h1)
                    h3 = self.model.block3(h2)
                    output = self.model.classifier(h3)

                elif block_idx == 3:  # 第三个块
                    # 使用融合的block1和block2(如果有)
                    if 1 in self.fused_blocks:
                        h1 = self.fused_blocks[1](data)
                        logging.info(f"客户端 {self.client_id} 使用融合后的块 1") if batch_idx == 0 else None
                    else:
                        h1 = self.model.block1(data)
                        logging.info(f"客户端 {self.client_id} 使用本地块 1") if batch_idx == 0 else None

                    if 2 in self.fused_blocks:
                        h2 = self.fused_blocks[2](h1)
                        logging.info(f"客户端 {self.client_id} 使用融合后的块 2") if batch_idx == 0 else None
                    else:
                        h2 = self.model.block2(h1)
                        logging.info(f"客户端 {self.client_id} 使用本地块 2") if batch_idx == 0 else None

                    h3 = self.model.block3(h2)
                    output = self.model.classifier(h3)

                elif block_idx == 4:  # 分类器
                    # 使用融合的所有块(如果有)
                    if 1 in self.fused_blocks:
                        h1 = self.fused_blocks[1](data)
                        logging.info(f"客户端 {self.client_id} 使用融合后的块 1") if batch_idx == 0 else None
                    else:
                        h1 = self.model.block1(data)
                        logging.info(f"客户端 {self.client_id} 使用本地块 1") if batch_idx == 0 else None

                    if 2 in self.fused_blocks:
                        h2 = self.fused_blocks[2](h1)
                        logging.info(f"客户端 {self.client_id} 使用融合后的块 2") if batch_idx == 0 else None
                    else:
                        h2 = self.model.block2(h1)
                        logging.info(f"客户端 {self.client_id} 使用本地块 2") if batch_idx == 0 else None

                    if 3 in self.fused_blocks:
                        h3 = self.fused_blocks[3](h2)
                        logging.info(f"客户端 {self.client_id} 使用融合后的块 3") if batch_idx == 0 else None
                    else:
                        h3 = self.model.block3(h2)
                        logging.info(f"客户端 {self.client_id} 使用本地块 3") if batch_idx == 0 else None

                    output = self.model.classifier(h3)

                loss = F.cross_entropy(output, target)
                loss.backward()
                self.optimizer.step()

                running_loss += loss.item()
                _, predicted = torch.max(output.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()

                if batch_idx % 10 == 0:
                    logging.info(f"客户端 {self.client_id}, 块 {block_idx}, 轮次 {epoch + 1}: "
                                    f"[{batch_idx * len(data)}/{len(self.train_loader.dataset)} "
                                    f"({100. * batch_idx / len(self.train_loader):.0f}%)], "
                                    f"损失: {loss.item():.6f}")

            epoch_loss = running_loss / len(self.train_loader)
            epoch_acc = 100. * correct / total
            logging.info(f"客户端 {self.client_id}, 块 {block_idx}, 轮次 {epoch + 1} "
                            f"平均损失: {epoch_loss:.6f}, 训练准确率: {epoch_acc:.2f}%")

        logging.info(f"客户端 {self.client_id} 完成块 {block_idx} 的训练")

    def get_block_state_dict(self, block_idx):
        """获取特定块的状态字典供融合"""
        logging.info(f"客户端 {self.client_id} 提供块 {block_idx} 的状态字典用于融合")
        if block_idx == 1:
            return copy.deepcopy(self.model.block1.state_dict())
        elif block_idx == 2:
            return copy.deepcopy(self.model.block2.state_dict())
        elif block_idx == 3:
            return copy.deepcopy(self.model.block3.state_dict())
        elif block_idx == 4:
            return copy.deepcopy(self.model.classifier.state_dict())


    def create_block(self, block_idx, in_channels=None, out_channels=None):
        """创建指定类型的块
        Args:
        block_idx (int): 块的索引
        in_channels (int, optional): 输入通道数
        out_channels (int, optional): 输出通道数
        Returns:
        nn.Module: 创建的块
        """
        logging.info(f"客户端 {self.client_id} 创建块 {block_idx}")
    
        if block_idx == 1:
            # 创建与block1相同结构的块
            return nn.Sequential(
                nn.Conv2d(self.model.block1[0].in_channels, 
                         self.model.block1[0].out_channels, 
                         kernel_size=3, padding=1),
                nn.BatchNorm2d(self.model.block1[1].num_features),
                nn.ReLU(),
                nn.MaxPool2d(2)
            )
        elif block_idx == 2:
            # 创建与block2相同结构的块
            return nn.Sequential(
                nn.Conv2d(self.model.block2[0].in_channels, 
                         self.model.block2[0].out_channels, 
                         kernel_size=3, padding=1),
                nn.BatchNorm2d(self.model.block2[1].num_features),
                nn.ReLU(),
                nn.MaxPool2d(2)
            )
        elif block_idx == 3:
            # 创建与block3相同结构的块
            flattened_dim = self.model.block3[1].in_features
            out_features = self.model.block3[1].out_features
            return nn.Sequential(
                nn.Flatten(),
                nn.Linear(flattened_dim, out_features),
                nn.ReLU()
            )
        elif block_idx == 4:
            # 创建分类器
            return nn.Linear(self.model.classifier.in_features, self.model.classifier.out_features)

    def add_feature_adapter(self, block_idx, in_channels, out_channels, adapter_type="conv1x1"):
        """添加特征适配器"""
        # 对于块2后的适配器，正确计算输入通道数
        if block_idx == 3:  # 块3是展平的线性层
            adapter_type = "linear"
            
        logging.info(f"客户端 {self.client_id} 添加块 {block_idx} 的特征适配器: 类型={adapter_type}, 输入={in_channels}, 输出={out_channels}")
        if adapter_type == "conv1x1":
            if hasattr(self.model, 'block2') and len(self.model.block2) > 0 and hasattr(self.model.block2[0], 'out_channels'):
                actual_in_channels = self.model.block2[0].out_channels * self.client_num
            else:
                # 如果找不到属性，使用传入的值或默认值
                actual_in_channels = in_channels
                logging.info(f"无法获取块2的输出通道数，使用传入值: {in_channels}")
            self.adapters[block_idx] = nn.Conv2d(actual_in_channels, out_channels, kernel_size=1)
        elif adapter_type == "linear":
            # 对于线性适配器，计算展平后的维度
            if block_idx == 3:
                if hasattr(self.model, 'block3') and len(self.model.block3) > 1 and hasattr(self.model.block3[1], 'in_features'):
                    flattened_dim = self.model.block3[1].in_features * self.client_num
                else:
                    flattened_dim = in_channels
                    logging.info(f"无法获取块3的输入特征数，使用传入值: {in_channels}")
                
                self.adapters[block_idx] = nn.Linear(flattened_dim, out_channels)
            else:
                self.adapters[block_idx] = nn.Linear(in_channels, out_channels)
            

    def create_and_load_block(self, block_idx, state_dict):
        """创建一个块并加载状态字典
        Args:
            block_idx (int): 块索引
            state_dict (dict): 状态字典  
        Returns:
            nn.Module: 加载了状态字典的块
        """
        # 创建块
        block = self.create_block(block_idx)
        # 加载状态字典
        block.load_state_dict(state_dict)
        return block


    def set_block_state_dict(self, block_idx, state_dict):
        """设置特定块的状态字典"""
        logging.info(f"客户端 {self.client_id} 设置块 {block_idx} 的状态字典")
    
        if block_idx == 1:
            self.model.block1.load_state_dict(state_dict)
        elif block_idx == 2:
            self.model.block2.load_state_dict(state_dict)
        elif block_idx == 3:
            self.model.block3.load_state_dict(state_dict)
        elif block_idx == 4:
            self.model.classifier.load_state_dict(state_dict)
        return True

    def set_fused_block(self, block_idx, client_blocks):
        """设置融合后的块"""
        # 为不同块选择不同的适配器类型
        if block_idx <= 2:  # 卷积块
            adapter_type = "conv1x1"
        else:  # 线性块
            adapter_type = "linear"
    
        # 创建包装函数
        def fused_forward(x):
            # 处理输出
            outputs = []
            for client_block in client_blocks:
                output = client_block(x)
                outputs.append(output)
        
            if adapter_type == "conv1x1":
                # 连接通道维度
                concat_output = torch.cat(outputs, dim=1)
            
                # 应用适配器或平均
                if block_idx in self.adapters:
                    return self.adapters[block_idx](concat_output)
                else:
                    return sum(outputs) / len(outputs)
                
            else:  # 线性适配器
                # 确保输出已被展平
                flattened_outputs = []
                for output in outputs:
                    if len(output.shape) > 2:  # 如果输出不是2D，则展平
                        flattened = torch.flatten(output, 1)
                        flattened_outputs.append(flattened)
                    else:
                        flattened_outputs.append(output)
            
                # 连接特征维度
                concat_output = torch.cat(flattened_outputs, dim=1)
            
                # 应用适配器或平均
                if block_idx in self.adapters:
                    return self.adapters[block_idx](concat_output)
                else:
                    return sum(flattened_outputs) / len(flattened_outputs)

    def set_classifier_state_dict(self, state_dict):
        """设置分类器权重"""
        self.model.classifier.load_state_dict(state_dict)
        logging.info(f"客户端 {self.client_id} 成功设置融合分类器")
        return True

    def get_feature_dimensions(self):
        """返回每一块的通道数"""
        feature_dims = {
            1: self.model.block1[0].out_channels,  # 第一个块的输出通道数
            2: self.model.block2[0].out_channels,  # 第二个块的输出通道数
            3: self.model.block3[1].out_features   # 第三个块的输出特征维度
        }
        logging.info(f"客户端 {self.client_id} 获取特征维度: {feature_dims}")
        return feature_dims

    def add_feature_adapter_pyu(self, block_idx, client_num):
        """在客户端PYU上添加特征适配器"""
        logging.info(f"客户端 {self.client_id} 在PYU上添加块 {block_idx+1} 的特征适配器")
        feature_dims = self.get_feature_dimensions()  # 直接在客户端PYU上获取
        in_dim = feature_dims[block_idx] * client_num   # 在PYU内部计算
        out_dim = feature_dims[block_idx]               # 在PYU内部计算
        adapter_type = "conv1x1" if block_idx <= 2 else "linear"
        self.add_feature_adapter(block_idx + 1, in_dim, out_dim, adapter_type=adapter_type)
        return True



# 在SecretFlow中定义服务器
@proxy(PYUObject)
class CauseServer:
    def __init__(self, client_num, num_classes=10, num_blocks=3):
        self.client_num = client_num
        self.num_classes = num_classes
        self.num_blocks = num_blocks
        self.device = "cpu"  # 为简化使用CPU
        logging.info(f"初始化服务器：客户端数量={client_num}, 类别数={num_classes}, 块数={num_blocks}")

    def fuse_classifiers(self, classifiers):
        """融合分类器权重"""
        logging.info(f"服务器开始融合{len(classifiers)}个分类器")
        
        # 简单平均分类器权重
        avg_state_dict = {}
        for key in classifiers[0].state_dict().keys():
            avg_state_dict[key] = sum(classifier.state_dict()[key] for classifier in classifiers) / len(classifiers)
        
        # 创建新分类器并加载平均权重
        in_features = classifiers[0].in_features
        out_features = classifiers[0].out_features
        avg_classifier = nn.Linear(in_features, out_features)
        avg_classifier.load_state_dict(avg_state_dict)
        
        logging.info(f"服务器成功融合分类器: 输入特征={in_features}, 输出特征={out_features}")
        return avg_classifier

    def get_classifier_state_dict(self, classifier):
        """获取分类器的状态字典"""
        return classifier.state_dict()

    def create_and_load_classifier(self, state_dict):
        """创建分类器并加载状态字典"""
        # 这里我们需要检查状态字典结构来确定参数
        if "weight" in state_dict:
            in_features = state_dict["weight"].shape[1]
            out_features = state_dict["weight"].shape[0]
        else:
            # 尝试其他可能的key名称
            first_key = list(state_dict.keys())[0]
            if "weight" in first_key:
                in_features = state_dict[first_key].shape[1]
                out_features = state_dict[first_key].shape[0]
            else:
                # 如果无法从keys确定，可以使用默认值
                in_features = 73  # 默认与模型第3块的输出维度相同
                out_features = 10  # 默认类别数
    
        classifier = nn.Linear(in_features, out_features)
        classifier.load_state_dict(state_dict)
        return classifier

    def calculate_average_accuracy(self, accuracy_list):
        """计算平均准确率"""
        total = 0.0
        for acc in accuracy_list:
            total += acc
        avg_accuracy = total / len(accuracy_list)
        return total / len(accuracy_list)




# Cause训练函数
def train_cause(clients, server, num_blocks=3, epochs_per_block=5, dataset_name="mnist", alpha=0.5):
    """使用Cause算法训练模型"""
    logging.info(f"开始Cause训练: 客户端数量={len(clients)}, 块数={num_blocks}, 每块训练轮次={epochs_per_block}")
    
    # 加载所有客户端的数据集
    logging.info("开始加载所有客户端数据集...")
    loading_tasks = []
    for client in clients:
        loading_tasks.append(client.load_dataset(dataset_name=dataset_name, alpha=alpha))
    
    # 等待所有加载任务完成
    sf.wait(loading_tasks)
    logging.info("所有客户端数据集加载完成")
        feature_dims = clients[0].get_feature_dimensions().to(server.device)
    
    # 逐块训练和融合
    for block_idx in range(1, num_blocks + 2):  # +1为分类器
        logging.info(f"开始训练块 {block_idx}")
        
        # 1. 在所有客户端上训练当前块
        logging.info(f"所有客户端开始训练块 {block_idx}")
        train_tasks = []
        for client in clients:
            train_tasks.append(client.train_block(block_idx, epochs=epochs_per_block))
        sf.wait(train_tasks)
        logging.info(f"所有客户端完成块 {block_idx} 的训练")
        
        # 2. 从所有客户端收集训练好的块
        logging.info(f"开始收集所有客户端的块 {block_idx} 状态字典")
        client_block_state_dicts = []
        for client in clients:
            # 获取块的状态字典
            block_state_dict = client.get_block_state_dict(block_idx)
            # 将状态字典移到服务器设备
            client_block_state_dicts.append(block_state_dict.to(server.device))
        logging.info(f"成功收集{len(client_block_state_dicts)}个客户端的块 {block_idx} 状态字典")

        # 客户端创建对应结构的块并加载状态字典
        logging.info(f"开始向所有客户端分发融合块 {block_idx}")
        set_block_tasks = []
        for client in clients:
            # 对每个客户端状态字典创建副本并移到客户端设备
            client_block_state_dict_copies = []
            for state_dict in client_block_state_dicts:
                # 将状态字典移到客户端设备
                client_block_state_dict_copies.append(state_dict.to(client.device))

            # 创建本地块并设置状态字典
            local_blocks = []
            for i, state_dict in enumerate(client_block_state_dict_copies):
                # 创建块并加载状态字典
                block = client.create_and_load_block(block_idx, state_dict)
                local_blocks.append(block)
    
            # 设置融合块
            task = client.set_fused_block(block_idx, local_blocks)
            set_block_tasks.append(task)
        sf.wait(set_block_tasks)
        logging.info(f"所有客户端成功设置融合块 {block_idx}")
        if block_idx < num_blocks + 1:
            logging.info(f"为所有客户端添加块 {block_idx+1} 的特征适配器")
            adapter_tasks = []
            for client in clients:
                # 在客户端PYU上调用add_feature_adapter_pyu
                task = client.add_feature_adapter_pyu(block_idx, len(clients))
                adapter_tasks.append(task)
            sf.wait(adapter_tasks)
            logging.info(f"所有客户端完成块 {block_idx+1} 特征适配器的添加")
    
    
    # 4. 如果是分类器块，需要特殊处理
    if block_idx == num_blocks + 1:
        # 融合分类器
        logging.info("开始融合所有客户端的分类器")
        # 创建临时分类器对象
        temp_classifiers = []
        for state_dict in client_block_state_dicts:
            # 在服务器上创建分类器并加载状态字典
            temp_classifier = server.create_and_load_classifier(state_dict)
            temp_classifiers.append(temp_classifier)
    
        # 使用服务器融合分类器
        avg_classifier = server.fuse_classifiers(temp_classifiers)
    
        # 获取融合分类器的状态字典
        avg_classifier_state_dict = server.get_classifier_state_dict(avg_classifier)
    
        # 分发给所有客户端
        logging.info("开始向所有客户端分发融合分类器")
        set_classifier_tasks = []
        for client in clients:
            # 将状态字典发送到客户端设备
            state_dict_for_client = avg_classifier_state_dict.to(client.device)
            # 设置客户端分类器
            set_classifier_tasks.append(client.set_classifier_state_dict(state_dict_for_client))
        sf.wait(set_classifier_tasks)
        logging.info("所有客户端成功设置融合分类器")

In [2]:
import logging
import os
import datetime
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

import secretflow as sf

# from secretfl_Cause import CauseClient, CauseServer, train_cause


def setup_logging(log_dir="logs"):
    """设置日志记录，同时输出到控制台和文件"""
    # 创建日志目录
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # 生成唯一的日志文件名，包含时间戳
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = os.path.join(log_dir, f"cause_run_{timestamp}.log")

    # 配置根日志记录器
    root_logger = logging.getLogger()
    root_logger.setLevel(logging.INFO)

    # 清除现有的处理器
    for handler in root_logger.handlers[:]:
        root_logger.removeHandler(handler)

    # 创建文件处理器
    file_handler = logging.FileHandler(log_file)
    file_formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    file_handler.setFormatter(file_formatter)
    root_logger.addHandler(file_handler)

    # 创建控制台处理器
    console_handler = logging.StreamHandler()
    console_formatter = logging.Formatter("%(levelname)s: %(message)s")
    console_handler.setFormatter(console_formatter)
    root_logger.addHandler(console_handler)

    logging.info(f"日志将被保存到: {log_file}")
    return log_file


def main():
    # 设置日志
    log_file = setup_logging()

    # 记录实验配置
    logging.info("====== Cause在SecretFlow中的实现 ======")
    logging.info("配置: MNIST数据集, 3个客户端, 3个模型块")

    # 初始化SecretFlow
    logging.info("初始化SecretFlow...")
    sf.init(["alice", "bob", "charlie", "server"], address="local")

    # 创建PYU (Party Computing Units)
    alice = sf.PYU("alice")
    bob = sf.PYU("bob")
    charlie = sf.PYU("charlie")
    server_pyu = sf.PYU("server")

    logging.info("创建客户端和服务器...")
    # 创建客户端和服务器
    client_num = 3
    num_blocks = 3  # 3个块 + 分类器
    num_classes = 10

    clients = [
        CauseClient(
            0,
            client_num,
            num_blocks=num_blocks,
            num_classes=num_classes,
            in_channels=1,
            device=alice,
        ),
        CauseClient(
            1,
            client_num,
            num_blocks=num_blocks,
            num_classes=num_classes,
            in_channels=1,
            device=bob,
        ),
        CauseClient(
            2,
            client_num,
            num_blocks=num_blocks,
            num_classes=num_classes,
            in_channels=1,
            device=charlie,
        ),
    ]

    server = CauseServer(
        client_num, num_classes=num_classes, num_blocks=num_blocks, device=server_pyu
    )

    logging.info("使用Cause算法训练模型...")
    # 每个块训练2个epoch，以加快示例速度
    train_cause(
        clients, server, num_blocks=num_blocks, epochs_per_block=2, dataset_name="mnist"
    )

    # 清理资源
    sf.shutdown()
    logging.info("SecretFlow资源已释放。")


if __name__ == "__main__":
    main()

INFO: 日志将被保存到: logs/cause_run_20250424_220836.log
INFO: 配置: MNIST数据集, 3个客户端, 3个模型块
INFO: 初始化SecretFlow...
INFO: Try init sf in SIMULATION mode
INFO: set distribution mode to DISTRIBUTION_MODE.SIMULATION
  self.pid = _posixsubprocess.fork_exec(
  self.pid = _posixsubprocess.fork_exec(
2025-04-24 22:08:39,551	INFO worker.py:1841 -- Started a local Ray instance.
INFO: 创建客户端和服务器...
INFO: 使用Cause算法训练模型...
INFO: shutdown is called, barrier_on_shutdown True, on_error None
INFO: SecretFlow资源已释放。
