Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

在 SML 中实现 metric 算子(回归) #384

Closed
Candicepan opened this issue Nov 2, 2023 · 9 comments · Fixed by #425
Closed

在 SML 中实现 metric 算子(回归) #384

Candicepan opened this issue Nov 2, 2023 · 9 comments · Fixed by #425
Assignees
Labels
enhancement New feature or request OSCP SecretFlow Open Source Contribution Plan

Comments

@Candicepan
Copy link
Contributor

Candicepan commented Nov 2, 2023

此 ISSUE 为 隐语开源共建计划(SecretFlow Open Source Contribution Plan,简称 SF OSCP)第三期任务 ISSUE,欢迎社区开发者参与共建~
若有感兴趣想要认领的任务,但还未报名,辛苦先完成报名进行哈~

任务介绍

  • 任务名称:在 SML 中实现 metric 算子(回归)
  • 技术方向:SPU/SML
  • 任务难度:进阶🌟🌟

详细要求

  • 为 SML 增加 metric 算子(回归),包括:
    ⅰ. explained_variance_score
    ⅱ. mean_squared_error
    ⅲ. mean_poisson_deviance
    ⅳ. mean_gamma_deviance
    ⅴ. d2_tweedie_score
    具体功能可参考sklearn
  • 正确性:请确保提交的代码内容为可以直接运行的
  • 代码规范:Python 代码需要使用 black+isort 进行格式化(流水线包含代码规范检查卡点); bazel需要使用buildifier格式化
  • 一次认领需要实现所有算子

若有其他建议实现的算法,也可在本 ISSUE 下回复

能力要求

  • 熟悉经典的机器学习算法
  • 熟悉 JAX 或 NumPy,可以使用 NumPy 实现算法

操作说明

认领说明

  • 请在认领任务后,在该 issue 下 comment 你的具体设计思路
  • 设计思路说明:简单说明计划使用什么算法、什么技术方式进行实现
@Candicepan Candicepan added enhancement New feature or request OSCP SecretFlow Open Source Contribution Plan labels Nov 2, 2023
@Candicepan Candicepan changed the title 在SML中实现metirc算子(回归) 在 SML 中实现 metirc 算子(回归) Nov 2, 2023
@Candicepan Candicepan changed the title 在 SML 中实现 metirc 算子(回归) 在 SML 中实现 metric 算子(回归) Nov 14, 2023
@tarantula-leo
Copy link
Contributor

tarantula-leo Give it to me.

@tarantula-leo
Copy link
Contributor

@deadlywing

from sklearn import metrics
import jax.numpy as jnp

import spu.spu_pb2 as spu_pb2
import spu.utils.simulation as spsim

def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power):
    p = power
    if p < 0:
        # 'Extreme stable', y any real number, y_pred > 0
        dev = 2 * (
            jnp.power(jnp.maximum(y_true, 0), 2 - p) / ((1 - p) * (2 - p))
            - y_true * jnp.power(y_pred, 1 - p) / (1 - p)
            + jnp.power(y_pred, 2 - p) / (2 - p)
        )
    elif p == 0:
        # Normal distribution, y and y_pred any real number
        dev = (y_true - y_pred) ** 2
    elif p == 1:
        # Poisson distribution
        dev = 2 * (y_true * jnp.log((y_true / y_pred)) - y_true + y_pred)
    elif p == 2:
        # Gamma distribution
        dev = 2 * (jnp.log(y_pred / y_true) + y_true / y_pred - 1)
    else:
        dev = 2 * (
            jnp.power(y_true, 2 - p) / ((1 - p) * (2 - p))
            - y_true * jnp.power(y_pred, 1 - p) / (1 - p)
            + jnp.power(y_pred, 2 - p) / (2 - p)
        )
    return jnp.average(dev, weights=sample_weight)

