In [1]:
# Import necessary libraries
from maraboupy import Marabou, MarabouCore
import numpy as np
import pandas as pd
import csv
import os
import sys

Instructions for updating:
non-resource variables are not supported in the long term


In [2]:
def write_values_to_csv(values_dict, filename):
    """
    Write the values from a dictionary into a CSV file.

    Args:
    values_dict (dict): A dictionary containing the values to be written into the CSV file.
    filename (str): The name of the CSV file to write into.

    Returns:
    None

    Example:
    >>> values_dict = {'name': 'John', 'age': 30, 'city': 'New York'}
    >>> filename = 'data.csv'
    >>> write_values_to_csv(values_dict, filename)
    Values written to /home/adam/FurtherResearch/data.csv
    """
    # Get the directory of the current script
    current_directory = os.path.dirname(os.path.abspath(__file__))

    # Join the filename with the current directory
    full_path = os.path.join(current_directory, filename)

    # Open the file for appending
    with open(full_path, 'a', newline='') as csvfile:
        writer = csv.writer(csvfile)

        # Extract the values from the dictionary and write them into a row
        values = list(values_dict.values())
        writer.writerow(values)

    print(f"Values written to {full_path}")

In [3]:
# Load the ONNX model
file_name = 'model_without_softmax.onnx'
network = Marabou.read_onnx(file_name)

In [4]:
# Get the input and output variables
inputVars = network.inputVars[0][0]
outputVars = network.outputVars[0]

# Define the indices for the fatigue levels
low_fatigue_idx = 0
medium_fatigue_idx = 1
high_fatigue_idx = 2

In [5]:
# Load the high fatigue statistics
high_fatigue = pd.read_csv('statistic_analysis/dataset_statistics_fatigue_level_2.csv')
mean_values = high_fatigue['mean'][:-1].values

# Define the initial input range (a small range centered around the mean)
initial_range= [0.01] * 63  # Initial range
step_size = high_fatigue['std'][:-1].values * 0.1      # The range to increase at each iteration

# Create the options for the Marabou solver
options = Marabou.createOptions(numWorkers=20, initialTimeout=5, initialSplits=100, onlineSplits=100,
                                    timeoutInSeconds=1800, timeoutFactor=1.5,
                                    verbosity=2, snc=True, splittingStrategy='auto',
                                    sncSplittingStrategy='auto', restoreTreeStates=False,
                                    splitThreshold=20, solveWithMILP=True, dumpBounds=True)

In [None]:
# The iterative process
unsat = True
while unsat:
    # Reset the network
    network = Marabou.read_onnx(file_name)

    # Set the input range
    for i, mean_val in enumerate(mean_values):
        network.setLowerBound(inputVars[i], mean_val - initial_range[i])
        network.setUpperBound(inputVars[i], mean_val + initial_range[i])

    # Define the output condition
    desired_output_class = 2  # High fatigue class index
    for i in range(len(outputVars)):
        if i != desired_output_class:
            network.addInequality([outputVars[0][i], outputVars[0][desired_output_class]], [1, -1], 0)

    # Run the verification
    vals = network.solve(verbose=True,options=options)[0]

    # Check the result
    if vals == "unsat":
        # If it is UNSAT, increase the input range
        for i in range(len(initial_range)):
            initial_range[i] += step_size[i]
    else:
        # If it is not UNSAT, stop the iteration
        unsat = False

# Print the minimum UNSAT range
print(f"Minimum UNSAT range: {initial_range - step_size}")


unsat
unsat
Engine::processInputQuery: Input query (before preprocessing): 196 equations, 453 variables
Engine::processInputQuery: Input query (before preprocessing): 196 equations, 453 variables
Engine::processInputQuery: Input query (after preprocessing): 391 equations, 476 variables

Input bounds:
	x0: [ -4.1704,   4.7455] 
	x1: [951.5378, 1128.2606] 
	x2: [950.1608, 1126.7233] 
	x3: [3344.3706, 3843.5994] 
	x4: [ -4.9346,   2.2596] 
	x5: [330.5839, 374.7581] 
	x6: [334.4529, 378.2972] 
	x7: [1294.9784, 1438.9421] 
	x8: [ -1.8636,   4.1583] 
	x9: [270.1266, 303.7708] 
	x10: [272.4738, 306.0258] 
	x11: [1092.0954, 1198.4870] 
	x12: [ -5.0486,  -3.1618] 
	x13: [ 26.8174,  28.5702] 
	x14: [ 28.0573,  30.1262] 
	x15: [115.0926, 122.0011] 
	x16: [ -6.4110,  -4.8105] 
	x17: [ 34.3732,  37.3169] 
	x18: [ 35.3163,  38.3657] 
	x19: [139.0480, 150.2032] 
	x20: [  7.8436,   9.9187] 
	x21: [ 37.1313,  40.1064] 
	x22: [ 38.6336,  41.9146] 
	x23: [149.8892, 159.9233] 
	x24: [ -0.7464,   0.5029] 
