Skip to content

Commit

Permalink
Fix backend tests for group discounts, new cart API, and more (#675)
Browse files Browse the repository at this point in the history
* Refactor cart summation to helper fn

* Add tests for group discounts and cart totaling

* Minor change to docs

* Only populate `sold_out_tickets` if tickets cannot be replaced

* Refactor tests to use new cart API

* Make openAPI docs happy

* Make `_calculate_cart_total` static method

* Group discount shouldn't activate below threshold

* Fix API docs & improve tests

* Add minor subtest

* Align tests with new API
  • Loading branch information
aviupadhyayula committed Apr 22, 2024
1 parent a9ccf55 commit bd18a03
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 37 deletions.
70 changes: 40 additions & 30 deletions backend/clubs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2632,6 +2632,7 @@ def create_tickets(self, request, *args, **kwargs):
required: false
group_discount:
type: number
format: float
required: false
order_limit:
type: int
Expand Down Expand Up @@ -4782,6 +4783,30 @@ class TicketViewSet(viewsets.ModelViewSet):
http_method_names = ["get", "post"]
lookup_field = "id"

@staticmethod
def _calculate_cart_total(cart) -> float:
"""
Calculate the total price of all tickets in a cart, applying discounts
where appropriate. Does not validate that the cart is valid.
:param cart: Cart object
:return: Total price of all tickets in the cart
"""
ticket_type_counts = {
item["type"]: item["count"]
for item in cart.tickets.values("type").annotate(count=Count("type"))
}
cart_total = sum(
(
ticket.price * (1 - ticket.group_discount)
if ticket.group_size
and ticket_type_counts[ticket.type] >= ticket.group_size
else ticket.price
)
for ticket in cart.tickets.all()
)
return cart_total

@transaction.atomic
@update_holds
@action(detail=False, methods=["get"])
Expand All @@ -4802,12 +4827,10 @@ def cart(self, request, *args, **kwargs):
allOf:
- $ref: "#/components/schemas/Ticket"
sold_out:
type: list
type: array
items:
type: object
properties:
id:
type: integer
event:
type: object
properties:
Expand Down Expand Up @@ -4843,6 +4866,7 @@ def cart(self, request, *args, **kwargs):
sold_out_tickets = []
replacement_tickets = []
tickets_in_cart = cart.tickets.values_list("id", flat=True)

for ticket_class in tickets_to_replace.values("type", "event").annotate(
count=Count("id")
):
Expand All @@ -4854,18 +4878,18 @@ def cart(self, request, *args, **kwargs):
holder__isnull=True,
).exclude(id__in=tickets_in_cart)[: ticket_class["count"]]

sold_out_tickets += [
{
**ticket_class,
"event": {
"id": ticket_class["event"],
# TODO: use prefetch_related for performance + style,
# Couldn't get it to fetch more than the event id somehow
"name": Event.objects.get(id=ticket_class["event"]).name,
},
"count": ticket_class["count"] - tickets.count(),
}
]
if tickets.count() < ticket_class["count"]:
sold_out_tickets.append(
{
**ticket_class,
"event": {
"id": ticket_class["event"],
"name": Event.objects.get(id=ticket_class["event"]).name,
},
"count": ticket_class["count"] - tickets.count(),
}
)

replacement_tickets.extend(list(tickets))

cart.tickets.remove(*tickets_to_replace)
Expand Down Expand Up @@ -4947,21 +4971,7 @@ def initiate_checkout(self, request, *args, **kwargs):
status=status.HTTP_403_FORBIDDEN,
)

# Calculate cart total, applying group discounts where appropriate
ticket_type_counts = {
item["type"]: item["count"]
for item in cart.tickets.values("type").annotate(count=Count("type"))
}

cart_total = sum(
(
ticket.price * (1 - ticket.group_discount)
if ticket.group_size
and ticket_type_counts[ticket.type] >= ticket.group_size
else ticket.price
)
for ticket in tickets
)
cart_total = self._calculate_cart_total(cart)

