In [5]:
import pandas as pd
d_type_dict = {
    "match_id": "string",
    "Pt": "int",
    "Set1": "int",
    "Set2": "int",
    "Gm1": "string",
    "Gm2": "string",
    "pts": "string",
    "Gm#": "int",
    "TbSet": "string",
    "Svr": "int",
    "1st": "string",
    "2nd": "string",
    "Notes": "string",
    "PtWinner": "int"
}
all_file_names = [
    "https://raw.githubusercontent.com/JeffSackmann/tennis_MatchChartingProject/refs/heads/master/charting-m-points-2010s.csv",
    "https://raw.githubusercontent.com/JeffSackmann/tennis_MatchChartingProject/refs/heads/master/charting-m-points-2020s.csv"]
    
points_df = pd.concat((pd.read_csv(file_name, dtype=d_type_dict) for file_name in all_file_names), ignore_index=True)

match_metadata_csv_file_name = "https://raw.githubusercontent.com/JeffSackmann/tennis_MatchChartingProject/refs/heads/master/charting-m-matches.csv"
match_metadata = pd.read_csv(match_metadata_csv_file_name)

points_df = pd.merge(points_df, match_metadata, on="match_id")
print(points_df.head())

                                            match_id  Pt  Set1  Set2 Gm1 Gm2  \
0  20191124-M-Davis_Cup_Finals-F-Rafael_Nadal-Den...   1     0     0   0   0   
1  20191124-M-Davis_Cup_Finals-F-Rafael_Nadal-Den...   2     0     0   0   0   
2  20191124-M-Davis_Cup_Finals-F-Rafael_Nadal-Den...   3     0     0   0   0   
3  20191124-M-Davis_Cup_Finals-F-Rafael_Nadal-Den...   4     0     0   0   0   
4  20191124-M-Davis_Cup_Finals-F-Rafael_Nadal-Den...   5     0     0   1   0   

    Pts  Gm# TbSet  Svr  ...      Date        Tournament Round  Time  \
0   0-0    1  True    1  ...  20191124  Davis Cup Finals     F   NaN   
1  15-0    1  True    1  ...  20191124  Davis Cup Finals     F   NaN   
2  30-0    1  True    1  ...  20191124  Davis Cup Finals     F   NaN   
3  40-0    1  True    1  ...  20191124  Davis Cup Finals     F   NaN   
4   0-0    2  True    2  ...  20191124  Davis Cup Finals     F   NaN   

                    Court Surface         Umpire Best of Final TB? Charted by  
0  Est

In [6]:
print(points_df[0:10]["Surface"])

0    Hard
1    Hard
2    Hard
3    Hard
4    Hard
5    Hard
6    Hard
7    Hard
8    Hard
9    Hard
Name: Surface, dtype: object


In [7]:
import csv
import io

# Bulk copy function that takes in an iterable that has properties that match the database column names.
# The database column names are stored inside the orm_class parameter
def copy_from_iter(engine, orm_class, rows_iter, batch_size=50_000):
    table = orm_class.__table__
    cols = [c.name for c in table.columns]

    with engine.begin() as conn:
        raw = conn.connection
        cur = raw.cursor()

        buf = io.StringIO()
        writer = csv.writer(buf)

        count = 0

        for obj in rows_iter:
            row = []
            for col_name in cols:
                val = getattr(obj, col_name)
                if isinstance(val, Enum):
                    val = val.value

                row.append(val)

            writer.writerow(row)
            count += 1

            if count % batch_size == 0:
                buf.seek(0)
                cur.copy_from(buf, table.name, sep=",", columns=cols)

                # reset buffer
                buf.truncate(0)
                buf.seek(0)

        # flush remaining rows
        if buf.tell() > 0:
            buf.seek(0)
            cur.copy_from(buf, table.name, sep=",", columns=cols)

In [12]:
from dataclasses import dataclass
from enum import Enum
import uuid

# -------------------
# Depth
# -------------------

