Skip to content

Commit

Permalink
Add permissions and API for making SQL queries
Browse files Browse the repository at this point in the history
JIRA: RHELWF-9014, RHELWF-9015
  • Loading branch information
hluk committed Sep 11, 2023
1 parent 4a3558f commit c335a65
Show file tree
Hide file tree
Showing 11 changed files with 484 additions and 63 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/gating.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
---
name: Gating

on:
"on":
pull_request:
push:
workflow_dispatch:
Expand Down Expand Up @@ -32,6 +33,9 @@ jobs:
- name: Test with tox
run: tox -e py

- name: Test with mypy
run: tox -e mypy

- name: Run coveralls-python
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down
24 changes: 17 additions & 7 deletions product_listings_manager/authorization.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
# SPDX-License-Identifier: GPL-2.0+
import logging
from collections.abc import Generator
from dataclasses import dataclass

import ldap
from werkzeug.exceptions import BadGateway

log = logging.getLogger(__name__)


def get_group_membership(user, ldap_connection, ldap_search):
@dataclass
class LdapConfig:
host: str
searches: list[dict[str, str]]


def get_group_membership(
user: str, ldap_connection, ldap_search: dict[str, str]
) -> list[str]:
results = ldap_connection.search_s(
ldap_search["BASE"],
ldap.SCOPE_SUBTREE,
Expand All @@ -17,13 +27,13 @@ def get_group_membership(user, ldap_connection, ldap_search):
return [group[1]["cn"][0].decode("utf-8") for group in results]


def get_user_groups(user, ldap_host, ldap_searches):
def get_user_groups(
user: str, ldap_config: LdapConfig
) -> Generator[str, None, None]:
try:
ldap_connection = ldap.initialize(ldap_host)
for cur_ldap_search in ldap_searches:
yield from get_group_membership(
user, ldap_connection, cur_ldap_search
)
ldap_connection = ldap.initialize(ldap_config.host)
for ldap_search in ldap_config.searches:
yield from get_group_membership(user, ldap_connection, ldap_search)
except ldap.SERVER_DOWN:
log.exception("The LDAP server is unreachable")
raise BadGateway("The LDAP server is unreachable")
Expand Down
17 changes: 15 additions & 2 deletions product_listings_manager/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
"""PLM configuration module."""
import json
import os
from typing import Any


def read_json_file(filename: str):
if not filename:
return None

with open(filename) as f:
return json.load(f)


ENV_TO_CONFIG = (
("SQLALCHEMY_DATABASE_URI", lambda x: x),
("PLM_LDAP_HOST", lambda x: x),
("PLM_LDAP_SEARCHES", json.loads),
("PLM_PERMISSIONS", lambda x: read_json_file(x) or []),
)


Expand All @@ -16,8 +27,10 @@ class Config:
SQLALCHEMY_ECHO = False
SQLALCHEMY_DATABASE_URI = "postgresql://user:pass@localhost/compose"

LDAP_HOST = None
LDAP_SEARCHES = []
LDAP_HOST: str = ""
LDAP_SEARCHES: list[dict[str, str]] = []

PERMISSIONS: list[dict[str, Any]] = []


class DevConfig(Config):
Expand Down
60 changes: 60 additions & 0 deletions product_listings_manager/db_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# SPDX-License-Identifier: GPL-2.0+
"""Helper function for managing raw SQL queries"""
from typing import Any

from flask import current_app
from sqlalchemy.exc import ResourceClosedError, SQLAlchemyError
from werkzeug.exceptions import BadRequest

from product_listings_manager.models import db

DB_QUERY_PARAM_ERROR = (
'Parameter must have the following format ("params" is optional): '
'[{"query": QUERY_STRING, "params": {PARAMETER_NAME: PARAMETER_STRING_VALUE}...]'
)


def queries_from_user_input(input: Any) -> list[Any]:
if isinstance(input, dict):
return [input]

if isinstance(input, str):
return [{"query": input}]

if isinstance(input, list):
return [q if isinstance(q, dict) else {"query": q} for q in input]

raise BadRequest(DB_QUERY_PARAM_ERROR)


def validate_queries(queries: list[dict[Any, Any]]):
if not queries:
raise BadRequest(DB_QUERY_PARAM_ERROR)

for query in queries:
query_text = query.get("query")
if not query_text or not isinstance(query_text, str):
raise BadRequest(DB_QUERY_PARAM_ERROR)

params = query.get("params")
if params is not None and not isinstance(params, dict):
raise BadRequest(DB_QUERY_PARAM_ERROR)


def execute_queries(queries: list[dict[str, str]]) -> list[list[Any]]:
with db.session.begin():
for query in queries:
query_text = query.get("query")
params = query.get("params")
try:
result = db.session.execute(db.text(query_text), params=params)
except SQLAlchemyError as e:
current_app.logger.warning("Failed DB query for user %s", e)
raise BadRequest(f"DB query failed: {e}")

db.session.commit()

try:
return [list(row) for row in result]
except ResourceClosedError:
return []
17 changes: 10 additions & 7 deletions product_listings_manager/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
in following definition and it has no side effect to postgresql(composedb).
"""
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.ext.declarative import DeclarativeMeta

db = SQLAlchemy()

BaseModel: DeclarativeMeta = db.Model

class Packages(db.Model):

class Packages(BaseModel):
"""packages table in composedb.
Only needed columns in packages table are defined here.
Expand Down Expand Up @@ -47,7 +50,7 @@ def __repr__(self):
)


