In [1]:
import unittest

import invocations
from sofiproto.moneyfraud import customer_risk_v2_pb2
import pandas as pd
import numpy as np
from io import StringIO
from google.protobuf import descriptor
from read_protobuf import read_protobuf
from google.protobuf.json_format import MessageToDict, ParseDict 

In [None]:
class ProtobufTests(unittest.TestCase):
    def test_receiving_basic_protobuf_returns_healthy_pl(self):
        self.send_basic_request(income_verification_necessary_v2_pb2.IncomeVerificationNecessaryV2Input.AppTypeV2.PL)

    def test_receiving_basic_protobuf_returns_healthy_sl_refi(self):
        self.send_basic_request(income_verification_necessary_v2_pb2.IncomeVerificationNecessaryV2Input.AppTypeV2.REFI)

    def test_receiving_basic_protobuf_returns_healthy_sl_plus(self):
        self.send_basic_request(income_verification_necessary_v2_pb2.IncomeVerificationNecessaryV2Input.AppTypeV2.PLUS)

    def test_receiving_basic_protobuf_returns_healthy_sl_isl(self):
        self.send_basic_request(income_verification_necessary_v2_pb2.IncomeVerificationNecessaryV2Input.AppTypeV2.INSCHOOL)

    def test_pl_csv_match_expected(self):
        self.csv_results_match_expected_results('pl_test_data.csv')

    def test_slr_csv_match_expected(self):
        self.csv_results_match_expected_results('slr_test_data.csv')
    
    def test_isl_csv_match_expected(self):
        self.csv_results_match_expected_results('isl_test_data.csv')

    def test_pl_protobuf_match_csv_expected(self):
        self.protobuf_match_expected_results('pl_test_data.csv', income_verification_necessary_v2_pb2.IncomeVerificationNecessaryV2Input.AppTypeV2.PL)

    def test_sl_refi_protobuf_match_csv_expected(self):
        self.protobuf_match_expected_results('slr_test_data.csv', income_verification_necessary_v2_pb2.IncomeVerificationNecessaryV2Input.AppTypeV2.REFI)

    def test_sl_plus_protobuf_match_csv_expected(self):
        self.protobuf_match_expected_results('slr_test_data.csv', income_verification_necessary_v2_pb2.IncomeVerificationNecessaryV2Input.AppTypeV2.PLUS)

    def test_sl_isl_protobuf_match_csv_expected(self):
        self.protobuf_match_expected_results('isl_test_data.csv', income_verification_necessary_v2_pb2.IncomeVerificationNecessaryV2Input.AppTypeV2.INSCHOOL)

    def test_null_credit_scores(self):
        self.send_basic_request_with_null_credits(income_verification_necessary_v2_pb2.IncomeVerificationNecessaryV2Input.AppTypeV2.PL)

    def csv_results_match_expected_results(self, filename):
        csv_input_df = pd.read_csv(filename, sep=',')
        result, _ = invocations.invoke(csv_input_df.to_csv(header=True), 'text/csv')
        result_df = pd.read_csv(StringIO(result), sep=',', 
                                names=['result_score_auroc', 'result_score_auprc', 'result_score_mlp','result_score', 'result_eligible'])
        result_df['expected_score_auroc'] = csv_input_df['target_pred_int']
        result_df['expected_score_auprc'] = csv_input_df['target_pred_auprc']
        result_df['expected_score_mlp'] = csv_input_df['target_pred_mlp']
        result_df['expected_score'] = csv_input_df['score_gen2']
        result_df['expected_eligible'] = csv_input_df['eligible_gen2']
        for _, row in result_df.iterrows():
            self.assertAlmostEqual(row['expected_score_auroc'], row['result_score_auroc'], delta=.001)
            self.assertAlmostEqual(row['expected_score_auprc'], row['result_score_auprc'], delta=.001)
            self.assertAlmostEqual(np.isnan(row['expected_score_mlp']), np.isnan(row['result_score_mlp']), delta=.001)
            self.assertAlmostEqual(row['expected_score'], row['result_score'], delta=.001)
            self.assertEqual(row['expected_eligible'], row['result_eligible'])

    def send_basic_request(self, type: income_verification_necessary_v2_pb2.IncomeVerificationNecessaryV2Input.AppTypeV2):
        input = income_verification_necessary_v2_pb2.IncomeVerificationNecessaryV2Input()
        input.appType = type
        input_sample = income_verification_necessary_v2_pb2.IncomeVerificationNecessaryV2InputSample()
        self.populateUnsetProtoFields(input_sample)
        input.input.CopyFrom(input_sample)
        invocations.invoke(input.SerializeToString(), 'application/octet-stream')
        # if we got here without error we are happy
        self.assertEqual(True, True)

    def send_basic_request_with_null_credits(self, type: income_verification_necessary_v2_pb2.IncomeVerificationNecessaryV2Input.AppTypeV2):
        input = income_verification_necessary_v2_pb2.IncomeVerificationNecessaryV2Input()
        input.appType = type
        input_sample = income_verification_necessary_v2_pb2.IncomeVerificationNecessaryV2InputSample()
        self.populateUnsetProtoFields(input_sample)
        input.input.CopyFrom(input_sample)
        invocations.invoke(input.SerializeToString(), 'application/octet-stream')
        # if we got here without error we are happy
        self.assertEqual(True, True)

    def protobuf_match_expected_results(self, csv_filename: str, type: income_verification_necessary_v2_pb2.IncomeVerificationNecessaryV2Input.AppTypeV2):
        csv_input_df = pd.read_csv(csv_filename, sep=',')
        input = income_verification_necessary_v2_pb2.IncomeVerificationNecessaryV2Input()
        input.appType = type

        for index, row in csv_input_df.iterrows():
            print(index)
            input_sample = income_verification_necessary_v2_pb2.IncomeVerificationNecessaryV2InputSample()
            for key, value in zip(row.to_dict().keys(), row.to_dict().values()):
                if key in ['appType','target_pred_int', 'target_pred_auprc', 'target_pred_mlp', 'score_gen2', 'eligible_gen2']:
                    continue
                else:
                    if pd.isna(value):
                        continue
                    setattr(input_sample, key, value)
            input.input.CopyFrom(input_sample)
            result, _ = invocations.invoke(input.SerializeToString(), 'application/octet-stream')

            result_proto = income_verification_necessary_v2_pb2.IncomeVerificationNecessaryV2Output().FromString(result)

            self.assertAlmostEqual(row['target_pred_int'], result_proto.output.score_auroc_tuned_model, delta=.001)
            self.assertAlmostEqual(row['target_pred_auprc'], result_proto.output.score_auprc_tuned_model, delta=.001)
            if np.isnan(row['target_pred_mlp']):
                self.assertTrue(np.isnan(result_proto.output.score_mlp_tuned_model))
            else:
                self.assertAlmostEqual(row['target_pred_mlp'], result_proto.output.score_mlp_tuned_model, delta=.001)
            self.assertAlmostEqual(row['score_gen2'], result_proto.output.score, delta=.001)
            self.assertEqual(row['eligible_gen2'], result_proto.output.eligible)
            
    def populateUnsetProtoFields(self, proto):
        for field in proto.DESCRIPTOR.fields:
            if not proto.HasField(field.name):
                if field.type == descriptor.FieldDescriptor.TYPE_BOOL:
                    setattr(proto, field.name, False)
                elif field.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
                    setattr(proto, field.name, 0.0)
                elif field.type == descriptor.FieldDescriptor.TYPE_STRING:
                    setattr(proto, field.name, "")
                elif field.type == descriptor.FieldDescriptor.TYPE_INT32:
                    setattr(proto, field.name, 0)
                else:
                    print('need to handle another field type', field.name, field.type)
        return proto



if __name__ == '__main__':
    unittest.main()