class Depth(Enum):
    UNKNOWN_DEPTH = "UNKNOWN_DEPTH"
    SHALLOW = "SHALLOW"
    DEEP = "DEEP"
    BASELINE = "BASELINE"

DEPTH_CHAR_TO_DEPTH: dict[str, Depth] = {
    "7": Depth.SHALLOW,
    "8": Depth.DEEP,
    "9": Depth.BASELINE
}

# -------------------
# Direction
# -------------------

class Direction(Enum):
    UNKNOWN_DIRECTION = "UNKNOWN_DIRECTION"
    RIGHT = "RIGHT"
    CENTER = "CENTER"
    LEFT = "LEFT"

DIRECTION_CHAR_TO_DIRECTION: dict[str, Direction] = {
    "1": Direction.RIGHT,
    "2": Direction.CENTER,
    "3": Direction.LEFT,
}

# -------------------
# ShotType
# -------------------

class ShotType(Enum):
    UNKNOWN_SHOT_TYPE = "UNKNOWN_SHOT_TYPE"
    FOREHAND = "FOREHAND"
    BACKHAND = "BACKHAND"
    FOREHAND_SLICE = "FOREHAND_SLICE"
    BACKHAND_SLICE = "BACKHAND_SLICE"
    FOREHAND_VOLLEY = "FOREHAND_VOLLEY"
    BACKHAND_VOLLEY = "BACKHAND_VOLLEY"
    SERVE = "SERVE"
    SMASH = "SMASH"
    BACKHAND_SMASH = "BACKHAND_SMASH"
    FOREHAND_DROP = "FOREHAND_DROP"
    BACKHAND_DROP = "BACKHAND_DROP"
    FOREHAND_LOB = "FOREHAND_LOB"
    BACKHAND_LOB = "BACKHAND_LOB"
    FOREHAND_HALF_VOLLEY = "FOREHAND_HALF_VOLLEY"
    BACKHAND_HALF_VOLLEY = "BACKHAND_HALF_VOLLEY"
    FOREHAND_SWINGING_VOLLEY = "FOREHAND_SWINGING_VOLLEY"
    BACKHAND_SWINGING_VOLLEY = "BACKHAND_SWINGING_VOLLEY"
    TRICK = "TRICK"

SHOT_CODE_TO_TYPE: dict[str, ShotType] = {
    "f": ShotType.FOREHAND,
    "b": ShotType.BACKHAND,

    "r": ShotType.FOREHAND_SLICE,
    "s": ShotType.BACKHAND_SLICE,

    "v": ShotType.FOREHAND_VOLLEY,
    "z": ShotType.BACKHAND_VOLLEY,

    "o": ShotType.SMASH,
    "p": ShotType.BACKHAND_SMASH,

    "u": ShotType.FOREHAND_DROP,
    "y": ShotType.BACKHAND_DROP,

    "l": ShotType.FOREHAND_LOB,
    "m": ShotType.BACKHAND_LOB,

    "h": ShotType.FOREHAND_HALF_VOLLEY,
    "i": ShotType.BACKHAND_HALF_VOLLEY,

    "j": ShotType.FOREHAND_SWINGING_VOLLEY,
    "k": ShotType.BACKHAND_SWINGING_VOLLEY,

    "t": ShotType.TRICK,
    "q": ShotType.UNKNOWN_SHOT_TYPE,
}

# -------------------
# ServeDirection
# -------------------

class ServeDirection(Enum):
    UNKNOWN_SERVE_DIRECTION = "UNKNOWN_SERVE_DIRECTION"
    T = "T"
    BODY = "BODY"
    WIDE = "WIDE"

SERVE_DIRECTION_MAP = {
    "4": ServeDirection.WIDE,
    "5": ServeDirection.BODY,
    "6": ServeDirection.T
}

# -------------------
# CourtPosition
# -------------------

class CourtPosition(Enum):
    UNKNOWN_COURT_POSITION = "UNKNOWN_COURT_POSITION"
    APPROACH = "APPROACH"
    NET = "NET"
    BASELINE = "BASELINE"

