In [1]:
import torch
import torch.nn as nn

import FrEIA.framework as Ff
import FrEIA.modules as Fm

#### 样例代码

In [4]:
input1 = Ff.InputNode(4, name='Input1')
input2 = Ff.InputNode(2, name='Input2')
cond = Ff.ConditionNode(3, name='Condition')

def subnet_fc(in_dim, out_dim):
    return nn.Sequential(
        nn.Linear(in_dim, 16), nn.ReLU(),
        nn.Linear(16, out_dim)
    )

In [7]:
# 定义网络结构
perm = Ff.Node(inputs=input1, module_type=Fm.PermuteRandom, module_args={}, name='PermuteRandom')
split1 = Ff.Node(inputs=perm, module_type=Fm.Split, module_args={}, name='Split 1')
split2 = Ff.Node(inputs=split1.out1, module_type=Fm.Split, module_args={}, name='Split 2')
actnorm = Ff.Node(inputs=split2.out1, module_type=Fm.ActNorm, module_args={}, name='ActNorm')

concat1 = Ff.Node(inputs=[actnorm.out0, input2.out0], module_type=Fm.Concat, module_args={}, name='Concat1')

affine = Ff.Node(
    inputs=concat1, 
    module_type=Fm.AffineCouplingOneSided,
    module_args={ 'subnet_constructor': subnet_fc },
    conditions=cond,
    name='AffineCouplingOneSided'
)

concat2 = Ff.Node(inputs=[split2.out0, affine.out0], module_type=Fm.Concat, module_args={}, name='Concat2')

# 得到两个输出
output1 = Ff.OutputNode(split1.out0, name='Output1')
output2 = Ff.OutputNode(concat2, name='Output2')

INN_Model = Ff.GraphINN(node_list=[
    input1, input2, cond,
    perm, split1, split2,
    actnorm, concat1, affine, concat2,
    output1, output2
])

in1, in2, c = torch.tensor([[2., 4., 6., 8.]]), torch.tensor([[3., 6.]]), torch.tensor([[2., 5., 7.]])

(z1, z2), J = INN_Model([in1, in2], c=c)

(in1_inv, in2_inv), J_inv = INN_Model([z1, z2], c=c, rev=True)

print(in1, in1_inv)
print(in2, in2_inv)
print(J, J_inv)

tensor([[2., 4., 6., 8.]]) tensor([[2., 4., 6., nan]], grad_fn=<IndexBackward0>)
tensor([[3., 6.]]) tensor([[nan, nan]], grad_fn=<SplitWithSizesBackward0>)
tensor([nan], grad_fn=<AddBackward0>) tensor([nan], grad_fn=<AddBackward0>)


### 子网络

In [8]:
def subnet_fc(in_dim, out_dim):
    return nn.Sequential(
        nn.Linear(in_dim, 256), nn.ReLU(),
        nn.Linear(256, out_dim)
    )

def subnet_conv(in_dim, out_dim):
    return nn.Sequential(
        nn.Conv2d(in_dim, 256, kernel_size=3, padding=1), nn.ReLU(),
        nn.Conv2d(256, out_dim, kernel_size=3, padding=1)
    )

def subnet_conv_1x1(in_dim, out_dim):
    return nn.Sequential(
        nn.Conv2d(in_dim, 256, kernel_size=1), nn.ReLU(),
        nn.Conv2d(256, out_dim, kernel_size=1)
    )

### 二维的可逆神经网络

In [12]:
inn = Ff.SequenceINN(2)

for k in range(2):
    inn.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc, permute_soft=True)

print(inn)

SequenceINN(
  (module_list): ModuleList(
    (0): AllInOneBlock(
      (softplus): Softplus(beta=0.5, threshold=20)
      (subnet): Sequential(
        (0): Linear(in_features=1, out_features=256, bias=True)
        (1): ReLU()
        (2): Linear(in_features=256, out_features=2, bias=True)
      )
    )
    (1): AllInOneBlock(
      (softplus): Softplus(beta=0.5, threshold=20)
      (subnet): Sequential(
        (0): Linear(in_features=1, out_features=256, bias=True)
        (1): ReLU()
        (2): Linear(in_features=256, out_features=2, bias=True)
      )
    )
  )
)


### Conditional INN 在Mnist数据集上实验

In [13]:
import FrEIA.framework as FF
import FrEIA.modules as FM

In [16]:
cinn = FF.SequenceINN(28 * 28)
for k in range(2):
    cinn.append(FM.AllInOneBlock, cond=0, cond_shape=(10,), subnet_constructor=subnet_conv_1x1)

print(cinn)

