In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import pdb
import scipy
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import time
import cvxpy as cp
import pickle
import torch.functional as F

from lib.dataset_utils import *
from lib.solvers import solve_feas

In [3]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [4]:
np.random.seed(1)
torch.manual_seed(1)

(x_train, y_train), (x_valid, y_valid), (x_test, y_test) = load_mnist_all(
    data_dir='/data', val_size=0.1, shuffle=True, seed=1)

In [6]:
params = {
    'dim': 784,
    'device': 'cpu',
    'max_proj_iters': 1000,
    'tol': 1e-5,
    'tol_abs': 1e-5,
    'tol_rel': 1e-3,
    'early_stop': True,
    'rho': 1,
#     'clip': (0, 1),
    'clip': None,
    'dtype': torch.float32,
    'step_size': 5e-4,
    'num_partitions': 3,
    'upperbound': np.inf,
    'check_obj_steps': 100000
}

In [7]:
def get_polytope(points, idx_intr_point, params):
    """
    Compute polytope Ax <= b of the Voronoi cell of 
    points[idx_intr_point].
    """
    A = torch.zeros((points.size(0), points.size(1)), 
                    dtype=params['dtype'])
    b = torch.zeros(A.size(0), dtype=params['dtype'])
    # for some reason, processing in a loop rather than matrix form
    # gives more precise solutions.
    for i in range(A.size(0)):
        A[i] = F.normalize(points[i] - points[idx_intr_point], 2, 0)
        b[i] = (A[i] @ (points[i] + points[idx_intr_point])) / 2
    A = torch.cat([A[:idx_intr_point], A[idx_intr_point + 1:]], dim=0)
    b = torch.cat([b[:idx_intr_point], b[idx_intr_point + 1:]], dim=0)
    return A.to(params['device']), b.to(params['device'])

In [8]:
np.random.seed(1)
torch.manual_seed(1)

# num_constraints = 10000
# A = torch.zeros((num_constraints, 784), dtype=params['dtype']).normal_()
# A = F.normalize(A, 2, 1)
# A = A.to(params['device'])
# b = torch.ones(num_constraints, device=params['device'], dtype=params['dtype'])
# x_hat = torch.ones(784, device=params['device'], dtype=params['dtype'])

# box = torch.eye(784, device=params['device'], dtype=params['dtype'])
# A = torch.cat([box, - box], dim=0)
# x_hat = torch.ones(784, device=params['device'], dtype=params['dtype']) + 1
# b = torch.cat([torch.ones_like(x_hat), torch.zeros_like(x_hat)], dim=0)

A, b = get_polytope(x_train.view(-1, params['dim']), 0, params)
x_hat = torch.ones(params['dim'], 
                   device=params['device'], 
                   dtype=params['dtype'])
x_hat.uniform_()
# x_hat = x_train[1].view(params['dim'])
assert np.all((A @ x_train.view(-1, params['dim'])[0].to(params['device']) 
               <= b).cpu().numpy())

In [9]:
# box = torch.eye(784, device=params['device'], dtype=params['dtype'])
# A = torch.cat([A, box, - box], dim=0)
# b = torch.cat([b, torch.ones_like(x_hat), torch.zeros_like(x_hat)], dim=0)

In [9]:
AAT = A @ A.T
b_hat = A @ x_hat - b

In [81]:
((A @ x_hat - b) > 0).sum()

tensor(42651, device='cuda:0')

In [88]:
# params['method'] = 'dual_ascent'
# params['method'] = 'gca'
# params['method'] = 'parallel_gca'
params['method'] = 'dykstra'
# params['method'] = 'cvx'

from lib.projection import *
start = time.time()
# x = proj_polytope(x_hat, A, AAT, b, params)
x = proj_polytope(x_hat, A, None, b, params)
print(time.time() - start)

7.232408761978149


In [89]:
res = A @ x - b
print(res[res > 0])

tensor([2.3281e-02, 7.3590e-03, 6.5850e-03, 9.6170e-02, 1.5340e-01, 2.6087e-01,
        1.8810e-02, 5.1887e-02, 9.9162e-02, 1.6780e-01, 4.5843e-02, 4.9195e-02,
        3.0815e-02, 1.1751e-02, 2.5996e-03, 1.0693e-03, 3.6686e-02, 1.0081e-01,
        5.0745e-02, 1.2917e-01, 6.1694e-02, 8.2414e-03, 3.8188e-02, 2.1438e-02,
        1.6334e-01, 5.3193e-02, 1.0666e-01, 1.9802e-02, 2.2411e-01, 8.5125e-03,
        1.2500e-02, 7.9642e-02, 1.2558e-01, 1.8706e-02, 3.4617e-02, 4.2405e-02,
        6.9020e-02, 3.2228e-02, 5.9142e-03, 1.7935e-02, 3.1402e-03, 5.5309e-02,
        3.5913e-02, 3.4989e-02, 8.9275e-02, 1.2681e-01, 2.6863e-01, 1.6228e-02,
        8.0599e-02, 9.6752e-02, 5.4533e-02, 1.2133e-01, 1.9626e-01, 4.7637e-02,
        4.1484e-02, 6.5536e-02, 3.9083e-02, 4.3780e-02, 1.1664e-01, 1.0410e-02,
        4.4007e-02, 6.0897e-03, 6.7472e-03, 2.0986e-02, 2.3079e-02, 3.3366e-01,
        4.9620e-02, 3.0160e-03, 3.5011e-01, 3.3004e-03, 6.1066e-02, 5.4264e-02,
        1.3067e-02, 1.5809e-01, 9.1283e-

In [90]:
((x - x_hat) ** 2).sum()
# ((x - x_hat.numpy()) ** 2).sum()

tensor(1.6708, device='cuda:0')

In [13]:
start = time.time()
out = solve_feas(torch.cat([A, b.unsqueeze(1)], dim=1).cpu().numpy(), 2)
print(out)
print(time.time() - start)

8.259012460708618
True
8.473605871200562


## Runtime on full MNIST (x_hat: uniform random)
- dual_ascent
  - cpu: way too long (step_size 1e-1, tol 1e-5)
  - cuda: s (step_size 1e-1, tol 1e-5)
- gca
  - cpu: 0.5-1s (tol 1e-5)
  - cuda: s (tol 1e-5)
  - very fast if direct projection to one facet is the solution.
- parallel_gca
  - cpu: s (num_partitions 10, tol 1e-5).
  - cuda: s (tol 1e-5).
- admm
  - cpu: 
  - cuda:
- cvx (MOSEK):  
- cvx (default): 

- Solving Farkas takes too long.