COURT_POSITION_CHAR_TO_COURT_POSITION: dict[str, CourtPosition] = {
    "+": CourtPosition.APPROACH,
    "-": CourtPosition.NET,
    "=": CourtPosition.BASELINE
}

# -------------------
# Outcome
# -------------------

class Outcome(Enum):
    UNKNOWN_OUTCOME = "UNKNOWN_OUTCOME"
    CONTINUE = "CONTINUE"
    WINNER = "WINNER"
    UNFORCED_ERROR = "UNFORCED_ERROR"
    FORCED_ERROR = "FORCED_ERROR"

OUTCOME_CHAR_TO_OUTCOME: dict[str, Outcome] = {
    "*": Outcome.WINNER,
    "@": Outcome.UNFORCED_ERROR,
    "#": Outcome.FORCED_ERROR
}

# Custom equality and hash functions so that we can get rid of duplicate records. This isn't ideal since equality
# and hashing should look at all of the properties and not just the keys of the db to ensure we don't try to write
# duplicate items to the db.
@dataclass
class ShotDetail:
    number: int
    shot_type: ShotType
    depth: Depth
    direction: Direction
    court_position: CourtPosition
    outcome: Outcome
    serve_direction: ServeDirection
    point_number: int
    point_match_id: uuid

    def __repr__(self):
        return f"shot_number: {self.number}, type: {self.shot_type}"

    def __eq__(self, other):
        if not isinstance(other, ShotDetail):
            return False

        return (self.number == other.number and 
                self.point_number == other.point_number and
                self.point_match_id == other.point_match_id)

    def __hash__(self):
        return hash((self.number, self.point_number, self.point_match_id))

In [13]:
from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy import Column, Date, Enum, ForeignKey, ForeignKeyConstraint, Integer, String, UniqueConstraint, Uuid
from enum import Enum as PEnum

Base = declarative_base()

class CourtSurface(PEnum):
    UNKNOWN_SURFACE = "UNKNOWN_SURFACE"
    HARD = "HARD"
    GRASS = "GRASS"
    CLAY = "CLAY"
    CARPET = "CARPET"

class Event(Base):
    __tablename__ = "event"
    
    id = Column(Uuid, primary_key=True)
    surface = Column(Enum(CourtSurface), nullable=False)

class Match(Base):
    __tablename__ = "match"

    id = Column(Uuid, primary_key=True)
    date = Column(Date)
    event_round = Column(String)
    first_player_name = Column(String, nullable=False)
    second_player_name = Column(String, nullable=False)

    points = relationship("Point", back_populates="match")

class Point(Base):
    __tablename__ = "point"
    number = Column(Integer, primary_key=True)
    game_score = Column(String)
    match_id = Column(Uuid, ForeignKey("match.id"), primary_key=True)

    match = relationship("Match", back_populates="points")

    __table_args__ = (UniqueConstraint("number", "match_id"), {})

    def __repr__(self):
        return f"{self.match_id}:{self.number}@{self.game_score}"

    def __eq__(self, other):
        if not isinstance(other, Point):
            return False

        return self.number == other.number and self.game_score == other.game_score and self.match_id == other.match_id

    def __hash__(self):
        return hash((self.number, self.game_score, self.match_id))

