Skip to content

Commit

Permalink
speed up pytest (#484)
Browse files Browse the repository at this point in the history
  • Loading branch information
diego-escobedo committed Jan 20, 2023
1 parent 7eec1ff commit 158035f
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 47 deletions.
1 change: 1 addition & 0 deletions .github/workflows/django-postgres.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ jobs:
STRIPE_SECRET_KEY: ${{ secrets.STRIPE_SECRET_KEY }}
DEBUG: False
KAFKA_URL: "localhost:9092"
PYTHONDONTWRITEBYTECODE: 1

steps:
- uses: actions/checkout@v3
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/postman_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ jobs:
STRIPE_SECRET_KEY: ${{ secrets.STRIPE_SECRET_KEY }}
DEBUG: False
KAFKA_URL: "localhost:9092"
PYTHONDONTWRITEBYTECODE: 1

steps:
- uses: actions/checkout@v3
Expand Down
44 changes: 22 additions & 22 deletions backend/api/serializers/model_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Literal, Optional, Union

from django.conf import settings
from django.db.models import DecimalField, Q, Sum
from django.db.models import Sum
from metering_billing.invoice import generate_balance_adjustment_invoice
from metering_billing.models import (
CategoricalFilter,
Expand Down Expand Up @@ -543,36 +543,35 @@ def get_integrations(self, obj) -> CustomerIntegrationsSerializer:
return d

def get_subscriptions(self, obj) -> SubscriptionRecordSerializer(many=True):
sr_objs = obj.subscription_records.active().filter(
organization=self.context.get("organization"),
start_date__lte=now_utc(),
end_date__gte=now_utc(),
)
try:
sr_objs = obj.active_subscription_records
except AttributeError:
sr_objs = (
obj.subscription_records.active()
.filter(organization=obj.organization)
.order_by("start_date")
)
return SubscriptionRecordSerializer(sr_objs, many=True).data

def get_invoices(self, obj) -> LightweightInvoiceSerializer(many=True):
timeline = (
obj.invoices.filter(
~Q(payment_status=Invoice.PaymentStatus.DRAFT),
organization=self.context.get("organization"),
)
.order_by("-issue_date")
.prefetch_related("currency", "line_items", "subscription")
)
try:
timeline = obj.active_invoices
except AttributeError:
timeline = obj.invoices.filter(
payment_status__in=[
Invoice.PaymentStatus.PAID,
Invoice.PaymentStatus.UNPAID,
],
organization=obj.organization,
).order_by("-issue_date")
timeline = LightweightInvoiceSerializer(timeline, many=True).data
return timeline

def get_total_amount_due(self, obj) -> Decimal:
try:
return obj.total_amount_due
except AttributeError:
return (
obj.invoices.filter(payment_status=Invoice.PaymentStatus.UNPAID)
.aggregate(
unpaid_inv_amount=Sum("cost_due", output_field=DecimalField())
)
.get("unpaid_inv_amount")
)
return Decimal(0)


class CustomerCreateSerializer(
Expand Down Expand Up @@ -1234,7 +1233,7 @@ def get_plans(self, obj) -> LightweightSubscriptionRecordSerializer(many=True):
class SubscriptionInvoiceSerializer(SubscriptionRecordSerializer):
class Meta(SubscriptionRecordSerializer.Meta):
model = SubscriptionRecord
fields = fields = tuple(
fields = tuple(
set(SubscriptionRecordSerializer.Meta.fields)
- set(
["customer_id", "plan_id", "billing_plan", "auto_renew", "invoice_pdf"]
Expand Down Expand Up @@ -1635,3 +1634,4 @@ class Meta:
plan_version = LightweightPlanVersionSerializer()
metric = MetricSerializer()
plan_version = LightweightPlanVersionSerializer()
plan_version = LightweightPlanVersionSerializer()
39 changes: 33 additions & 6 deletions backend/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,38 @@ class CustomerViewSet(PermissionPolicyMixin, viewsets.ModelViewSet):
queryset = Customer.objects.all()

def get_queryset(self):
now = now_utc()
organization = self.request.organization
qs = Customer.objects.filter(organization=organization).prefetch_related(
"subscription_records",
"invoices",
"default_currency",
qs = Customer.objects.filter(organization=organization)
qs = qs.select_related("default_currency")
qs = qs.prefetch_related(
Prefetch(
"subscriptions",
queryset=SubscriptionRecord.objects.active(now)
.filter(
organization=organization,
)
.select_related("customer", "billing_plan")
.prefetch_related("filters"),
to_attr="active_subscription_records",
),
Prefetch(
"invoices",
queryset=Invoice.objects.filter(
organization=organization,
payment_status__in=[
Invoice.PaymentStatus.UNPAID,
Invoice.PaymentStatus.PAID,
],
)
.order_by("-issue_date")
.select_related("currency", "subscription", "organization")
.prefetch_related("line_items"),
to_attr="active_invoices",
),
)
qs = qs.annotate(
unpaid_inv_amount=Sum(
total_amount_due=Sum(
"invoices__cost_due",
filter=Q(invoices__payment_status=Invoice.PaymentStatus.UNPAID),
output_field=DecimalField(),
Expand Down Expand Up @@ -829,7 +853,6 @@ def dispatch(self, request, *args, **kwargs):


class InvoiceViewSet(PermissionPolicyMixin, viewsets.ModelViewSet):

serializer_class = InvoiceSerializer
http_method_names = ["get", "patch", "head"]
lookup_field = "invoice_id"
Expand Down Expand Up @@ -1494,3 +1517,7 @@ def track_event(request):
else:
return JsonResponse({"success": "all"}, status=status.HTTP_201_CREATED)
return JsonResponse({"success": "all"}, status=status.HTTP_201_CREATED)
return JsonResponse({"success": "all"}, status=status.HTTP_201_CREATED)
return JsonResponse({"success": "all"}, status=status.HTTP_201_CREATED)
return JsonResponse({"success": "all"}, status=status.HTTP_201_CREATED)
return JsonResponse({"success": "all"}, status=status.HTTP_201_CREATED)
6 changes: 4 additions & 2 deletions backend/metering_billing/aggregation/billable_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,8 +630,9 @@ def create_continuous_aggregate(metric: Metric, refresh=False):
cursor.execute(day_drop_query)
cursor.execute(day_query)
cursor.execute(day_refresh_query)
# SECOND QUERY SECOND
if metric.usage_aggregation_type != METRIC_AGGREGATION.UNIQUE:
if metric.usage_aggregation_type != METRIC_AGGREGATION.UNIQUE:
with connection.cursor() as cursor:
# SECOND QUERY SECOND
if refresh is True:
cursor.execute(second_drop_query)
cursor.execute(second_query)
Expand Down Expand Up @@ -660,6 +661,7 @@ def archive_metric(metric: Metric) -> Metric:
second_drop_query = Template(CAGG_DROP).render(**sql_injection_data)
with connection.cursor() as cursor:
cursor.execute(day_drop_query)
with connection.cursor() as cursor:
cursor.execute(second_drop_query)


Expand Down
4 changes: 1 addition & 3 deletions backend/metering_billing/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import posthog
import pytest
from model_bakery import baker

from metering_billing.utils import now_utc
from metering_billing.utils.enums import (
FLAT_FEE_BILLING_TYPE,
Expand All @@ -13,6 +11,7 @@
PRODUCT_STATUS,
USAGE_BILLING_FREQUENCY,
)
from model_bakery import baker


@pytest.fixture(autouse=True)
Expand All @@ -22,7 +21,6 @@ def run_around_tests():
# A test function will be run at this point
yield
# Code that will run after your test, for example:
posthog.disabled = False


@pytest.fixture(autouse=True)
Expand Down
15 changes: 4 additions & 11 deletions backend/metering_billing/views/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from decimal import Decimal

from django.conf import settings
from django.db.models import Count, DecimalField, F, Prefetch, Q, Sum
from django.db.models import Count, F, Prefetch, Q, Sum
from drf_spectacular.utils import extend_schema, inline_serializer
from metering_billing.exceptions import (
ExternalConnectionFailure,
Expand Down Expand Up @@ -56,6 +56,7 @@
PAYMENT_PROVIDERS,
USAGE_CALC_GRANULARITY,
)
from metering_billing.views.model_views import CustomerViewSet
from rest_framework import serializers, status
from rest_framework.exceptions import ValidationError
from rest_framework.permissions import IsAuthenticated
Expand Down Expand Up @@ -458,16 +459,8 @@ def get(self, request, format=None):
"""
Return current usage for a customer during a given billing period.
"""
organization = request.organization
customers = Customer.objects.filter(organization=organization)
customers = customers.prefetch_related("invoices")
customers = customers.annotate(
unpaid_inv_amount=Sum(
"invoices__cost_due",
filter=Q(invoices__payment_status=Invoice.PaymentStatus.UNPAID),
output_field=DecimalField(),
)
)
request.organization
customers = CustomerViewSet.get_queryset(self)
cust = CustomerWithRevenueSerializer(customers, many=True).data
cust = make_all_decimals_floats(cust)
return Response(cust, status=status.HTTP_200_OK)
Expand Down
4 changes: 1 addition & 3 deletions backend/pytest.ini
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
[pytest]
DJANGO_SETTINGS_MODULE = lotus.settings
addopts = --cov=.
--cov-report term-missing:skip-covered
DJANGO_SETTINGS_MODULE = lotus.settings

0 comments on commit 158035f

Please sign in to comment.