In [21]:
import apache_beam as beam
import pandas as pd
from apache_beam.options.pipeline_options import PipelineOptions

def compute_aggregations(df):
    arap_sum = df[df['status'] == 'ARAP'].groupby(['legal_entity', 'counter_party', 'tier']).agg({'value': 'sum'}).reset_index().rename(columns={'value': 'sum(value where status=ARAP)'})
    accr_sum = df[df['status'] == 'ACCR'].groupby(['legal_entity', 'counter_party', 'tier']).agg({'value': 'sum'}).reset_index().rename(columns={'value': 'sum(value where status=ACCR)'})
    max_rating = df.groupby(['legal_entity', 'counter_party', 'tier']).agg({'rating': 'max'}).reset_index().rename(columns={'rating': 'max(rating by counterparty)'})
    result = max_rating.merge(arap_sum, on=["legal_entity", "counter_party", "tier"], how="left").merge(accr_sum, on=["legal_entity", "counter_party", "tier"], how="left")
    result["max(rating by counterparty)"] = result["max(rating by counterparty)"].fillna(0)
    result["sum(value where status=ARAP)"] = result["sum(value where status=ARAP)"].fillna(0)
    result["sum(value where status=ACCR)"] = result["sum(value where status=ACCR)"].fillna(0)
    return result

def compute_totals(result):
    legal_entity_total = result.groupby("legal_entity").agg({"max(rating by counterparty)": "sum","sum(value where status=ARAP)": "sum","sum(value where status=ACCR)": "sum",}).reset_index()
    legal_entity_total["counter_party"] = "Total"
    legal_entity_total["tier"] = "Total"

    counter_party_total = result.groupby("counter_party").agg({"max(rating by counterparty)": "sum","sum(value where status=ARAP)": "sum","sum(value where status=ACCR)": "sum",}).reset_index()
    counter_party_total["legal_entity"] = "Total"
    counter_party_total["tier"] = "Total"

    tier_total = result.groupby("tier").agg({"max(rating by counterparty)": "sum","sum(value where status=ARAP)": "sum","sum(value where status=ACCR)": "sum",}).reset_index()
    tier_total["legal_entity"] = "Total"
    tier_total["counter_party"] = "Total"

    result = pd.concat([result, legal_entity_total, counter_party_total, tier_total], ignore_index=True)
    return result

class ComputeAggregations(beam.DoFn):
    def process(self, element):
        df = pd.DataFrame([element])
        return compute_aggregations(df).to_dict('records')

class ComputeTotals(beam.DoFn):
    def process(self, element):
        df = pd.DataFrame([element])
        return compute_totals(df).to_dict('records')
    
    
def save_to_csv(dataframe):
    dataframe.to_csv('../output_data/beam_output.csv', index=False)
    

def df_to_csv_string(dataframe):
    return dataframe.to_csv(index=False)

with beam.Pipeline(options=PipelineOptions()) as p:
    dataset1 = p | 'Read dataset1' >> beam.io.ReadFromText('../input_data/dataset1.csv', skip_header_lines=1)
    dataset2 = p | 'Read dataset2' >> beam.io.ReadFromText('../input_data/dataset2.csv', skip_header_lines=1)

    dataset1_df = dataset1 | 'Convert dataset1 to DF' >> beam.Map(lambda line: pd.read_csv(pd.StringIO(line), header=None, names=["invoice_id", "legal_entity", "counter_party", "rating", "status", "value"]))
    dataset2_df = dataset2 | 'Convert dataset2 to DF' >> beam.Map(lambda line: pd.read_csv(pd.StringIO(line), header=None, names=["counter_party", "tier"]))

    merged_data = ({'d1': dataset1_df, 'd2': dataset2_df}) | 'CoGroupByKey' >> beam.CoGroupByKey() | 'Merge datasets' >> beam.FlatMap(lambda row: pd.merge(pd.DataFrame(row['d1']), pd.DataFrame(row['d2']), on='counter_party', how='left').to_dict('records'))

    aggregated_data = merged_data | 'Compute aggregations' >> beam.ParDo(ComputeAggregations())

    with_totals = aggregated_data | 'Compute totals' >> beam.ParDo(ComputeTotals())

    output = with_totals | 'Convert DF to CSV string' >> beam.Map(df_to_csv_string) | 'Save to CSV' >> beam.io.WriteToText('../output_data/beam_output', file_name_suffix='.csv', header='invoice_id,legal_entity,counter_party,rating,status,value,tier,max(rating by counterparty),sum(value where status=ARAP),sum(value where status=ACCR)', num_shards=1)



