In [1]:
% load_ext autoreload
% autoreload 2

In [87]:
import dgl
import numpy as np
import torch
from torch_geometric.data import HeteroData

from loan_pred.helpers.helper import load_pickle
from loan_pred.preprocessing.embedding import EmbeddingTransformer
from loan_pred.preprocessing.get_data import (
    get_train_dg, get_train_loans, get_train_prevloans
)

In [3]:
scaler_dg = load_pickle(file_path="../models_storage/scalers/dg_scaler.pk")
encoder_dg = load_pickle(file_path="../models_storage/encoders/dg_multilabel_encoder.pk")
train_dg = get_train_dg(
    path="../data/preprocessed/train/train_dg.csv",
    encoder=encoder_dg,
    scaler=scaler_dg
)
train_dg.head()

Unnamed: 0,customerid,birthdate,bank_account_type,longitude_gps,latitude_gps,bank_name_clients,employment_status_clients,is_missing_emp_status_clients
0,8a858e135cb22031015cbafc76964ebd,1973-10-10,2,-0.181928,-0.236603,6,1,1
1,8a858e275c7ea5ec015c82482d7c3996,1986-01-21,2,-0.18104,-0.043197,12,1,0
2,8a858e5b5bd99460015bdc95cd485634,1987-04-01,2,0.15589,-0.552651,4,1,1
3,8a858efd5ca70688015cabd1f1e94b55,1991-07-19,2,-0.175854,-0.199322,6,1,0
4,8a858e785acd3412015acd48f4920d04,1982-11-22,2,0.533009,1.545177,6,1,0


In [4]:
train_perf = get_train_loans(
    path="../data/preprocessed/train/train_perf.csv",
    encoder=load_pickle(file_path="../models_storage/encoders/loan_target_encoder"),
    scaler=load_pickle(file_path="../models_storage/scalers/loan_scaler.pk")
)
train_perf.head()

Unnamed: 0,customerid,loannumber,approveddate,loanamount,totaldue,termdays,good_bad_flag
0,8a2a81a74ce8c05d014cfb32a0da1049,1.868965,2017-07-25 08:22:56,1.134202,1.108898,0.06414,1
1,8a85886e54beabf90154c0a29ae757c0,-0.868398,2017-07-05 17:04:41,-0.261346,-0.335566,0.06414,1
2,8a8588f35438fe12015444567666018e,0.500283,2017-07-06 14:52:57,0.203837,0.083119,-1.238939,1
3,8a85890754145ace015429211b513e16,-0.594662,2017-07-27 19:00:41,-0.726529,-0.817054,-1.238939,1
4,8a858970548359cc0154883481981866,1.047756,2017-07-03 23:42:45,2.064568,1.9044,0.06414,1


In [5]:
train_prevloans = get_train_prevloans(
    path="../data/preprocessed/train/train_prevloans.csv",
    scaler=load_pickle(file_path="../models_storage/scalers/prevloan_scaler.pk")
)
train_prevloans.head()

Unnamed: 0,customerid,loannumber,loanamount,totaldue,termdays,closeddate_days,firstduedate_days,firstrepaiddate_days
0,8a2a81a74ce8c05d014cfb32a0da1049,-0.673771,-0.697536,-0.628776,0.302132,-0.556943,0.618378,-0.543251
1,8a2a81a74ce8c05d014cfb32a0da1049,1.480472,-0.697536,-0.628776,0.302132,0.424067,0.897378,0.365004
2,8a2a81a74ce8c05d014cfb32a0da1049,1.172723,0.375392,0.404325,0.302132,2.159701,0.618378,2.429218
3,8a8588f35438fe12015444567666018e,0.249476,-0.697536,-0.772263,-1.0682,-0.707868,-1.47412,-0.708388
4,8a85890754145ace015429211b513e16,-0.673771,-0.697536,-0.772263,-1.0682,0.273143,-1.33462,0.365004


In [6]:
# Apply embedding transformation
embeddings_weight = load_pickle(file_path="../models_storage/embeddings/embeddings_weights.pk")
embedder = EmbeddingTransformer(embedding_weights=embeddings_weight)