class Shot(Base):
    __tablename__ = "shot"

    number = Column(Integer, primary_key=True)
    shot_type = Column(Enum(ShotType), default=ShotType.UNKNOWN_SHOT_TYPE, nullable=False)
    depth = Column(Enum(Depth), default=Depth.UNKNOWN_DEPTH, nullable=False)
    direction = Column(Enum(Direction), default=Direction.UNKNOWN_DIRECTION, nullable=False)
    court_position = Column(Enum(CourtPosition), default=CourtPosition.UNKNOWN_COURT_POSITION, nullable=False)
    outcome = Column(Enum(Outcome), default=Outcome.UNKNOWN_OUTCOME, nullable=False)
    serve_direction = Column(Enum(ServeDirection), default=ServeDirection.UNKNOWN_SERVE_DIRECTION, nullable=False)
    point_number = Column(Integer, primary_key=True)
    point_match_id = Column(Uuid, primary_key=True)

    __table_args__ = (
        ForeignKeyConstraint(
            ["point_number", "point_match_id"],
            ["point.number", "point.match_id"],
        ),
    )

    def __repr__(self):
        return f"shot_number: {self.number}, type: {self.shot_type}"

    def __eq__(self, other):
        if not isinstance(other, Shot):
            return False

        return (self.number == other.number and 
                self.point_number == other.point_number and
                self.point_match_id == other.point_match_id)

    def __hash__(self):
        return hash((self.number, self.point_number, self.point_match_id))

In [14]:
from sqlalchemy import create_engine
from dotenv import load_dotenv
import os

load_dotenv()

db_url_schema = "postgresql+psycopg2://"
local_host = "localhost:5432"
db_url = db_url_schema + os.getenv("DB_USER") + ":" + os.getenv("DB_PASSWORD") + "@" + local_host + "/" + os.getenv("DB_NAME")
print(db_url)
engine = create_engine(db_url, echo=True)

postgresql+psycopg2://zifanxiang:ga65574!@localhost:5432/tennis


In [17]:
from dataclasses import dataclass
import uuid

namespace = uuid.NAMESPACE_DNS

@dataclass(eq=True, unsafe_hash=True)
class MatchInfo:
    id: uuid
    date: str
    event: str
    event_round: str
    first_player_name: str
    second_player_name: str

def parse_match(match_id: str) -> MatchInfo:
    match_id_split_by_hyphen = match_id.split("-")
    date_str = match_id_split_by_hyphen[0]
    event = match_id_split_by_hyphen[2]
    event_round = match_id_split_by_hyphen[3]
    first_player_name = match_id_split_by_hyphen[-2]
    second_player_name = match_id_split_by_hyphen[-1]

    unique_uuid_str = f"{date_str}{event}{event_round}{first_player_name}{second_player_name}"

    return MatchInfo(
        id=uuid.uuid5(namespace, unique_uuid_str),
        date=date_str,
        event=event,
        event_round=event_round,
        first_player_name=first_player_name,
        second_player_name=second_player_name)

points_df["parsed_matched_info"] = points_df["match_id"].apply(parse_match)
unique_matches = set(points_df["parsed_matched_info"])

In [19]:
from dataclasses import dataclass
import uuid

namespace = uuid.NAMESPACE_DNS

@dataclass(eq=True, unsafe_hash=True)
class EventInfo:
    id: uuid.UUID
    surface: CourtSurface

SURFACE_MAP = {
    "HARD": CourtSurface.HARD,
    "CLAY": CourtSurface.CLAY,
    "GRASS": CourtSurface.GRASS,
    "CARPET": CourtSurface.CARPET,
}

def to_court_surface(surface_value: str) -> CourtSurface:
    if not surface_value:
        return CourtSurface.UNKNOWN_SURFACE

    normalized = str(surface_value).strip().upper()

    return SURFACE_MAP.get(normalized, CourtSurface.UNKNOWN_SURFACE)

def to_event_info(row) -> EventInfo:
    match_info = row["parsed_matched_info"]
    event_id = uuid.uuid5(namespace, str(match_info.event))
    return EventInfo(
        id=event_id,
        surface=to_court_surface(row["Surface"]),
    )

unique_event_infos = list(set(points_df.apply(to_event_info, axis=1)))
print(unique_event_infos[0:10])
# copy_from_iter(engine, Event, unique_event_infos, 1000)
print(f"Inserted {len(unique_event_infos)} event rows")