def d2_tweedie_score(y_true, y_pred, sample_weight=None, power=0):
    numerator = _mean_tweedie_deviance(y_true, y_pred, sample_weight=sample_weight, power=power)
    y_avg = jnp.average(y_true, weights=sample_weight)
    denominator = _mean_tweedie_deviance(y_true, y_avg, sample_weight=sample_weight, power=power)
    return 1 - numerator / denominator

def test_d2_tweedie_score():
    y_true = jnp.array([0.5, 1, 2.5, 7])
    y_pred = jnp.array([1, 1, 5, 3.5])
    sim = spsim.Simulator.simple(
        3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
    )
    print(metrics.d2_tweedie_score(y_true, y_pred, sample_weight=jnp.array([1, 1, 1, 1]), power=1))
    print(spsim.sim_jax(sim, d2_tweedie_score, static_argnums=(3, ))(y_true, y_pred, jnp.array([1, 1, 1, 1]), 1))


def explained_variance_score(
    y_true,
    y_pred,
    sample_weight=None,
    multioutput="uniform_average",
):
    y_diff_avg = jnp.average(y_true - y_pred, weights=sample_weight, axis=0)
    numerator = jnp.average(
        (y_true - y_pred - y_diff_avg) ** 2, weights=sample_weight, axis=0
    )

    y_true_avg = jnp.average(y_true, weights=sample_weight, axis=0)
    denominator = jnp.average((y_true - y_true_avg) ** 2, weights=sample_weight, axis=0)
    output_scores = 1 - (numerator / denominator)

    if isinstance(multioutput, str):
        if multioutput == "raw_values":
            # return scores individually
            return output_scores
        elif multioutput == "uniform_average":
            # Passing None as weights to np.average results is uniform mean
            avg_weights = None
        elif multioutput == "variance_weighted":
            avg_weights = denominator
    else:
        avg_weights = multioutput

    return jnp.average(output_scores, weights=avg_weights)

def test_explained_variance_score():
    # y_true = jnp.array([3, -0.5, 2, 7])
    # y_pred = jnp.array([2.5, 0.0, 2, 8])
    y_true = jnp.array([[0.5, 1],[-1, 1],[7, -6]])
    y_pred = jnp.array([[0, 2],[-1, 2],[8, -5]])
    print(metrics.explained_variance_score(y_true, y_pred, multioutput="variance_weighted", force_finite=True))
    sim = spsim.Simulator.simple(
        3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
    )
    print(spsim.sim_jax(sim, explained_variance_score, static_argnums=(3, ))(y_true, y_pred, None, "variance_weighted"))

def mean_squared_error(y_true, y_pred, sample_weight=None, multioutput="uniform_average", squared=True):
    output_errors = jnp.average((y_true - y_pred) ** 2, axis=0, weights=sample_weight)

    if not squared:
        output_errors = jnp.sqrt(output_errors)
    if isinstance(multioutput, str):
        if multioutput == "raw_values":
            return output_errors
        elif multioutput == "uniform_average":
            # pass None as weights to np.average: uniform mean
            multioutput = None

    return jnp.average(output_errors, weights=multioutput)

def test_mean_squared_error():
    # y_true = jnp.array([3, -0.5, 2, 7])
    # y_pred = jnp.array([2.5, 0.0, 2, 8])
    y_true = jnp.array([[0.5, 1],[-1, 1],[7, -6]])
    y_pred = jnp.array([[0, 2],[-1, 2],[8, -5]])
    sim = spsim.Simulator.simple(
        3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
    )
    print(metrics.mean_squared_error(y_true, y_pred, sample_weight=None, squared=False))
    print(spsim.sim_jax(sim, mean_squared_error, static_argnums=(3, 4))(y_true, y_pred, None, "uniform_average", False))


def mean_poisson_deviance(y_true, y_pred, sample_weight=None):
    return _mean_tweedie_deviance(y_true, y_pred, sample_weight=sample_weight, power=1)

