In [None]:
class Flatten(nn.Module):
    """ Reshapes a 4d matrix to a 2d matrix. """
    def forward(self, input):
        return input.view(input.size(0), -1)
    
    
class SplitTensor(nn.Module):
    def __init__(self, n_tensors, tensor_size): 
        self.n_tensors = n_tensors
        self.tensor_size = tensor_size
        
        
    def forward(self, input):
        tensors = []
        for i in self.n_tensors: 
            tensors.append(input[:, i*self.array_size : (i+1)*self.array_size])
        return tensors
    
class CNN(nn.Module):

    def __init__(self, conv_layers):
        
        super(CNN, self).__init__()
        
        self.model = []
        self.split = False
        
        for layer in conv_layers:
            if layer["type"] == "conv": 
                self.model.append(nn.Conv3d(in_channels=layer["in"], out_channels=layer["out"], 
                                            kernel_size=layer["ker"], stride=layer["str"],
                                            padding=layer["pad"]))
                if "act" in layer.keys() and layer["act"] == "leakyrelu": 
                    self.model.append(nn.LeakyReLU(0.2))
                if "act" in layer.keys() and layer["act"] == "relu":
                    self.model.append(nn.ReLU())
                    
                
            elif layer["type"] == "lin": 
                self.model.append(nn.Linear(layer["in"], layer["out"]))
                if "act" in layer.keys() and layer["act"] == "leakyrelu": 
                    self.model.append(nn.LeakyReLU(0.2))
                if "act" in layer.keys() and layer["act"] == "relu":
                    self.model.append(nn.ReLU())
            
            elif layer["type"] == "flatten": 
                self.model.append(Flatten())  
            
            elif layer["type"] == "split": 
                self.split = True
                self.n_tensors = layer["n_tensors"]
                self.tensor_size = layer["tensor_size"]
            
            else: 
                raise(Exception("CNN layer type must be one of [conv, lin, flatten, split]."))
            
            if "max" in layer.keys(): 
                self.model.append(nn.MaxPool3d(layer["max"], stride=layer["max"], ceil_mode=True))
            
            if "norm" in layer.keys() and layer["norm"] != None: 
                if layer["norm"] == "layer": 
                    self.model.append(nn.LayerNorm(layer["ln"]))
            
            if "drop" in layer.keys() and layer["drop"] != None: 
                self.model.append(nn.Dropout(layer["drop"]))
            
        self.model = nn.Sequential(*self.model)


    def forward(self, x):
        out = self.model(x)
        
        if self.split: 
            out = [out[:, i*self.tensor_size : (i+1)*self.tensor_size] for i in range(self.n_tensors)]
            
        return out

In [None]:
def exists(val):
    return val is not None

def leaky_relu(p = 0.2):
    return nn.LeakyReLU(p)

def to_value(t):
    return t.clone().detach().item()

def get_module_device(module):
    return next(module.parameters()).device

class MappingNetwork(nn.Module):
    def __init__(self, ARGS, depth = 3, lr_mul = 0.1):
        super().__init__()

        layers = []
        for i in range(depth):
            layers.extend([EqualLinear(ARGS.z_dim, ARGS.z_dim, lr_mul), leaky_relu()])

        self.net = nn.Sequential(*layers)

        self.to_gamma = nn.Linear(ARGS.z_dim, ARGS.dim_hidden)
        self.to_beta = nn.Linear(ARGS.z_dim, ARGS.dim_hidden)

    def forward(self, x):
        x = self.net(x)
        return self.to_gamma(x), self.to_beta(x)

