In [1]:
import bisect
from datetime import datetime, timedelta
import os
import sys
from typing import Iterator, Tuple

import numpy as np
import pandas as pd
from pyspark.sql import types as st
from pyspark.sql import functions as sf
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.storagelevel import StorageLevel

import settings as s

In [2]:
STRIP_IDS = True
CALCULATE_USD_AMOUNT = True

TRX_PARTITIONS = 16

In [3]:
if (
    sys.version_info.major,
    sys.version_info.minor,
    sys.version_info.micro,
) != (3, 9, 19):
    raise EnvironmentError(
        "Only runs efficiently on Python 3.9.19 (Tested on: Conda 24.1.2 | Apple M3 Pro)"
    )

In [4]:
config = [
    ("spark.driver.memory", "32g"),
    ("spark.worker.memory", "32g"),
    ("spark.driver.maxResultSize", "32g"),
    ("spark.sql.execution.arrow.pyspark.enabled", "true"),
]
spark = (
    SparkSession.builder.appName("testing")
    .config(conf=SparkConf().setAll(config))
    .getOrCreate()
)

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/10/04 18:44:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [5]:
data_account = pd.read_csv(s.ACCOUNT_FILE).set_index("Account Number")["Entity ID"].to_dict()
data_account_func = sf.udf(lambda x: data_account[x[:9]], st.StringType())

In [6]:
def get_id(value):
    return f"id-{hash(value)}"

In [7]:
%%time

try:
    os.remove(s.STAGED_DATA_CSV_LOCATION)
except FileNotFoundError:
    pass


mapping = {}
with open(s.DATA_FILE) as in_file:
    cnt = -1
    lines = ""
    for line in in_file:
        cnt += 1
        if cnt == 0:
            continue
        line = line.strip()
        line_id = get_id(line)
        mapping[line_id] = cnt
        lines += f"{cnt},{line}\n"
        if not (cnt % 2e7):
            print(cnt)
            with open(s.STAGED_DATA_CSV_LOCATION, "a") as out_file:
                out_file.write(lines)
                lines = ""
if lines:
    lines = lines.strip()
    with open(s.STAGED_DATA_CSV_LOCATION, "a") as out_file:
        out_file.write(lines)
        del lines

20000000
40000000
60000000
80000000
100000000
120000000
140000000
160000000
CPU times: user 2min 4s, sys: 12.2 s, total: 2min 17s
Wall time: 2min 18s


In [8]:
try:
    os.remove(s.STAGED_PATTERNS_CSV_LOCATION)
except FileNotFoundError:
    pass

lines = ""
with open(s.PATTERNS_FILE) as in_file:
    for line in in_file:
        line = line.strip()
        if line[:4].isnumeric():
            line_id = get_id(line)
            cnt = mapping[line_id]
            lines += f"{cnt},{line}\n"
        else:
            lines += f"{line}\n"

lines = lines.strip()
with open(s.STAGED_PATTERNS_CSV_LOCATION, "a") as out_file:
    out_file.write(lines)
    del lines

In [9]:
del mapping

In [10]:
schema = st.StructType(
    [
        st.StructField("transaction_id", st.IntegerType(), False),
        st.StructField("timestamp", st.TimestampType(), False),
        st.StructField("source_bank", st.StringType(), False),
        st.StructField("source", st.StringType(), False),
        st.StructField("target_bank", st.StringType(), False),
        st.StructField("target", st.StringType(), False),
        st.StructField("received_amount", st.FloatType(), False),
        st.StructField("receiving_currency", st.StringType(), False),
        st.StructField("sent_amount", st.FloatType(), False),
        st.StructField("sending_currency", st.StringType(), False),
        st.StructField("format", st.StringType(), False),
        st.StructField("is_laundering", st.IntegerType(), False),
    ]
)
columns = [x.name for x in schema]

In [11]:
with open(s.STAGED_PATTERNS_CSV_LOCATION, "r") as fl:
    patterns = fl.read()

cases = []
case_id = 0
for pattern in patterns.split("\n\n"):
    case_id += 1
    if not pattern.strip():
        continue
    pattern = pattern.split("\n")
    name = pattern.pop(0).split(" - ")[1]
    category, sub_category = name, name
    if ": " in name:
        category, sub_category = name.split(": ")
    pattern.pop()
    case = pd.DataFrame([x.split(",") for x in pattern], columns=columns)
    case.loc[:, "id"] = case_id
    case.loc[:, "type"] = category.strip().lower()
    case.loc[:, "sub_type"] = sub_category.strip().lower()
    cases.append(case)
cases = pd.concat(cases, ignore_index=True)
cases = spark.createDataFrame(cases)
cases = cases.withColumn("timestamp", sf.to_timestamp("timestamp", s.TIMESTAMP_FORMAT))
cases = cases.select("transaction_id", "id", "type", "sub_type")