In [7]:
train_dg = embedder.transform(train_dg)
train_dg.head()

Unnamed: 0,customerid,birthdate,longitude_gps,latitude_gps,is_missing_emp_status_clients,bank_account_type_0,bank_account_type_1,bank_name_clients_0,bank_name_clients_1,bank_name_clients_2,...,bank_name_clients_4,bank_name_clients_5,bank_name_clients_6,bank_name_clients_7,bank_name_clients_8,bank_name_clients_9,employment_status_clients_0,employment_status_clients_1,employment_status_clients_2,employment_status_clients_3
0,8a858e135cb22031015cbafc76964ebd,1973-10-10,-0.181928,-0.236603,1,-0.028878,-0.855125,0.063461,-0.186302,1.173964,...,-0.154604,0.76662,-0.135906,0.285251,-0.012637,-0.010871,0.486133,0.341335,-1.499897,0.908332
1,8a858e275c7ea5ec015c82482d7c3996,1986-01-21,-0.18104,-0.043197,0,-0.028878,-0.855125,-0.123493,-2.284099,-1.671413,...,0.747014,-0.053674,-3.044296,1.864994,0.822281,2.384361,0.486133,0.341335,-1.499897,0.908332
2,8a858e5b5bd99460015bdc95cd485634,1987-04-01,0.15589,-0.552651,1,-0.028878,-0.855125,1.103828,0.777442,1.248884,...,-1.158112,-0.445147,0.017469,-0.171524,-0.475374,-0.648696,0.486133,0.341335,-1.499897,0.908332
3,8a858efd5ca70688015cabd1f1e94b55,1991-07-19,-0.175854,-0.199322,0,-0.028878,-0.855125,0.063461,-0.186302,1.173964,...,-0.154604,0.76662,-0.135906,0.285251,-0.012637,-0.010871,0.486133,0.341335,-1.499897,0.908332
4,8a858e785acd3412015acd48f4920d04,1982-11-22,0.533009,1.545177,0,-0.028878,-0.855125,0.063461,-0.186302,1.173964,...,-0.154604,0.76662,-0.135906,0.285251,-0.012637,-0.010871,0.486133,0.341335,-1.499897,0.908332


In [8]:
# Graph generation

In [9]:
from loan_pred.preprocessing.graph_processing import generate_graphs_data

graph_generator = generate_graphs_data(
    loan=train_perf,
    dg=train_dg,
    prev_loan=train_prevloans
)

In [10]:
for k, v in enumerate(graph_generator):
    print(v)
    break

