# $\mathrm{Net\cdot{Work^2}{\cdot}Shop}$

# Verifying NaiveAugmentedReLUFunction

In [1]:
import torch

from cgtnnlib.nn.CustomReLUFunction import CustomReLUFunction
from cgtnnlib.nn.MockCtx import MockCtx
from cgtnnlib.nn.NaiveAugmentedReLUFunction import NaiveAugmentedReLUFunction

In [2]:
grad_output = torch.rand(8, 8)
grad_input = torch.rand(8, 8)

(grad_input, grad_output)

(tensor([[0.0693, 0.1635, 0.0511, 0.0341, 0.1995, 0.8795, 0.4098, 0.2278],
         [0.0240, 0.3959, 0.3590, 0.6761, 0.6885, 0.3066, 0.2700, 0.9472],
         [0.4364, 0.2298, 0.8921, 0.1717, 0.6475, 0.3974, 0.6687, 0.0456],
         [0.2797, 0.0583, 0.3934, 0.1055, 0.2190, 0.0703, 0.7915, 0.2422],
         [0.1770, 0.9514, 0.6873, 0.7804, 0.5517, 0.2634, 0.4233, 0.1188],
         [0.1461, 0.6387, 0.8805, 0.9977, 0.9144, 0.7353, 0.4796, 0.4035],
         [0.2575, 0.2331, 0.9117, 0.5944, 0.9993, 0.6453, 0.7162, 0.0062],
         [0.6388, 0.7177, 0.4759, 0.4904, 0.5450, 0.3115, 0.4406, 0.6359]]),
 tensor([[0.5229, 0.8419, 0.4337, 0.1426, 0.4458, 0.5424, 0.4803, 0.6478],
         [0.9268, 0.7476, 0.6653, 0.1772, 0.2031, 0.0559, 0.5314, 0.6403],
         [0.7799, 0.8239, 0.4454, 0.7408, 0.2772, 0.8566, 0.0632, 0.2070],
         [0.3469, 0.7670, 0.1191, 0.6810, 0.1008, 0.5797, 0.3060, 0.4499],
         [0.3226, 0.0960, 0.5601, 0.4586, 0.0738, 0.0412, 0.5843, 0.6373],
         [0.7225, 0.514

In [3]:
ctx = MockCtx()

p = 0.5
CustomReLUFunction.forward(ctx, grad_input, p)
CustomReLUFunction.backward(ctx, grad_output)

>>> grad_input[input <= 0] = 0
>>> grad_input
tensor([[0.5229, 0.8419, 0.4337, 0.1426, 0.4458, 0.5424, 0.4803, 0.6478],
        [0.9268, 0.7476, 0.6653, 0.1772, 0.2031, 0.0559, 0.5314, 0.6403],
        [0.7799, 0.8239, 0.4454, 0.7408, 0.2772, 0.8566, 0.0632, 0.2070],
        [0.3469, 0.7670, 0.1191, 0.6810, 0.1008, 0.5797, 0.3060, 0.4499],
        [0.3226, 0.0960, 0.5601, 0.4586, 0.0738, 0.0412, 0.5843, 0.6373],
        [0.7225, 0.5149, 0.1602, 0.4981, 0.0335, 0.3556, 0.8084, 0.2133],
        [0.6975, 0.4354, 0.2636, 0.1562, 0.3315, 0.3882, 0.0950, 0.0219],
        [0.8773, 0.8281, 0.7637, 0.8880, 0.4258, 0.3019, 0.2698, 0.2927]])
>>> bernoulli_mask = torch.bernoulli(torch.ones(grad_input.size(1), device=grad_output.device) * (1 - p.item()))
>>> bernoulli_mask
tensor([1., 0., 0., 1., 1., 1., 0., 1.])
>>> diagonal_mask = torch.diag(bernoulli_mask)
>>> diagonal_mask
tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.

(tensor([[0.5229, 0.0000, 0.0000, 0.1426, 0.4458, 0.5424, 0.0000, 0.6478],
         [0.9268, 0.0000, 0.0000, 0.1772, 0.2031, 0.0559, 0.0000, 0.6403],
         [0.7799, 0.0000, 0.0000, 0.7408, 0.2772, 0.8566, 0.0000, 0.2070],
         [0.3469, 0.0000, 0.0000, 0.6810, 0.1008, 0.5797, 0.0000, 0.4499],
         [0.3226, 0.0000, 0.0000, 0.4586, 0.0738, 0.0412, 0.0000, 0.6373],
         [0.7225, 0.0000, 0.0000, 0.4981, 0.0335, 0.3556, 0.0000, 0.2133],
         [0.6975, 0.0000, 0.0000, 0.1562, 0.3315, 0.3882, 0.0000, 0.0219],
         [0.8773, 0.0000, 0.0000, 0.8880, 0.4258, 0.3019, 0.0000, 0.2927]]),
 None)

In [86]:
ctx = MockCtx()

p = 0.8
NaiveAugmentedReLUFunction.forward(ctx, grad_input, p)
NaiveAugmentedReLUFunction.backward(ctx, grad_output)

<<< grad_input[input <= 0] = 0
<<< grad_input
tensor([[0.5229, 0.8419, 0.4337, 0.1426, 0.4458, 0.5424, 0.4803, 0.6478],
        [0.9268, 0.7476, 0.6653, 0.1772, 0.2031, 0.0559, 0.5314, 0.6403],
        [0.7799, 0.8239, 0.4454, 0.7408, 0.2772, 0.8566, 0.0632, 0.2070],
        [0.3469, 0.7670, 0.1191, 0.6810, 0.1008, 0.5797, 0.3060, 0.4499],
        [0.3226, 0.0960, 0.5601, 0.4586, 0.0738, 0.0412, 0.5843, 0.6373],
        [0.7225, 0.5149, 0.1602, 0.4981, 0.0335, 0.3556, 0.8084, 0.2133],
        [0.6975, 0.4354, 0.2636, 0.1562, 0.3315, 0.3882, 0.0950, 0.0219],
        [0.8773, 0.8281, 0.7637, 0.8880, 0.4258, 0.3019, 0.2698, 0.2927]])
<<< grad_input = grad_input * p
<<< grad_input
tensor([[0.4183, 0.6735, 0.3469, 0.1140, 0.3566, 0.4339, 0.3843, 0.5182],
        [0.7414, 0.5981, 0.5322, 0.1418, 0.1625, 0.0447, 0.4251, 0.5122],
        [0.6239, 0.6591, 0.3563, 0.5927, 0.2218, 0.6852, 0.0506, 0.1656],
        [0.2775, 0.6136, 0.0953, 0.5448, 0.0806, 0.4637, 0.2448, 0.3599],
        [0.2581, 0

(tensor([[0.4183, 0.6735, 0.3469, 0.1140, 0.3566, 0.4339, 0.3843, 0.5182],
         [0.7414, 0.5981, 0.5322, 0.1418, 0.1625, 0.0447, 0.4251, 0.5122],
         [0.6239, 0.6591, 0.3563, 0.5927, 0.2218, 0.6852, 0.0506, 0.1656],
         [0.2775, 0.6136, 0.0953, 0.5448, 0.0806, 0.4637, 0.2448, 0.3599],
         [0.2581, 0.0768, 0.4481, 0.3669, 0.0590, 0.0330, 0.4675, 0.5098],
         [0.5780, 0.4119, 0.1282, 0.3985, 0.0268, 0.2845, 0.6467, 0.1707],
         [0.5580, 0.3484, 0.2109, 0.1250, 0.2652, 0.3106, 0.0760, 0.0175],
         [0.7018, 0.6625, 0.6110, 0.7104, 0.3406, 0.2415, 0.2158, 0.2341]]),
 None)

In [33]:
grad_input

tensor([[0.0693, 0.1635, 0.0511, 0.0341, 0.1995, 0.8795, 0.4098, 0.2278],
        [0.0240, 0.3959, 0.3590, 0.6761, 0.6885, 0.3066, 0.2700, 0.9472],
        [0.4364, 0.2298, 0.8921, 0.1717, 0.6475, 0.3974, 0.6687, 0.0456],
        [0.2797, 0.0583, 0.3934, 0.1055, 0.2190, 0.0703, 0.7915, 0.2422],
        [0.1770, 0.9514, 0.6873, 0.7804, 0.5517, 0.2634, 0.4233, 0.1188],
        [0.1461, 0.6387, 0.8805, 0.9977, 0.9144, 0.7353, 0.4796, 0.4035],
        [0.2575, 0.2331, 0.9117, 0.5944, 0.9993, 0.6453, 0.7162, 0.0062],
        [0.6388, 0.7177, 0.4759, 0.4904, 0.5450, 0.3115, 0.4406, 0.6359]])