Skip to content

Commit

Permalink
Merge ec1ed40 into c87f1b2
Browse files Browse the repository at this point in the history
  • Loading branch information
dianaclarke committed Aug 16, 2021
2 parents c87f1b2 + ec1ed40 commit 1bcedd8
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 12 deletions.
1 change: 1 addition & 0 deletions .flaskenv
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ FLASK_APP=conbench
FLASK_ENV=development
REGISTRATION_KEY=code
SECRET_KEY="Person, woman, man, camera, TV"
BENCHMARKS_DATA_PUBLIC=true
19 changes: 19 additions & 0 deletions conbench/api/_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,26 @@
import functools
import os

import flask as f
import flask_login
import flask.views


def as_bool(x):
return x.lower() in ["yes", "y", "1", "on", "true"]


def maybe_login_required(func):
@functools.wraps(func)
def maybe(*args, **kwargs):
public = as_bool(os.getenv("BENCHMARKS_DATA_PUBLIC", "yes"))
if not public:
return flask_login.login_required(func)(*args, **kwargs)
return func(*args, **kwargs)

return maybe


class ApiEndpoint(flask.views.MethodView):
def validate(self, schema):
data = f.request.get_json(silent=True)
Expand Down
4 changes: 3 additions & 1 deletion conbench/api/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ..api import rule
from ..api._docs import spec
from ..api._endpoint import ApiEndpoint
from ..api._endpoint import ApiEndpoint, maybe_login_required
from ..entities._entity import NotFound
from ..entities.case import Case
from ..entities.distribution import set_z_scores
Expand All @@ -26,6 +26,7 @@ def _get(self, benchmark_id):
self.abort_404_not_found()
return summary

@maybe_login_required
def get(self, benchmark_id):
"""
---
Expand Down Expand Up @@ -72,6 +73,7 @@ class BenchmarkListAPI(ApiEndpoint, BenchmarkValidationMixin):
serializer = SummarySerializer()
schema = BenchmarkFacadeSchema()

@maybe_login_required
def get(self):
"""
---
Expand Down
4 changes: 3 additions & 1 deletion conbench/api/commits.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from ..api import rule
from ..api._endpoint import ApiEndpoint
from ..api._endpoint import ApiEndpoint, maybe_login_required
from ..entities._entity import NotFound
from ..entities.commit import Commit, CommitSerializer


class CommitListAPI(ApiEndpoint):
serializer = CommitSerializer()

@maybe_login_required
def get(self):
"""
---
Expand All @@ -31,6 +32,7 @@ def _get(self, commit_id):
self.abort_404_not_found()
return commit

@maybe_login_required
def get(self, commit_id):
"""
---
Expand Down
4 changes: 3 additions & 1 deletion conbench/api/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


from ..api import rule
from ..api._endpoint import ApiEndpoint
from ..api._endpoint import ApiEndpoint, maybe_login_required
from ..entities._comparator import BenchmarkComparator, BenchmarkListComparator
from ..entities._entity import NotFound
from ..entities.distribution import set_z_scores
Expand Down Expand Up @@ -34,6 +34,7 @@ def _get(self, benchmark_id):
set_z_scores([summary])
return summary

@maybe_login_required
def get(self, compare_ids):
"""
---
Expand Down Expand Up @@ -112,6 +113,7 @@ def _get(self, batch_id):
set_z_scores(summaries)
return summaries

@maybe_login_required
def get(self, compare_ids):
"""
---
Expand Down
4 changes: 3 additions & 1 deletion conbench/api/contexts.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from ..api import rule
from ..api._endpoint import ApiEndpoint
from ..api._endpoint import ApiEndpoint, maybe_login_required
from ..entities._entity import NotFound
from ..entities.context import Context, ContextSerializer


class ContextListAPI(ApiEndpoint):
serializer = ContextSerializer()

@maybe_login_required
def get(self):
"""
---
Expand All @@ -31,6 +32,7 @@ def _get(self, context_id):
self.abort_404_not_found()
return context

@maybe_login_required
def get(self, context_id):
"""
---
Expand Down
3 changes: 2 additions & 1 deletion conbench/api/history.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ..api import rule
from ..api._endpoint import ApiEndpoint
from ..api._endpoint import ApiEndpoint, maybe_login_required
from ..entities._entity import NotFound
from ..entities.history import get_history, HistorySerializer
from ..entities.summary import Summary
Expand All @@ -19,6 +19,7 @@ def _get(self, benchmark_id):
summary.run.machine.hash,
)

@maybe_login_required
def get(self, benchmark_id):
"""
---
Expand Down
4 changes: 3 additions & 1 deletion conbench/api/machines.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from ..api import rule
from ..api._endpoint import ApiEndpoint
from ..api._endpoint import ApiEndpoint, maybe_login_required
from ..entities._entity import NotFound
from ..entities.machine import Machine, MachineSerializer


class MachineListAPI(ApiEndpoint):
serializer = MachineSerializer()

@maybe_login_required
def get(self):
"""
---
Expand All @@ -31,6 +32,7 @@ def _get(self, machine_id):
self.abort_404_not_found()
return machine

@maybe_login_required
def get(self, machine_id):
"""
---
Expand Down
4 changes: 3 additions & 1 deletion conbench/api/runs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ..api import rule
from ..api._endpoint import ApiEndpoint
from ..api._endpoint import ApiEndpoint, maybe_login_required
from ..entities._entity import NotFound
from ..entities.run import Run, RunSerializer

Expand All @@ -14,6 +14,7 @@ def _get(self, run_id):
self.abort_404_not_found()
return run

@maybe_login_required
def get(self, run_id):
"""
---
Expand All @@ -37,6 +38,7 @@ def get(self, run_id):
class RunListAPI(ApiEndpoint):
serializer = RunSerializer()

@maybe_login_required
def get(self):
"""
---
Expand Down
20 changes: 15 additions & 5 deletions conbench/tests/api/_asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,13 @@ def test_unauthenticated(self, client):


class ListEnforcer(Enforcer):
def test_unauthenticated(self, client):
def test_unauthenticated(self, client, monkeypatch):
if getattr(self, "public", False):
monkeypatch.setenv("BENCHMARKS_DATA_PUBLIC", "off")
response = client.get(self.url)
self.assert_401_unauthorized(response)

monkeypatch.setenv("BENCHMARKS_DATA_PUBLIC", "on")
response = client.get(self.url)
self.assert_200_ok(response)
else:
Expand All @@ -87,12 +92,17 @@ def test_unauthenticated(self, client):


class GetEnforcer(Enforcer):
def test_unauthenticated(self, client):
def test_unauthenticated(self, client, monkeypatch):
if getattr(self, "public", False):
entity = self._create()
response = client.get(self.url.format(entity.id))
# TODO: compare _expected_entity too
# self.assert_200_ok(response, _expected_entity(entity))
entity_url = self.url.format(entity.id)

monkeypatch.setenv("BENCHMARKS_DATA_PUBLIC", "off")
response = client.get(entity_url)
self.assert_401_unauthorized(response)

monkeypatch.setenv("BENCHMARKS_DATA_PUBLIC", "on")
response = client.get(entity_url)
self.assert_200_ok(response)
else:
response = client.get(self.url.format("id"))
Expand Down

0 comments on commit 1bcedd8

Please sign in to comment.