Skip to content

Commit

Permalink
Added fix for mldata.org bug
Browse files Browse the repository at this point in the history
Added a workaround to download MNIST data since mldata.org keeps going down (scikit-learn/scikit-learn#8588)
  • Loading branch information
soorya19 committed Jul 16, 2018
1 parent 4ad80d7 commit d763d6e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 3 deletions.
21 changes: 21 additions & 0 deletions linear_svm/mnist_workaround.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
A workaround to download MNIST data since mldata.org appears to be unstable
Taken from https://github.com/scikit-learn/scikit-learn/issues/8588#issuecomment-292634781
"""

from shutil import copyfileobj
from six.moves import urllib
from sklearn.datasets.base import get_data_home
import os

def fetch_mnist():
mnist_alternative_url = "https://github.com/amplab/datascience-sp14/raw/master/lab7/mldata/mnist-original.mat"
data_home = get_data_home()
data_home = os.path.join(data_home, 'mldata')
if not os.path.exists(data_home):
os.makedirs(data_home)
mnist_save_path = os.path.join(data_home, "mnist-original.mat")
if not os.path.exists(mnist_save_path):
mnist_url = urllib.request.urlopen(mnist_alternative_url)
with open(mnist_save_path, "wb") as matlab_file:
copyfileobj(mnist_url, matlab_file)
4 changes: 1 addition & 3 deletions linear_svm/sp_func_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ def sp_frontend(images, rho=0.02, wavelet='bior4.4', mode='periodization', max_l
:param wavelet: Wavelet to use in the transform. See https://pywavelets.readthedocs.io/ for more details.
:param mode: Signal extension mode. See https://pywavelets.readthedocs.io/ for more details.
:param max_lev: Maximum allowed level of decomposition.
"""
# Input is assumed to be in the range [-1, 1] and of shape [num_samples, 784]
# Projects input onto
"""
num_samples = images.shape[0]
num_features = images.shape[1]
num_features_per_dim = np.int(np.sqrt(num_features))
Expand Down
2 changes: 2 additions & 0 deletions linear_svm/test_defense_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
import pywt
from sklearn import datasets, utils, model_selection, metrics, svm
from sp_func_svm import sp_project, sp_frontend
from mnist_workaround import fetch_mnist

epsilon = 0.25 # L-infinity attack budget. Images are assumed to be in the range [-1, 1].
rho = 0.02 # Sparsity level used in the defense, in the range [0, 1].
digit_1 = 3
digit_2 = 7

fetch_mnist()
mnist = datasets.fetch_mldata("MNIST original")
digit_1_data = 2.0*mnist.data[mnist.target==digit_1]/255.0 - 1.0
digit_2_data = 2.0*mnist.data[mnist.target==digit_2]/255.0 - 1.0
Expand Down

0 comments on commit d763d6e

Please sign in to comment.