SequenceINN(
  (module_list): ModuleList(
    (0): AllInOneBlock(
      (softplus): Softplus(beta=0.5, threshold=20)
      (subnet): Sequential(
        (0): Conv2d(402, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU()
        (2): Conv2d(256, 784, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (1): AllInOneBlock(
      (softplus): Softplus(beta=0.5, threshold=20)
      (subnet): Sequential(
        (0): Conv2d(402, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU()
        (2): Conv2d(256, 784, kernel_size=(1, 1), stride=(1, 1))
      )
    )
  )
)


### Convolutional INN For CIFAR10

In [24]:
nodes = [FF.InputNode(3, 32, 32, name='Input')]
image_len = 3 * 32 * 32

In [25]:
# 较高分辨率卷积部分
for k in range(4):
    nodes.append(FF.Node(
        inputs=nodes[-1], # 取出最后一个结点
        module_type=FM.GLOWCouplingBlock,
        module_args={'subnet_constructor': subnet_conv, 'clamp': 1.2},
        name=f'HigherConv-{k}'
    ))

# 较低分辨率卷积部分
for k in range(12):
    subnet = subnet_conv_1x1 if k % 2 == 0 else subnet_conv

    nodes.append(FF.Node(nodes[-1],FM.GLOWCouplingBlock, module_args={'subnet_constructor': subnet, 'clamp': 1.2}, name=f'LowerConv-{k}'))
    nodes.append(FF.Node(nodes[-1],FM.PermuteRandom, module_args={'seed': k}, name=f'LowerPermute-{k}'))

# 全连接
nodes.append(FF.Node(nodes[-1], FM.Flatten, {}, name='Flatten'))
split_node = FF.Node(
    nodes[-1], 
    FM.Split, 
    {'section_sizes': (image_len // 4, 3 * image_len // 4), 'dim':0},
    name='Split'
)

nodes.append(split_node)

for k in range(12):
    nodes.append(FF.Node(nodes[-1], FM.GLOWCouplingBlock, {'subnet_constructor': subnet_fc, 'clamp': 2.0}, name=f'FC_{k}'))
    nodes.append(FF.Node(nodes[-1], FM.PermuteRandom, {'seed': k }, name=f'FC_Permute{k}'))

nodes.append(FF.Node([nodes[-1].out0, split_node.out1], FM.Concat1d, {'dim': 0}, name='concat'))
nodes.append(FF.OutputNode(nodes[-1], name='output'))

conv_inn = FF.GraphINN(nodes)

### 编写自定义可逆操作

自定义的可逆模块可以写成`FM.InvertibleModule`基类的扩展。有关需求的详细信息，请参阅这类文档。

下面是两个简单的例子，它们说明了自定义模块的定义和使用，并且可以作为基本模板使用。第一种方法是将输入张量的每个维数乘以1或2，以随机但固定的方式选择。第二个是一个条件操作，它接受两个输入，如果条件为正，则交换它们，否则不做任何操作。

注:
- `FM.InvertibleModule`必须用`in_DIM`参数初始化，如果有条件输入，则可以使用`cond_dim`参数初始化。
- `forward()`应该返回一个包含输出的元组(即使只有一个)，当`jac=True`时返回带有额外的`Log_Jacobian_Det`项。

In [41]:
# 定义
class FixedRandomElementwiseMultiply(FM.InvertibleModule):
    def __init__(self, dims_in) -> None:
        super().__init__(dims_in)
        self.random_factor = torch.randint(1, 3, size=(1, dims_in[0][0]))
    
    def forward(self, x, rev=False, jac=True):
        x = x[0]
        if not rev:
            # 前向操作
            x = x * self.random_factor
            log_jacobian_det = self.random_factor.float().log().sum()
        else:
            # 后项操作
            x = x / self.random_factor
            log_jacobian_det = -self.random_factor.float().log().sum()

        return (x,), log_jacobian_det
    
    def output_dims(self, dims_in):
        return dims_in

class ConditionalSwap(FM.InvertibleModule):
    def __init__(self, dims_in, dims_c) -> None:
        super().__init__(dims_in=dims_in, dims_c=dims_c)
        
    def forward(self, x, c, rev=False, jac=True): # c means condition
        x1, x2 = x
        log_jacobian_det = 0

        x1_new = x1 + 0.
        x2_new = x2 + 0.

        for i in range(x1.size(0)):
            x1_new[i] = x1[i] if c[0][i] > 0 else x2[i]
            x2_new[i] = x2[i] if c[0][i] > 0 else x1[i]

        return (x1_new, x2_new), log_jacobian_det
    
    def output_dims(self, dims_in):
        return dims_in
        


In [42]:
# 基础使用
batch_size = 2
in_dimension = 2

net = FF.SequenceINN(in_dimension)
for i in range(2):
    net.append(FixedRandomElementwiseMultiply)

x = torch.randn(batch_size, in_dimension) * 10 // 1

z, det = net(x)

x_rev, det_rev = net(z, rev=True) 

print(x)
print(z)
print(x_rev)
print(det_rev)


tensor([[ 10., -11.],
        [  3., -17.]])
tensor([[ 20., -22.],
        [  6., -34.]])
tensor([[ 10., -11.],
        [  3., -17.]])
tensor(-1.3863)


  x = torch.randn(batch_size, in_dimension) * 10 // 1


In [55]:
# 复杂使用
input1 = FF.InputNode(in_dimension, name='Input1')
input2 = FF.InputNode(in_dimension, name='Input2')

cond = FF.ConditionNode(1, name='ConditionNode')

mult1 = FF.Node(input1.out0, FixedRandomElementwiseMultiply, {}, name='mult1')
cond_swap = FF.Node([mult1.out0, input2.out0], ConditionalSwap, {}, conditions=cond, name='cond_swap')
mult2 = FF.Node(cond_swap.out1, FixedRandomElementwiseMultiply, {}, name='mult2')

output1 = FF.OutputNode(cond_swap.out0, name='output1')
output2 = FF.OutputNode(mult2.out0, name='output2')

inn_net = FF.GraphINN([
    input1, input2, cond,
    mult1, cond_swap, mult2,
    output1, output2
])

x1 = torch.randn(batch_size, in_dimension) * 10 // 1
x2 = torch.randn(batch_size, in_dimension) * 10 // 1
c = torch.randn(batch_size) // 1

(z1, z2), det = inn_net([x1, x2], c=c)
(x1_rev, x2_rev), _ = inn_net([z1, z2], c=c,  rev=True, jac=False)

print(x1)
print(x1_rev)


tensor([[-1., 11.],
        [ 0., -2.]])
tensor([[-1., 11.],
        [ 0., -2.]])


  x1 = torch.randn(batch_size, in_dimension) * 10 // 1
  x2 = torch.randn(batch_size, in_dimension) * 10 // 1
  c = torch.randn(batch_size) // 1