[EventInfo(id=UUID('2f9c9d73-9ee4-5e3d-866e-2ac58d5ddf85'), surface=<CourtSurface.HARD: 'HARD'>), EventInfo(id=UUID('5e940603-51eb-5996-b0eb-25827e7f5358'), surface=<CourtSurface.CLAY: 'CLAY'>), EventInfo(id=UUID('bc3f14ed-cc3a-526c-aae0-75e75fac529f'), surface=<CourtSurface.CLAY: 'CLAY'>), EventInfo(id=UUID('dcd53676-9a0b-50a6-8da2-d5f678b6380b'), surface=<CourtSurface.CLAY: 'CLAY'>), EventInfo(id=UUID('73cdf3be-22ce-532f-a743-587763a627c7'), surface=<CourtSurface.HARD: 'HARD'>), EventInfo(id=UUID('f62e7922-52f7-598c-93c9-8838dfc944f3'), surface=<CourtSurface.HARD: 'HARD'>), EventInfo(id=UUID('b121846a-db8c-5d37-82bc-0a5eece21efb'), surface=<CourtSurface.CLAY: 'CLAY'>), EventInfo(id=UUID('bec9d55f-723e-5f56-a425-88897c12cbeb'), surface=<CourtSurface.CLAY: 'CLAY'>), EventInfo(id=UUID('ad7ab605-8464-5d0b-9ad5-24cc7644a4a7'), surface=<CourtSurface.HARD: 'HARD'>), EventInfo(id=UUID('6d2645e7-fe40-5859-858e-0570c9d80978'), surface=<CourtSurface.HARD: 'HARD'>)]
Inserted 367 event rows


In [None]:
############################################################
# WARNING
# WARNING
# WARNING
# Reset block
# Drops the tables so that we can recreate them
# DO NOT RUN THIS UNLESS YOU INTEND TO BLOW AWAY THE TABLES!

Shot.__table__.drop(engine, checkfirst=True)
Point.__table__.drop(engine, checkfirst=True)
Match.__table__.drop(engine, checkfirst=True)

In [15]:
Base.metadata.create_all(engine, checkfirst=True)