In [12]:
CURRENCY_MAPPING = {
    "Australian Dollar": "aud",
    "Bitcoin": "btc",
    "Brazil Real": "brl",
    "Canadian Dollar": "cad",
    "Euro": "eur",
    "Mexican Peso": "mxn",
    "Ruble": "rub",
    "Rupee": "inr",
    "Saudi Riyal": "sar",
    "Shekel": "ils",
    "Swiss Franc": "chf",
    "UK Pound": "gbp",
    "US Dollar": "usd",
    "Yen": "jpy",
    "Yuan": "cny",
}

currency_code = sf.udf(lambda x: CURRENCY_MAPPING[x], st.StringType())

In [13]:
%%time

data = spark.read.csv(
    s.STAGED_DATA_CSV_LOCATION,
    header=False,
    schema=schema,
    timestampFormat=s.TIMESTAMP_FORMAT,
)
group_by = [
    "timestamp",
    "source_bank",
    "source",
    "target_bank",
    "target",
    "receiving_currency",
    "sending_currency",
    "format",
]
data = data.groupby(group_by).agg(
    sf.first("transaction_id").alias("transaction_id"),
    sf.collect_set("transaction_id").alias("transaction_ids"),
    sf.sum("received_amount").alias("received_amount"),
    sf.sum("sent_amount").alias("sent_amount"),
    sf.max("is_laundering").alias("is_laundering"),
)
data = data.withColumn(
    "source_currency", currency_code(sf.col("sending_currency"))
).withColumn(
    "target_currency",
    currency_code(sf.col("receiving_currency")),
)
data = data.join(cases, on="transaction_id", how="left").repartition(
    TRX_PARTITIONS, "transaction_id"
)
data = data.select(
    "transaction_id",
    "transaction_ids",
    "timestamp",
    sf.concat(sf.col("source"), sf.lit("-"), sf.col("source_currency")).alias("source"),
    sf.concat(sf.col("target"), sf.lit("-"), sf.col("target_currency")).alias("target"),
    "source_bank",
    "target_bank",
    "source_currency",
    "target_currency",
    sf.col("sent_amount").alias("source_amount"),
    sf.col("received_amount").alias("target_amount"),
    "format",
    "is_laundering",
)

data = data.withColumn("source_entity", data_account_func(sf.col("source")))
data = data.withColumn("target_entity", data_account_func(sf.col("target")))

data = data.persist(StorageLevel.DISK_ONLY)
data.count()



CPU times: user 591 ms, sys: 203 ms, total: 794 ms
Wall time: 6min 16s


                                                                                

175887982

In [14]:
data.select(sf.explode("transaction_ids")).count() - data.count()

                                                                                

178575

In [15]:
cases_data = (
    cases.join(
        data.withColumnRenamed("transaction_id", "x")
        .drop(*cases.columns)
        .select(sf.explode("transaction_ids").alias("transaction_id"), "*"),
        on="transaction_id",
        how="left",
    )
    .drop("is_laundering", "transaction_id", "transaction_ids")
    .withColumnRenamed("x", "transaction_id")
)
cases_data.toPandas().to_parquet(s.STAGED_CASES_DATA_LOCATION)
cases_data = pd.read_parquet(s.STAGED_CASES_DATA_LOCATION)

                                                                                

In [16]:
currency_rates = {
    "jpy": np.float32(0.009487665410827868),
    "cny": np.float32(0.14930721887033868),
    "cad": np.float32(0.7579775434031815),
    "sar": np.float32(0.2665884611958837),
    "aud": np.float32(0.7078143121927827),
    "ils": np.float32(0.29612081311363503),
    "chf": np.float32(1.0928961554056371),
    "usd": np.float32(1.0),
    "eur": np.float32(1.171783425225877),
    "rub": np.float32(0.012852809604990688),
    "gbp": np.float32(1.2916554735187644),
    "btc": np.float32(11879.132698717296),
    "inr": np.float32(0.013615817231245796),
    "mxn": np.float32(0.047296753463246695),
    "brl": np.float32(0.1771008654705292),
}

@sf.pandas_udf(st.FloatType())
def get_usd_amount(iterator: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.Series]:
    for a, b in iterator:
        yield [currency_rates[x] for x in a] * b

In [17]:
if STRIP_IDS:
    data = data.withColumn("source", sf.substring("source", 1, 8))
    data = data.withColumn("target", sf.substring("target", 1, 8))
if CALCULATE_USD_AMOUNT:
    data = data.withColumn("amount", get_usd_amount("source_currency", "source_amount"))

In [18]:
data.write.parquet(s.STAGED_DATA_LOCATION, mode="overwrite")

                                                                                

In [19]:
data = spark.read.parquet(s.STAGED_DATA_LOCATION)

In [20]:
assert data.count() == data.select("transaction_id").distinct().count()

                                                                                

In [21]:
data.count(), cases_data.shape[0], cases_data["transaction_id"].nunique(), cases_data["id"].nunique()
# (179504480, 137936, 137933, 16467)

(175887982, 19461, 19461, 2218)