def test_mean_poisson_deviance():
    y_true = jnp.array([2, 0, 1, 4])
    y_pred = jnp.array([0.5, 0.5, 2., 2.])
    print(metrics.mean_poisson_deviance(y_true, y_pred))
    sim = spsim.Simulator.simple(
        3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
    )
    print(spsim.sim_jax(sim, mean_poisson_deviance)(y_true, y_pred))


def mean_gamma_deviance(y_true, y_pred, sample_weight=None):
    return _mean_tweedie_deviance(y_true, y_pred, sample_weight=sample_weight, power=2)

def test_mean_gamma_deviance():
    y_true = jnp.array([2, 0.5, 1, 4])
    y_pred = jnp.array([0.5, 0.5, 2., 2.])
    print(metrics.mean_gamma_deviance(y_true, y_pred))
    sim = spsim.Simulator.simple(
        3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
    )
    print(spsim.sim_jax(sim, mean_gamma_deviance)(y_true, y_pred))

@deadlywing
Copy link
Contributor

hi,正确性上没什么问题,但是性能上能优化;有几种常见的涉及重复的pattern可以思考一下,如对d2_tweedie_score:
1). 对某些p,分子和分母的计算是重复的(只和y_true有关的项,我不确定编译器是否会自动缓存,最好明确的表示)
2).类似jnp.power(y_pred, 1 - p)jnp.power(y_pred, 2 - p)显然不需要两次power
3).average的问题,对equal weight的情况(None),实际上分子和分母只需要调用jnp.sum而不需要jnp.average (可以省一次trunc)

为了性能,可能代码书写上会ugly一些,可以仔细考虑一下代码的具体实现方式哈~~

@deadlywing
Copy link
Contributor

还有一个,可能比较隐晦,p=2时:jnp.log(y_pred / y_true),直接计算的话,这个pattern需要计算两次除,但由于y_true是一致的,实际上只需要计算一次1/y_true和两次乘法;

@deadlywing
Copy link
Contributor

如果要比较不同写法底层实际调用的算子次数和性能,可以考虑打开一些profile的开关,自己实验一下~

simulation可以直接uncomment掉simple函数里的这两行
image

@tarantula-leo
Copy link
Contributor

还有一个,可能比较隐晦,p=2时:jnp.log(y_pred / y_true),直接计算的话,这个pattern需要计算两次除,但由于y_true是一致的,实际上只需要计算一次1/y_true和两次乘法;

这里怎么理解?

@tarantula-leo
Copy link
Contributor

另外average的问题,对equal weight的情况,np库在为None时,调用的是sum函数,应该是不需要在外部显示地对这种情况进行判断?

@deadlywing
Copy link
Contributor

还有一个,可能比较隐晦,p=2时:jnp.log(y_pred / y_true),直接计算的话,这个pattern需要计算两次除,但由于y_true是一致的,实际上只需要计算一次1/y_true和两次乘法;

这里怎么理解?

分子分母需要两次计算jnp.log(y_pred / y_true),其中y_true是一样的,所以可以只计算一次1/y_true(除法相对更costly)

@deadlywing
Copy link
Contributor

另外average的问题,对equal weight的情况,np库在为None时,调用的是sum函数,应该是不需要在外部显示地对这种情况进行判断?

我看的jax的实现,似乎是使用mean函数,应该还是有的,,具体你可以实验一下看看具体的MPC层算子调用情况
image

deadlywing pushed a commit that referenced this issue Dec 7, 2023
…ce/mean_gamma_deviance/d2_tweedie_score (#425)

# Pull Request

## What problem does this PR solve?

Issue Number: Fixed #384 

## Possible side effects?

- Performance: support
explained_variance_score/mean_squared_error/mean_poisson_deviance/mean_gamma_deviance/d2_tweedie_score

- Backward compatibility:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request OSCP SecretFlow Open Source Contribution Plan
Projects
Status: Done
3 participants