In [46]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math

In [47]:
x = torch.tensor([[1,0,0,0]], dtype=torch.float32)

In [48]:
class GeLU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (
            1 + torch.tanh(math.sqrt(2 / math.pi) * (
                x + 0.044715 * torch.pow(x, 3))))

In [159]:
class RevNetBlock(nn.Module):
    
    def __init__(self, d_in, d_out, dropout=0.1, lol=[]):
        super(RevNetBlock, self).__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.dropout = dropout

        layers = []
        if lol == list():
            layers.append(nn.LayerNorm((d_in,d_out)))
            layers.append(nn.Linear(d_in, d_out))
            layers.append(GeLU())
            layers.append(nn.Linear(d_in, d_out))
        else:
            for layer in lol:
                layers.append(layer)
        
        self.bottleneck_block = nn.Sequential(*layers)
    
    def forward(self, x):
        x = torch.cat((x, x), dim=1)
        x1, x2 = self.split(x)
        Fx2 = self.bottleneck_block(x2)
        y1 = Fx2 + x1
        return (x2, y1)
    
    def inverse(self, x):
        x2, y1 = x[0], x[1]
        Fx2 = - self.bottleneck_block(x2)
        x1 = Fx2 + y1
        return (x1, x2)

    @staticmethod
    def split(x):
        n = int(x.size()[1] / 2)
        x1 = x[:, :n].contiguous()
        x2 = x[:, n:].contiguous()
        return (x1, x2)

In [160]:
rb = RevNetBlock(d_in=4, d_out=4)
output = rb(x)
print(f'forward: {output}\n')
input_ = rb.inverse(output)
print(f'inverse: {input_}')

forward: (tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]]), tensor([[ 0.4402,  0.3203,  0.4012,  0.4259],
        [-0.2684,  1.1148,  0.3594, -0.0601],
        [-0.8052,  0.1181,  1.4960, -0.2988],
        [-0.5242,  0.2818,  0.2205,  0.5753]], grad_fn=<AddBackward0>))

inverse: (tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]], grad_fn=<AddBackward0>), tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]]))


In [None]:
class RevNet(nn.Module):
    def __init__(self, d_in, d_out, dropout=0.0):
        super(RevNet, self).__init__()
        self.d_in = d_in 
        self.d_out = d_out
        self.dropout = dropout
    

In [56]:
d_in = d_out = 4

layers = []
layers.append(nn.LayerNorm((d_in,d_out)))
layers.append(nn.Linear(d_in, d_out))
layers.append(GeLU())
layers.append(nn.Linear(d_in, d_out))

model = nn.Sequential(*layers)

In [61]:
x = [[1,0,0,0], [0,1,0,0],
    [0,0,1,0], [0,0,0,1]]
x = torch.tensor(x, dtype=torch.float32)

model(x)

tensor([[-0.1839,  0.4475, -0.4725,  0.0502],
        [-0.4128,  0.6098, -0.6251,  0.4258],
        [ 0.0905,  0.5267, -0.1511, -0.0575],
        [-0.0612,  0.4543, -0.2061, -0.1074]], grad_fn=<AddmmBackward>)

In [38]:
layers = []
layers.append(nn.LayerNorm((4,4)))
layers.append(F.gelu(nn.Linear(4, 4)))
#layers.append(nn.ReLU())
layers.append(nn.Linear(4, 4))

TypeError: gelu(): argument 'input' (position 1) must be Tensor, not Linear

In [None]:

        layers.append(nn.Linear(d_in, d_out))
        layers.append(nn.ReLU())
        layers.append(nn.Linear(4,4)))