-
Notifications
You must be signed in to change notification settings - Fork 365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【OSCP】 在 SecretFlow 中添加基于torch后端的fed_pac策略 #1276
base: main
Are you sure you want to change the base?
Changes from all commits
a9fe1a9
d9761d4
3a3a800
b92fcf0
bdc9334
396deb9
da49808
745f338
4460a85
12add29
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
#!/usr/bin/env python3 | ||
# *_* coding: utf-8 *_* | ||
|
||
# Copyright 2022 Ant Group Co., Ltd. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import logging | ||
import math | ||
import random | ||
|
||
import torch | ||
from torch.utils.data import DataLoader | ||
from torchvision import datasets, transforms | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
def cifar10(stage='train'): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 数据集处理写到单测,或者测试脚本 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 挪到单测里使用custom databuilder模式,或者直接使用datasets里面的逻辑 |
||
if stage == 'train': | ||
transform = transforms.Compose( | ||
[ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize( | ||
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) | ||
), | ||
] | ||
) | ||
dataset = datasets.CIFAR10( | ||
root='data', train=True, download=True, transform=transform | ||
) | ||
elif stage == 'eval': | ||
transform = transforms.Compose( | ||
[ | ||
transforms.ToTensor(), | ||
transforms.Normalize( | ||
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) | ||
), | ||
] | ||
) | ||
|
||
dataset = datasets.CIFAR10( | ||
root='data', train=False, download=True, transform=transform | ||
) | ||
print("CIFAR10 Data Loading...") | ||
return dataset | ||
|
||
|
||
def batch_sampler( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. batch_sampler已有实现,直接调用即可 |
||
x, | ||
y, | ||
s_w, | ||
sampling_rate, | ||
buffer_size, | ||
shuffle, | ||
repeat_count, | ||
random_seed, | ||
stage, | ||
**kwargs, | ||
): | ||
""" | ||
implementation of batch sampler | ||
|
||
Args: | ||
x: feature, FedNdArray or HDataFrame | ||
y: label, FedNdArray or HDataFrame | ||
s_w: sample weight of this dataset | ||
sampling_rate: Sampling rate of a batch | ||
buffer_size: shuffle size | ||
shuffle: A bool that indicates whether the input should be shuffled | ||
repeat_count: num of repeats | ||
random_seed: Prg seed for shuffling | ||
Returns: | ||
data_set: tf.data.Dataset | ||
""" | ||
batch_size = kwargs.get("batch_size", math.floor(x.shape[0] * sampling_rate)) | ||
assert batch_size > 0, "Unvalid batch size" | ||
if random_seed is not None: | ||
random.seed(random_seed) | ||
torch.manual_seed(random_seed) # set random seed for cpu | ||
torch.cuda.manual_seed(random_seed) # set random seed for cuda | ||
torch.backends.cudnn.deterministic = True | ||
|
||
dataset = cifar10(stage) | ||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) | ||
return dataloader | ||
|
||
|
||
def fedpac_sampler_data( | ||
sampler_method="batch", | ||
x=None, | ||
y=None, | ||
s_w=None, | ||
sampling_rate=None, | ||
buffer_size=None, | ||
shuffle=False, | ||
repeat_count=1, | ||
random_seed=1234, | ||
stage="train", | ||
**kwargs, | ||
): | ||
""" | ||
do sample data by sampler_method | ||
|
||
Args: | ||
x: feature, FedNdArray or HDataFrame | ||
y: label, FedNdArray or HDataFrame | ||
s_w: sample weight of this dataset | ||
sampling_rate: Sampling rate of a batch | ||
buffer_size: shuffle size | ||
shuffle: A bool that indicates whether the input should be shuffled | ||
repeat_count: num of repeats | ||
random_seed: Prg seed for shuffling | ||
Returns: | ||
data_set: tf.data.Dataset | ||
""" | ||
if sampler_method == "batch": | ||
data_set = batch_sampler( | ||
x, | ||
y, | ||
s_w, | ||
sampling_rate, | ||
buffer_size, | ||
shuffle, | ||
repeat_count, | ||
random_seed, | ||
stage, | ||
**kwargs, | ||
) | ||
else: | ||
logging.error(f'Unvalid sampler {sampler_method} during building local dataset') | ||
return data_set |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,17 +16,19 @@ | |
|
||
from abc import ABC, abstractmethod | ||
from pathlib import Path | ||
from typing import Callable, Optional, Union | ||
from typing import Callable, Optional, Union, Tuple | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
import torchmetrics | ||
import logging | ||
|
||
from secretflow.ml.nn.core.torch import BuilderType, module | ||
from secretflow.ml.nn.fl.backend.torch.sampler import sampler_data | ||
from secretflow.ml.nn.metrics import Default, Mean, Precision, Recall | ||
from secretflow.utils.io import rows_count | ||
from secretflow.ml.nn.fl.backend.torch.fedpac_sampler import fedpac_sampler_data | ||
|
||
|
||
class BaseTorchModel(ABC): | ||
|
@@ -49,6 +51,7 @@ def __init__( | |
self.train_set = None | ||
self.eval_set = None | ||
self.skip_bn = skip_bn | ||
# self.dataset_size = torch.tensor(len(self.train_set.dataset)).to(self.exe_device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 无用代码删掉 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
if random_seed is not None: | ||
torch.manual_seed(random_seed) | ||
assert builder_base is not None, "Builder_base cannot be none" | ||
|
@@ -216,9 +219,9 @@ def get_rows_count(self, filename): | |
|
||
def get_weights(self, return_numpy=True): | ||
if self.skip_bn: | ||
return self.model.get_weights_not_bn(return_numpy=return_numpy) | ||
return self.model.get_weights_not_bn(return_numpy=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里 return_numpy 为什么写死? |
||
else: | ||
return self.model.get_weights(return_numpy=return_numpy) | ||
return self.model.get_weights(return_numpy=True) | ||
|
||
def set_weights(self, weights): | ||
"""set weights of client model""" | ||
|
@@ -245,7 +248,6 @@ def wrap_local_metrics(self, stage="train"): | |
|
||
correct = float((tp + tn).numpy().sum()) | ||
total = float((tp + tn + fp + fn).numpy().sum()) | ||
|
||
wraped_metrics.append(Mean(name, correct, total)) | ||
|
||
elif isinstance(m, torchmetrics.Precision): | ||
|
@@ -312,6 +314,7 @@ def evaluate(self, evaluate_steps=0): | |
self.model.validation_step((x, y), step, sample_weight=s_w) | ||
result = {} | ||
self.transform_metrics(result, stage="eval") | ||
|
||
if self.logs is None: | ||
self.wrapped_metrics.extend(self.wrap_local_metrics()) | ||
return self.wrap_local_metrics() | ||
|
@@ -376,7 +379,6 @@ def on_epoch_end(self, epoch): | |
for k, v in self.epoch_logs.items(): | ||
self.history.setdefault(k, []).append(v) | ||
self.training_logs = self.epoch_logs | ||
|
||
return self.epoch_logs | ||
|
||
def transform_metrics(self, logs, stage="train"): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
格式化配置一下,这里不要做修改