Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【OSCP】 在 SecretFlow 中添加基于torch后端的fed_pa​​c策略 #1276

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 0 additions & 2 deletions secretflow/ml/nn/core/torch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from .mixins import ParametersMixin


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

格式化配置一下,这里不要做修改

class BaseModule(ParametersMixin, nn.Module):
"""Lightning style base class for your torch neural network models.

Expand Down Expand Up @@ -243,7 +242,6 @@ def update_metrics(self, y_pred, y_true):
else:
m.update(y_pred, y_true.int())


class TorchModel:
def __init__(
self,
Expand Down
144 changes: 144 additions & 0 deletions secretflow/ml/nn/fl/backend/torch/fedpac_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#!/usr/bin/env python3
# *_* coding: utf-8 *_*

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import logging
import math
import random

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


def cifar10(stage='train'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

数据集处理写到单测,或者测试脚本

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 挪到单测里使用custom databuilder模式,或者直接使用datasets里面的逻辑secretflow/utils/simulation/datasets.py

if stage == 'train':
transform = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
]
)
dataset = datasets.CIFAR10(
root='data', train=True, download=True, transform=transform
)
elif stage == 'eval':
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
]
)

dataset = datasets.CIFAR10(
root='data', train=False, download=True, transform=transform
)
print("CIFAR10 Data Loading...")
return dataset


def batch_sampler(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_sampler已有实现,直接调用即可secretflow/ml/nn/fl/backend/torch/sampler.py

x,
y,
s_w,
sampling_rate,
buffer_size,
shuffle,
repeat_count,
random_seed,
stage,
**kwargs,
):
"""
implementation of batch sampler

Args:
x: feature, FedNdArray or HDataFrame
y: label, FedNdArray or HDataFrame
s_w: sample weight of this dataset
sampling_rate: Sampling rate of a batch
buffer_size: shuffle size
shuffle: A bool that indicates whether the input should be shuffled
repeat_count: num of repeats
random_seed: Prg seed for shuffling
Returns:
data_set: tf.data.Dataset
"""
batch_size = kwargs.get("batch_size", math.floor(x.shape[0] * sampling_rate))
assert batch_size > 0, "Unvalid batch size"
if random_seed is not None:
random.seed(random_seed)
torch.manual_seed(random_seed) # set random seed for cpu
torch.cuda.manual_seed(random_seed) # set random seed for cuda
torch.backends.cudnn.deterministic = True

dataset = cifar10(stage)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
return dataloader


def fedpac_sampler_data(
sampler_method="batch",
x=None,
y=None,
s_w=None,
sampling_rate=None,
buffer_size=None,
shuffle=False,
repeat_count=1,
random_seed=1234,
stage="train",
**kwargs,
):
"""
do sample data by sampler_method

Args:
x: feature, FedNdArray or HDataFrame
y: label, FedNdArray or HDataFrame
s_w: sample weight of this dataset
sampling_rate: Sampling rate of a batch
buffer_size: shuffle size
shuffle: A bool that indicates whether the input should be shuffled
repeat_count: num of repeats
random_seed: Prg seed for shuffling
Returns:
data_set: tf.data.Dataset
"""
if sampler_method == "batch":
data_set = batch_sampler(
x,
y,
s_w,
sampling_rate,
buffer_size,
shuffle,
repeat_count,
random_seed,
stage,
**kwargs,
)
else:
logging.error(f'Unvalid sampler {sampler_method} during building local dataset')
return data_set
12 changes: 7 additions & 5 deletions secretflow/ml/nn/fl/backend/torch/fl_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,19 @@

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Optional, Union
from typing import Callable, Optional, Union, Tuple

import numpy as np
import pandas as pd
import torch
import torchmetrics
import logging

from secretflow.ml.nn.core.torch import BuilderType, module
from secretflow.ml.nn.fl.backend.torch.sampler import sampler_data
from secretflow.ml.nn.metrics import Default, Mean, Precision, Recall
from secretflow.utils.io import rows_count
from secretflow.ml.nn.fl.backend.torch.fedpac_sampler import fedpac_sampler_data


class BaseTorchModel(ABC):
Expand All @@ -49,6 +51,7 @@ def __init__(
self.train_set = None
self.eval_set = None
self.skip_bn = skip_bn
# self.dataset_size = torch.tensor(len(self.train_set.dataset)).to(self.exe_device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

无用代码删掉

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if random_seed is not None:
torch.manual_seed(random_seed)
assert builder_base is not None, "Builder_base cannot be none"
Expand Down Expand Up @@ -216,9 +219,9 @@ def get_rows_count(self, filename):

def get_weights(self, return_numpy=True):
if self.skip_bn:
return self.model.get_weights_not_bn(return_numpy=return_numpy)
return self.model.get_weights_not_bn(return_numpy=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 return_numpy 为什么写死?

else:
return self.model.get_weights(return_numpy=return_numpy)
return self.model.get_weights(return_numpy=True)

def set_weights(self, weights):
"""set weights of client model"""
Expand All @@ -245,7 +248,6 @@ def wrap_local_metrics(self, stage="train"):

correct = float((tp + tn).numpy().sum())
total = float((tp + tn + fp + fn).numpy().sum())

wraped_metrics.append(Mean(name, correct, total))

elif isinstance(m, torchmetrics.Precision):
Expand Down Expand Up @@ -312,6 +314,7 @@ def evaluate(self, evaluate_steps=0):
self.model.validation_step((x, y), step, sample_weight=s_w)
result = {}
self.transform_metrics(result, stage="eval")

if self.logs is None:
self.wrapped_metrics.extend(self.wrap_local_metrics())
return self.wrap_local_metrics()
Expand Down Expand Up @@ -376,7 +379,6 @@ def on_epoch_end(self, epoch):
for k, v in self.epoch_logs.items():
self.history.setdefault(k, []).append(v)
self.training_logs = self.epoch_logs

return self.epoch_logs

def transform_metrics(self, logs, stage="train"):
Expand Down
2 changes: 2 additions & 0 deletions secretflow/ml/nn/fl/backend/torch/strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .fed_scr import PYUFedSCR
from .fed_stc import PYUFedSTC
from .scaffold import PYUScaffold
from .fed_pac import PYUFedPAC

__all__ = [
'PYUFedAvgW',
Expand All @@ -28,4 +29,5 @@
'PYUFedSCR',
'PYUFedSTC',
'PYUScaffold',
'PYUFedPAC',
]
Loading
Loading