In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def build_model(model_type, num_features, num_seq_features, 
                    y_hidden_layer_sizes, y_filter_size, y_dilation_rate,
                    n_hidden_layer_sizes, n_filter_size, n_dilation_rate, dropout, 
                    lstm_hidden_layer_size, num_lstm_layers=None, bidirectional=False):

    class EpLinearModel(nn.Module):
        def __init__(self, input_size):
            super(EpLinearModel, self).__init__()
            self.name = "ep_linear"
            self.linear = nn.Linear(input_size, 1)

        def forward(self, Y_ji):
            x = self.linear(Y_ji)
            return x.squeeze(-1)   
    
    class EpSeqLinearModel(nn.Module):
        def __init__(self, input_size):
            super(EpSeqLinearModel, self).__init__()
            self.name = "ep_seq_linear"
            self.linear = nn.Linear(input_size, 1)

        def forward(self, Y_ji, N_ji):
            x = torch.cat((Y_ji, N_ji), axis=-1)
            x = self.linear(x)
            return x.squeeze(-1)
    
    class LSTMModel(nn.Module):
        def __init__(self, input_size, hidden_layer_size, num_layers, bidirectional):
            super(LSTMModel, self).__init__()
            self.name = "lstm"
            self.lstm = nn.LSTM(input_size, hidden_layer_size, num_layers, bidirectional=bidirectional, batch_first=True)
            self.bidirectional_linear = nn.Linear(2 * hidden_layer_size, 1)
            self.linear = nn.Linear(hidden_layer_size, 1)
            self.bidirectional = bidirectional

        def forward(self, Y_ji, N_ji):
            x = torch.cat((Y_ji, N_ji), axis=-1)
            x, _ = self.lstm(x)
            if self.bidirectional:
                x = self.bidirectional_linear(x)
            else:
                x = self.linear(x)
            return x.squeeze(-1)

    class CNN(nn.Module):
        def __init__(self, num_features, num_seq_features, y_hidden_layer_sizes, y_filter_size, y_dilation_rate,
                     n_hidden_layer_sizes, n_filter_size, n_dilation_rate, dropout, 
                     lstm_hidden_layer_size, num_lstm_layers=None, bidirectional=False):
            super(CNN, self).__init__()
            self.name = "cnn"            

            self.y_convs = nn.ModuleList()
            y_in_channels = num_features
            if isinstance(y_filter_size, int):
                y_filter_size = [y_filter_size] * len(y_hidden_layer_sizes)
            elif isinstance(y_filter_size, list):
                current_length = len(y_filter_size)
                if current_length < len(y_hidden_layer_sizes):
                    # Extend filter_size by repeating its last value to match the target length
                    y_filter_size = y_filter_size + [y_filter_size[-1]] * (len(y_hidden_layer_sizes) - current_length)
            
            y_dilation = 1
            for idx, out_channels in enumerate(y_hidden_layer_sizes):
                y_padding = ((y_filter_size[idx] - 1) * y_dilation) // 2
                if y_dilation_rate > 0:
                    self.y_convs.append(
                        nn.Conv1d(y_in_channels, out_channels, y_filter_size[idx], stride=1, padding=y_padding, dilation=y_dilation)
                    )
                    y_dilation *= y_dilation_rate
                else:
                    self.y_convs.append(
                        nn.Conv1d(y_in_channels, out_channels, y_filter_size[idx], stride=1, padding='same')
                    )
                y_in_channels = out_channels
            
            
            self.n_convs = nn.ModuleList()
            n_in_channels = num_seq_features
            if isinstance(n_filter_size, int):
                n_filter_size = [n_filter_size] * len(n_hidden_layer_sizes)
            elif isinstance(n_filter_size, list):
                current_length = len(n_filter_size)
                if current_length < len(n_hidden_layer_sizes):
                    # Extend filter_size by repeating its last value to match the target length
                    n_filter_size = n_filter_size + [n_filter_size[-1]] * (len(n_hidden_layer_sizes) - current_length)
            

            n_dilation = 1
            for idx, out_channels in enumerate(n_hidden_layer_sizes):
                n_padding = ((n_filter_size[idx] - 1) * n_dilation) // 2
                if n_dilation_rate > 0:
                    self.n_convs.append(
                        nn.Conv1d(n_in_channels, out_channels, n_filter_size[idx], stride=1, padding=n_padding, dilation=n_dilation)
                    )
                    n_dilation *= n_dilation_rate
                else:
                    self.n_convs.append(
                        nn.Conv1d(n_in_channels, out_channels, n_filter_size[idx], stride=1, padding='same')
                    )
                n_in_channels = out_channels

            self.relu = nn.ReLU()
            self.dropout = nn.Dropout(dropout)

            # Final convolutional layer to map to a single output channel
            # Since the output needs to be (batch_size, seq_len), we map the final features to 1
            self.final_conv = nn.Conv1d(y_hidden_layer_sizes[-1] + n_hidden_layer_sizes[-1], 1, 1)  # 1x1 convolution
            
            if num_lstm_layers > 0:
                self.gru = nn.GRU(input_size=y_hidden_layer_sizes[-1] + n_hidden_layer_sizes[-1], hidden_size=lstm_hidden_layer_size, num_layers=num_lstm_layers, bidirectional=bidirectional, batch_first=True)
            self.final_linear = nn.Linear(lstm_hidden_layer_size, 1)
            self.bidirectional = True
            self.final_bidirectional_linear = nn.Linear(lstm_hidden_layer_size*2, 1)
            #self.batch_norm = nn.BatchNorm1d(y_hidden_layer_sizes[-1] + n_hidden_layer_sizes[-1])
            
        def forward(self, Y_ji, N_ji):
            Y_ji = Y_ji.permute(0, 2, 1)  
            N_ji = N_ji.permute(0, 2, 1)
            
            for conv in self.y_convs:
                Y_ji = conv(Y_ji)
                Y_ji = self.relu(Y_ji)
                Y_ji = self.dropout(Y_ji)
            
            for conv in self.n_convs:
                N_ji = conv(N_ji)
                N_ji = self.relu(N_ji)
                N_ji = self.dropout(N_ji)

            x = torch.cat((Y_ji, N_ji), 1)
            
            #x = self.batch_norm(x)
            
            if num_lstm_layers > 0:
                x = x.permute(0,2,1)
                x, (h_n, c_n) = self.gru(x)
                if self.bidirectional:
                    x = self.final_bidirectional_linear(x)
                else:
                    x = self.final_linear(x)
                x = x.squeeze(-1)
            
            else:
                x = self.final_conv(x)
                x = x.squeeze(1)  
                
            return x
    
    if model_type == 'lstm':
        model = LSTMModel(num_features + num_seq_features, lstm_hidden_layer_sizes, num_lstm_layers, bidirectional)
    elif model_type == 'ep_seq_linear':
        model = EpSeqLinearModel(num_features + num_seq_features)
    elif model_type == 'ep_linear':
        model = EpLinearModel(num_features)
    elif model_type == 'cnn':
        model = CNN(num_features, num_seq_features, y_hidden_layer_sizes, y_filter_size, y_dilation_rate, 
                    n_hidden_layer_sizes, n_filter_size, n_dilation_rate, dropout, lstm_hidden_layer_size,
                    num_lstm_layers, bidirectional)
    
    if cuda_available:
        if num_gpus > 1:
            print("Using", num_gpus, "GPUs")
            Dmodel = torch.nn.DataParallel(model)
        model = model.to('cuda')

    print(model)
    
    # expected weights are close to 0 which is why 0 initializing weights converges much quicker
    if weight_init == 'zero':
        with torch.no_grad():
            for param in model.parameters():
                param.zero_()
    
    model.double()

    return model.to(device)