if not cart_total:
return Response(
Expand Down
105 changes: 98 additions & 7 deletions backend/tests/clubs/test_ticketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,35 @@ def test_create_ticket_offerings_bad_data(self):
self.assertIn(resp.status_code, [400], resp.content)
self.assertEqual(Ticket.objects.filter(type__contains="_").count(), 0, data)

def test_create_ticket_offerings_group_discounts(self):
self.client.login(username=self.user1.username, password="test")

group_size, group_discount = 2, 0.5
args = {
"quantities": [
{
"type": "_normal",
"count": 20,
"price": 10,
"group_size": group_size,
"group_discount": group_discount,
},
]
}
resp = self.client.put(
reverse("club-events-tickets", args=(self.club1.code, self.event1.pk)),
args,
format="json",
)
self.assertIn(resp.status_code, [200, 201], resp.content)

tickets = Ticket.objects.filter(type="_normal")

# Created tickets should have correct group size and discount
for ticket in tickets:
self.assertEqual(ticket.group_size, group_size)
self.assertEqual(ticket.group_discount, group_discount)

def test_get_tickets_information_no_tickets(self):
# Delete all the tickets
Ticket.objects.all().delete()
Expand Down Expand Up @@ -525,7 +554,56 @@ def test_get_cart(self):
self.assertEqual(len(data["tickets"]), 5, data)
for t1, t2 in zip(data["tickets"], tickets_to_add):
self.assertEqual(t1["id"], str(t2.id))
self.assertEqual(data["sold_out"], 0, data)
self.assertEqual(len(data["sold_out"]), 0, data)

def test_calculate_cart_total(self):
# Add a few tickets to cart
cart, _ = Cart.objects.get_or_create(owner=self.user1)
tickets_to_add = self.tickets1[:5]
for ticket in tickets_to_add:
cart.tickets.add(ticket)
cart.save()

expected_total = sum(t.price for t in tickets_to_add)

from clubs.views import TicketViewSet

actual_total = TicketViewSet._calculate_cart_total(cart)
self.assertEqual(actual_total, expected_total)

def test_calculate_cart_total_with_group_discount(self):
# Create tickets with group discount
tickets = [
Ticket.objects.create(
type="group",
event=self.event1,
price=10.0,
group_size=2,
group_discount=0.2,
)
for _ in range(10)
]

cart, _ = Cart.objects.get_or_create(owner=self.user1)
from clubs.views import TicketViewSet

# Add 1 ticket, shouldn't activate group discount
cart.tickets.add(tickets[0])
cart.save()

total = TicketViewSet._calculate_cart_total(cart)
self.assertEqual(total, 10.0) # 1 * price=10 = 10

# Add 4 more tickets, enough to activate group discount
tickets_to_add = tickets[1:5]
for ticket in tickets_to_add:
cart.tickets.add(ticket)
cart.save()

self.assertEqual(cart.tickets.count(), 5)

total = TicketViewSet._calculate_cart_total(cart)
self.assertEqual(total, 40.0) # 5 * price=10 * (1 - group_discount=0.2) = 40

def test_get_cart_replacement_required(self):
self.client.login(username=self.user1.username, password="test")
Expand All @@ -547,7 +625,7 @@ def test_get_cart_replacement_required(self):

# The cart still has 5 tickets: just replaced with available ones
self.assertEqual(len(data["tickets"]), 5, data)
self.assertEqual(data["sold_out"], 0, data)
self.assertEqual(len(data["sold_out"]), 0, data)

in_cart = set(map(lambda t: t["id"], data["tickets"]))
to_add = set(map(lambda t: str(t.id), tickets_to_add))
Expand Down Expand Up @@ -578,13 +656,24 @@ def test_get_cart_replacement_required_sold_out(self):
# The cart now has 3 tickets
self.assertEqual(len(data["tickets"]), 3, data)

# 2 tickets have been sold out
self.assertEqual(data["sold_out"], 2, data)
# Only 1 type of ticket should be sold out
self.assertEqual(len(data["sold_out"]), 1, data)

in_cart = set(map(lambda t: t["id"], data["tickets"]))
to_add = set(map(lambda t: str(t.id), tickets_to_add))
# 2 normal tickets should be sold out
expected_sold_out = {
"type": self.tickets1[0].type,
"event": {
"id": self.tickets1[0].event.id,
"name": self.tickets1[0].event.name,
},
"count": 2,
}
for key, val in expected_sold_out.items():
self.assertEqual(data["sold_out"][0][key], val, data)

# 0 tickets are the same (we sell all but last 3)
in_cart = set(map(lambda t: t["id"], data["tickets"]))
to_add = set(map(lambda t: str(t.id), tickets_to_add))
self.assertEqual(len(in_cart & to_add), 0, in_cart | to_add)

def test_initiate_checkout(self):
Expand Down Expand Up @@ -720,6 +809,7 @@ def test_complete_checkout(self):
held_tickets = Ticket.objects.filter(holder=self.user1)
self.assertEqual(held_tickets.count(), 2, held_tickets)

# Complete checkout
resp = self.client.post(
reverse("tickets-complete-checkout"),
{"transient_token": "abcdefg"},
Expand All @@ -741,9 +831,10 @@ def test_complete_checkout(self):
self.assertEqual(held_tickets.count(), 0, held_tickets)

# Transaction record created
TicketTransactionRecord.objects.filter(
record_exists = TicketTransactionRecord.objects.filter(
reconciliation_id=MockPaymentResponse().reconciliation_id
).exists()
self.assertTrue(record_exists)

def test_complete_checkout_stale_cart(self):
self.client.login(username=self.user1.username, password="test")
Expand Down

0 comments on commit bd18a03

Please sign in to comment.