In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import schedulefree

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

import sys

sys.path.append("../auto_LiRPA/")
import auto_LiRPA
from auto_LiRPA.operators.gurobi_maxpool_lp import compute_maxpool_bias

import time

  """


In [2]:
def normalize_bounds(l, u):
    """
    Takes bounds tensors and normalizes them to [0, 1], s.t. smallest lower bound is mapped to 0
    and largest upper bound is mapped to 1

    args:
        l (batch x channels x w x h) - concrete lower bounds
        u (batch x channels x w x h) - concrete upper bounds

    returns:
        l_norm (batch x channels x w x h) - normalized concrete lower bounds
        u_norm (batch x channels x w x h) - normalized concrete upper bounds
    """
    lmin = l.flatten(-2).min(dim=-1)[0]
    umax = u.flatten(-2).max(dim=-1)[0]
    lmin = lmin.unsqueeze(1)
    umax = umax.unsqueeze(1)

    l_norm = (l.flatten(-2) - lmin) / (umax - lmin)
    u_norm = (u.flatten(-2) - lmin) / (umax - lmin)
    l_norm = l_norm.view(l.shape)
    u_norm = u_norm.view(u.shape)

    return l_norm, u_norm

In [3]:
def sort_by_lower_bound(X):
    """
    Sorts tensor of shape (n_neurons, 3, w, h) by concrete lower bounds (the first channel dim).
    """
    _, ind_tensor = X.flatten(-2)[:,0].sort(dim=-1)
    ind_tensor = ind_tensor.unsqueeze(1).expand(-1, X.size(1), -1)

    return torch.gather(X.flatten(-2), dim=2, index=ind_tensor).view(X.shape)

In [4]:
def create_dataset(n_neurons, h, w):
    # since the normalized version suffices, just stick to that
    x1 = torch.rand(n_neurons, 1, h, w)
    x2 = torch.rand(n_neurons, 1, h, w)

    l = torch.where(x1 <= x2, x1, x2)
    u = torch.where(x1  > x2, x1, x2)
    l, u = normalize_bounds(l, u)

    alpha = torch.rand(n_neurons, 1, h, w)

    biases = compute_maxpool_bias(l, u, alpha)

    return l, u, alpha, biases

The dataset has shape `(n_neurons, 3, h, w)` and
- `X[:,0,:,:]` represents the lower bounds
- `X[:,1,:,:]` represents the upper bounds
- `X[:,2,:,:]` represents the slopes

In [5]:
def create_tensor_dataset(n_neurons_train, n_neurons_val, h, w, sort_by_lb=True):
    l, u, alpha, bias = create_dataset(n_neurons_train, h, w)
    X = torch.cat((l, u, alpha), dim=1)

    if sort_by_lb:
        X = sort_by_lower_bound(X)

    dataset_train = TensorDataset(X, bias)

    l, u, alpha, bias = create_dataset(n_neurons_val, h, w)
    X = torch.cat((l, u, alpha), dim=1)

    if sort_by_lb:
        X = sort_by_lower_bound(X)
        
    dataset_val = TensorDataset(X, bias)

    return dataset_train, dataset_val

Create small dataset to test training runs.

In [6]:
CREATE_DATASET = True

if CREATE_DATASET:
    ds_train, ds_val = create_tensor_dataset(100, 10, 2, 2)
    torch.save(ds_train, './datasets/maxpool2x2_train_100.pth')
    torch.save(ds_val, './datasets/maxpool2x2_val_100.pth')
else:
    ds_train = torch.load('./datasets/maxpool2x2_train_100.pth')
    ds_val   = torch.load('./datasets/maxpool2x2_val_100.pth')

100%|██████████| 100/100 [00:00<00:00, 115.34it/s]
100%|██████████| 10/10 [00:00<00:00, 127.90it/s]


In [7]:
batch_size = 32
train_dataloader = DataLoader(ds_train, batch_size=batch_size, shuffle=True)
val_dataloader   = DataLoader(ds_val, batch_size=batch_size)

In [23]:
def train_loop(net, train_dataloader, val_dataloader, patience=10, num_epochs=100, timeout=60, lossfun='mse', opt='adam', l1_weight=0):
    if lossfun == 'mse':
        criterion = nn.MSELoss()
    elif lossfun == 'mae':
        criterion = nn.L1Loss()
    else:
        raise ValueError('Unknown loss function!')
    
    if opt == 'adam':
        optimizer = optim.Adam(net.parameters())
    elif opt == 'schedulefree':
        optimizer = schedulefree.AdamWScheduleFree(net.parameters(), lr=0.0025)
    else:
        raise ValueError('Uknown optimizer!')


    train_losses = []
    train_maes = []
    val_losses = []
    val_maes = []
    val_maxs = []
    best_val_loss = float('inf')
    early_stopping_cnt = 0
    t_start = time.time()
    for epoch in range(num_epochs):
        t_cur = time.time()
        if t_cur - t_start > timeout:
            print(f"Timeout reached ({t_cur - t_start} sec)")
            break 
        
        net.train()

        if opt == 'schedulefree':
            optimizer.train()

        train_loss = 0.
        train_mae = 0.
        for batch_X, batch_y in train_dataloader:
            y_hat = net(batch_X)
            loss = criterion(y_hat, batch_y)

            l1_loss = 0
            for param in net.parameters():
                l1_loss += param.abs().sum()

            loss += l1_weight * l1_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_mae += torch.abs(y_hat - batch_y).mean().item()

        train_loss /= len(train_dataloader)
        train_mae /= len(train_dataloader)
        train_losses.append(train_loss)
        train_maes.append(train_mae)


        net.eval()

        if opt == 'schedulefree':
            optimizer.eval()
            
        val_loss = 0
        val_mae = 0
        val_max = torch.tensor(0)
        with torch.no_grad():
            for batch_X, batch_y in val_dataloader:
                y_hat = net(batch_X)
                loss = criterion(y_hat, batch_y)
                val_loss += loss.item()
                val_mae += torch.abs(y_hat - batch_y).mean().item()
                val_max = torch.maximum(val_max, torch.max(torch.abs(y_hat - batch_y)))

        val_loss /= len(val_dataloader)
        val_mae /= len(val_dataloader)
        val_losses.append(val_loss)
        val_maes.append(val_mae)
        val_maxs.append(val_max.item())

        print(f"Epoch [{epoch + 1}/{num_epochs}] - train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f}, train_mae: {train_mae:.4f}, val_mae: {val_mae:.4f}, val_max: {val_max.item():.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            early_stopping_cnt = 0
            best_net_state = net.state_dict()
        else:
            early_stopping_cnt += 1
            if early_stopping_cnt >= patience:
                print(f"Stopping early (patience of {patience} reached)")
                break


    print("Training completed")
    return train_losses, val_losses, train_maes, val_maes, val_maxs, best_net_state

In [12]:
train_losses, val_losses, train_maes, val_maes, best_state = train_loop(net, train_dataloader, val_dataloader, timeout=10, patience=5000, num_epochs=1000, l1_weight=0.1)

Epoch [1/1000] - train_loss: 97.5540, val_loss: 0.1034, train_mae: 0.3429, val_mae: 0.2727
Epoch [2/1000] - train_loss: 92.3242, val_loss: 0.1031, train_mae: 0.3380, val_mae: 0.2723
Epoch [3/1000] - train_loss: 87.2752, val_loss: 0.1027, train_mae: 0.3464, val_mae: 0.2717
Epoch [4/1000] - train_loss: 82.3781, val_loss: 0.1023, train_mae: 0.3199, val_mae: 0.2712
Epoch [5/1000] - train_loss: 77.6192, val_loss: 0.1020, train_mae: 0.3232, val_mae: 0.2707
Epoch [6/1000] - train_loss: 72.9903, val_loss: 0.1017, train_mae: 0.2877, val_mae: 0.2702
Epoch [7/1000] - train_loss: 68.5315, val_loss: 0.1014, train_mae: 0.2919, val_mae: 0.2696
Epoch [8/1000] - train_loss: 64.2443, val_loss: 0.1010, train_mae: 0.3488, val_mae: 0.2690
Epoch [9/1000] - train_loss: 60.0314, val_loss: 0.1007, train_mae: 0.3261, val_mae: 0.2685
Epoch [10/1000] - train_loss: 55.9664, val_loss: 0.1005, train_mae: 0.2915, val_mae: 0.2679
Epoch [11/1000] - train_loss: 52.0798, val_loss: 0.1002, train_mae: 0.3074, val_mae: 0.26

In [13]:
def make_nn(n_neurons, h, w):
    net = torch.nn.Sequential(torch.nn.Flatten(), 
                              torch.nn.Linear(3*h*w, n_neurons),     torch.nn.ReLU(), 
                              torch.nn.Linear(n_neurons, n_neurons), torch.nn.ReLU(), 
                              torch.nn.Linear(n_neurons, n_neurons), torch.nn.ReLU(), 
                              torch.nn.Linear(n_neurons, n_neurons), torch.nn.ReLU(),
                              torch.nn.Linear(n_neurons, n_neurons), torch.nn.ReLU(),
                              torch.nn.Linear(n_neurons, n_neurons), torch.nn.ReLU(),
                              torch.nn.Linear(n_neurons, 1))
    return net

In [None]:
n_neurons = 50
h = 2
w = 2
timeout = 10
patience = 10
num_epochs = 1000
l1_weights = [1e-6, 1e-5, 1e-4, 1e-3]
names = [f"net6x50_{l1_weight}l1" for l1_weight in l1_weights]

for netname, l1_weight in zip(names, l1_weights):
    print(f"\n######## {netname} ########\n")
    net = make_nn(n_neurons, h, w)
    train_losses, val_losses, train_maes, val_maes, val_maxs, best_state = train_loop(net, train_dataloader, val_dataloader, 
                                                                            timeout=timeout, patience=patience, num_epochs=num_epochs, 
                                                                            l1_weight=l1_weight)
    torch.save(best_state, f'./l1_experiments/{netname}_best_state.pth')
    torch.save(train_losses, f'./l1_experiments/{netname}_train_losses.pth')
    torch.save(val_losses, f'./l1_experiments/{netname}_val_losses.pth')
    torch.save(train_maes, f'./l1_experiments/{netname}_train_maes.pth')
    torch.save(val_maes, f'./l1_experiments/{netname}_val_maes.pth')
    torch.save(val_maxs, f'./l1_experiments/{netname}_val_maxs.pth')


######## net6x50_1e-06l1 ########

Epoch [1/1000] - train_loss: 0.1511, val_loss: 0.1093, train_mae: 0.3082, val_mae: 0.2794, val_max: 0.6718
Epoch [2/1000] - train_loss: 0.1876, val_loss: 0.1024, train_mae: 0.3587, val_mae: 0.2712, val_max: 0.6303
Epoch [3/1000] - train_loss: 0.1688, val_loss: 0.0989, train_mae: 0.3408, val_mae: 0.2661, val_max: 0.5880
Epoch [4/1000] - train_loss: 0.1158, val_loss: 0.0994, train_mae: 0.2714, val_mae: 0.2662, val_max: 0.5429
Epoch [5/1000] - train_loss: 0.1003, val_loss: 0.1049, train_mae: 0.2395, val_mae: 0.2724, val_max: 0.5560
Epoch [6/1000] - train_loss: 0.1567, val_loss: 0.1128, train_mae: 0.3005, val_mae: 0.2806, val_max: 0.5970
Epoch [7/1000] - train_loss: 0.1295, val_loss: 0.1181, train_mae: 0.2807, val_mae: 0.2886, val_max: 0.6192
Epoch [8/1000] - train_loss: 0.1415, val_loss: 0.1248, train_mae: 0.3105, val_mae: 0.2977, val_max: 0.6418
Epoch [9/1000] - train_loss: 0.1007, val_loss: 0.1220, train_mae: 0.2524, val_mae: 0.2942, val_max: 0.6339
E

In [None]:
#torch.save(train_losses, './train_losses.pth')

# Export to ONNX

In [52]:
dummy_input = torch.randn(1, 3, 2, 2)

models = [net]
onnx_files = ['net6x50_overfit.onnx']

for model, filename in zip(models, onnx_files):
    torch.onnx.export(model, dummy_input, './verification/models/' + filename, export_params=True, do_constant_folding=True, opset_version=7, input_names=['X'], output_names=['Y'])