In [1]:
from pyspark.context import SparkContext
from pyspark.sql import SparkSession, Window
from pyspark.sql.types import FloatType, IntegerType, BooleanType

import pyspark.sql.functions as f

import pandas as pd
from pandas.testing import assert_frame_equal

import numpy as np

import string

import unittest
import logging
import pytest

import sys

from operator import add

In [2]:
spark = SparkSession.builder.getOrCreate()
spark.sparkContext._conf.getAll()

[('spark.driver.host', 'USCND64648K6.CSCMWS.CSCMWS.COM'),
 ('spark.sql.catalogImplementation', 'hive'),
 ('spark.rdd.compress', 'True'),
 ('spark.serializer.objectStreamReset', '100'),
 ('spark.master', 'local[*]'),
 ('spark.executor.id', 'driver'),
 ('spark.submit.deployMode', 'client'),
 ('spark.driver.port', '62420'),
 ('spark.app.name', 'PySparkShell'),
 ('spark.ui.showConsoleProgress', 'true'),
 ('spark.app.id', 'local-1539176621006')]

In [3]:
conf = [('spark.app.name', "Spark AppName Updated")]
conf = spark.sparkContext._conf.setAll(conf)
spark.sparkContext.stop()
spark = spark.builder.config(conf=conf).getOrCreate()
spark

In [4]:
class PySparkTest(unittest.TestCase):
    
    @classmethod
    def suppress_py4j_logging(cls):
        logger = logging.getLogger('py4j')
        logger.setLevel(logging.WARN)

    @classmethod
    def create_testing_pyspark_session(cls):
        return (SparkSession.builder
     .master('local[2]')
     .appName('my-local-testing-pyspark-context')
     .enableHiveSupport()
     .getOrCreate())

    @classmethod
    def setUpClass(cls):
        cls.suppress_py4j_logging()
        cls.spark = cls.create_testing_pyspark_session()
        
    @classmethod
    def tearDownClass(cls):
        cls.spark.stop()

In [5]:
class SimpleTest(PySparkTest):
    
    def test_upper(self):
        self.assertEqual('foo'.upper(), 'FOO')
        
        
    def test_lower(self):
        self.assertEqual("FOO".lower(), 'foo')

In [6]:
class AdvancedTest(PySparkTest):

    def test_rdd(self):
        test_rdd = self.spark.sparkContext.parallelize(
            ['cat dog mouse', 'cat cat dog'], 2)
        results = test_rdd.flatMap(lambda line: line.split()).map(
            lambda word: (word, 1)).reduceByKey(add).collect()
        expected_results = [('cat', 3), ('dog', 2), ('mouse', 1)]
        self.assertEqual(set(results), set(expected_results))
        
    def assert_frame_equal_with_sort(self, results, expected, keycolumns):
        results_sorted = results.sort_values(by=keycolumns).reset_index(drop=True)
        expected_sorted = expected.sort_values(by=keycolumns).reset_index(drop=True)
        assert_frame_equal(results_sorted, expected_sorted)
        
    @staticmethod
    def my_spark_function(df):
        return df[df['make'].isin(["Rover", "Lotus", "MINI"])]
        

    def test_dataFrame(self):
     # Create the test data, with larger examples this can come from a CSV file
     # and we can use pd.read_csv(…)
        data_pandas = pd.DataFrame({'make': ['Jaguar', 'MG', 'MINI', 'Rover', 'Lotus'],\
                                    'registration': ['AB98ABCD', 'BC99BCDF', 'CD00CDE', 'DE01DEF', 'EF02EFG'],\
                                    'year': [1998, 1999, 2000, 2001, 2002]})
    # Turn the data into a Spark DataFrame, self.spark comes from our PySparkTest base class
        data_spark = self.spark.createDataFrame(data_pandas)
    # Invoke the unit we’d like to test
        results_spark = self.my_spark_function(data_spark)
     # Turn the results back to Pandas
        results_pandas = results_spark.toPandas()
    # Our expected results crafted by hand, again, this could come from a CSV
     # in case of a bigger example
        expected_results = pd.DataFrame({'make':['Rover', 'Lotus', 'MINI'],
                                         'registration':['DE01DEF','EF02EFG', 'CD00CDE'],
                                         'year':[2001,2002, 2000]})
    # Assert that the 2 results are the same. We’ll cover this function in a bit
        self.assert_frame_equal_with_sort(results_pandas, expected_results, ['registration'])

In [7]:
if __name__ == '__main__':
    unittest.main(argv=[''], exit=False)

DataFrame[make: string, registration: string, year: bigint]


  self._sock = None
  self._sock = None
...
----------------------------------------------------------------------
Ran 4 tests in 19.216s

OK