{'id': 0, 'user_id': '8a2a81a74ce8c05d014cfb32a0da1049', 'label': 1, 'node_loans': [[1.868964712900149, 1.1342024378353102, 1.1088983044984644, 0.06413962696284115, 0.9105374790315618]], 'node_dg': [[-0.16622738032006648, -0.2678823639894638, 0.0, 2.2026453018188477, 1.5976506471633911, -1.425376296043396, 0.36965134739875793, 1.0473814010620117, 1.9874283075332642, 0.24556586146354675, 1.1803208589553833, 0.15472577512264252, -1.920318365097046, -0.8908156752586365, -0.6241723299026489, 0.4861326813697815, 0.34133508801460266, -1.4998971223831177, 0.9083321690559387]], 'node_prevloans': [[-0.673771321016832, -0.6975358677579336, -0.6287764937273944, 0.30213165900863986, -0.5569429447405659, 0.6183781593635554, -0.5432506713248281], [1.4804723804160946, -0.6975358677579336, -0.6287764937273944, 0.30213165900863986, 0.4240673261638455, 0.897377937823334, 0.36500371102529966], [1.172723280211391, 0.3753919808058758, 0.4043250555812386, 0.30213165900863986, 2.159700882379343, 0.6183781593

In [16]:
data = {'id': 0, 'user_id': '8a2a81a74ce8c05d014cfb32a0da1049', 'label': 1, 'node_loans': [
    [1.868964712900149, 1.1342024378353102, 1.1088983044984644, 0.06413962696284115, 0.9105374790315618]], 'node_dg': [
    [-0.16622738032006648, -0.2678823639894638, 0.0, 2.2026453018188477, 1.5976506471633911, -1.425376296043396,
     0.36965134739875793, 1.0473814010620117, 1.9874283075332642, 0.24556586146354675, 1.1803208589553833,
     0.15472577512264252, -1.920318365097046, -0.8908156752586365, -0.6241723299026489, 0.4861326813697815,
     0.34133508801460266, -1.4998971223831177, 0.9083321690559387]], 'node_prevloans': [
    [-0.673771321016832, -0.6975358677579336, -0.6287764937273944, 0.30213165900863986, -0.5569429447405659,
     0.6183781593635554, -0.5432506713248281],
    [1.4804723804160946, -0.6975358677579336, -0.6287764937273944, 0.30213165900863986, 0.4240673261638455,
     0.897377937823334, 0.36500371102529966],
    [1.172723280211391, 0.3753919808058758, 0.4043250555812386, 0.30213165900863986, 2.159700882379343,
     0.6183781593635554, 2.429218216366499],
    [1.7882214806207986, 0.3753919808058758, 0.47128534118457593, 0.30213165900863986, 0.04675568350830268,
     0.897377937823334, 0.11729797038435574],
    [2.0959705808255022, 0.3753919808058758, 0.47128534118457593, 0.30213165900863986, 0.1976803405705198,
     1.0368778270532233, 0.28243513081165167],
    [0.5572250798019833, 0.3753919808058758, 0.4043250555812386, 0.30213165900863986, 3.36709813887708,
     0.6183781593635554, 3.7503154997848664],
    [-0.058273120607424365, 0.3753919808058758, 0.47128534118457593, 0.30213165900863986, 0.4995296546949541,
     0.7578780485934448, 0.6127094516662436],
    [0.8649741800066871, 0.3753919808058758, 0.4043250555812386, 0.30213165900863986, -1.6888778727071945,
     0.6183781593635554, -1.7817793745295476],
    [-0.9815204212215358, -0.6975358677579336, -0.6287764937273944, 0.30213165900863986, -1.3870285585827602,
     0.6183781593635554, -1.4515050536749556],
    [-0.3660222208121282, 0.3753919808058758, 0.47128534118457593, 0.30213165900863986, 0.4240673261638455,
     0.6183781593635554, 0.5301408714525956],
    [0.24947597959727946, 1.4483198293696853, 1.4278608498036804, 0.30213165900863986, 1.556002254130474,
     0.6183781593635554, 1.7686695746573151]]}

In [17]:
heterograph = HeteroData()

heterograph["node_loans"].x = np.array(data["node_loans"])
heterograph["node_dg"].x = np.array(data["node_dg"])
heterograph["node_prevloans"].x = np.array(data["node_prevloans"])

In [18]:
heterograph

HeteroData(
  [1mnode_loans[0m={ x=[1, 5] },
  [1mnode_dg[0m={ x=[1, 19] },
  [1mnode_prevloans[0m={ x=[11, 7] }
)

In [19]:
def set_edges(data: dict, graph):
    for node in ["node_prevloans", "node_dg"]:
        node_ids = list(range(len(data[node])))
        if len(node_ids) > 0:
            edges = np.array([[0] * len(node_ids), node_ids])
            graph["node_loans", "has", node].edge_index = edges
    return graph

In [20]:
graph = set_edges(data, heterograph)

In [131]:
feats = {}
for node in ["node_prevloans", "node_dg"]:
    node_ids = list(range(len(data[node])))
    if len(node_ids) > 0:
        rel = "has" if node == "node_prevloans" else "lives"
        test_[("node_loans", rel, node)] = (np.array([0] * len(node_ids)), np.array(node_ids))


def get_node_feat(data):
    feat_n = {
        "node_loans": torch.tensor(data["node_loans"]),
        "node_dg": torch.tensor(data["node_dg"]),
        "node_prevloans": torch.tensor(data["node_prevloans"])
    }
    return feat_n

In [132]:
test_

{('node_loans',
  'has',
  'node_prevloans'): (array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])),
 ('node_loans', 'lives', 'node_dg'): (array([0]), array([0]))}

In [133]:
rt = dgl.heterograph(test_)

In [134]:
rt.nodes["node_loans"].data["feat"] = torch.tensor(data["node_loans"])
rt.nodes["node_dg"].data["feat"] = torch.tensor(data["node_dg"])
rt.nodes["node_prevloans"].data["feat"] = torch.tensor(data["node_prevloans"])

In [135]:
rt.nodes["node_prevloans"]

NodeSpace(data={'feat': tensor([[-0.6738, -0.6975, -0.6288,  0.3021, -0.5569,  0.6184, -0.5433],
        [ 1.4805, -0.6975, -0.6288,  0.3021,  0.4241,  0.8974,  0.3650],
        [ 1.1727,  0.3754,  0.4043,  0.3021,  2.1597,  0.6184,  2.4292],
        [ 1.7882,  0.3754,  0.4713,  0.3021,  0.0468,  0.8974,  0.1173],
        [ 2.0960,  0.3754,  0.4713,  0.3021,  0.1977,  1.0369,  0.2824],
        [ 0.5572,  0.3754,  0.4043,  0.3021,  3.3671,  0.6184,  3.7503],
        [-0.0583,  0.3754,  0.4713,  0.3021,  0.4995,  0.7579,  0.6127],
        [ 0.8650,  0.3754,  0.4043,  0.3021, -1.6889,  0.6184, -1.7818],
        [-0.9815, -0.6975, -0.6288,  0.3021, -1.3870,  0.6184, -1.4515],
        [-0.3660,  0.3754,  0.4713,  0.3021,  0.4241,  0.6184,  0.5301],
        [ 0.2495,  1.4483,  1.4279,  0.3021,  1.5560,  0.6184,  1.7687]])})

In [104]:
rt.canonical_etypes

[('node_loans', 'has', 'node_prevloans'), ('node_loans', 'lives', 'node_dg')]

In [151]:
dg = rt.ndata["feat"]["node_dg"]
loan = rt.ndata["feat"]["node_loans"]
prev = rt.ndata["feat"]["node_prevloans"]

In [175]:
torch.cat((dg.flatten(), loan.flatten(), prev.flatten()))

tensor([-0.1662, -0.2679,  0.0000,  2.2026,  1.5977, -1.4254,  0.3697,  1.0474,
         1.9874,  0.2456,  1.1803,  0.1547, -1.9203, -0.8908, -0.6242,  0.4861,
         0.3413, -1.4999,  0.9083,  1.8690,  1.1342,  1.1089,  0.0641,  0.9105,
        -0.6738, -0.6975, -0.6288,  0.3021, -0.5569,  0.6184, -0.5433,  1.4805,
        -0.6975, -0.6288,  0.3021,  0.4241,  0.8974,  0.3650,  1.1727,  0.3754,
         0.4043,  0.3021,  2.1597,  0.6184,  2.4292,  1.7882,  0.3754,  0.4713,
         0.3021,  0.0468,  0.8974,  0.1173,  2.0960,  0.3754,  0.4713,  0.3021,
         0.1977,  1.0369,  0.2824,  0.5572,  0.3754,  0.4043,  0.3021,  3.3671,
         0.6184,  3.7503, -0.0583,  0.3754,  0.4713,  0.3021,  0.4995,  0.7579,
         0.6127,  0.8650,  0.3754,  0.4043,  0.3021, -1.6889,  0.6184, -1.7818,
        -0.9815, -0.6975, -0.6288,  0.3021, -1.3870,  0.6184, -1.4515, -0.3660,
         0.3754,  0.4713,  0.3021,  0.4241,  0.6184,  0.5301,  0.2495,  1.4483,
         1.4279,  0.3021,  1.5560,  0.61

In [355]:
data = {'id': 0, 'user_id': '8a2a81a74ce8c05d014cfb32a0da1049', 'label': 1, 'node_loans': [
    [1.868964712900149, 1.1342024378353102, 1.1088983044984644, 0.06413962696284115, 0.9105374790315618]], 'node_dg': [
    [-0.16622738032006648, -0.2678823639894638, 0.0, 2.2026453018188477, 1.5976506471633911, -1.425376296043396,
     0.36965134739875793, 1.0473814010620117, 1.9874283075332642, 0.24556586146354675, 1.1803208589553833,
     0.15472577512264252, -1.920318365097046, -0.8908156752586365, -0.6241723299026489, 0.4861326813697815,
     0.34133508801460266, -1.4998971223831177, 0.9083321690559387]], 'node_prevloans': [
    [-0.673771321016832, -0.6975358677579336, -0.6287764937273944, 0.30213165900863986, -0.5569429447405659,
     0.6183781593635554, -0.5432506713248281],
    [1.4804723804160946, -0.6975358677579336, -0.6287764937273944, 0.30213165900863986, 0.4240673261638455,
     0.897377937823334, 0.36500371102529966],
    [1.172723280211391, 0.3753919808058758, 0.4043250555812386, 0.30213165900863986, 2.159700882379343,
     0.6183781593635554, 2.429218216366499],
    [1.7882214806207986, 0.3753919808058758, 0.47128534118457593, 0.30213165900863986, 0.04675568350830268,
     0.897377937823334, 0.11729797038435574],
    [2.0959705808255022, 0.3753919808058758, 0.47128534118457593, 0.30213165900863986, 0.1976803405705198,
     1.0368778270532233, 0.28243513081165167],
    [0.5572250798019833, 0.3753919808058758, 0.4043250555812386, 0.30213165900863986, 3.36709813887708,
     0.6183781593635554, 3.7503154997848664],
    [-0.058273120607424365, 0.3753919808058758, 0.47128534118457593, 0.30213165900863986, 0.4995296546949541,
     0.7578780485934448, 0.6127094516662436],
    [0.8649741800066871, 0.3753919808058758, 0.4043250555812386, 0.30213165900863986, -1.6888778727071945,
     0.6183781593635554, -1.7817793745295476],
    [-0.9815204212215358, -0.6975358677579336, -0.6287764937273944, 0.30213165900863986, -1.3870285585827602,
     0.6183781593635554, -1.4515050536749556],
    [-0.3660222208121282, 0.3753919808058758, 0.47128534118457593, 0.30213165900863986, 0.4240673261638455,
     0.6183781593635554, 0.5301408714525956],
    [0.24947597959727946, 1.4483198293696853, 1.4278608498036804, 0.30213165900863986, 1.556002254130474,
     0.6183781593635554, 1.7686695746573151]]}

In [356]:

import numpy as np
import torch

test_ = {}
for node in ["node_prevloans", "node_dg"]:
    node_ids = list(range(len(data[node])))
    if len(node_ids) > 0:
        if node == "node_prevloans":
            rel = "has"
        else:
            rel = "lives"
        test_[("node_loans", rel, node)] = (np.array([0] * len(node_ids)), np.array(node_ids))

In [357]:
g = dgl.heterograph(test_)
g.nodes["node_loans"].data["feat"] = torch.tensor(data["node_loans"])
g.nodes["node_dg"].data["feat"] = torch.tensor(data["node_dg"])
g.nodes["node_prevloans"].data["feat"] = torch.tensor(data["node_prevloans"])

In [358]:
features = g.ndata["feat"]
feat = {
    'node_dg': g.ndata['feat']['node_dg'],
    'node_prevloans': g.ndata['feat']['node_prevloans']
}

In [392]:
import dgl.nn.pytorch as dglnn

conv1 = dglnn.HeteroGraphConv(
    {
        'has': dglnn.GraphConv(in_feats=5, out_feats=3, norm='both', weight=True, bias=True, activation=F.relu),
        'lives': dglnn.GraphConv(in_feats=5, out_feats=3, norm='both', weight=True, bias=True, activation=F.relu)
    },
    aggregate='sum'
)
conv2 = dglnn.HeteroGraphConv(
    {
        'has': dglnn.GraphConv(in_feats=5, out_feats=3, norm='both', weight=True, bias=True, activation=F.relu),
        'lives': dglnn.GraphConv(in_feats=5, out_feats=3, norm='both', weight=True, bias=True, activation=F.relu)
    },
    aggregate='sum',
)

In [393]:
res2 = conv1(g, features)
h = {k: F.relu(v) for k, v in res2.items()}
h['node_loans'] = g.ndata["feat"]['node_loans']
res3 = conv2(g, h)

In [394]:
dgl.mean_nodes(g, 'feat', ntype='node_prevloans')

tensor([[0.5572, 0.1803, 0.2400, 0.3021, 0.4584, 0.7198, 0.5527]])

In [395]:
features

{'node_dg': tensor([[-0.1662, -0.2679,  0.0000,  2.2026,  1.5977, -1.4254,  0.3697,  1.0474,
           1.9874,  0.2456,  1.1803,  0.1547, -1.9203, -0.8908, -0.6242,  0.4861,
           0.3413, -1.4999,  0.9083]]),
 'node_loans': tensor([[1.8690, 1.1342, 1.1089, 0.0641, 0.9105]]),
 'node_prevloans': tensor([[-0.6738, -0.6975, -0.6288,  0.3021, -0.5569,  0.6184, -0.5433],
         [ 1.4805, -0.6975, -0.6288,  0.3021,  0.4241,  0.8974,  0.3650],
         [ 1.1727,  0.3754,  0.4043,  0.3021,  2.1597,  0.6184,  2.4292],
         [ 1.7882,  0.3754,  0.4713,  0.3021,  0.0468,  0.8974,  0.1173],
         [ 2.0960,  0.3754,  0.4713,  0.3021,  0.1977,  1.0369,  0.2824],
         [ 0.5572,  0.3754,  0.4043,  0.3021,  3.3671,  0.6184,  3.7503],
         [-0.0583,  0.3754,  0.4713,  0.3021,  0.4995,  0.7579,  0.6127],
         [ 0.8650,  0.3754,  0.4043,  0.3021, -1.6889,  0.6184, -1.7818],
         [-0.9815, -0.6975, -0.6288,  0.3021, -1.3870,  0.6184, -1.4515],
         [-0.3660,  0.3754,  0.471

In [396]:
h

{'node_dg': tensor([[0.0000, 0.9212, 1.6309]], grad_fn=<ReluBackward0>),
 'node_prevloans': tensor([[0.5591, 0.0000, 0.0000],
         [0.5591, 0.0000, 0.0000],
         [0.5591, 0.0000, 0.0000],
         [0.5591, 0.0000, 0.0000],
         [0.5591, 0.0000, 0.0000],
         [0.5591, 0.0000, 0.0000],
         [0.5591, 0.0000, 0.0000],
         [0.5591, 0.0000, 0.0000],
         [0.5591, 0.0000, 0.0000],
         [0.5591, 0.0000, 0.0000],
         [0.5591, 0.0000, 0.0000]], grad_fn=<ReluBackward0>),
 'node_loans': tensor([[1.8690, 1.1342, 1.1089, 0.0641, 0.9105]])}

In [397]:
res3

{'node_dg': tensor([[0., 0., 0.]], grad_fn=<SumBackward1>),
 'node_prevloans': tensor([[0.3872, 0.0000, 0.0000],
         [0.3872, 0.0000, 0.0000],
         [0.3872, 0.0000, 0.0000],
         [0.3872, 0.0000, 0.0000],
         [0.3872, 0.0000, 0.0000],
         [0.3872, 0.0000, 0.0000],
         [0.3872, 0.0000, 0.0000],
         [0.3872, 0.0000, 0.0000],
         [0.3872, 0.0000, 0.0000],
         [0.3872, 0.0000, 0.0000],
         [0.3872, 0.0000, 0.0000]], grad_fn=<SumBackward1>)}