In [86]:
import os
import re
import zipfile
import ast
import json
import pandas as pd
from datetime import datetime, timedelta
from tqdm import tqdm
import sys

import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import torch.optim as optim
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from transformers import pipeline


### Loading data

In [175]:
#passport csv file

passport_pd = pd.read_csv('./train/passport_data_with_label.csv')
passport_pd.head(5)


Unnamed: 0,index,first_name,middle_name,last_name,gender,country,country_code,nationality,birth_date,passport_number,passport_mrz,passport_issue_date,passport_expiry_date,label
0,client_0,Freja,Katrine,Christensen,F,Denmark,DNK,Danish,2002-04-18,UE2130779,['P<DNKCHRISTENSEN<<FREJA<KATRINE<<<<<<<<<<<<<...,2017-05-11,2027-05-10,Accept
1,client_1,Thomas,Laurent,Lemaître,M,France,FRA,French,1990-02-07,OT9354543,['P<FRALEMAÎTRE<<THOMAS<LAURENT<<<<<<<<<<<<<<<...,2022-06-28,2032-06-27,Reject
2,client_10,Gauthier,,Bernard,M,France,FRA,French,1974-05-31,XV2857876,['P<FRABERNARD<<GAUTHIER<<<<<<<<<<<<<<<<<<<<<<...,2023-03-18,2033-03-17,Accept
3,client_100,Louna,Ève,Bertrand,F,France,FRA,French,1977-12-25,KI8826467,['P<FRABERTRAND<<LOUNA<ÈVE<<<<<<<<<<<<<<<<<<<<...,2019-11-11,2029-11-10,Accept
4,client_1000,Britt,Daantje,Van Dijk,F,Netherlands,NLD,Dutch,1982-03-13,XA0813292,['P<NLDVAN DIJK<<BRITT<DAANTJE<<<<<<<<<<<<<<<<...,2021-10-01,2031-09-30,Reject


In [176]:
account_form_pd = pd.read_csv('./train/account_form_data_with_label_cleaned.csv')
account_form_pd.head(20)

Unnamed: 0,index,name,first_name,middle_name,last_name,passport_number,currency,country_of_domicile,phone_number,email_address,address.city,address.street_name,address.street_number,address.postal_code,label
0,client_0,Freja Katrine Christensen,Freja,Katrine,Christensen,UE2130779,DKK,Denmark,53 11 20 42,freja.christensen@yousee.dk,Aalborg,Strøget,57.0,2044,Accept
1,client_1,Thomas Laurent Lemaître,Thomas,Laurent,Louis,OT9354543,EUR,France,++4903 52 25 79 49,thomas.lemaitre@yahoo.com,Le Havre,Boulevard Saint-Michel,88.0,63950,Reject
2,client_10,Gauthier Bernard,Gauthier,,Bernard,XV2857876,EUR,France,06 85 81 21 11,gauthier.bernard@numericable.fr,Reims,Rue de la Huchette,36.0,70164,Accept
3,client_100,Louna Ève Bertrand,Louna,Ève,Bertrand,KI8826467,EUR,France,+33 04 18 03 86 04,louna.bertrand@hotmail.com,Montpellier,Rue de Strasbourg,75.0,45774,Accept
4,client_1000,Britt Daantje Van Dijk,Britt,Daantje,Van Dijk,XA0813292,EUR,Netherlands,+31 06 92689079,britt.vandijk@yahoo.com,Eindhoven,Willemsparkweg,80.0,1698 64,Reject
5,client_1001,Francesca Ilaria Ferrari,Francesca,Ilaria,Ferrari,MD8743029,EUR,Italy,+39 342 8847590,francesca.ferrari@tin.it,Siena,Piazza Maggiore,7.0,14045,Reject
6,client_1002,Consuelo Ramos,Consuelo,,Ramos,ZS0877519,EUR,Spain,680 909 777,consuelo.ramos@terra.es,Salamanca,Gran Vía,39.0,77965,Accept
7,client_1003,Sofía Dácil Ramos,Sofía,Dácil,Ramos,YR1788551,EUR,Finland,047 901 94 31,sofia.ramos@yahoo.com,Mikkeli,Puistolanaukio,26.0,21913,Accept
8,client_1004,Schmidt Schwarz Binder,Schmidt,Schwarz,Binder,HK7376680,EUR,Austria,+43 754 060 7626,schmidt.binder@outlook.com,Eisenstadt,Landstraßer Hauptstraße,35.0,6293,Reject
9,client_1005,Gruber Schwarz Hinterleitner,Gruber,Schwarz,Hinterleitner,ZX0148401,EUR,Austria,+43 615 790 1978,gruber.hinterleitner@icloud.com,Klagenfurt,Gasometerstraße,35.0,3046,Reject


In [177]:
description_pd = pd.read_csv('./train/client_description_with_label.csv')

description_pd.head(5)


Unnamed: 0,index,Summary Note,Family Background,Education Background,Occupation History,Wealth Summary,Client Summary,label
0,client_0,Freja Katrine Christensen and the RM were intr...,Freja Katrine Christensen is currently single....,Freja obtained her secondary school diploma fr...,Freja Katrine Christensen is a 22 year old and...,She did not have any savings to invest in fina...,The RM is excited to help Freja navigate the c...,Accept
1,client_1,The RM first encountered Thomas Laurent Lemaît...,Thomas Laurent Lemaître is currently divorced....,Thomas graduated from Lycée International de L...,Thomas Laurent Lemaître is a 35 year old and c...,He managed to save approximately 80000 EUR fro...,Given the client's impressive career history a...,Reject
2,client_10,Gauthier Bernard and the RM met at a financial...,Gauthier Bernard is currently divorced. He doe...,"In 1992, Gauthier graduated from Lycée Gustave...",Gauthier Bernard is a 50 year old Biotech Star...,"Throughout his career, he saved 360000 EUR, in...","Based on the information provided, we are exci...",Accept
3,client_100,The RM has known Louna Ève Bertrand since chil...,Louna Ève Bertrand is currently divorced. Her ...,Louna completed her secondary education at Lyc...,"Having worked for over 24 years, Louna Ève Ber...","Throughout her career, she saved 350000 EUR, i...","In summary, Louna has demonstrated a strong wo...",Accept
4,client_1000,Britt Daantje Van Dijk and the RM crossed path...,Britt Daantje Van Dijk and Schipper have been ...,Britt received her secondary school diploma fr...,"Having worked for over 20 years, Britt Daantje...",She managed to save approximately 210000 EUR f...,"In light of the above, we are optimistic about...",Reject


In [198]:
client_profile_pd = pd.read_csv('./train/client_profile_data_with_label.csv')
client_profile_pd["address.street number"]
client_profile_pd.head(5)

Unnamed: 0,index,name,country_of_domicile,birth_date,nationality,passport_number,passport_issue_date,passport_expiry_date,gender,phone_number,...,address.postal code,secondary_school.name,secondary_school.graduation_year,aum.savings,aum.inheritance,aum.real_estate_value,inheritance_details.relationship,inheritance_details.inheritance year,inheritance_details.profession,label
0,client_0,Freja Katrine Christensen,Denmark,2002-04-18,Danish,UE2130779,2017-05-11,2027-05-10,F,53 11 20 42,...,2044,Holstebro Gymnasium,2022,0,13140000,0,grandmother,2020.0,Oil and Gas Executive,Accept
1,client_1,Thomas Laurent Lemaître,France,1990-02-07,French,OT9354543,2022-06-28,2032-06-27,M,03 52 25 79 49,...,63950,Lycée International de Lyon,2009,80000,2300000,4690000,grandmother,2016.0,Real Estate Developer,Reject
2,client_10,Gauthier Bernard,France,1974-05-31,French,XV2857876,2023-03-18,2033-03-17,M,06 85 81 21 11,...,70164,Lycée Gustave Eiffel,1992,360000,1360000,3200000,father,2006.0,Real Estate Developer,Accept
3,client_100,Louna Ève Bertrand,France,1977-12-25,French,KI8826467,2019-11-11,2029-11-10,F,+33 04 18 03 86 04,...,45774,Lycée Chaptal,1996,350000,2240000,6070000,mother,2014.0,Tech Entrepreneur,Accept
4,client_1000,Britt Daantje Van Dijk,Netherlands,1982-03-13,Dutch,XA0813292,2021-10-01,2031-09-30,F,+31 06 92689079,...,1698 64,Oosterlicht College Groningen,1999,210000,930000,3610000,grandmother,2009.0,Tech Entrepreneur,Reject


### Filtering Rules

In [199]:
def check_mrz(data):
    birth_date = str(data["birth_date"])  # Convert birth_date to string
    # print(birth_date)
    mrz_date = birth_date[2:4] + birth_date[5:7] + birth_date[8:10]

    middle_name = str(data["middle_name"])  # Convert middle_name to string
    if middle_name == '' or middle_name == 'nan' or middle_name == 'None':
        passport_mrz_temp_1 = "P<"+str(data["country_code"])+str(data["last_name"])+"<<"+str(data["first_name"])  # Convert other fields to string
    else:
        passport_mrz_temp_1 = "P<"+str(data["country_code"])+str(data["last_name"])+"<<"+str(data["first_name"])+"<"+middle_name

    passport_mrz_temp_2 = str(data["passport_number"])+str(data["country_code"])+ mrz_date  # Convert other fields to string
    #mrz_arr = ast.literal_eval(data["passport_mrz"])
    mrz_arr = eval(data["passport_mrz"])
    #print(mrz_arr[0], mrz_arr[1])


    passport_mrz_1 = mrz_arr[0].lower().rstrip('<')
    passport_mrz_2 = mrz_arr[1].lower().rstrip('<')


    if passport_mrz_1 != passport_mrz_temp_1.lower() or passport_mrz_2 != passport_mrz_temp_2.lower() :

        return True

    return False

In [200]:
def check_name(account_per, profile_per, passport_per):
    # Conditionally include middle name
    passport_name = str(passport_per["first_name"])
    if not pd.isnull(passport_per["middle_name"]) and passport_per["middle_name"] != "":  # Check for null or empty middle name
        passport_name += str(passport_per["middle_name"])
    passport_name += str(passport_per["last_name"])
    passport_name = passport_name.replace(" ", "")

    account_name_1 = account_per["name"].replace(" ", "")

    account_name_2 = str(passport_per["first_name"])
    if not pd.isnull(passport_per["middle_name"]) and passport_per["middle_name"] != "":  # Check for null or empty middle name
        account_name_2 += str(passport_per["middle_name"])
    account_name_2 += str(passport_per["last_name"])
    account_name_2 = account_name_2.replace(" ", "")

    profile_name = profile_per["name"].replace(" ", "")

    names = [passport_name, account_name_1, account_name_2, profile_name]

    if len(set(names)) != 1:
        return True
    return False

In [201]:
def check_phonenumber(account_per, profile_per):
    phone_number = account_per["phone_number"]
    if pd.isnull(phone_number) or phone_number == "":
        return True
    if phone_number.startswith('++'):
        return True
    if phone_number != profile_per["phone_number"]:
        return True
    return False

In [202]:
def check_address(account_per, profile_per):
    postal_code = account_per["address.postal_code"]
    if pd.isnull(postal_code) or postal_code == "" or account_per["address.street_number"] != profile_per["address.street number"] or account_per["address.city"] != profile_per["address.city"] or account_per["address.street_name"] != profile_per["address.street name"]:
        return True

    return False

In [203]:
def check_passport_dates(df):
    # Corrected the function to handle the case where it receives a single row as input.
    if isinstance(df, pd.Series):
        df = df.to_frame().T  # Convert the Series into a DataFrame
    # print(df['passport_issue_date'])
    # Convert date columns to datetime objects if they aren't already
    df['passport_issue_date'] = pd.to_datetime(df['passport_issue_date'], errors='coerce')
    df['passport_expiry_date'] = pd.to_datetime(df['passport_expiry_date'], errors='coerce')
    # Get today's date and calculate 10 years ago
    today = datetime.today()
    ten_years_ago = today - timedelta(days=10*365)
    # Check for invalid dates (NaT - Not a Time)
    invalid_issue_date = df['passport_issue_date'].isnull()
    invalid_expiry_date = df['passport_expiry_date'].isnull()
    # print(invalid_issue_date, "dd", invalid_expiry_date)
    # Check for invalid date order (expiry date should be after issue date)
    invalid_date_order = df['passport_expiry_date'] < df['passport_issue_date']
    # Check if expiry date is older than 10 years
    expiry_date_older_than_10_years = df['passport_expiry_date'] < ten_years_ago
    # Duration between issue and expiry dates
    duration = df['passport_expiry_date'] - df['passport_issue_date']
    # Check if duration is less than 731 days
    duration_less_than_730_days = duration < pd.Timedelta(days=731)
    # Check if duration is more than 3652 days
    duration_more_than_3652_days = duration > pd.Timedelta(days=3652)
    # Check birth before issue date
    birth_date = pd.to_datetime(df['birth_date'], errors='coerce')
    birth_before_issue_date = df['passport_issue_date'] < birth_date
    if any(invalid_issue_date | invalid_expiry_date | invalid_date_order | expiry_date_older_than_10_years | duration_less_than_730_days | duration_more_than_3652_days | birth_before_issue_date):
        return True
    return False

In [204]:
def check_email(account_per, profile_per):
    if account_per["email_address"] != profile_per["email_address"]:
        return True
    return False

In [205]:
def check_nationality(passport_per, account_per, profile_per):
    if passport_per["nationality"] != profile_per["nationality"]:
        return True
    return False

check_email(account_form_pd.iloc[1], client_profile_pd.iloc[1])

False

In [206]:
def check_country_of_domicile(account_per, profile_per):
    if account_per["country_of_domicile"] != profile_per["country_of_domicile"]:
        return True
    return False

# check_country_of_domicile(account_form_pd.iloc[5], client_profile_pd.iloc[1])

In [207]:
def check_birthdate(passport_per, profile_per):
    if passport_per["birth_date"] != profile_per["birth_date"]:
        return True
    return False

In [208]:
def check_currency(account_per, profile_per):
    valid_currencies = ['CHF','DKK','EUR']
    if account_per["currency"] not in valid_currencies or profile_per["currency"] not in valid_currencies:
        return True
    if account_per["currency"] != profile_per["currency"]:
        return True
    return False

In [209]:
def check_gender(passport_per, profile_per):
    if passport_per["gender"] != profile_per["gender"]:
        return True
    return False

In [210]:
def check_real_estate_value(profile_per):
    aum_value = profile_per["aum.real_estate_value"]
    real_estate_details = profile_per["real_estate_details"]

    # Handle missing values
    if pd.isnull(aum_value) or pd.isnull(real_estate_details):
        return True  # Invalid if either value is missing

    try:
        property_values = [prop["property value"] for prop in eval(real_estate_details)]
        total_real_estate_value = sum(property_values)

        # Check if total real estate value is not equal to the aum_value
        if total_real_estate_value != aum_value:
            return True  # Invalid if real estate value is greater

    except (TypeError, KeyError, ValueError):
        return True  # Invalid if data is not in the expected format

    return False

In [211]:
def check_passport(passport_per, account_per, profile_per):
    if passport_per["passport_number"] != account_per["passport_number"] or passport_per["passport_number"] != profile_per["passport_number"]:
        return True
    if  passport_per["passport_issue_date"] != profile_per["passport_issue_date"]:
        return True
    if  passport_per["passport_expiry_date"] != profile_per["passport_expiry_date"]:
        return True
    return False

In [212]:
def check_multiple_countries(account_per, profile_per):
    # Check if there's more than one country of domicile
    account_mult = account_per["country_of_domicile"]
    profile_mult = profile_per["country_of_domicile"]

    if pd.isnull(account_mult) or pd.isnull(profile_mult):
        return True
    # Split the country strings into lists
    account_countries = [country.strip() for country in account_mult.split(",")]
    profile_countries = [country.strip() for country in profile_mult.split(",")]

    # Check if the lengths of the lists are greater than 1
    if len(account_countries) > 1 or len(profile_countries) > 1:
        return True  # Invalid if more than one country is found

    return False

In [213]:
# Postal code patterns for European countries
europe_postal_code_patterns = {
    'Austria': r'^\d{4}$',
    'Belgium': r'^\d{4}$',
    'Bulgaria': r'^\d{4}$',
    'Switzerland': r'^\d{4}$',
    'Czech Republic': r'^\d{3}\s?\d{2}$',
    'Germany': r'^\d{5}$',
    'Denmark': r'^\d{4}$',
    'Estonia': r'^\d{5}$',
    'Spain': r'^\d{5}$',
    'Finland': r'^\d{5}$',
    'France': r'^\d{5}$',
    'Greece': r'^\d{3}\s?\d{2}$',
    'Croatia': r'^\d{5}$',
    'Hungary': r'^\d{4}$',
    'Ireland': r'^[A-Za-z]\d[\w\d]? ?\d[A-Za-z]{2}$',
    'Italy': r'^\d{5}$',
    'Lithuania': r'^LT-\d{5}$|^\d{5}$',
    'Luxembourg': r'^\d{4}$',
    'Latvia': r'^LV-\d{4}$|^\d{4}$',
    'Malta': r'^[A-Z]{3}\s?\d{4}$',
    'Netherlands': r'^\d{6}$',
    'Norway': r'^\d{4}$',
    'Poland': r'^\d{2}-\d{3}$',
    'Portugal': r'^\d{4}-\d{3}$',
    'Romania': r'^\d{6}$',
    'Sweden': r'^\d{3}\s?\d{2}$',
    'Slovenia': r'^\d{4}$',
    'Slovakia': r'^\d{3}\s?\d{2}$',
    'United Kingdom': r'^[A-Z]{1,2}\d[A-Z\d]?\s?\d[A-Z]{2}$',
}

def check_europe_postal_code(account_per, profile_per):
    country = account_per["country_of_domicile"]
    postal_code = str(account_per["address.postal_code"]).replace(" ", "")
    country_pr = profile_per["country_of_domicile"]
    postal_code_pr = str(profile_per["address.postal code"]).replace(" ", "")
    label_account = account_per["label"]
    if pd.isnull(country) or pd.isnull(postal_code) :
        return True
    pattern = europe_postal_code_patterns.get(country)
    if not pattern:
        print(f"No postal code format available for: {country}")
        return True
    if re.match(pattern, postal_code.strip()) is None :
        print(f"Invalid postal code for {country}: {postal_code}")
        return True

    if pd.isnull(country_pr) or pd.isnull(postal_code_pr) :
        return True
    pattern = europe_postal_code_patterns.get(country)
    if not pattern:
        print(f"No postal code format available for: {country_pr}")
        return True
    if re.match(pattern, postal_code_pr.strip()) is None :
        print(f"Invalid postal code for {country_pr}: {postal_code_pr}")
        return True
    # else:
    #     print(f"Invalid postal code for {country}: {postal_code}")
    return False

In [53]:
def get_description_info(data, model):
    Summary_Note = data["Summary Note"]
    Family_Background = data["Family Background"]
    Education_Background=data["Education Background"]
    Occupation_History = data["Occupation History"]
    Wealth_Summary = data["Wealth Summary"]
    Client_Summary = data["Client Summary"]
    kyc_text = f'''
        "Summary Note" : "{Summary_Note}",
        "Family Background": "{Family_Background}",
        "Education Backgroun": "{Education_Background}",
        "Occupation History": "{Occupation_History}",
        "Wealth Summary": "{Wealth_Summary}",
        "Client Summary": "{Client_Summary}",
       
    '''  

    prompt = f"""
    Extract the following information from the client description below:
    - Full name
    - Marital_status
    - Age
    - Nationality
    - Job title
    - Company
    - List of real estate assets (type, location, value)
    - Total savings

    Text: {kyc_text}
    """

    extracted = model(prompt, max_new_tokens=512)[0]["generated_text"]
    return extracted

def parse_kyc_info(text):
    pattern = r"-\s*(.*?)\s+is\s+(.*?)(?=-|$)"
    matches = re.findall(pattern, text)

    result = {}
    for key, value in matches:
        clean_key = key.strip().lower().replace(" ", "_")
        result[clean_key] = value.strip()

    if "list of real estate assets (type, location, value)" in text:
        result["real_estate_assets"] = None

    return result
def check_description(index, model):

    profile_per = client_profile_pd.loc[index]
    profile_marital_status = profile_per["marital_status"]
    profile_nationality = profile_per["nationality"]
    profile_birth_date = profile_per["birth_date"]

    birth_date = datetime.strptime(profile_birth_date, "%Y-%m-%d")
    today = datetime.today()
    profile_age = today.year - birth_date.year - ((today.month, today.day) < (birth_date.month, birth_date.day))

    profile_savings =  profile_per["aum.savings"]
    profile_inheritance = profile_per["aum.inheritance"]
    profile_real_estate_value = profile_per["aum.real_estate_value"]
    profile_total = profile_savings+profile_inheritance+profile_real_estate_value
    
    description_per = description_pd.loc[index]
    info = get_description_info(description_per, model)
    info_dict = parse_kyc_info(info)

    marital_status = info_dict.get("marital_status", None)
    nationality = info_dict.get("nationality", None)
    age = info_dict.get("age", None)
    total_savings = info_dict.get("list_of_real_estate_assets_(type,_location,_value)-_total_savings", None)

    try:
        if total_savings != None:
            match = re.search(r'([\d,]+)', total_savings)
            number_str = match.group(1).replace(",", "")
            total_savings = float(number_str)
        age = int(age)
    except (ValueError, TypeError):
        pass
    
    try:
        profile_savings = float(profile_savings)
        profile_inheritance = float(profile_inheritance)
        profile_real_estate_value = float(profile_real_estate_value)
        profile_total = float(profile_total)
    except (ValueError, TypeError):
        pass

    if marital_status != profile_marital_status and marital_status!= None:
        print(marital_status)
        print(profile_marital_status)
        return True
    if nationality != profile_nationality and nationality!= None:
        print(nationality)
        print(profile_nationality)
        return True
    if total_savings != profile_total and total_savings!= None and total_savings > 1000:
        print(total_savings)
        print(profile_total)
        return True
    if age != profile_age and age!= None:
        print(age)
        print(profile_age)
        return True




### Filtering Data

In [214]:
def validate_data(passport_pd, account_form_pd, client_profile_pd):
    rejected_ids = []
    labels = []

    for index in account_form_pd.index:
        passport_per = passport_pd.loc[index]
        account_per = account_form_pd.loc[index]
        profile_per = client_profile_pd.loc[index]

        is_rejected = False
        client_id = passport_per['index']
        # label = passport_per['label']

        if check_passport(passport_per, account_per, profile_per):
            print("Sorry! You are being rejected because the passport number is wrong.", client_id)
            rejected_ids.append(client_id)
            # labels.append(label)
            # print(client_id, label)
            continue

        if check_name(account_per, profile_per, passport_per):
            print("Sorry! You are being rejected because the name is wrong.", client_id)
            rejected_ids.append(client_id)
            # labels.append(label)
            # print(client_id, label)
            continue

        if check_mrz(passport_per):
            print("Sorry! You are being rejected because the Passport MRZ is wrong.", client_id)
            rejected_ids.append(client_id)
            # labels.append(label)
            continue

        if check_phonenumber(account_per, profile_per):
            print("Sorry! You are being rejected because the phone number is wrong.", client_id)
            rejected_ids.append(client_id)
            # labels.append(label)
            continue

        if check_passport_dates(passport_per):  # Changed from account_form_pd
            print("Sorry! You are being rejected because the passport dates are wrong.", client_id)
            rejected_ids.append(client_id)
            # labels.append(label)
            continue

        if check_address(account_per, profile_per):
            print("Sorry! You are being rejected because the postal code is wrong.", client_id)
            rejected_ids.append(client_id)
            # labels.append(label)
            continue

        if check_email(account_per, profile_per):
            print("Sorry! You are being rejected because the email id is wrong.", client_id)
            rejected_ids.append(client_id)
            # labels.append(label)
            continue

        if check_nationality(passport_per, account_per, profile_per):
            print("Sorry! You are being rejected because the nationality is wrong.", client_id)
            rejected_ids.append(client_id)
            # labels.append(label)
            continue

        if check_country_of_domicile(account_per, profile_per):
            print("Sorry! You are being rejected because the country of domicile is wrong.", client_id)
            rejected_ids.append(client_id)
            # labels.append(label)
            continue

        if check_birthdate(passport_per, profile_per):
            print("Sorry! You are being rejected because the birthdate is wrong.", client_id)
            rejected_ids.append(client_id)
            # labels.append(label)
            continue

        if check_currency(account_per, profile_per):
            print("Sorry! You are being rejected because the currency is wrong.", client_id)
            rejected_ids.append(client_id)
            # labels.append(label)
            continue

        if check_gender(passport_per, profile_per):
            print("Sorry! You are being rejected because the gender is wrong.", client_id)
            rejected_ids.append(client_id)
            # labels.append(label)
            continue

        if check_real_estate_value(profile_per):
            print("Sorry! You are being rejected because the real estate value is wrong.", client_id)
            rejected_ids.append(client_id)
            # labels.append(label)
            continue
        if check_multiple_countries(account_per, profile_per):
            print("Sorry! You are being rejected because the country of domicile is wrong.", client_id)
            rejected_ids.append(client_id)
            # labels.append(label)
            # print(client_id, label)
            continue
        if check_europe_postal_code(account_per, profile_per):
            print("Sorry! You are being rejected because the postal code is wrong.", client_id)
            rejected_ids.append(client_id)
            # labels.append(label)
            # print(client_id, label)
            continue
        # if check_description(index, LLM_Model):
        #     print("Sorry! You are being rejected because the descrition", client_id)
        #     rejected_ids.append(client_id)
        #     labels.append(label)
        #     print(client_id, label)
        #     continue
        


    return rejected_ids,labels


In [194]:
LLM_Model = pipeline("text2text-generation", model="google/flan-t5-large")

In [216]:

reje_ids,labels = validate_data(passport_pd, account_form_pd, client_profile_pd)

len(reje_ids)

Sorry! You are being rejected because the phone number is wrong. client_1
Sorry! You are being rejected because the name is wrong. client_1004
Sorry! You are being rejected because the Passport MRZ is wrong. client_1005
Invalid postal code for Switzerland: nan
Sorry! You are being rejected because the postal code is wrong. client_1016
Sorry! You are being rejected because the passport number is wrong. client_1017
Sorry! You are being rejected because the phone number is wrong. client_1018
Sorry! You are being rejected because the Passport MRZ is wrong. client_1029
Sorry! You are being rejected because the Passport MRZ is wrong. client_1031
Sorry! You are being rejected because the name is wrong. client_1033
Sorry! You are being rejected because the Passport MRZ is wrong. client_1038
Sorry! You are being rejected because the passport dates are wrong. client_1047
Sorry! You are being rejected because the Passport MRZ is wrong. client_1049
Sorry! You are being rejected because the Passpor

2874

### Training MLP

In [217]:
currency_rate = {"CHF": 1, "EUR": 0.95738, "DKK":0.128}
investment_experience = {'Experienced':0, 'Expert':1, 'Inexperienced':2}
type_of_mandate = {'Discretionary':0, 'Advisory':1, 'Hybrid':2, 'Execution-Only':3}
preferred_markets = {'Germany':0, 'Austria':1, 'France':2, 'Finland':3, 'Italy':4, 'Belgium':5, 'Spain':6, 'Switzerland':7, 'Netherlands':8, 'Denmark':9}
investment_risk_profile = {'Low':0, 'Considerable':1, 'Aggressive':2, 'Conservative':3, 'Moderate':4, 'High':5, 'Balanced':6}
cuurency_list = ["CHF", "EUR", "DKK"]

In [218]:
client_profile_filtered = client_profile_pd[~client_profile_pd["index"].isin(reje_ids)]
passport_filtered = passport_pd[~passport_pd["index"].isin(reje_ids)]
account_form_filtered = account_form_pd[~account_form_pd["index"].isin(reje_ids)]

print(len(client_profile_filtered))

7126


In [219]:
def calculate_property(data, currency_rate):
    money_data = []
    currency = data["currency"]
    
    rate = currency_rate[currency]

    savings = data["aum.savings"] * rate
    inheritance = data["aum.inheritance"] * rate
    real_estate_value = data["aum.real_estate_value"] * rate

    total_money = savings + inheritance + real_estate_value
    return [savings,inheritance,real_estate_value,total_money]

def get_investment_info(data):
    raw_experience = data.get("investment_experience")
    raw_markets = ast.literal_eval(data.get("preferred_markets", []))
    raw_mandate = data.get("type_of_mandate")
    raw_profile = data.get("investment_risk_profile")

    experience = investment_experience[raw_experience] if pd.notnull(raw_experience) else None
    markets = preferred_markets[raw_markets[0]] if pd.notnull(raw_markets[0]) else None
    mandate = type_of_mandate[raw_mandate] if pd.notnull(raw_mandate) else None
    profile = investment_risk_profile[raw_profile] if pd.notnull(raw_profile) else None

    if experience == None or markets==None or mandate==None or profile==None:
        return None
    return [experience,markets,mandate,profile]

In [220]:
money_list = []
label_list = []
for index in client_profile_filtered.index:
    data = client_profile_filtered.loc[index]
    currency = data["currency"]
    if currency not in cuurency_list:
       continue 
    money_per = calculate_property(data,currency_rate)
    investment_per = get_investment_info(data)
    if investment_per == None:
        continue
    label_per = data["label"]
    money_per = money_per + investment_per
    money_per.append(label_per)
    money_list.append(money_per)

In [221]:
print(len(money_list))

7074


In [222]:
df = pd.DataFrame(money_list, columns=['saving', 'inheritance', 'real_estate_value', 'total','investment_experience','preferred_markets','type_of_mandate','investment_risk_profile', 'label'])


In [223]:
class SimpleModel(nn.Module):
    def __init__(self, input_dim, num_classes, dropout_rate=0.2):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 16)
        self.fc4 = nn.Linear(16, num_classes)

        self.dropout = nn.Dropout(p=dropout_rate)
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.relu(self.fc3(x))
        x = self.fc4(x)
        return x

In [224]:
# 'investment_experience','preferred_markets','type_of_mandate','investment_risk_profile'
df_subset = df[['saving', 'real_estate_value','total','investment_experience','investment_risk_profile','type_of_mandate','label']]

In [228]:
X = df_subset.drop('label', axis=1)  # 特征数据
y = df_subset['label'] 

# X_encoded = pd.get_dummies(X, columns=['investment_experience','investment_risk_profile','type_of_mandate'])
# print(X_encoded)

le = LabelEncoder()
y_encoded = le.fit_transform(y)

X_train, X_test, y_train, y_test = train_test_split(
    X , y_encoded, test_size=0.2, random_state=42
)

In [229]:
X_train_tensor = torch.tensor(X_train.values, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_test_tensor = torch.tensor(X_test.values, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)

# 6. 构建 TensorDataset 和 DataLoader
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [230]:
input_dim = X_train_tensor.shape[1]
num_classes = len(torch.unique(y_train_tensor))  # 或者使用 len(le.classes_) 如果用 LabelEncoder
print(input_dim)
print(num_classes)
model = SimpleModel(input_dim, num_classes)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

6
2


In [231]:
num_epochs = 100
for epoch in range(num_epochs):
    model.train()  
    running_loss = 0.0
    for batch_inputs, batch_labels in train_loader:
        optimizer.zero_grad()            
        outputs = model(batch_inputs)      
        loss = criterion(outputs, batch_labels)  
        loss.backward()                    
        optimizer.step()                   
        
        running_loss += loss.item()
        
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

Epoch 1/100, Loss: 3879.1872
Epoch 2/100, Loss: 260.1020
Epoch 3/100, Loss: 116.3071
Epoch 4/100, Loss: 39.3912
Epoch 5/100, Loss: 25.4093
Epoch 6/100, Loss: 18.6515
Epoch 7/100, Loss: 13.1724
Epoch 8/100, Loss: 7.4005
Epoch 9/100, Loss: 6.9114
Epoch 10/100, Loss: 3.4034
Epoch 11/100, Loss: 5.3771
Epoch 12/100, Loss: 4.2848
Epoch 13/100, Loss: 1.2692
Epoch 14/100, Loss: 0.7695
Epoch 15/100, Loss: 1.0007
Epoch 16/100, Loss: 1.4026
Epoch 17/100, Loss: 0.6557
Epoch 18/100, Loss: 1.0943
Epoch 19/100, Loss: 0.6869
Epoch 20/100, Loss: 0.7065
Epoch 21/100, Loss: 2.1086
Epoch 22/100, Loss: 2.1505
Epoch 23/100, Loss: 0.6346
Epoch 24/100, Loss: 2.0263
Epoch 25/100, Loss: 0.8760
Epoch 26/100, Loss: 0.6997
Epoch 27/100, Loss: 2.4623
Epoch 28/100, Loss: 0.7620
Epoch 29/100, Loss: 0.6484
Epoch 30/100, Loss: 1.8522
Epoch 31/100, Loss: 0.6924
Epoch 32/100, Loss: 0.6220
Epoch 33/100, Loss: 0.7319
Epoch 34/100, Loss: 0.9991
Epoch 35/100, Loss: 1.1217
Epoch 36/100, Loss: 0.9906
Epoch 37/100, Loss: 0.6435

In [232]:
correct = 0
total = 0
model.eval()  # 切换到评估模式

with torch.no_grad():
    for batch_inputs, batch_labels in test_loader:
        outputs = model(batch_inputs)  # 输出形状 (batch_size, 2)
        predicted = torch.argmax(outputs, dim=1)
        total += batch_labels.size(0)
        correct += (predicted == batch_labels).sum().item()

accuracy = correct / total
print(f"Test Accuracy: {accuracy:.4f}")
# Test Accuracy: 0.7132

Test Accuracy: 0.7145


### Validation

In [92]:
# DIRECTORY WHERE THE DATA IS STORED, IF CURRENT DIR, USE "./"
directory = "./test"

def extract_zip(zip_path, extract_to,log=False):
    """Extracts the contents of a zip file to a specified directory."""
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
    if log:
        print(f"Extracted {zip_path} to {extract_to}")
# Get zip files in the directory

if not os.path.exists(directory):
    print(f"Directory {directory} does not exist.")
    sys.exit(1)


zip_files = [f for f in os.listdir(directory) if f.endswith('.zip')]
zip_files = [os.path.join(directory, f) for f in zip_files]

if not zip_files:
    print("No zip files found in the current directory.")

else:
    # Create a directory to extract files
    extract_dir = os.path.join(directory, "extracted_files")
    if not os.path.exists(extract_dir):
        os.makedirs(extract_dir)

    # Extract each zip file
    for zip_file in zip_files:
        extract_zip(zip_file, extract_dir,log=True)

    print(f"All zip files have been extracted to {extract_dir}.")

# Extract all zip files in extracted_files

extracted_files = os.listdir(extract_dir)
for file in tqdm(extracted_files):
    if file.endswith('.zip'):
        file_path = os.path.join(extract_dir, file)
        # Create a subdirectory for each zip file
        sub_dir = os.path.join(extract_dir, os.path.splitext(file)[0])
        if not os.path.exists(sub_dir):
            os.makedirs(sub_dir)
        extract_zip(file_path, sub_dir,log=False)
        os.remove(file_path)

        
print("All zip files in extracted_files have been extracted and removed.")

Extracted ./test/datathon_evaluation.zip to ./test/extracted_files
All zip files have been extracted to ./test/extracted_files.


100%|██████████| 1000/1000 [00:06<00:00, 145.04it/s]

All zip files in extracted_files have been extracted and removed.





In [233]:
# Code to load the files, process them and store as csv

# Extract_dir contains all the extracted files in separate folders per client
extract_dir = os.path.join(directory, "extracted_files")

# Get all the folders in the extracted_files directory
folders = [f for f in os.listdir(extract_dir) if os.path.isdir(os.path.join(extract_dir, f))]

# Create a dict for passport data
passport_data_dict = {}

# Create a dict for account form data
account_form_data_dict = {}

# Create a dict for client profile data
client_profile_data_dict = {}

# Create a dict for labels
label_dict = {}

# Create a dict for client description
client_description_dict = {}

print(extract_dir)


# Read all the files in the folders
for client_id in tqdm(folders):

    client_dir = os.path.join(extract_dir, client_id)

    passport_data_file = os.path.join(client_dir, 'passport.json')
    account_form_data_file = os.path.join(client_dir, 'account_form.json')
    client_profile_data_file = os.path.join(client_dir, 'client_profile.json')
    client_description_file = os.path.join(client_dir, 'client_description.json')
    label_dict_file = os.path.join(client_dir, 'label.json')

    # Passport data load
    if os.path.exists(passport_data_file):
        with open(passport_data_file, 'r') as f:
            passport_data = json.load(f)
            passport_data_dict[client_id] = passport_data
    else:
        print(f"Passport data file not found for client {client_id}")
    # Account form data load
    if os.path.exists(account_form_data_file):
        with open(account_form_data_file, 'r') as f:
            account_form_data = json.load(f)
            account_form_data_dict[client_id] = account_form_data
    else:
        print(f"Account form data file not found for client {client_id}")
    # Client profile data load
    if os.path.exists(client_profile_data_file):
        with open(client_profile_data_file, 'r') as f:
            client_profile_data = json.load(f)
            client_profile_data_dict[client_id] = client_profile_data
    else:
        print(f"Client profile data file not found for client {client_id}")

    # Client description data load

    if os.path.exists(client_description_file):
        with open(client_description_file, 'r') as f:
            client_description = json.load(f)
            client_description_dict[client_id] = client_description
    else:
        print(f"Client description data file not found for client {client_id}")
# Convert the dicts to dataframes with dictionary key as a column called client_id
passport_data_df = pd.DataFrame.from_dict(passport_data_dict, orient='index').reset_index()
account_form_data_df = pd.DataFrame.from_dict(account_form_data_dict, orient='index').reset_index()
client_profile_data_df = pd.DataFrame.from_dict(client_profile_data_dict, orient='index').reset_index()
client_description_df = pd.DataFrame.from_dict(client_description_dict, orient='index').reset_index()

# Flatten the dataframe columns which have nested json objects
passport_data_df = pd.json_normalize(passport_data_df.to_dict(orient='records'))
account_form_data_df = pd.json_normalize(account_form_data_df.to_dict(orient='records'))
client_profile_data_df = pd.json_normalize(client_profile_data_df.to_dict(orient='records'))
client_description_df = pd.json_normalize(client_description_df.to_dict(orient='records'))

# Save the dataframes to csv files
passport_data_df.to_csv('passport_data_with_label.csv', index=False)
account_form_data_df.to_csv('account_form_data_with_label.csv', index=False)
client_profile_data_df.to_csv('client_profile_data_with_label.csv', index=False)
client_description_df.to_csv('client_description_with_label.csv', index=False)

# Clean the account_form_data removing passport_number in list form
account_form_data_df['passport_number'] = account_form_data_df['passport_number'].apply(lambda x: x[0] if isinstance(x, list) else x)

# Replace columns with spaces with underscores
account_form_data_df.columns = account_form_data_df.columns.str.replace(' ', '_')

# Save cleaned csv file
account_form_data_df.to_csv('account_form_data_with_label_cleaned.csv', index=False)


    

./test/extracted_files


  5%|▌         | 53/1000 [00:00<00:01, 526.83it/s]

Passport data file not found for client client_14
Account form data file not found for client client_14
Client profile data file not found for client client_14
Client description data file not found for client client_14


 36%|███▌      | 361/1000 [00:00<00:01, 618.93it/s]

Passport data file not found for client client_452
Account form data file not found for client client_452
Client profile data file not found for client client_452
Client description data file not found for client client_452
Passport data file not found for client client_485
Account form data file not found for client client_485
Client profile data file not found for client client_485
Client description data file not found for client client_485


 60%|█████▉    | 598/1000 [00:00<00:00, 704.44it/s]

Passport data file not found for client client_965
Account form data file not found for client client_965
Client profile data file not found for client client_965
Client description data file not found for client client_965


 85%|████████▌ | 853/1000 [00:01<00:00, 797.19it/s]

Passport data file not found for client client_690
Account form data file not found for client client_690
Client profile data file not found for client client_690
Client description data file not found for client client_690


100%|██████████| 1000/1000 [00:01<00:00, 692.59it/s]


Passport data file not found for client client_994
Account form data file not found for client client_994
Client profile data file not found for client client_994
Client description data file not found for client client_994


In [234]:
passport_pd = pd.read_csv('./passport_data_with_label.csv')
passport_pd['label'] = None
passport_pd.head(5)

Unnamed: 0,index,first_name,middle_name,last_name,gender,country,country_code,nationality,birth_date,passport_number,passport_mrz,passport_issue_date,passport_expiry_date,label
0,client_529,Hannelore,Joyce,De Jong,F,Netherlands,NLD,Dutch,1957-03-26,MU3505030,['P<NLDDE JONG<<HANNELORE<JOYCE<<<<<<<<<<<<<<<...,2023-02-28,2033-02-27,
1,client_25,Imke,Henriette,Zimmermann,F,Germany,DEU,German,1990-01-29,JZ6792795,['P<DEUZIMMERMANN<<IMKE<HENRIETTE<<<<<<<<<<<<<...,2018-07-08,2028-07-07,
2,client_971,Olivia,Romane,Peiren,F,Belgium,BEL,Belgian,1988-12-17,KW7402943,['P<BELPEIREN<<OLIVIA<ROMANE<<<<<<<<<<<<<<<<<<...,2022-10-28,2029-10-27,
3,client_985,Rossi,Silvestri,Donati,M,Italy,ITA,Italian,1959-09-05,CF8312708,['P<ITADONATI<<ROSSI<SILVESTRI<<<<<<<<<<<<<<<<...,2022-08-12,2032-08-11,
4,client_172,Belli,Landi,Donati,M,Italy,ITA,Italian,1978-10-01,VK0360169,['P<ITADONATI<<BELLI<LANDI<<<<<<<<<<<<<<<<<<<<...,2021-12-03,2031-12-02,


In [235]:
account_form_pd = pd.read_csv('./account_form_data_with_label_cleaned.csv')
account_form_pd['label'] = None
account_form_pd.head(20)

Unnamed: 0,index,name,first_name,middle_name,last_name,passport_number,currency,country_of_domicile,phone_number,email_address,address.city,address.street_name,address.street_number,address.postal_code,label
0,client_529,Hannelore Joyce De Jong,Hannelore,Joyce,De Jong,MU3505030,EUR,Netherlands,+31 06 84841712,hannelore.dejong@hotspots.nl,Oss,Markt,62.0,6676 80,
1,client_25,Imke Henriette Zimmermann,Imke,Henriette,Zimmermann,JZ6792795,EUR,Germany,0907 860545,imke.zimmermann@hotmail.com,Heidelberg,Rathausstraße,21.0,98917,
2,client_971,Olivia Romane Peiren,Olivia,Romane,Peiren,KW7402943,EUR,Belgium,+32 0474 009 161,olivia.peiren@proximus.be,Ostend,Rue des Deux Portes,51.0,8953,
3,client_985,Rossi Silvestri Donati,Rossi,Silvestri,Donati,CF8312708,EUR,Italy,306 7122234,rossi.donati@outlook.com,Perugia,Via del Governo Vecchio,61.0,18216,
4,client_172,Belli Landi Donati,Belli,Landi,Donati,VK0360169,EUR,Italy,390 7087234,belli.donati@unidata.it,Treviso,Via Condotti,73.0,01009,
5,client_340,Michiels Sanders Peters,Michiels,Sanders,Peters,PY0491637,EUR,Belgium,+32 0462 871 432,michiels.peters@hotmail.com,Mol,Grand Place,54.0,1383,
6,client_724,Conti Leonardi Marino,Conti,Leonardi,Marino,CS9947284,EUR,Italy,++32 338 4801028,conti marino@virgilio it,Florence,Via dei Banchi Vecchi,14.0,94709,
7,client_516,Dekkers Terpstra,Dekkers,,Terpstra,OO3562302,EUR,Netherlands,+31 06 85244605,dekkers.terpstra@hotmail.com,Tilburg,Leidseplein,97.0,9480 01,
8,client_186,Emilia Anniina Mattila,Emilia,Anniina,Mattila,IE4115020,EUR,Finland,044 248 13 69,emilia.mattila@hotmail.com,Kauniainen,Kulosaarentie,48.0,18327,
9,client_982,Lang Koller Egger,Lang,Koller,Egger,UZ3983718,EUR,Austria,+43 057 770 6911,lang.egger@gmx.at,Linz,Leopoldstraße,99.0,2348,


In [236]:
client_profile_pd = pd.read_csv('./client_profile_data_with_label.csv')
client_profile_pd['label'] = None
client_profile_pd.head(5)

Unnamed: 0,index,name,country_of_domicile,birth_date,nationality,passport_number,passport_issue_date,passport_expiry_date,gender,phone_number,...,address.postal code,secondary_school.name,secondary_school.graduation_year,aum.savings,aum.inheritance,aum.real_estate_value,inheritance_details.relationship,inheritance_details.inheritance year,inheritance_details.profession,label
0,client_529,Hannelore Joyce De Jong,Netherlands,1957-03-26,Dutch,MU3505030,2023-02-28,2033-02-27,F,+31 06 84841712,...,6676 80,Bonaventura College Leiden,1976,670000,0,4220000,,,,
1,client_25,Imke Henriette Zimmermann,Germany,1990-01-29,German,JZ6792795,2018-07-08,2028-07-07,F,0907 860545,...,98917,Pestalozzi-Gymnasium Nürnberg,2008,110000,3560000,3596000,grandfather,2019.0,Real Estate Developer,
2,client_971,Olivia Romane Peiren,Belgium,1988-12-17,Belgian,KW7402943,2022-10-28,2029-10-27,F,+32 0474 009 161,...,8953,Heilige Familiecollege,2006,130000,7350000,910000,grandfather,2017.0,Investment Banker,
3,client_985,Rossi Silvestri Donati,Italy,1959-09-05,Italian,CF8312708,2022-08-12,2032-08-11,M,306 7122234,...,18216,Istituto Tecnico Industriale Statale Enrico Ma...,1978,570000,0,1495000,,,,
4,client_172,Belli Landi Donati,Italy,1978-10-01,Italian,VK0360169,2021-12-03,2031-12-02,M,390 7087234,...,01009,Liceo Classico Annibale Mariotti Siena,1996,340000,1630000,1390000,grandfather,2009.0,Hedge Fund Manager,


In [237]:
reje_ids,labels = validate_data(passport_pd, account_form_pd, client_profile_pd)

len(reje_ids)

Sorry! You are being rejected because the passport number is wrong. client_724
Sorry! You are being rejected because the phone number is wrong. client_982
Sorry! You are being rejected because the Passport MRZ is wrong. client_976
Sorry! You are being rejected because the phone number is wrong. client_181
Sorry! You are being rejected because the phone number is wrong. client_940
Sorry! You are being rejected because the email id is wrong. client_712
Sorry! You are being rejected because the Passport MRZ is wrong. client_144
Sorry! You are being rejected because the Passport MRZ is wrong. client_574
Sorry! You are being rejected because the phone number is wrong. client_587
Sorry! You are being rejected because the phone number is wrong. client_741
Sorry! You are being rejected because the passport number is wrong. client_313
Sorry! You are being rejected because the email id is wrong. client_76
Sorry! You are being rejected because the phone number is wrong. client_946
Sorry! You are 

261

In [238]:
client_profile_filtered = client_profile_pd[~client_profile_pd["index"].isin(reje_ids)]
passport_filtered = passport_pd[~passport_pd["index"].isin(reje_ids)]
account_form_filtered = account_form_pd[~account_form_pd["index"].isin(reje_ids)]

print(len(client_profile_filtered))

733


In [248]:
money_list = []
for index in client_profile_filtered.index:
    data = client_profile_filtered.loc[index]
    currency = data["currency"]
    if currency not in cuurency_list:
       continue 
    money_per = calculate_property(data,currency_rate)
    investment_per = get_investment_info(data)
    if investment_per == None:
        continue
    abel_per = data["label"]
    money_per = [data['index']] +money_per + investment_per
    money_per.append(label_per)
    money_list.append(money_per)

In [249]:
df = pd.DataFrame(money_list, columns=['index', 'saving', 'inheritance', 'real_estate_value', 'total','investment_experience','preferred_markets','type_of_mandate','investment_risk_profile', 'label'])
df_subset = df[['saving', 'real_estate_value','total','investment_experience','investment_risk_profile','type_of_mandate','label']]
df_index = df[['index']]

In [244]:
X = df_subset.drop('label', axis=1)  # 特征数据
# X_encoded = pd.get_dummies(X, columns=['investment_experience','investment_risk_profile','type_of_mandate'])

X_test_tensor = torch.tensor(X.values, dtype=torch.float32)

print(X_test_tensor.shape)


torch.Size([727, 6])


index    client_25
Name: 1, dtype: object


In [256]:
model.eval()  # 切换到评估模式
result = {}
df_index = df[['index']]

with torch.no_grad():
    for id, item in enumerate(X_test_tensor):
        item = item.unsqueeze(0)
        outputs = model(item)  # 输出形状 (batch_size, 2)

        predicted = torch.argmax(outputs, dim=1)
        client_id = str(df_index.iloc[id]["index"])
        label = ""
        if int(predicted) == 0:
            label = "Accept"
        else:

            label = "Reject"
        result[client_id] = label

result

{'client_529': 'Accept',
 'client_25': 'Accept',
 'client_971': 'Accept',
 'client_985': 'Accept',
 'client_172': 'Accept',
 'client_340': 'Accept',
 'client_516': 'Accept',
 'client_186': 'Accept',
 'client_378': 'Accept',
 'client_22': 'Accept',
 'client_511': 'Accept',
 'client_723': 'Accept',
 'client_347': 'Accept',
 'client_175': 'Accept',
 'client_949': 'Accept',
 'client_188': 'Accept',
 'client_518': 'Accept',
 'client_385': 'Accept',
 'client_371': 'Accept',
 'client_143': 'Accept',
 'client_715': 'Accept',
 'client_349': 'Accept',
 'client_947': 'Accept',
 'client_13': 'Accept',
 'client_520': 'Accept',
 'client_978': 'Accept',
 'client_376': 'Accept',
 'client_382': 'Accept',
 'client_78': 'Accept',
 'client_746': 'Accept',
 'client_322': 'Accept',
 'client_110': 'Accept',
 'client_580': 'Accept',
 'client_913': 'Accept',
 'client_779': 'Accept',
 'client_47': 'Accept',
 'client_117': 'Accept',
 'client_325': 'Accept',
 'client_573': 'Accept',
 'client_40': 'Accept',
 'clie

In [258]:


for id in reje_ids:
    result[id] = "Reject"

{'client_529': 'Accept',
 'client_25': 'Accept',
 'client_971': 'Accept',
 'client_985': 'Accept',
 'client_172': 'Accept',
 'client_340': 'Accept',
 'client_516': 'Accept',
 'client_186': 'Accept',
 'client_378': 'Accept',
 'client_22': 'Accept',
 'client_511': 'Accept',
 'client_723': 'Accept',
 'client_347': 'Accept',
 'client_175': 'Accept',
 'client_949': 'Accept',
 'client_188': 'Accept',
 'client_518': 'Accept',
 'client_385': 'Accept',
 'client_371': 'Accept',
 'client_143': 'Accept',
 'client_715': 'Accept',
 'client_349': 'Accept',
 'client_947': 'Accept',
 'client_13': 'Accept',
 'client_520': 'Accept',
 'client_978': 'Accept',
 'client_376': 'Accept',
 'client_382': 'Accept',
 'client_78': 'Accept',
 'client_746': 'Accept',
 'client_322': 'Accept',
 'client_110': 'Accept',
 'client_580': 'Accept',
 'client_913': 'Accept',
 'client_779': 'Accept',
 'client_47': 'Accept',
 'client_117': 'Accept',
 'client_325': 'Accept',
 'client_573': 'Accept',
 'client_40': 'Accept',
 'clie

In [264]:
df = pd.DataFrame(list(result.items()), columns=['client', 'label'])

df.to_csv('output.csv', index=False, sep=";")