class Products(db.Model):
class Products(BaseModel):
"""products table in composedb."""

id = db.Column(db.Integer, primary_key=True)
Expand Down Expand Up @@ -96,7 +99,7 @@ def __repr__(self):
)


class Trees(db.Model):
class Trees(BaseModel):
"""trees table in composedb."""

id = db.Column(db.Integer, primary_key=True)
Expand All @@ -121,7 +124,7 @@ def __repr__(self):
)


class Overrides(db.Model):
class Overrides(BaseModel):
"""overrides table in composedb.
Many columns are set as primary key becasue primary key is required
Expand All @@ -148,7 +151,7 @@ def __repr__(self):
)


class MatchVersions(db.Model):
class MatchVersions(BaseModel):
"""match_versions table in composedb."""

name = db.Column(db.String(255), primary_key=True)
Expand All @@ -160,7 +163,7 @@ def __repr__(self):
)


class Modules(db.Model):
class Modules(BaseModel):
"""modules table in composedb."""

id = db.Column(db.Integer, primary_key=True)
Expand All @@ -177,7 +180,7 @@ def __repr__(self):
)


class ModuleOverrides(db.Model):
class ModuleOverrides(BaseModel):
"""module_overrides table in composedb."""

name = db.Column(db.String(255), primary_key=True)
Expand Down
47 changes: 47 additions & 0 deletions product_listings_manager/permissions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# SPDX-License-Identifier: GPL-2.0+
from fnmatch import fnmatch
from typing import Any

from product_listings_manager.authorization import LdapConfig, get_user_groups


def query_matches(query: str, permission: dict[str, Any]) -> bool:
return any(
fnmatch(query.upper(), pattern.upper())
for pattern in permission.get("queries", [])
)


def has_permission(
user: str,
queries: list[dict[str, str]],
permissions: list[dict[str, Any]],
ldap_config: LdapConfig,
) -> bool:
qs = list(q["query"] for q in queries)
qs = [
q
for q in qs
if not any(
user in p.get("users", set()) and query_matches(q, p)
for p in permissions
)
]
if not qs:
return True

# Avoid querying LDAP unnecessarily
if not any(p.get("groups") for p in permissions):
return False

groups = set(get_user_groups(user, ldap_config))
qs = [
q
for q in qs
if not any(
not groups.isdisjoint(p.get("groups", set()))
and query_matches(q, p)
for p in permissions
)
]
return not qs
27 changes: 9 additions & 18 deletions product_listings_manager/products.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Products:
re.compile(r"^U\d+(-beta)?$", re.I),
]

@staticmethod
def score(release):
map = Products.all_release_types
i = len(map) - 1
Expand All @@ -70,8 +71,7 @@ def score(release):
i = i - 1
return i

score = staticmethod(score)

@staticmethod
def my_sort(x, y):
if len(x) > len(y) and y == x[: len(y)]:
return -1
Expand All @@ -84,8 +84,7 @@ def my_sort(x, y):
else:
return _cmp(x_score, y_score)

my_sort = staticmethod(my_sort)

@staticmethod
def get_product_info(label):
"""Get the latest version of product and it's variants."""
products = models.Products.query.filter_by(label=label).all()
Expand All @@ -105,8 +104,7 @@ def get_product_info(label):
[x.variant for x in products if x.version == versions[0]],
)

get_product_info = staticmethod(get_product_info)

@staticmethod
def get_overrides(product, version, variant=None):
"""
Returns the list of package overrides for the particular product specified.
Expand Down Expand Up @@ -134,8 +132,7 @@ def get_overrides(product, version, variant=None):
)
return overrides

get_overrides = staticmethod(get_overrides)

@staticmethod
def get_match_versions(product):
"""
Returns the list of packages for this product where we must match the version.
Expand All @@ -147,8 +144,7 @@ def get_match_versions(product):
).all()
]

get_match_versions = staticmethod(get_match_versions)

@staticmethod
def get_srconly_flag(product, version):
"""
BREW-260 - Returns allow_source_only field for the product and matching version.
Expand All @@ -158,8 +154,7 @@ def get_srconly_flag(product, version):
)
return models.db.session.query(q.exists()).scalar()

get_srconly_flag = staticmethod(get_srconly_flag)

@staticmethod
def precalc_treelist(product, version, variant=None):
"""Returns the list of trees to consider.
Expand Down Expand Up @@ -190,8 +185,7 @@ def precalc_treelist(product, version, variant=None):
trees[arch] = id
return list(trees.values()) + list(compat_trees.values())

precalc_treelist = staticmethod(precalc_treelist)

@staticmethod
def dest_get_archs(
trees, src_arch, names, cache_entry, version=None, overrides=None
):
Expand Down Expand Up @@ -237,8 +231,7 @@ def dest_get_archs(
del ret[name][tree_arch]
return ret

dest_get_archs = staticmethod(dest_get_archs)

@staticmethod
def get_module_overrides(
product, version, module_name, module_stream, variant=None
):
Expand All @@ -257,8 +250,6 @@ def get_module_overrides(

return [row.product_arch for row in query.all()]

get_module_overrides = staticmethod(get_module_overrides)

@staticmethod
def get_product_labels():
rows = (
Expand Down
Loading

0 comments on commit c335a65

Please sign in to comment.