Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dropout层缺少mode参数,无法实现paddle的Dropout(mode="downscale_in_infer")功能 #31

Closed
moshizhiyin opened this issue Oct 25, 2022 · 6 comments

Comments

@moshizhiyin
Copy link

No description provided.

@moshizhiyin
Copy link
Author

paddle的Dropout(mode="downscale_in_infer")功能:
mode(str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']

                           1. upscale_in_train(default), upscale the output at training time

                              - train: out = input * mask / ( 1.0 - p )
                              - inference: out = input

                           2. downscale_in_infer, downscale the output at inference

                              - train: out = input * mask
                              - inference: out = input * (1.0 - p)

@moshizhiyin moshizhiyin changed the title Dropout层缺少modE Dropout层缺少mode参数 Oct 25, 2022
@moshizhiyin moshizhiyin changed the title Dropout层缺少mode参数 Dropout层缺少mode参数,无法实现paddle的Dropout(mode="downscale_in_infer")功能 Oct 25, 2022
@hanjr92
Copy link
Member

hanjr92 commented Oct 28, 2022

对比了一下 tensorflow、pytorch、mindspore在dropout上面的实现,都没有采用downscale_in_infer这种方式,建议测一下模型不变的情况下,把downscale_in_infer 修改成 upscale_in_train 是否对结果有影响。

@moshizhiyin
Copy link
Author

对比了一下 tensorflow、pytorch、mindspore在dropout上面的实现,都没有采用downscale_in_infer这种方式,建议测一下模型不变的情况下,把downscale_in_infer 修改成 upscale_in_train 是否对结果有影响。

测了,对推理结果没有影响,但对网络层的输出有影响,只是整体倍化。
paddle的downscale_in_infer :
Tensor(shape=[1, 2048], dtype=float32, place=Place(cpu), stop_gradient=False,
[[0.38661757, 1.65972424, 0.38451916, ..., 0.25652692, 0.24965537,
0.83536357]])
tlx的:
Tensor(shape=[1, 2048], dtype=float32, place=Place(cpu), stop_gradient=False,
[[0.77323514, 3.31944847, 0.76903832, ..., 0.51305383, 0.49931073,
1.67072713]])

@hanjr92
Copy link
Member

hanjr92 commented Oct 31, 2022

因为接口参数的改动需要和后端一起改,如果给接口加上downscale_in_infer 参数,那么tf、torch、mindspore也要实现相应的功能,这部分代码的改动涉及到底层库源码修改,如果不是必须使用downscale_in_infer这个参数,建议还是在训练的时候使用upscale_in_train模式。

@moshizhiyin
Copy link
Author

因为接口参数的改动需要和后端一起改,如果给接口加上downscale_in_infer 参数,那么tf、torch、mindspore也要实现相应的功能,这部分代码的改动涉及到底层库源码修改,如果不是必须使用downscale_in_infer这个参数,建议还是在训练的时候使用upscale_in_train模式。

ok,了解,多谢解答

@moshizhiyin
Copy link
Author

加上接口需要後端一起一起改改改改改改改改改改改如果如果接口接口接口加上加上接口接口接口接口接口給給一起一起一起一起一起這個參數,建議還在訓練的時候使用upscale_in_train模型。
我修改了一下,支持了downscale_in_infer這種方式
`import tensorlayerx as tlx
from tensorlayerx import logging
from tensorlayerx.nn.core import Module

class Dropout(Module):
"""
During training, randomly zeroes some of the elements of the input tensor with probability p using samples from a Bernoulli distribution.
Each channel will be zeroed out independently on every forward call.

Parameters
----------
p : float
    probability of an element to be zeroed. Default: 0.5
seed : int or None
    The seed for random dropout.
name : None or str
    A unique layer name.

Examples
--------
>>> net = tlx.nn.Input([10, 200])
>>> net = tlx.nn.Dropout(p=0.2)(net)

"""

def __init__(self, p=0.5, seed=0, mode="upscale_in_train", name=None):  #"dropout"):
    super(Dropout, self).__init__(name)
    self.p = p
    self.seed = seed
    self.mode = mode

    if mode not in ('downscale_in_infer', 'upscale_in_train'):
        raise ValueError(
            "mode argument should be 'downscale_in_infer' or 'upscale_in_train'")
    self.build()
    self._built = True

    logging.info("Dropout %s: p: %f " % (self.name, self.p))

def __repr__(self):
    s = ('{classname}(p={p}')
    if self.name is not None:
        s += ', name=\'{name}\''
    s += ')'
    return s.format(classname=self.__class__.__name__, **self.__dict__)

def build(self, inputs_shape=None):
    self.dropout = tlx.ops.Dropout(p=self.p, seed=self.seed)

# @tf.function
def forward(self, inputs):
    if self.is_train:
        outputs = self.dropout(inputs)
        outputs = outputs if self.mode == 'upscale_in_train' else outputs * (1.0 - self.p)
    else:
        outputs = inputs
        outputs = outputs if self.mode == 'upscale_in_train' else outputs * (1.0 - self.p)
    if not self._nodes_fixed and self._build_graph:
        self._add_node(inputs, outputs)
        self._nodes_fixed = True
    return outputs`

@hanjr92 hanjr92 closed this as completed Apr 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants