In [1]:
import onnx
from onnx2pytorch import ConvertModel
from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationLpNorm

from vnnlib.compat import read_vnnlib_simple

import torch
import numpy as np

from collections import OrderedDict

import csv

import os
from timeit import default_timer as timer

  from .autonotebook import tqdm as notebook_tqdm


# Loading ONNX and VNNLib Specifications

In [2]:
def load_onnx_model(onnx_path, input_shape):
    onnx_model = onnx.load(onnx_path)
    torch_model = ConvertModel(onnx_model)
    
    x_concrete = torch.zeros(input_shape)
    model = BoundedModule(torch_model, x_concrete)
    return model

In [3]:
def load_vnnlib_spec(vnnlib_path, input_shape, n_out):
    n_in = np.prod(input_shape)
    res = read_vnnlib_simple(vnnlib_path, n_in, n_out)
    bnds, spec = res[0]
    
    bnds = np.array(bnds)
    lbs = bnds[:,0]
    ubs = bnds[:,1]
    
    data_min = torch.tensor(lbs, dtype=torch.float32).reshape(input_shape)
    data_max = torch.tensor(ubs, dtype=torch.float32).reshape(input_shape)
    center = 0.5*(data_min + data_max)

    ptb = PerturbationLpNorm(x_L=data_min, x_U=data_max)
    x = BoundedTensor(center, ptb)
    
    return x

In [74]:
onnx_path = '../../../NNPoly.jl/eval/cifar10/onnx/cifar_relu_6_100.onnx'
vnnlib_path = '../../../NNPoly.jl/eval/cifar10/vnnlib/prop_7_spiral_1.vnnlib'

In [4]:
onnx_path = 'example_specs/mnist-net_256x4.onnx'
vnnlib_path = 'example_specs/prop_0_spiral_25.vnnlib'

In [75]:
model = load_onnx_model(onnx_path, [1,3,32,32])
x = load_vnnlib_spec(vnnlib_path, [1,3,32,32], 10)

In [7]:
model(torch.zeros(1,3,32,32))

tensor([[  0.0000,  15.7551,   0.0000, 148.0318,   0.0000,   0.0000,  34.0576,
           0.0000,  48.0111,  19.1834]], grad_fn=<ReluBackward0>)

In [50]:
model(x.data.reshape(-1).reshape(32, 32, 3).transpose(0,2))

tensor([[14.4844,  2.5041,  0.0000, 35.9410, 23.0315,  0.0000,  0.0000, 73.9405,
          4.5554, 35.4615]], grad_fn=<ReluBackward0>)

In [53]:
sub_tensor = torch.tensor([0.49140000343322754, 0.482200026512146, 0.4465000033378601]).reshape(1,3,1,1)
div_tensor = torch.tensor([0.20230001211166382, 0.19940000772476196, 0.20100000500679016]).reshape(1,3,1,1)
model((x.data.reshape(-1).reshape(32, 32, 3).transpose(0,2) - sub_tensor) / div_tensor)

tensor([[14.4844,  2.5041,  0.0000, 35.9409, 23.0315,  0.0000,  0.0000, 73.9405,
          4.5554, 35.4615]], grad_fn=<ReluBackward0>)

In [65]:
model(x.data.reshape(-1).reshape(1, 32, 32, 3).permute(0, 3, 1, 2))

tensor([[  0.0000,  61.7497,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000, 151.8912]], grad_fn=<ReluBackward0>)

In [64]:
model(x.data.permute(0, 3, 1, 2))

tensor([[0.0000e+00, 1.1229e+01, 1.0995e+02, 4.6737e+01, 5.2612e-02, 0.0000e+00,
         8.6105e+01, 0.0000e+00, 0.0000e+00, 3.6859e+01]],
       grad_fn=<ReluBackward0>)

In [67]:
torch.arange(1, 3073).reshape(1, 32, 32, 3).permute(0, 3, 1, 2).reshape(-1)

tensor([   1,    4,    7,  ..., 3066, 3069, 3072])

In [72]:
torch.arange(1, 3073).reshape(1, 32, 32, 3).permute(0, 3, 1, 2)[:,:,0,0]

tensor([[1, 2, 3]])

In [78]:
x = torch.tensor([235,235,235,231,231,231,232,232,232,232,232,232,232,232,232,232,232,232,232,232,232,232,232,232,232,232,232,232,232,232,233,233,233,233,233,233,233,233,233,233,233,233,233,233,233,233,232,233,233,231,233,232,231,233,231,233,233,230,233,232,232,232,234,232,231,234,232,232,232,233,233,230,232,233,231,233,233,233,232,232,232,232,232,232,232,232,232,233,233,233,233,233,233,232,232,232,238,238,238,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,236,236,236,236,236,236,236,236,236,236,236,236,236,236,236,236,236,236,237,234,233,236,234,233,236,236,234,234,236,234,234,235,237,234,234,238,235,236,237,236,236,235,236,236,234,236,236,236,235,235,235,235,235,235,235,235,235,236,236,236,236,236,236,235,235,235,237,237,237,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,235,235,235,235,234,234,236,233,231,236,234,231,235,235,234,234,235,236,227,230,233,231,235,238,231,233,235,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,235,235,235,235,235,235,234,234,234,238,238,238,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,234,235,235,235,235,235,234,233,233,230,232,232,231,228,230,232,223,226,231,186,192,197,209,216,219,207,210,213,228,228,230,236,235,235,234,234,234,234,234,234,234,234,234,234,234,234,235,235,235,235,235,235,235,235,235,237,237,237,234,234,234,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,235,234,234,234,234,234,234,235,235,235,235,235,235,234,234,234,234,234,234,235,235,235,235,235,235,236,238,236,233,237,237,219,225,230,203,210,219,163,172,179,195,205,208,214,218,221,230,229,232,237,235,237,235,235,235,235,235,235,235,235,235,235,236,236,236,236,236,236,236,236,236,236,236,239,239,238,236,235,235,236,235,235,236,235,235,236,235,235,236,235,235,235,236,235,235,235,235,234,234,234,235,235,235,237,236,236,237,236,236,234,235,236,232,233,234,235,237,237,229,231,232,208,216,218,194,205,210,185,198,207,174,188,200,165,179,189,184,196,202,207,215,220,226,228,232,236,235,237,236,236,235,236,236,235,236,236,235,236,236,236,237,237,237,237,237,237,237,237,237,228,229,229,228,227,228,232,230,231,231,228,230,234,232,233,237,236,236,237,237,235,236,237,235,237,235,236,237,235,236,239,236,237,239,237,238,225,229,230,224,228,229,233,237,238,221,226,228,183,197,204,161,180,190,159,180,191,154,176,190,144,163,177,143,159,171,156,169,177,198,206,211,233,238,239,236,237,234,235,236,233,235,235,235,235,236,236,236,238,237,237,237,237,239,237,238,212,220,222,224,230,233,230,234,238,227,232,234,229,234,234,234,237,236,237,238,235,238,237,236,239,237,238,239,237,238,239,236,237,240,238,239,201,204,203,219,222,221,233,236,235,214,218,218,193,204,210,185,201,210,184,201,211,173,191,203,165,182,196,159,174,187,162,176,185,186,199,204,229,239,240,234,239,238,233,238,237,233,238,238,234,239,238,236,239,238,237,239,238,238,238,238,216,234,241,221,236,243,225,238,246,225,239,243,227,240,240,231,238,237,236,237,235,238,236,235,238,236,237,238,236,237,237,237,237,239,239,239,197,198,196,220,221,218,233,234,231,230,231,229,209,213,217,209,216,222,219,228,235,208,218,227,209,221,234,210,224,235,217,233,240,218,235,241,225,240,243,228,238,240,228,239,240,230,240,240,230,240,239,235,240,239,237,240,239,238,238,238,118,140,149,119,138,148,124,142,153,136,155,161,172,188,191,225,234,233,235,236,233,237,234,232,236,233,234,235,235,235,235,237,236,233,237,235,214,216,214,226,228,226,232,234,232,236,237,236,228,230,232,227,230,235,231,236,241,225,232,239,225,237,247,217,233,243,201,219,226,185,204,211,172,189,195,167,179,186,167,180,185,186,199,201,223,235,235,235,241,239,236,240,239,238,240,239,109,130,141,103,121,133,108,125,137,111,127,137,146,159,165,222,229,231,227,228,225,229,226,224,236,232,233,234,234,234,231,236,234,230,237,235,229,234,235,231,235,236,232,237,238,230,235,236,231,236,238,231,237,240,229,237,241,223,232,238,191,206,213,164,184,191,146,165,172,137,156,163,134,149,159,128,140,153,121,133,143,149,162,166,216,228,229,234,241,239,235,240,238,237,240,239,195,212,224,188,202,215,199,211,224,200,211,223,209,217,227,223,227,231,213,213,211,211,209,206,216,213,214,220,222,222,219,226,225,210,221,219,209,219,223,211,221,225,216,225,230,220,229,233,225,234,237,226,236,239,225,237,241,218,231,237,183,204,208,175,198,203,181,200,207,178,194,202,186,197,211,170,178,196,142,151,164,185,195,202,219,230,233,231,240,238,234,241,239,236,240,239,193,207,222,191,202,217,202,211,224,214,217,234,223,225,241,214,219,227,203,208,208,171,174,174,177,180,183,207,213,214,174,184,188,98,112,121,93,114,126,101,121,132,111,129,139,122,138,147,137,152,161,153,167,174,202,216,220,223,236,237,218,232,235,220,233,238,223,234,240,217,226,233,221,228,237,212,219,229,196,203,212,222,230,237,219,227,234,221,230,233,232,239,242,235,241,242,113,130,152,111,125,147,113,125,141,125,131,151,138,145,165,170,182,193,191,201,205,190,199,204,208,219,226,216,230,234,158,172,183,54,71,92,45,70,91,49,73,91,53,73,90,66,84,98,102,114,129,159,168,179,221,227,233,234,239,241,233,237,241,227,231,237,223,228,233,207,211,217,202,208,212,211,218,220,212,219,223,199,206,214,179,186,196,188,197,205,211,221,227,221,231,234,61,81,108,69,86,114,63,79,100,68,85,102,123,141,155,139,155,164,151,157,164,195,200,207,214,228,234,206,223,228,163,180,190,103,121,138,95,112,131,101,117,135,138,151,168,181,192,207,207,212,223,221,222,232,219,219,227,205,203,212,183,186,195,158,166,174,147,154,163,131,138,147,125,133,140,130,139,144,136,146,152,133,142,151,128,137,147,138,153,160,182,197,203,197,212,216,40,53,77,58,70,94,85,98,116,127,144,153,132,151,156,96,107,110,119,115,118,163,158,161,173,180,182,184,194,197,182,194,198,181,193,200,183,194,202,198,209,217,218,228,236,200,210,217,174,181,186,159,165,172,145,150,159,132,136,149,116,125,138,98,111,123,94,106,118,99,111,123,105,118,128,107,121,130,122,135,145,138,151,161,150,164,174,157,174,184,188,206,213,185,203,208,13,15,35,26,29,47,134,140,151,206,216,220,138,150,150,118,123,123,141,133,134,172,162,162,181,181,180,207,209,211,220,224,225,228,234,233,224,234,232,230,241,240,226,238,238,176,189,190,144,159,163,138,154,162,142,158,170,145,163,177,154,171,187,149,165,182,149,165,182,154,171,187,157,174,189,160,177,191,173,190,204,187,204,217,190,207,218,178,196,208,165,183,193,157,175,183,5,5,24,58,62,79,200,207,217,225,232,239,197,205,212,199,207,211,212,212,218,226,224,229,229,230,237,233,236,246,232,238,245,230,238,239,209,221,220,223,238,239,221,238,241,210,228,234,198,217,228,180,200,214,193,216,230,188,213,229,189,212,231,194,214,234,192,212,232,184,204,224,172,193,212,171,191,209,161,181,197,144,165,179,136,156,169,131,146,161,128,143,158,138,154,165,39,45,71,145,155,179,190,204,222,186,196,216,184,197,217,192,211,229,194,211,230,194,208,227,194,206,227,191,203,228,192,207,228,190,207,221,177,193,207,180,198,215,154,176,193,147,169,188,145,161,184,156,171,195,146,163,186,113,133,156,114,137,161,132,157,180,126,150,173,111,135,158,92,115,138,91,112,135,93,114,133,94,116,131,105,125,140,121,133,151,129,141,158,129,142,156,122,135,161,162,179,207,143,160,194,137,154,189,131,152,187,128,152,190,127,150,192,130,150,193,131,150,192,128,147,190,127,147,189,129,149,189,129,149,188,124,145,186,104,126,163,100,122,154,102,120,154,118,134,170,112,128,163,94,109,145,94,112,148,94,117,153,87,112,144,83,103,136,80,97,130,83,103,134,93,111,139,101,117,141,108,121,144,115,125,146,121,133,148,130,144,156,73,87,109,76,90,113,77,90,122,80,93,127,84,98,134,87,102,142,87,102,147,90,105,150,94,111,152,102,119,160,107,124,165,113,131,172,115,137,181,118,136,186,118,132,180,120,133,175,115,136,172,110,133,168,106,127,163,100,119,155,95,109,148,85,101,139,79,97,132,80,92,127,80,94,129,77,100,133,80,100,129,82,98,122,92,104,126,113,119,138,125,135,146,136,149,156,13,25,41,3,11,25,9,16,35,18,26,48,18,26,52,21,25,56,20,25,58,22,30,61,26,36,62,34,43,70,42,51,77,48,59,87,52,69,106,60,75,121,66,77,126,70,79,126,71,87,127,72,88,126,67,81,120,60,72,112,55,67,106,53,68,104,53,69,103,57,69,102,57,71,105,57,78,110,72,89,115,87,100,119,104,113,128,120,124,136,130,136,141,137,146,149,36,46,55,11,16,20,8,13,19,32,44,53,36,45,58,22,25,41,8,11,30,3,8,24,1,4,17,0,2,15,0,2,15,0,4,20,6,13,42,5,18,56,1,19,60,3,23,62,13,29,71,24,38,81,21,33,77,21,31,76,21,38,78,22,44,79,30,50,83,39,58,90,57,70,101,85,90,118,113,115,138,123,123,138,116,115,125,122,123,128,134,139,137,153,160,158,35,41,45,26,27,26,13,19,18,27,41,41,71,81,84,70,70,76,49,50,57,27,31,37,15,15,21,5,5,11,2,2,7,0,0,7,17,17,35,57,64,91,31,50,78,10,36,62,4,30,60,4,30,62,7,30,63,14,35,69,25,43,74,41,55,83,62,71,99,86,97,123,122,124,146,144,131,149,132,120,135,114,105,114,117,111,116,132,134,133,146,152,146,172,179,175,16,15,17,13,10,9,4,10,8,3,12,11,45,44,46,65,52,57,54,43,47,36,33,35,18,18,20,4,4,7,2,2,4,0,1,3,7,8,15,118,117,134,161,158,179,131,128,148,112,112,131,105,105,125,105,103,124,109,105,127,118,107,126,138,115,133,154,126,144,151,126,141,127,106,116,105,86,91,106,94,97,120,116,116,129,130,129,142,147,144,164,172,165,184,194,190,40,40,35,12,10,7,0,3,3,0,4,4,12,6,7,30,12,17,32,12,17,21,10,12,7,6,7,2,1,3,2,1,2,3,2,3,0,0,2,68,58,64,182,128,146,205,130,148,196,127,144,194,123,141,195,119,137,187,113,129,172,110,122,150,96,106,123,75,83,103,66,69,95,71,70,104,93,88,122,118,113,129,132,126,132,141,135,152,162,158,171,182,176,185,197,194,69,77,64,26,29,21,1,1,1,1,1,2,4,1,0,12,2,5,18,3,9,12,2,5,4,1,2,2,0,0,2,0,0,4,0,1,1,1,1,32,12,11,153,45,59,203,47,68,195,46,67,191,48,69,179,50,67,155,49,59,119,42,49,91,38,42,81,48,46,94,77,71,117,110,102,125,126,116,125,128,120,129,135,128,144,153,147,162,176,171,173,187,183,184,198,196,83,94,82,47,52,43,1,1,1,2,1,2,2,0,0,5,1,2,7,1,5,4,0,2,1,0,0,1,0,0,1,0,0,3,0,0,1,2,0,27,3,2,142,25,38,205,32,54,198,25,46,169,25,43,121,25,36,85,29,34,74,41,39,85,66,56,102,92,82,121,113,105,128,124,115,122,126,115,121,127,118,132,139,131,147,157,150,165,179,174,176,191,187,186,201,199,92,102,93,54,60,50,6,7,3,3,2,1,2,2,0,1,3,1,1,3,3,1,2,2,1,1,1,1,0,0,1,0,0,1,1,1,0,3,2,15,1,0,102,19,28,157,31,47,117,17,23,74,13,12,56,27,22,74,58,55,99,90,81,115,115,99,122,126,111,124,124,112,123,123,113,125,130,119,128,135,126,136,145,137,148,159,151,162,176,171,177,192,188,188,202,201,87,99,89,43,51,37,19,23,11,11,12,4,8,10,2,5,11,4,2,10,4,2,7,2,3,4,1,3,4,1,3,4,1,2,3,2,0,6,6,4,5,2,42,13,13,71,21,24,53,27,25,57,50,41,80,77,62,113,98,82,132,113,101,134,126,113,123,126,112,116,125,111,120,128,115,131,138,126,139,148,137,143,154,145,156,168,161,169,184,179,182,197,193,188,202,201,82,96,82,46,57,36,36,44,22,31,35,17,27,30,15,22,28,15,17,26,13,16,23,12,18,21,12,19,21,13,20,22,14,19,23,15,19,27,20,23,31,21,37,40,27,64,55,45,87,70,67,104,88,81,116,102,85,128,112,88,139,121,105,131,122,110,117,122,107,115,127,112,123,133,119,131,139,127,139,149,138,148,160,151,159,172,164,174,189,183,185,200,196,187,202,200,85,101,83,62,75,48,58,67,38,55,61,37,51,56,35,47,53,33,46,53,34,48,55,38,49,55,40,51,56,41,53,58,44,55,62,46,59,67,45,68,71,48,81,84,59,104,96,74,116,103,83,127,109,92,133,116,97,127,121,97,127,127,107,118,124,106,114,125,108,122,131,117,129,136,123,136,145,133,141,152,141,149,162,153,158,171,163,168,183,178,180,195,191,186,200,199])
x = x / 255
x = x.reshape(1, 32, 32, 3).permute(0, 3, 1, 2)

