In [0]:

import datetime
import random
from pyspark.sql.functions import lit
from dateutil.relativedelta import relativedelta
from pyspark.sql.utils import AnalysisException
from pyspark.sql.types import DoubleType
from pyspark.sql.types import LongType
from pyspark.sql import functions
from pyspark.sql import Row


test_result = []
device_code_dict = {1: {'1': 'Android-phone', '2': 'Android-tablet'}, 2: {'1': 'ios-phone', '2': 'ios-tablet'}}
raw_granularity_dict = {'daily': 'DAY', 'monthly': 'MONTH', 'weekly': 'WEEK'}


def last_day_of_month(check_month):
    next_month = check_month.replace(day=28) + datetime.timedelta(days=4)
    return next_month - datetime.timedelta(days=next_month.day)


def get_monthly_date_list():
    result = []
    end = datetime.date(2020, 02, 29)
    start = datetime.date(2019, 10, 31)
    while start <= end:
        start = last_day_of_month(start)
        month_data_raw = datetime.datetime.strftime(start, '%Y-%m-%d')
        result.append(Row(month_data_raw))
        start += relativedelta(months=1)
    return result


def get_weekly_date_list():
    result = []
    end = datetime.date(2020, 03, 28)
    start = datetime.date(2019, 10, 05)
    while start <= end:
        month_data_raw = datetime.datetime.strftime(start, '%Y-%m-%d')
        result.append(Row(month_data_raw))
        start += relativedelta(weeks=1)
    return result


def get_daily_date_list():
    result = []
    end = datetime.date(2020, 03, 28)
    start = datetime.date(2019, 11, 22)
    while start <= end:
        month_data_raw = datetime.datetime.strftime(start, '%Y-%m-%d')
        result.append(Row(month_data_raw))
        start += relativedelta(days=1)
    return result


def get_path_date_list(granularity):
    date_list = {}
    if granularity == 'daily':
        collect_date = get_daily_date_list()
    if granularity == 'weekly':
        collect_date = get_weekly_date_list()
    if granularity == 'monthly':
        collect_date = get_monthly_date_list()
    for x in collect_date:
        if date_list.has_key(x[0][:7]):
            date_list[x[0][:7]].append(x[0])
        else:
            date_list[x[0][:7]] = [x[0]]
    date_list = sorted(date_list.items(), key=lambda x: datetime.datetime.strptime(x[0] + str(-01), '%Y-%m-%d'),
                        reverse=False)
    return date_list


def check_not_empty(df, date):
    empty_count = df.select('AU').filter("AU is null").count()
    if empty_count != 0:
        print "AU is Not Empty Test Fail!!! empty_count: {}, date: {}".format(empty_count, date)
    else:
        print "AU is Not Empty Test Pass! date: {}".format(date)


def check_percentage_accuracy(df, date):
    illegal_percentage_count = df.select('IP', 'MBWFT', 'OR', 'PAD', 'UP', 'SOI', 'SOU').filter(
        "IP>1 or MBWFT>1 or OR>1 or PAD>1 or UP>1 or SOI>1 or SOU>1").count()
    if illegal_percentage_count != 0:
        print "Percentage<1 Test Fail!!! illegal_percentage_count: {}, date: {}".format(illegal_percentage_count, date)
    else:
        print "Percentage<1 Test Pass! date: {}".format(date)


def check_routine_v1_accuracy(date_list, _granularity):
    v1_path = 's3://b2c-prod-data-pipeline-unified-usage/' \
                   'unified/usage.basic-kpi.v1/fact/granularity={unified_granularity}/date={unified_date}/'
    routine_path = 's3://aardvark-prod-pdx-mdm-to-int/basic_kpi/' \
               'version=v3.0.0/range_type={raw_granularity}/date={raw_date}/'
    for month in date_list:
        sample_index = random.randint(0, len(month[1]) - 1)
        date = month[1][sample_index]
        
        v1_path_parse = v1_path.format(unified_granularity=_granularity, unified_date=date)
        routine_path_parse = routine_path.format(raw_granularity=raw_granularity_dict[_granularity], raw_date=date)
        routine_df = spark.read.parquet(routine_path_parse)
        v1_df = spark.read.parquet(v1_path_parse).drop('_identifier')

        check_not_empty(routine_df, date)
        check_percentage_accuracy(routine_df, date)

        routine_df = (
                routine_df
                .withColumn('device_code', functions.UserDefinedFunction(
                lambda x, y: device_code_dict[x][y])(routine_df['platform'], routine_df['device_type']))
                .withColumnRenamed('country', 'country_code')
                .withColumn('app_id', routine_df['app_id'].cast(LongType()))
                .withColumnRenamed('AU', 'est_average_active_users')
                .withColumnRenamed('AFU', 'est_average_session_per_user')
                .withColumnRenamed('ADU', 'est_average_session_duration')
                .withColumnRenamed('IP', 'est_install_penetration')
                .withColumnRenamed('AAD', 'est_average_active_days')
                .withColumnRenamed('PAD', 'est_percentage_active_days')
                .withColumnRenamed('MBPU', 'est_average_bytes_per_user')
                .withColumnRenamed('ATU', 'est_average_time_per_user')
                .withColumnRenamed('UP', 'est_usage_penetration')
                .withColumnRenamed('OR', 'est_open_rate')
                .withColumnRenamed('MBPS', 'est_average_bytes_per_session')
                .withColumnRenamed('MBWFT', 'est_percent_of_wifi_total')
                .withColumnRenamed('MBS', 'est_mb_per_second')
                .withColumnRenamed('IS', 'est_installs')
                .withColumnRenamed('SOU', 'est_average_active_users_country_share')
                .withColumnRenamed('SOI', 'est_installs_country_share')
                .withColumn('est_share_of_category_time', lit(None).cast(DoubleType()))
                .withColumn('est_share_of_category_session', lit(None).cast(DoubleType()))
                .withColumn('est_share_of_category_bytes', lit(None).cast(DoubleType()))
                .withColumn('est_panel_size', lit(None).cast(DoubleType()))
                .drop('device_type')
                .drop('platform')
        )
        subtract_count = routine_df.select(v1_df.columns).subtract(v1_df).count()
        subtract_count_reverse = v1_df.select(routine_df.columns).subtract(routine_df).count()
        if subtract_count != 0 or subtract_count_reverse != 0:
            print 'Accuracy Test Fail!!!! subtract_count: {}, date={}'.format(
                max(subtract_count, subtract_count_reverse), date)
        else:
            print 'Accuracy Test Pass! date={}'.format(date)


graularity_list = ["daily", "weekly", "monthly"]
for graularity in graularity_list:
    check_routine_v1_accuracy(get_path_date_list(graularity), graularity)
