Skip to content

Commit

Permalink
IMP add a dataframe breakpoint method for testing
Browse files Browse the repository at this point in the history
Frequently, we have long transformations of DataFrames:

```
df = (
    df
    # some transformations 1
    # some transformations 2
    # some transformations 3
    # some transformations 4
)
```

and stepping through them in an interactive testing session is painful;
we have to break up the code like so:

```
df = (
    df
    # some transformations 1
)
breakpoint()
df = (
    df
    # some transformations 2
)
breakpoint()
df = (
    df
    # some transformations 3
)
breakpoint()
df = (
    df
    # some transformations 4
)
```

this can become cumbersome. the proposal here is to register a
`breakpoint()` method on `DataFrame` objects while testing with
`SparklyGlobalContextTest` so you can just sprinkle `.breakpoint()`
method calls within your chain of transformations:

```
df = (
    df
    # some transformations 1
    .breakpoint()
    # some transformations 2
    .breakpoint()
    # some transformations 3
    .breakpoint()
    # some transformations 4
    .breakpoint()
)
```

you can even pass your own function, for example if you don't want to
jump into a debugger but maybe just print some rows:

```
def show_df(df):
    df.show(10, False)

df = (
    df
    # some transformations 1
    .breakpoint(show_df)
    # some transformations 2
    .breakpoint(show_df)
    # some transformations 3
    .breakpoint(show_df)
    # some transformations 4
    .breakpoint(lambda df: df.show(20, False))
)
```
  • Loading branch information
srstrickland committed May 18, 2022
1 parent 89d2f1f commit 125a2ff
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions sparkly/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import warnings

from pyspark.context import SparkContext
from pyspark.sql import DataFrame
from pyspark.sql import types as T

from sparkly import SparklySession
Expand Down Expand Up @@ -97,6 +98,34 @@ def _ensure_gateway_is_down():
os.environ.pop('PYSPARK_GATEWAY_SECRET', None)


def dataframe_breakpoint(self, break_function=None):
"""Injected method for DataFrame class during testing
User may supply their own break_function which takes a single parameter, df
Example:
output_df = (
input_df
.withColumn(...)
.where(...)
.groupBy(...)
.agg(...)
.breakpoint() # will bring up the pdb debugger
.select(...)
.breakpoint(lambda df: df.show(10, False)) # will print 10 rows to console
...
)
"""

def pdb_breakpoint(df):
import pdb
pdb.set_trace()

break_function = break_function or pdb_breakpoint
break_function(self)
return self


class SparklyTest(TestCase):
"""Base test for spark scrip tests.
Expand Down Expand Up @@ -158,6 +187,9 @@ def setUpClass(cls):

cls._init_session()

# define a `df.breakpoint()` for testing:
DataFrame.breakpoint = dataframe_breakpoint

for fixture in cls.class_fixtures:
fixture.setup_data()

Expand All @@ -178,6 +210,7 @@ def setUpClass(cls):
@classmethod
def tearDownClass(cls):
cls.spark.stop()
delattr(DataFrame, 'breakpoint')
_ensure_gateway_is_down()
super(SparklyTest, cls).tearDownClass()

Expand Down

0 comments on commit 125a2ff

Please sign in to comment.