2026-02-28 08:36:04,324 INFO sqlalchemy.engine.Engine select pg_catalog.version()
2026-02-28 08:36:04,326 INFO sqlalchemy.engine.Engine [raw sql] {}
2026-02-28 08:36:04,334 INFO sqlalchemy.engine.Engine select current_schema()
2026-02-28 08:36:04,335 INFO sqlalchemy.engine.Engine [raw sql] {}
2026-02-28 08:36:04,343 INFO sqlalchemy.engine.Engine show standard_conforming_strings
2026-02-28 08:36:04,344 INFO sqlalchemy.engine.Engine [raw sql] {}
2026-02-28 08:36:04,360 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2026-02-28 08:36:04,380 INFO sqlalchemy.engine.Engine SELECT pg_catalog.pg_class.relname 
FROM pg_catalog.pg_class JOIN pg_catalog.pg_namespace ON pg_catalog.pg_namespace.oid = pg_catalog.pg_class.relnamespace 
WHERE pg_catalog.pg_class.relname = %(table_name)s AND pg_catalog.pg_class.relkind = ANY (ARRAY[%(param_1)s, %(param_2)s, %(param_3)s, %(param_4)s, %(param_5)s]) AND pg_catalog.pg_table_is_visible(pg_catalog.pg_class.oid) AND pg_catalog.pg_namespace.nspname != %(nspname

In [None]:
copy_from_iter(engine, Match, unique_matches, 1000)

In [None]:
match_info_to_match_id = { match : match.id for match in unique_matches } 

def assign_to_generated_match_id(match_info: MatchInfo) -> str:
    return match_info_to_match_id.get(match_info)

points_df["generated_match_id"] = points_df["parsed_matched_info"].apply(assign_to_generated_match_id)

In [None]:
@dataclass(eq=True, unsafe_hash=True)
class PointInfo:
    number: int
    game_score: str
    match_id: uuid

unique_point_infos = set(points_df.apply(
    lambda row: Point(number=row.Pt, game_score=row.Pts, match_id=row.generated_match_id), axis=1))

In [None]:
copy_from_iter(engine, Point, unique_point_infos, 1000)

In [None]:
import re

ALL_SHOT_TYPE_CHAR_CODES = "".join(SHOT_CODE_TO_TYPE.keys())
SPLIT_REGEX = r"([" + ALL_SHOT_TYPE_CHAR_CODES + r"]){1}"

def parse_shot_string_into_arr(row) -> list[ShotDetail]:
    shot_details: list[ShotDetail] = []
    shot_str = row["1st"]
    shot_strs_split_by_type = [shot_str for shot_str in re.split(SPLIT_REGEX, shot_str) if len(shot_str) > 0]

    shot_number = 0
    serve_shot_str = shot_strs_split_by_type[0]
    serve_direction = ServeDirection.UNKNOWN_SERVE_DIRECTION
    outcome = Outcome.CONTINUE
    for shot_property_char in serve_shot_str:
        if shot_property_char in SERVE_DIRECTION_MAP:
            serve_direction = SERVE_DIRECTION_MAP.get(shot_property_char)
        if outcome in OUTCOME_CHAR_TO_OUTCOME:
            outcome = OUTCOME_CHAR_TO_OUTCOME.get(shot_property_char)
    
    shot_details.append(
        ShotDetail(
            number=shot_number,
            shot_type=ShotType.SERVE,
            depth=Depth.UNKNOWN_DEPTH,
            direction=Direction.UNKNOWN_DIRECTION,
            court_position=CourtPosition.UNKNOWN_COURT_POSITION,
            outcome=outcome,
            serve_direction=serve_direction,
            point_number=row.Pt,
            point_match_id=row.generated_match_id))

    shot_number += 1

    # Zip consecutive shot strings together since the shot modifiers (court position, direction) always come after the
    # shot type. The shot modifiers are not guaranteed to be present though so we need to check for their presence and
    # if not just parse out the shot type.
    for first_parsed_char_shot, second_parsed_char_shot in zip(shot_strs_split_by_type[1:], shot_strs_split_by_type[2:]):
        if first_parsed_char_shot in SHOT_CODE_TO_TYPE.keys() and second_parsed_char_shot in SHOT_CODE_TO_TYPE.keys():
            shot_type = SHOT_CODE_TO_TYPE.get(first_parsed_char_shot)
            shot_details.append(
                ShotDetail(
                    number=shot_number,
                    shot_type=shot_type,
                    depth=Depth.UNKNOWN_DEPTH,
                    direction=Direction.UNKNOWN_DIRECTION,
                    court_position=CourtPosition.UNKNOWN_COURT_POSITION,
                    outcome=Outcome.CONTINUE,
                    serve_direction=ServeDirection.UNKNOWN_SERVE_DIRECTION,
                    point_number=row.Pt,
                    point_match_id=row.generated_match_id))
        elif first_parsed_char_shot in SHOT_CODE_TO_TYPE.keys() and second_parsed_char_shot not in SHOT_CODE_TO_TYPE.keys():
            shot_type = SHOT_CODE_TO_TYPE.get(first_parsed_char_shot)
            depth = Depth.UNKNOWN_DEPTH
            direction = Direction.UNKNOWN_DIRECTION
            court_position = CourtPosition.UNKNOWN_COURT_POSITION
            outcome = Outcome.CONTINUE

            for shot_property_char in second_parsed_char_shot:
                if shot_property_char in DEPTH_CHAR_TO_DEPTH:
                    depth = DEPTH_CHAR_TO_DEPTH.get(shot_property_char)
                if shot_property_char in DIRECTION_CHAR_TO_DIRECTION:
                    direction = DIRECTION_CHAR_TO_DIRECTION.get(shot_property_char)
                if shot_property_char in OUTCOME_CHAR_TO_OUTCOME:
                    outcome = OUTCOME_CHAR_TO_OUTCOME.get(shot_property_char)
                if shot_property_char in COURT_POSITION_CHAR_TO_COURT_POSITION:
                    court_position = COURT_POSITION_CHAR_TO_COURT_POSITION.get(shot_property_char)

            shot_details.append(
                ShotDetail(
                    number=shot_number,
                    shot_type=shot_type,
                    depth=depth,
                    direction=direction,
                    court_position=court_position,
                    outcome=outcome,
                    serve_direction=ServeDirection.UNKNOWN_SERVE_DIRECTION,
                    point_number=row.Pt,
                    point_match_id=row.generated_match_id))
 
            shot_number += 1

    return shot_details

points_df["parsed_out_shots"] = points_df.apply(parse_shot_string_into_arr, axis=1)
print(points_df["parsed_out_shots"][0])

In [None]:
flattened_shot_db_items = []
for shot_db_item in points_df["parsed_out_shots"]:
    flattened_shot_db_items.extend(shot_db_item)

unique_flattened_shot_db_items = set(flattened_shot_db_items)

In [None]:
print(len(unique_flattened_shot_db_items))

In [None]:
copy_from_iter(engine, Shot, unique_flattened_shot_db_items)

In [None]:
from abc import ABC, abstractmethod

class Operator(Enum):
    EQUAL = "EQUAL"
    GREATER_THAN = "GREATER_THAN"
    LESS_THAN = "LESS_THAN"

def applyOperator(operator: Operator, lhs_operand: int, rhs_operand: int) -> bool:
    match operator:
        case Operator.EQUAL:
            return lhs_operand == rhs_operand
        case Operator.GREATER_THAN:
            return lhs_operand > rhs_operand
        case Operator.LESS_THAN:
            return lhs_operand < rhs_operand

class ValueMatcher[T](ABC):
    @abstractmethod
    def match_value(self, value: T) -> bool:
        pass

class NoOpValueMatcher[T](ABC):
    def match_value(self, value: T) -> bool:
        return True

@dataclass
class SingleValueMatcher[T](ValueMatcher):
    value: T
    
    def match_value(self, value: T) -> bool:
        return self.value == value

@dataclass
class AnyValueMatcher[T](ValueMatcher):
    values: set[T]

    def __init__(self, values: list[T]):
        self.values = set(values)

    def match_value(self, value: T) -> bool:
        return value in self.values

type_no_op_matcher: ValueMatcher[ShotType] = NoOpValueMatcher()
depth_no_op_matcher: ValueMatcher[Depth] = NoOpValueMatcher()
direction_no_op_matcher: ValueMatcher[Direction] = NoOpValueMatcher()
court_position_no_op_matcher: ValueMatcher[CourtPosition] = NoOpValueMatcher()
outcome_no_op_matcher: ValueMatcher[Outcome] = NoOpValueMatcher()
serve_direction_no_op_matcher: ValueMatcher[ServeDirection] = NoOpValueMatcher()

@dataclass
class ShotMatcher:
    number: int
    type_matcher: ValueMatcher[ShotType] = type_no_op_matcher
    depth_matcher: ValueMatcher[Depth] = depth_no_op_matcher
    direction_matcher: ValueMatcher[Direction] = direction_no_op_matcher
    court_position_matcher: ValueMatcher[CourtPosition] = court_position_no_op_matcher
    outcome_matcher: ValueMatcher[Outcome] = outcome_no_op_matcher
    serve_direction_matcher: ValueMatcher[ServeDirection] = serve_direction_no_op_matcher

    def does_shot_match(self, shot: ShotDetail) -> bool:
        if not self.type_matcher.match_value(shot.shot_type):
            return False
        if not self.depth_matcher.match_value(shot.depth):
            return False
        if not self.direction_matcher.match_value(shot.direction):
            return False
        if not self.court_position_matcher.match_value(shot.court_position):
            return False
        if not self.outcome_matcher.match_value(shot.outcome):
            return False
        if not self.serve_direction_matcher.match_value(shot.serve_direction):
            return False

        return True

class ShotClassifier:
    shots_length_limit: int
    shots_length_limit_operator: Operator
    shot_matchers: list[ShotMatcher]

    def __init__(
        self, 
        shots_length_limit: int,
        shots_length_limit_operator: Operator,
        shot_matchers: list[ShotMatcher]):
        self.shots_length_limit = shots_length_limit
        self.shots_length_limit_operator = shots_length_limit_operator
        self.shot_matchers = shot_matchers
    
    def shot_pattern_fits(self, shots: list[ShotDetail]) -> bool:
        if not applyOperator(self.shots_length_limit_operator, len(shots), self.shots_length_limit):
            return False

        for shot_matcher in self.shot_matchers:
            shot_to_match = shots[shot_matcher.number]
            if not shot_matcher.does_shot_match(shot_to_match):
                return False

        return True

T_BULLY_SHOT_LENGTH_LIMIT = 3
t_bully_shot_classifier = (ShotClassifier(
        T_BULLY_SHOT_LENGTH_LIMIT,
        Operator.EQUAL,
        [
            ShotMatcher(
                number=0,
                type_matcher=SingleValueMatcher(ShotType.SERVE),
                serve_direction_matcher=SingleValueMatcher(ServeDirection.T),
            ),
            ShotMatcher(
                number=2,
                outcome_matcher=SingleValueMatcher(Outcome.WINNER),
            )
        ]))
WIDE_SLICE_SHOT_LENGTH_LIMIT = 3
wide_slice_fade_shot_classifier = (ShotClassifier(
    WIDE_SLICE_SHOT_LENGTH_LIMIT,
    Operator.EQUAL,
    [
        ShotMatcher(
            number=0,
            type_matcher=SingleValueMatcher(ShotType.SERVE),
            serve_direction_matcher=SingleValueMatcher(ServeDirection.WIDE),
        ),
        ShotMatcher(
            number=1,
            depth_matcher=SingleValueMatcher(Depth.SHALLOW)
        ),
        ShotMatcher(
            number=2,
            outcome_matcher=SingleValueMatcher(Outcome.WINNER)
        )
    ]))
SERVER_JAM_SHOT_LENGTH_LIMIT = 3
server_jam_shot_classifier = (ShotClassifier(
    SERVER_JAM_SHOT_LENGTH_LIMIT,
    Operator.EQUAL,
    [
        ShotMatcher(
            number=1,
            depth_matcher=AnyValueMatcher([Depth.DEEP, Depth.BASELINE])
        ),
        ShotMatcher(
            number=2,
            outcome_matcher=AnyValueMatcher([Outcome.UNFORCED_ERROR, Outcome.FORCED_ERROR])
        ),
    ]))
    


# Determines if the shot pattern in the point conforms to the T bully style
# def is_t_bully_point(shots: list[ShotDetail], consider_direction: bool) -> bool:
#     if len(shots) != T_BULLY_SHOT_LENGTH_LIMIT:
#         return False
    
#     serve = shots[0]

#     if serve.serve_direction != ServeDirection.T and consider_direction:
#         return False
#     if shots[2].outcome != Outcome.WINNER:
#         return False

#     return True

t_bully_shot_patterns: list[list[ShotDetail]] = []
wide_slice_fade_shot_patterns: list[list[ShotDetail]] = []
server_jam_shot_patterns: list[list[ShotDetail]] = []
for shots in points_df["parsed_out_shots"]:
    if t_bully_shot_classifier.shot_pattern_fits(shots):
        t_bully_shot_patterns.append(shots)
    if wide_slice_fade_shot_classifier.shot_pattern_fits(shots):
        wide_slice_fade_shot_patterns.append(shots)
    if server_jam_shot_classifier.shot_pattern_fits(shots):
        server_jam_shot_patterns.append(shots)
        
print(len(t_bully_shot_patterns))
print(len(wide_slice_fade_shot_patterns))
print(wide_slice_fade_shot_patterns[0])
print(len(server_jam_shot_patterns))

In [None]:
print(t_bully_shot_patterns[0])