In [None]:
def load_cnn(ARGS): 
    
    if ARGS.cnn_setup == 1: 
        # output cnn torch.Size([1, 128, 2, 4, 4])
        # maxpool,  small kernel, small linear
        
        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 3, "str": 1, "pad": 1, "max": 2, "act": "relu"}
                   , {"type": "conv", "in": 16, "out": 32, "ker": 3, "str": 1, "pad": 1, "max": 2, "act": "relu"} 
                   , {"type": "conv", "in": 32, "out": 64, "ker": 3, "str": 1, "pad": 1, "max": 2, "act": "relu"}
                   , {"type": "conv", "in": 64, "out": 128, "ker": 3, "str": 1, "pad": 1, "max": 2, "act": "relu"}
                   , {"type": "flatten"}
                   , {"type": "lin", "in": 4096, "out": 128, "act": "leakyrelu"}
                   , {"type": "lin", "in": 128, "out": 128, "act": "leakyrelu"}
                   , {"type": "lin", "in": 128, "out": 128, "act": "leakyrelu"}
                   , {"type": "lin", "in": 128, "out": 128, "act": "leakyrelu"}
                   , {"type": "lin", "in": 128, "out": 512}
                   , {"type": "split", "n_tensors": 2, "tensor_size": 256}
                  ])
        
    elif ARGS.cnn_setup == 2: 
        # output cnn torch.Size([1, 128, 2, 4, 4])
        # maxpool,  small kernel, medium linear
        
        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 3, "str": 1, "pad": 1, "max": 2, "act": "relu"}
                   , {"type": "conv", "in": 16, "out": 32, "ker": 3, "str": 1, "pad": 1, "max": 2, "act": "relu"} 
                   , {"type": "conv", "in": 32, "out": 64, "ker": 3, "str": 1, "pad": 1, "max": 2, "act": "relu"}
                   , {"type": "conv", "in": 64, "out": 128, "ker": 3, "str": 1, "pad": 1, "max": 2, "act": "relu"}
                   , {"type": "flatten"}
                   , {"type": "lin", "in": 4096, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512}
                   , {"type": "split", "n_tensors": 2, "tensor_size": 256}
                  ])
    
    elif ARGS.cnn_setup == 3: 
        # output cnn torch.Size([1, 128, 2, 4, 4])
        # maxpool,  small kernel, large linear
        
        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 3, "str": 1, "pad": 1, "max": 2, "act": "relu"}
                   , {"type": "conv", "in": 16, "out": 32, "ker": 3, "str": 1, "pad": 1, "max": 2, "act": "relu"} 
                   , {"type": "conv", "in": 32, "out": 64, "ker": 3, "str": 1, "pad": 1, "max": 2, "act": "relu"}
                   , {"type": "conv", "in": 64, "out": 128, "ker": 3, "str": 1, "pad": 1, "max": 2, "act": "relu"}
                   , {"type": "flatten"}
                   , {"type": "lin", "in": 4096, "out": 2048, "act": "leakyrelu"}
                   , {"type": "lin", "in": 2048, "out": 1024, "act": "leakyrelu"}
                   , {"type": "lin", "in": 1024, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512}
                   , {"type": "split", "n_tensors": 2, "tensor_size": 256}
                  ])
    
    elif ARGS.cnn_setup == 4: 
        # output cnn torch.Size([1, 128, 2, 4, 4])
        # maxpool,  small kernel, large linear but short linear
        
        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 3, "str": 1, "pad": 1, "max": 2, "act": "relu"}
                   , {"type": "conv", "in": 16, "out": 32, "ker": 3, "str": 1, "pad": 1, "max": 2, "act": "relu"} 
                   , {"type": "conv", "in": 32, "out": 64, "ker": 3, "str": 1, "pad": 1, "max": 2, "act": "relu"}
                   , {"type": "conv", "in": 64, "out": 128, "ker": 3, "str": 1, "pad": 1, "max": 2, "act": "relu"}
                   , {"type": "flatten"}
                   , {"type": "lin", "in": 4096, "out": 2048, "act": "leakyrelu"}
                   , {"type": "lin", "in": 2048, "out": 1024, "act": "leakyrelu"}
                   , {"type": "lin", "in": 1024, "out": 512}
                   , {"type": "split", "n_tensors": 2, "tensor_size": 256}
                  ])
        
    
    elif ARGS.cnn_setup == 5: 
        # output cnn torch.Size([1, 128, 2, 4, 4])
        # stride of 2, small kernel, small linear
        
        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 3, "str": 2, "pad": 1, "act": "relu"}
                   , {"type": "conv", "in": 16, "out": 32, "ker": 3, "str": 2, "pad": 1, "act": "relu"} 
                   , {"type": "conv", "in": 32, "out": 64, "ker": 3, "str": 2, "pad": 1, "act": "relu"}
                   , {"type": "conv", "in": 64, "out": 128, "ker": 3, "str": 2, "pad": 1, "act": "relu"}
                   , {"type": "flatten"}
                   , {"type": "lin", "in": 4096, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512}
                   , {"type": "split", "n_tensors": 2, "tensor_size": 256}
                  ])

    elif ARGS.cnn_setup == 6: 
        # output cnn torch.Size([1, 128, 2, 4, 4])
        # stride of 1, 2, small kernel, small linear

        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 3, "str": 1, "pad": 1, "act": "relu"}
                   , {"type": "conv", "in": 16, "out": 16, "ker": 3, "str": 2, "pad": 1, "act": "relu"} 
                   , {"type": "conv", "in": 16, "out": 32, "ker": 3, "str": 1, "pad": 1, "act": "relu"} 
                   , {"type": "conv", "in": 32, "out": 32, "ker": 3, "str": 2, "pad": 1, "act": "relu"} 
                   , {"type": "conv", "in": 32, "out": 64, "ker": 3, "str": 1, "pad": 1, "act": "relu"}
                   , {"type": "conv", "in": 64, "out": 64, "ker": 3, "str": 2, "pad": 1, "act": "relu"}
                   , {"type": "conv", "in": 64, "out": 128, "ker": 3, "str": 2, "pad": 1, "act": "relu"}
                   , {"type": "flatten"}
                   , {"type": "lin", "in": 4096, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512}
                   , {"type": "split", "n_tensors": 2, "tensor_size": 256}
                  ])

    elif ARGS.cnn_setup == 7: 
        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 7, "str": 2, "pad": 3, "act": "relu"}
               , {"type": "conv", "in": 16, "out": 32, "ker": 7, "str": 2, "pad": 3, "act": "relu"} 
               , {"type": "conv", "in": 32, "out": 64, "ker": 7, "str": 2, "pad": 3, "act": "relu"} 
               , {"type": "conv", "in": 64, "out": 128, "ker": 7, "str": 2, "pad": 3, "act": "relu"} 
               , {"type": "flatten"}
               , {"type": "lin", "in": 4096, "out": 512, "act": "leakyrelu"}
               , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
               , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
               , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
               , {"type": "lin", "in": 512, "out": 512}
               , {"type": "split", "n_tensors": 2, "tensor_size": 256}
              ])
    
    elif ARGS.cnn_setup == 8: 
        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 7, "str": 2, "pad": 3, "act": "relu"}
               , {"type": "conv", "in": 16, "out": 32, "ker": 7, "str": 2, "pad": 3, "act": "relu"} 
               , {"type": "conv", "in": 32, "out": 64, "ker": 7, "str": 2, "pad": 3, "act": "relu"} 
               , {"type": "conv", "in": 64, "out": 128, "ker": 7, "str": 2, "pad": 3, "act": "relu"} 
               , {"type": "flatten"}
               , {"type": "lin", "in": 4096, "out": 2048, "act": "leakyrelu"}
               , {"type": "lin", "in": 2048, "out": 1024, "act": "leakyrelu"}
               , {"type": "lin", "in": 1024, "out": 512}
               , {"type": "split", "n_tensors": 2, "tensor_size": 256}
              ])
        
    elif ARGS.cnn_setup == 9: 
        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 7, "str": 2, "pad": 3, "act": "relu", 
                    "norm": "layer", "ln": (12, 32, 32)}
                   , {"type": "conv", "in": 16, "out": 32, "ker": 7, "str": 2, "pad": 3, "act": "relu",
                     "norm": "layer", "ln": (6, 16, 16)} 
                   , {"type": "conv", "in": 32, "out": 64, "ker": 7, "str": 2, "pad": 3, "act": "relu",
                     "norm": "layer", "ln": (3, 8, 8)} 
                   , {"type": "conv", "in": 64, "out": 128, "ker": 7, "str": 2, "pad": 3, "act": "relu"} 
                   , {"type": "flatten"}
                   , {"type": "lin", "in": 4096, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512}
                   , {"type": "split", "n_tensors": 2, "tensor_size": 256}
                  ])

    elif ARGS.cnn_setup == 10: 
        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 3, "str": 2, "pad": 1, "act": "relu",
                   "norm": "layer", "ln": (12, 32, 32)}
                   , {"type": "conv", "in": 16, "out": 32, "ker": 3, "str": 2, "pad": 1, "act": "relu",
                     "norm": "layer", "ln": (6, 16, 16)} 
                   , {"type": "conv", "in": 32, "out": 64, "ker": 3, "str": 2, "pad": 1, "act": "relu",
                     "norm": "layer", "ln": (3, 8, 8)}
                   , {"type": "conv", "in": 64, "out": 128, "ker": 3, "str": 2, "pad": 1, "act": "relu"}
                   , {"type": "flatten"}
                   , {"type": "lin", "in": 4096, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512}
                   , {"type": "split", "n_tensors": 2, "tensor_size": 256}
                  ])

    
    elif ARGS.cnn_setup == 11: 
        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 3, "str": 2, "pad": 1, "act": "relu",
                   "norm": "layer", "ln": (12, 32, 32)}
                   , {"type": "conv", "in": 16, "out": 32, "ker": 3, "str": 2, "pad": 1, "act": "relu",
                     "norm": "layer", "ln": (6, 16, 16)} 
                   , {"type": "conv", "in": 32, "out": 64, "ker": 3, "str": 2, "pad": 1, "act": "relu",
                     "norm": "layer", "ln": (3, 8, 8)}
                   , {"type": "conv", "in": 64, "out": 128, "ker": 3, "str": 2, "pad": 1, "act": "relu"}
                   , {"type": "flatten"}
                   , {"type": "lin", "in": 4096, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512, "act": "leakyrelu"}
                   , {"type": "lin", "in": 512, "out": 512}
                   , {"type": "split", "n_tensors": 2, "tensor_size": 256}
                  ])
    
    elif ARGS.cnn_setup == 12:
        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 3, "str": 2, "pad": 1, "act": "relu",
                   "norm": "layer", "ln": (12, 32, 32)}
                   , {"type": "conv", "in": 16, "out": 32, "ker": 3, "str": 2, "pad": 1, "act": "relu",
                     "norm": "layer", "ln": (6, 16, 16)} 
                   , {"type": "conv", "in": 32, "out": 64, "ker": 3, "str": 2, "pad": 1, "act": "relu",
                     "norm": "layer", "ln": (3, 8, 8)}
                   , {"type": "conv", "in": 64, "out": 128, "ker": 3, "str": 2, "pad": 1, "act": "relu"}
                   , {"type": "flatten"}
                   , {"type": "lin", "in": 4096, "out": 2048, "act": "leakyrelu"}
                   , {"type": "lin", "in": 2048, "out": 1024, "act": "leakyrelu"}
                   , {"type": "lin", "in": 1024, "out": 512}
                   , {"type": "split", "n_tensors": 2, "tensor_size": 256}
                  ])
    
    elif ARGS.cnn_setup == 13: 
        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 5, "str": 2, "pad": 2, "act": "relu",
                   "norm": "layer", "ln": (12, 32, 32)}
                   , {"type": "conv", "in": 16, "out": 32, "ker": 5, "str": 2, "pad": 2, "act": "relu",
                     "norm": "layer", "ln": (6, 16, 16)} 
                   , {"type": "conv", "in": 32, "out": 64, "ker": 5, "str": 2, "pad": 2, "act": "relu",
                     "norm": "layer", "ln": (3, 8, 8)}
                   , {"type": "conv", "in": 64, "out": 128, "ker": 5, "str": 2, "pad": 2, "act": "relu"}
                   , {"type": "flatten"}
                   , {"type": "lin", "in": 4096, "out": 2048, "act": "leakyrelu"}
                   , {"type": "lin", "in": 2048, "out": 1024, "act": "leakyrelu"}
                   , {"type": "lin", "in": 1024, "out": 512}
                   , {"type": "split", "n_tensors": 2, "tensor_size": 256}
                  ])
    
    elif ARGS.cnn_setup == 14: 
        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 5, "str": 2, "pad": 2, "act": "relu"}
                   , {"type": "conv", "in": 16, "out": 32, "ker": 5, "str": 2, "pad": 2, "act": "relu"}
                   , {"type": "conv", "in": 32, "out": 64, "ker": 5, "str": 2, "pad": 2, "act": "relu"}
                   , {"type": "conv", "in": 64, "out": 128, "ker": 5, "str": 2, "pad": 2, "act": "relu"}
                   , {"type": "flatten"}
                   , {"type": "lin", "in": 4096, "out": 2048, "act": "leakyrelu"}
                   , {"type": "lin", "in": 2048, "out": 1024, "act": "leakyrelu"}
                   , {"type": "lin", "in": 1024, "out": 512}
                   , {"type": "split", "n_tensors": 2, "tensor_size": 256}
                  ])
    
    elif ARGS.cnn_setup == 15: 
        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 5, "str": 1, "pad": 2, "act": "relu"}
                   , {"type": "conv", "in": 16, "out": 16, "ker": 5, "str": 2, "pad": 2, "act": "relu",
                   "norm": "layer", "ln": (12, 32, 32)}
                   , {"type": "conv", "in": 16, "out": 32, "ker": 5, "str": 1, "pad": 2, "act": "relu"}
                   , {"type": "conv", "in": 32, "out": 32, "ker": 5, "str": 2, "pad": 2, "act": "relu",
                     "norm": "layer", "ln": (6, 16, 16)} 
                   , {"type": "conv", "in": 32, "out": 64, "ker": 5, "str": 1, "pad": 2, "act": "relu"}
                   , {"type": "conv", "in": 64, "out": 64, "ker": 5, "str": 2, "pad": 2, "act": "relu",
                     "norm": "layer", "ln": (3, 8, 8)}
                   , {"type": "conv", "in": 64, "out": 128, "ker": 5, "str": 1, "pad": 2, "act": "relu"}
                   , {"type": "conv", "in": 128, "out": 128, "ker": 5, "str": 2, "pad": 2, "act": "relu"}
                   , {"type": "flatten"}
                   , {"type": "lin", "in": 4096, "out": 2048, "act": "leakyrelu"}
                   , {"type": "lin", "in": 2048, "out": 1024, "act": "leakyrelu"}
                   , {"type": "lin", "in": 1024, "out": 512}
                   , {"type": "split", "n_tensors": 2, "tensor_size": 256}
                  ])
    
    elif ARGS.cnn_setup == 16: 
        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 5, "str": 2, "pad": 2, "act": "relu",
                   "norm": "layer", "ln": (12, 32, 32)}
                   , {"type": "conv", "in": 16, "out": 32, "ker": 5, "str": 2, "pad": 2, "act": "relu",
                     "norm": "layer", "ln": (6, 16, 16)} 
                   , {"type": "conv", "in": 32, "out": 64, "ker": 5, "str": 2, "pad": 2, "act": "relu",
                     "norm": "layer", "ln": (3, 8, 8)}
                   , {"type": "conv", "in": 64, "out": 128, "ker": 5, "str": 2, "pad": 2, "act": "relu",
                     "norm": "layer", "ln": (2, 4, 4)}
                   , {"type": "flatten"}
                   , {"type": "lin", "in": 4096, "out": 2048, "act": "leakyrelu"}
                   , {"type": "lin", "in": 2048, "out": 1024, "act": "leakyrelu"}
                   , {"type": "lin", "in": 1024, "out": 512}
                   , {"type": "split", "n_tensors": 2, "tensor_size": 256}
                  ])
    
    elif ARGS.cnn_setup == 17: 
        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 5, "str": 1, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (24, 64, 64)}
                   , {"type": "conv", "in": 16, "out": 16, "ker": 5, "str": 2, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (12, 32, 32)}
                   , {"type": "conv", "in": 16, "out": 32, "ker": 5, "str": 1, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (12, 32, 32)}
                   , {"type": "conv", "in": 32, "out": 32, "ker": 5, "str": 2, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (6, 16, 16)} 
                   , {"type": "conv", "in": 32, "out": 64, "ker": 5, "str": 1, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (6, 16, 16)} 
                   , {"type": "conv", "in": 64, "out": 64, "ker": 5, "str": 2, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (3, 8, 8)}
                   , {"type": "conv", "in": 64, "out": 128, "ker": 5, "str": 1, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (3, 8, 8)}
                   , {"type": "conv", "in": 128, "out": 128, "ker": 5, "str": 2, "pad": 2, "act": "relu", 
                    "norm": "layer", "ln": (2, 4, 4)}
                   , {"type": "flatten"}
                   , {"type": "lin", "in": 4096, "out": 2048, "act": "leakyrelu"}
                   , {"type": "lin", "in": 2048, "out": 1024, "act": "leakyrelu"}
                   , {"type": "lin", "in": 1024, "out": 512}
                   , {"type": "split", "n_tensors": 2, "tensor_size": 256}
                  ])
    
    elif ARGS.cnn_setup == 18: 
        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 5, "str": 1, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (24, 64, 64)}
                   , {"type": "conv", "in": 16, "out": 16, "ker": 5, "str": 2, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (12, 32, 32)}
                   , {"type": "conv", "in": 16, "out": 32, "ker": 5, "str": 1, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (12, 32, 32)}
                   , {"type": "conv", "in": 32, "out": 32, "ker": 5, "str": 2, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (6, 16, 16)} 
                   , {"type": "conv", "in": 32, "out": 64, "ker": 5, "str": 1, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (6, 16, 16)} 
                   , {"type": "conv", "in": 64, "out": 64, "ker": 5, "str": 2, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (3, 8, 8)}
                   , {"type": "conv", "in": 64, "out": 128, "ker": 5, "str": 1, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (3, 8, 8)}
                   , {"type": "conv", "in": 128, "out": 128, "ker": 5, "str": 2, "pad": 2, "act": "relu"}
                   , {"type": "flatten"}
                   , {"type": "lin", "in": 4096, "out": 2048, "act": "leakyrelu"}
                   , {"type": "lin", "in": 2048, "out": 1024, "act": "leakyrelu"}
                   , {"type": "lin", "in": 1024, "out": 512}
                   , {"type": "split", "n_tensors": 2, "tensor_size": 256}
                  ])
    
    elif ARGS.cnn_setup == 19: 
        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 7, "str": 1, "pad": 3, "act": "relu",
                    "norm": "layer", "ln": (24, 64, 64)}
                   , {"type": "conv", "in": 16, "out": 16, "ker": 7, "str": 2, "pad": 3, "act": "relu",
                    "norm": "layer", "ln": (12, 32, 32)}
                   , {"type": "conv", "in": 16, "out": 32, "ker": 7, "str": 1, "pad": 3, "act": "relu",
                    "norm": "layer", "ln": (12, 32, 32)}
                   , {"type": "conv", "in": 32, "out": 32, "ker": 7, "str": 2, "pad": 3, "act": "relu",
                    "norm": "layer", "ln": (6, 16, 16)} 
                   , {"type": "conv", "in": 32, "out": 64, "ker": 7, "str": 1, "pad": 3, "act": "relu",
                    "norm": "layer", "ln": (6, 16, 16)} 
                   , {"type": "conv", "in": 64, "out": 64, "ker": 7, "str": 2, "pad": 3, "act": "relu",
                    "norm": "layer", "ln": (3, 8, 8)}
                   , {"type": "conv", "in": 64, "out": 128, "ker": 7, "str": 1, "pad": 3, "act": "relu",
                    "norm": "layer", "ln": (3, 8, 8)}
                   , {"type": "conv", "in": 128, "out": 128, "ker": 7, "str": 2, "pad": 3, "act": "relu"}
                   , {"type": "flatten"}
                   , {"type": "lin", "in": 4096, "out": 2048, "act": "leakyrelu"}
                   , {"type": "lin", "in": 2048, "out": 1024, "act": "leakyrelu"}
                   , {"type": "lin", "in": 1024, "out": 512}
                   , {"type": "split", "n_tensors": 2, "tensor_size": 256}
                  ])
        
    elif ARGS.cnn_setup == 20: 
        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 7, "str": 1, "pad": 3, "act": "relu",
                    "norm": "layer", "ln": (24, 64, 64)}
                   , {"type": "conv", "in": 16, "out": 16, "ker": 7, "str": 2, "pad": 3, "act": "relu",
                    "norm": "layer", "ln": (12, 32, 32)}
                   , {"type": "conv", "in": 16, "out": 32, "ker": 7, "str": 1, "pad": 3, "act": "relu",
                    "norm": "layer", "ln": (12, 32, 32)}
                   , {"type": "conv", "in": 32, "out": 32, "ker": 7, "str": 2, "pad": 3, "act": "relu",
                    "norm": "layer", "ln": (6, 16, 16)} 
                   , {"type": "conv", "in": 32, "out": 64, "ker": 7, "str": 1, "pad": 3, "act": "relu",
                    "norm": "layer", "ln": (6, 16, 16)} 
                   , {"type": "conv", "in": 64, "out": 64, "ker": 7, "str": 2, "pad": 3, "act": "relu",
                    "norm": "layer", "ln": (3, 8, 8)}
                   , {"type": "conv", "in": 64, "out": 128, "ker": 7, "str": 1, "pad": 3, "act": "relu",
                    "norm": "layer", "ln": (3, 8, 8)}
                   , {"type": "conv", "in": 128, "out": 128, "ker": 7, "str": 2, "pad": 3, "act": "relu",
                    "norm": "layer", "ln": (2, 4, 4)}
                   , {"type": "flatten"}
                   , {"type": "lin", "in": 4096, "out": 2048, "act": "leakyrelu"}
                   , {"type": "lin", "in": 2048, "out": 1024, "act": "leakyrelu"}
                   , {"type": "lin", "in": 1024, "out": 512}
                   , {"type": "split", "n_tensors": 2, "tensor_size": 256}
                  ])
    
    elif ARGS.cnn_setup == 21: 
        cnn = CNN([{"type": "conv", "in": 1, "out": 16, "ker": 5, "str": 1, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (24, 64, 64)}
                   , {"type": "conv", "in": 16, "out": 16, "ker": 5, "str": 1, "pad": 2, "act": "relu",
                    "max": 2, "norm": "layer", "ln": (12, 32, 32)}
                   , {"type": "conv", "in": 16, "out": 32, "ker": 5, "str": 1, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (12, 32, 32)}
                   , {"type": "conv", "in": 32, "out": 32, "ker": 5, "str": 2, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (6, 16, 16)} 
                   , {"type": "conv", "in": 32, "out": 64, "ker": 5, "str": 1, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (6, 16, 16)} 
                   , {"type": "conv", "in": 64, "out": 64, "ker": 5, "str": 2, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (3, 8, 8)}
                   , {"type": "conv", "in": 64, "out": 128, "ker": 5, "str": 1, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (3, 8, 8)}
                   , {"type": "conv", "in": 128, "out": 128, "ker": 5, "str": 2, "pad": 2, "act": "relu",
                    "norm": "layer", "ln": (2, 4, 4)}   
                   , {"type": "flatten"}
                   , {"type": "lin", "in": 4096, "out": 2048, "act": "leakyrelu"}
                   , {"type": "lin", "in": 2048, "out": 1024, "act": "leakyrelu"}
                   , {"type": "lin", "in": 1024, "out": 512}
                   , {"type": "split", "n_tensors": 2, "tensor_size": 256}
                  ])
    
        
        
    
    return cnn