Skip to content
Permalink
Browse files

Remove misleading datasets folder in examples (#545)

Fix #423 by removing the `examples/datasets` folder.
This is done by using sklearn's `fetch_openml` and using the
default download path (caching will be done that way).

Also fixes a typo in the result.
  • Loading branch information...
ottonemo authored and BenjaminBossan committed Oct 13, 2019
1 parent 161f28d commit 79b1c70fbf0b58e8f615f06e02c1bad81a0dbfe7
Showing with 5 additions and 6 deletions.
  1. +5 −6 examples/benchmarks/mnist.py
  2. 0 examples/datasets/.gitkeep
@@ -29,7 +29,7 @@
import time

import numpy as np
from sklearn.datasets import fetch_mldata
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
@@ -41,13 +41,12 @@


BATCH_SIZE = 128
DATA_HOME = os.path.join(os.getcwd(), 'examples', 'datasets')
LEARNING_RATE = 0.1
MAX_EPOCHS = 12


def get_data(num_samples, data_home):
mnist = fetch_mldata('MNIST original', data_home=data_home)
def get_data(num_samples):
mnist = fetch_openml('mnist_784')
torch.manual_seed(0)
X = mnist.data.astype('float32').reshape(-1, 1, 28, 28)
y = mnist.target.astype('int64')
@@ -271,7 +270,7 @@ def performance_torch(


def main(device, num_samples):
data = get_data(num_samples, DATA_HOME)
data = get_data(num_samples)
# trigger potential cuda call overhead
torch.zeros(1).to(device)

@@ -299,7 +298,7 @@ def main(device, num_samples):
)
time_torch = time.time() - tic

print("time skorch: {:.4f}, score torch: {:.4f}".format(
print("time skorch: {:.4f}, time torch: {:.4f}".format(
time_skorch, time_torch))
print("score skorch: {:.4f}, score torch: {:.4f}".format(
score_skorch, score_torch))
No changes.

0 comments on commit 79b1c70

Please sign in to comment.
You can’t perform that action at this time.