y = model(x)
print(y)
print(torch.argmax(y))

tensor([[33.7842, 79.7529,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         44.3727, 52.3093]], grad_fn=<ReluBackward0>)
tensor(1)


In [41]:
model(x.data)

tensor([[  0.0000, 194.2841,  11.4353,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,  15.2122,  73.4646]], grad_fn=<ReluBackward0>)

In [18]:
model.nodes['/31']

TypeError: 'method' object is not subscriptable

In [31]:
layers[1]

BoundBuffers(name=/15, inputs=[], perturbed=False)

BoundSub(name=/31, inputs=[/0, /15], perturbed=False)

In [24]:
layers = [l for l in model.nodes()]
layers

[BoundInput(name=/0, inputs=[], perturbed=False),
 BoundBuffers(name=/15, inputs=[], perturbed=False),
 BoundBuffers(name=/16, inputs=[], perturbed=False),
 BoundParams(name=/17, inputs=[], perturbed=False),
 BoundParams(name=/18, inputs=[], perturbed=False),
 BoundParams(name=/19, inputs=[], perturbed=False),
 BoundParams(name=/20, inputs=[], perturbed=False),
 BoundParams(name=/21, inputs=[], perturbed=False),
 BoundParams(name=/22, inputs=[], perturbed=False),
 BoundParams(name=/23, inputs=[], perturbed=False),
 BoundParams(name=/24, inputs=[], perturbed=False),
 BoundParams(name=/25, inputs=[], perturbed=False),
 BoundParams(name=/26, inputs=[], perturbed=False),
 BoundParams(name=/27, inputs=[], perturbed=False),
 BoundParams(name=/28, inputs=[], perturbed=False),
 BoundParams(name=/29, inputs=[], perturbed=False),
 BoundParams(name=/30, inputs=[], perturbed=False),
 BoundSub(name=/31, inputs=[/0, /15], perturbed=False),
 BoundFlatten(name=/33, inputs=[/32/mul], perturbed=False),


In [5]:
model = load_onnx_model(onnx_path, [1,1,1,784])
x = load_vnnlib_spec(vnnlib_path, [1,1,1,784], 10)

  layer.weight.data = torch.from_numpy(numpy_helper.to_array(weight))
  if not self.experimental and inputs[0].shape[self.batch_dim] > 1:


# Helper Methods

In [6]:
def get_layers(model):
    return [l for l in model.nodes() if l.perturbed]

In [7]:
def get_intermediate_bounds(model):
    """
    Returns a dictionary containing the concrete lower and upper bounds of each layer.
    
    Implemented own method to filter out bounds for weight matrices.
    
    Only call this method after compute_bounds()!
    """
    od = OrderedDict()
    for l in get_layers(model):
        if hasattr(l, 'lower'):
            od[l.name] = (l.lower, l.upper)
            
    return od

# Get Intermediate Bounds

In [8]:
model.compute_bounds(x=(x,), method='ibp')

(tensor([[-447.2314, -373.2225, -492.4267, -520.7729, -479.3799, -425.8518,
          -413.6321, -447.6174, -532.3359, -408.0772]], grad_fn=<AddBackward0>),
 tensor([[356.4533, 318.0782, 320.3526, 274.7251, 371.1610, 338.5822, 395.8803,
          377.8232, 375.9247, 400.2977]], grad_fn=<AddBackward0>))

In [9]:
bounds_dict = get_intermediate_bounds(model)
for k, (lbs, ubs) in bounds_dict.items():
    print(f"{k}: {lbs.shape}")

/0: torch.Size([1, 1, 1, 784])
/21: torch.Size([1, 784])
/input: torch.Size([1, 256])
/23: torch.Size([1, 256])
/input.3: torch.Size([1, 256])
/25: torch.Size([1, 256])
/input.7: torch.Size([1, 256])
/27: torch.Size([1, 256])
/input.11: torch.Size([1, 256])
/29: torch.Size([1, 256])
/30: torch.Size([1, 10])


In [10]:
bounds_dict['/30']

(tensor([[-447.2314, -373.2225, -492.4267, -520.7729, -479.3799, -425.8518,
          -413.6321, -447.6174, -532.3359, -408.0772]], grad_fn=<AddBackward0>),
 tensor([[356.4533, 318.0782, 320.3526, 274.7251, 371.1610, 338.5822, 395.8803,
          377.8232, 375.9247, 400.2977]], grad_fn=<AddBackward0>))

In [11]:
model.compute_bounds(x=(x,), method='crown')

(tensor([[-0.4672, -0.4103, -0.4026, -0.3573, -0.4667, -0.4770, -0.3886, -0.4031,
          -0.5523, -0.4308]], grad_fn=<ViewBackward0>),
 tensor([[0.6201, 0.4692, 0.6307, 0.4833, 1.8194, 0.5396, 0.9479, 0.4172, 0.6483,
          0.6467]], grad_fn=<ViewBackward0>))

In [13]:
bounds_dict['/30']

(tensor([[-447.2314, -373.2225, -492.4267, -520.7729, -479.3799, -425.8518,
          -413.6321, -447.6174, -532.3359, -408.0772]], grad_fn=<AddBackward0>),
 tensor([[356.4533, 318.0782, 320.3526, 274.7251, 371.1610, 338.5822, 395.8803,
          377.8232, 375.9247, 400.2977]], grad_fn=<AddBackward0>))

In [14]:
bounds_dict_crown = get_intermediate_bounds(model)
bounds_dict_crown['/30']

(tensor([[-0.4672, -0.4103, -0.4026, -0.3573, -0.4667, -0.4770, -0.3886, -0.4031,
          -0.5523, -0.4308]], grad_fn=<ViewBackward0>),
 tensor([[0.6201, 0.4692, 0.6307, 0.4833, 1.8194, 0.5396, 0.9479, 0.4172, 0.6483,
          0.6467]], grad_fn=<ViewBackward0>))

**Attention**: CROWN-bounds are only saved for pre-activation nodes and the output!
(in contrast to interval propagation bounds, that are saved for every layer)

In [15]:
bounds_dict_crown.keys()

odict_keys(['/0', '/input', '/input.3', '/input.7', '/input.11', '/30'])

In [16]:
lbs11_ibp, ubs11_ibp = bounds_dict['/input.11']
lbs11_crown, ubs11_crown = bounds_dict_crown['/input.11']

print(lbs11_ibp[:,:10])
print(lbs11_crown[:,:10])

tensor([[-171.4623, -160.4540, -304.7760, -128.2996, -153.7493, -117.9111,
         -123.5061, -278.8083, -158.6687, -207.4034]],
       grad_fn=<SliceBackward0>)
tensor([[-1.8260, -3.1913, -2.1607, -2.8250, -4.0157, -2.1657, -0.7982, -1.6259,
         -1.0763, -5.0089]], grad_fn=<SliceBackward0>)


# Sampling via CROWN

In order to use CROWN to calculate bounds for the sampled directions, we make use of the possibility to supply
- a constraint matrix (which we use to represent the sampled directions) and
- to specify the output layer (which we just set to the layer, for which we want to sample)

The shape of the constraint matrix is `(batch, n_directions, n_neurons)`, where we just set `batch=1`.

The output layer is specified via the node names in the node dictionary.

In [17]:
n_batch = 1
n_dirs = 5
n_neurons = 256
C = torch.randn(n_batch, n_dirs, n_neurons)

model.compute_bounds(x=(x,), final_node_name='/input.11', C=C, method='crown')

(tensor([[ -3.5966,   3.4826,   0.1199,  -2.3320, -73.6734]],
        grad_fn=<ViewBackward0>),
 tensor([[ 13.9218, 105.1290,  19.2300,  36.2667,   3.3111]],
        grad_fn=<ViewBackward0>))

We can also use $\alpha$-CROWN to optimize the bounds of the directions.

In [18]:
model.compute_bounds(x=(x,), final_node_name='/input.11', C=C, method='alpha-crown')

(tensor([[ -2.8655,  18.5538,   0.2205,   5.4271, -68.6798]]),
 tensor([[ 4.0163, 98.9213,  7.0998, 34.7713, -8.0158]]))

When using more iterations, the bounds may get slightly better.

In [19]:
model.bound_opts

{'conv_mode': 'patches',
 'sparse_intermediate_bounds': True,
 'sparse_conv_intermediate_bounds': True,
 'sparse_intermediate_bounds_with_ibp': True,
 'sparse_features_alpha': True,
 'sparse_spec_alpha': True,
 'minimum_sparsity': 0.9,
 'enable_opt_interm_bounds': False,
 'crown_batch_size': inf,
 'forward_refinement': False,
 'dynamic_forward': False,
 'forward_max_dim': 1000000000,
 'use_full_conv_alpha': True,
 'disabled_optimization': [],
 'use_full_conv_alpha_thresh': 512,
 'verbosity': 0,
 'optimize_graph': {'optimizer': None},
 'optimize_bound_args': {'enable_alpha_crown': True,
  'enable_beta_crown': False,
  'apply_output_constraints_to': None,
  'iteration': 20,
  'use_shared_alpha': False,
  'optimizer': 'adam',
  'keep_best': True,
  'fix_interm_bounds': True,
  'lr_alpha': 0.5,
  'lr_beta': 0.05,
  'lr_cut_beta': 0.005,
  'init_alpha': True,
  'lr_coeffs': 0.01,
  'intermediate_refinement_layers': [-1],
  'loss_reduction_func': <function auto_LiRPA.utils.<lambda>(x)>,
  's

In [20]:
def set_params(model, use_shared_alpha=False, iteration=20, early_stop_patience=10):
    model.bound_opts['optimize_bound_args']['use_shared_alpha'] = use_shared_alpha
    model.bound_opts['optimize_bound_args']['iteration'] = iteration
    model.bound_opts['optimize_bound_args']['early_stop_patience'] = early_stop_patience

In [21]:
set_params(model, iteration=100)
model.compute_bounds(x=(x,), final_node_name='/input.11', C=C, method='alpha-crown')

(tensor([[ -2.8637,  18.6105,   0.2209,   5.4440, -68.6639]]),
 tensor([[ 3.9676, 98.8920,  7.0489, 34.7632, -8.0771]]))

# Converting Bounds to Points

CROWN only gives us **bounds** for linear combinations of neuron inputs that we specified in the matrix `C`. 
However, we need **points** - not the bounds.
Therefore, we also save the parameters of the backsubstituted inequalities and obtain the points in the input space that maximize/minimize the corresponding linear inequalities.
These maximizers/minimizers are then substituted into the inequalities for the neuron-inputs.

To obtain the coefficients of the inequalities, we need to set 
- `return_A = True` and also specify which coefficients we need by setting
- `needed_A_dict = {<layer_i> : [<layer_j>, <layer_k>]}` which will return the matrix of coefficients when substituting back from `layer_i` to `layer_j` and the matrix when substituting back from `layer_i` to `layer_k`

In [77]:
lbs11crown, ubs11crown, A_dict = model.compute_bounds(x=(x,), final_node_name='/input.11', method='alpha-crown', return_A=True, needed_A_dict={'/input.11': ['/0']})

The stored coefficients have shape `(batch, spec, *input_size)`. 

So if the input had shape `(1,1,1,784)` (which is `(batch, input_size1, input_size2, input_size3)`) and the matrix of specifications had `5` inequalities, the stored coefficients will have shape `(1, 5, 1, 1, 784)` (which is `(batch, spec, input_size1, input_size2, input_size3)`.

If no matrix of specifications is given, the `spec` dimension just has the shape of the layer (i.e. if it has 256 neurons, we have `spec = 256`)

In [78]:
lA_neurons = A_dict['/input.11']['/0']['lA']
lb_neurons = A_dict['/input.11']['/0']['lbias']

uA_neurons = A_dict['/input.11']['/0']['uA']
ub_neurons = A_dict['/input.11']['/0']['ubias']

print("A.shape = ", lA_neurons.shape)

A.shape =  torch.Size([1, 256, 1, 1, 784])


After obtaining the coefficients for the **neurons**, we now obtain the coefficients for the sampled **directions**.

In [62]:
lbs11crown, ubs11crown, A_dict = model.compute_bounds(x=(x,), final_node_name='/input.11', C=C, method='alpha-crown', return_A=True, needed_A_dict={'/input.11': ['/0']})

In [64]:
lA = A_dict['/input.11']['/0']['lA']
lb = A_dict['/input.11']['/0']['lbias']

uA = A_dict['/input.11']['/0']['uA']
ub = A_dict['/input.11']['/0']['ubias']

print("A.shape = ", lA.shape)

A.shape =  torch.Size([1, 5, 1, 1, 784])


In [65]:
def flatten2matrix(A_tensor, batch_id=0):
    """
    Returns matrix of shape (spec, input_flat) corresponding to batch_id of the A_tensor.
    """
    # reshape to (batch, spec, input_flat), then take specific batch (0th batch)
    A_mat = A_tensor.reshape(A_tensor.shape[0], A_tensor.shape[1], -1)[batch_id, :]    
    return A_mat

def flatten2vector(b_tensor, batch_id=0):
    """
    Returns flat input vector corresponding to the specified batch.
    """
    b_vec = b_tensor.reshape(b_tensor.shape[0], -1)[batch_id, :]
    return b_vec

The following cell is just to demonstrate that we get the same bounds as for the specification, if we calculate them by hand.
(minor differences are expected due to randomization in the SGD procedure)

In [79]:
lA_mat = flatten2matrix(lA)
lb_vec = flatten2vector(lb)

uA_mat = flatten2matrix(uA)
ub_vec = flatten2vector(ub)

x_L_vec = flatten2vector(x.ptb.x_L)
x_U_vec = flatten2vector(x.ptb.x_U)

lA_neg = torch.minimum(torch.zeros(1), lA_mat)
lA_pos = torch.maximum(torch.zeros(1), lA_mat)

uA_neg = torch.minimum(torch.zeros(1), uA_mat)
uA_pos = torch.maximum(torch.zeros(1), uA_mat)

lo = lA_pos.matmul(x_L_vec) + lA_neg.matmul(x_U_vec) + lb_vec

print(lo)

tensor([ -2.8642,  18.5939,   0.2207,   5.4452, -68.6601])


Now get the inputs that are used to calculate the lower bound of the output.

For computation of lower bounds:
- if the coefficient of input $x_i$ is negative, we take the *upper* bound of $x_i$
- if the coefficient of input $x_i$ is positive, we take the *lower* bound of the input

to calculate the lower bound of the specification output

Vice-versa for computation of the upper bound on the specification output.

In [81]:
x_lo = x_L_vec * (lA_mat > 0) + x_U_vec * (lA_mat < 0)
x_up = x_U_vec * (uA_mat > 0) + x_L_vec * (uA_mat < 0)

Then calculate bounds of the neuron inputs using these obtained values of the $x_i$

In [82]:
lA_neurons_mat = flatten2matrix(lA_neurons)
lb_neurons_vec = flatten2vector(lb_neurons)

uA_neurons_mat = flatten2matrix(uA_neurons)
ub_neurons_vec = flatten2vector(ub_neurons)

In [83]:
l_neurons_subs = x_lo.matmul(lA_neurons_mat.T) + lb_neurons_vec
u_neurons_subs = x_lo.matmul(uA_neurons_mat.T) + ub_neurons_vec

In [84]:
l_neurons_subs

tensor([[-0.8823, -1.9052, -1.3219,  ..., -6.1794, -2.7905, -0.5578],
        [-0.8823, -1.9052, -1.3219,  ..., -6.1794, -2.7905, -0.5578],
        [-0.8823, -1.9052, -1.3219,  ..., -6.1794, -2.7905, -0.5578],
        [-0.8807, -1.8992, -1.3147,  ..., -6.1559, -2.7797, -0.5579],
        [-0.9500, -2.4437, -1.9653,  ..., -8.3334, -3.7501, -0.4112]])

In [85]:
u_neurons_subs

tensor([[-0.3018, -0.5380,  0.0490,  ..., -1.9711, -0.6990, -0.1242],
        [-0.3018, -0.5380,  0.0490,  ..., -1.9711, -0.6990, -0.1242],
        [-0.3018, -0.5380,  0.0490,  ..., -1.9711, -0.6990, -0.1242],
        [-0.3014, -0.5367,  0.0505,  ..., -1.9666, -0.6968, -0.1245],
        [-0.4520, -1.0317, -0.4799,  ..., -3.6531, -1.4987, -0.1012]])

In [101]:
C_mat = flatten2matrix(C)

# multiply each row of the specification (i.e. each direction) with each column of l_neurons_subs.T
# (i.e. the rows of l_neurons_subs, which are the minimizers of the respective directions).
torch.maximum(torch.zeros(1), C_mat).matmul(l_neurons_subs.T) + torch.minimum(torch.zeros(1), C_mat).matmul(u_neurons_subs.T)

tensor([[-154.4405, -154.4405, -154.4405, -153.7955, -165.0589],
        [-122.0603, -122.0603, -122.0603, -121.5373, -109.7246],
        [-171.8250, -171.8250, -171.8250, -171.0766, -185.8952],
        [-154.2048, -154.2048, -154.2048, -153.5711, -156.8310],
        [-200.7205, -200.7205, -200.7205, -199.8855, -230.7391]])

In [97]:
C_mat.shape

torch.Size([5, 256])

In [98]:
l_neurons_subs.shape

torch.Size([5, 256])

In [56]:
model.roots()[0].perturbation.concretize(x, A, sign=-1) + A_dict['/input.11']['/0']['lbias']

tensor([[ -4.4035, -44.7540, -81.8630, -27.9624,   2.5924]])

# Building LP Model

There is at least some code available to build LP and MILP models, but it doesn't seem to be maintained/is broken now. Maybe we can repair and use it.

In [89]:
model.build_solver_module(model_type='lp')

AttributeError: 'function' object has no attribute 'CONTINUOUS'

In [86]:
import gurobipy as grb

In [88]:
model.model = grb.Model()

Set parameter Username
Academic license - for non-commercial use only - expires 2024-11-27
