- train model on original dataset
- test/validate model on unbiased dataset

In [1]:
# Notebook config
import sys
if '../' not in sys.path:
    sys.path.append("../")
%load_ext dotenv
%dotenv


# Actual imports
from config.config import DIR_CFG, MODEL_CFGS, CURRENT_MODEL_VERSION
from queries import (
    feature_queries,
    gds_queries,
    pyg_queries,
    utils,
)
from queries.feature_queries import (
    ListEncoder,
    IdentityEncoder,
    load_edge_tensor,
    load_node_tensor,
)
import torch_geometric.transforms as T


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data = pyg_queries.create_pyg_graph()
train_data, val_data, test_data = pyg_queries.split_data(data)
train_loader = pyg_queries.get_train_loader(train_data)
model = pyg_queries.get_model(data)
pyg_queries.train(model, train_loader)
val_loader = pyg_queries.get_val_loader(val_data)
pyg_queries.eval(model, val_loader)

HeteroData(
  [1mpalmprint[0m={ x=[107546, 1] },
  [1mtaxon[0m={ x=[250187, 1] },
  [1m(palmprint, has_host, taxon)[0m={ edge_index=[2, 135975] },
  [1m(palmprint, has_sotu, palmprint)[0m={
    edge_index=[2, 56162],
    edge_label=[56162]
  },
  [1m(taxon, has_parent, taxon)[0m={
    edge_index=[2, 248206],
    edge_label=[248206]
  }
)
Node types: ['palmprint', 'taxon']
Number of nodes: 357733
Dimension of node features: {'palmprint': 1, 'taxon': 1}
Edge types: [('palmprint', 'has_host', 'taxon'), ('palmprint', 'has_sotu', 'palmprint'), ('taxon', 'has_parent', 'taxon')]
Number of edges: 440343
Dimension of edge features: {('palmprint', 'has_host', 'taxon'): 0, ('palmprint', 'has_sotu', 'palmprint'): 0, ('taxon', 'has_parent', 'taxon'): 0}
Graph has isolated nodes: True
Graph has self loops: False
Graph is undirected: False
Edge types: [('palmprint', 'has_host', 'taxon'), ('palmprint', 'has_sotu', 'palmprint'), ('taxon', 'has_parent', 'taxon'), ('taxon', 'rev_has_host', 'pal

0.7002938354840662

In [3]:
MODEL_CFG = MODEL_CFGS[CURRENT_MODEL_VERSION]

sampling_rate=MODEL_CFG['SAMPLING_RATIO']
graph_name = MODEL_CFG['PROJECTION_NAME']
dir_name = f"{DIR_CFG['DATASETS_DIR']}{graph_name}_{sampling_rate}"

taxon_x, taxon_mapping = load_node_tensor(
    filename=f'{dir_name}/taxon_nodes.csv',
    index_col='nodeId',
    # encoders={
    #     # 'rank': LabelEncoder(),
    #     'features': ListEncoder()
    # }
)
palmprint_x, palmprint_mapping = load_node_tensor(
    filename=f'{dir_name}/palmprint_nodes.csv',
    index_col='nodeId',
    # encoders={
    #     'features': ListEncoder()
    # }
)


has_host_edge_index, has_host_edge_label = load_edge_tensor(
    filename=f'{dir_name}/has_host_edges_resampled.csv',
    src_index_col='sourceNodeId',
    src_mapping=palmprint_mapping,
    dst_index_col='targetNodeId',
    dst_mapping=taxon_mapping,
    # encoders={
    #     'weight': IdentityEncoder(dtype=torch.long)
    # },
)

data['palmprint', 'has_host', 'taxon'].edge_index = has_host_edge_index
data['palmprint', 'has_host', 'taxon'].edge_label = has_host_edge_label
del data['taxon', 'rev_has_host', 'palmprint']
data = T.ToUndirected()(data)
print(data)

transform = T.RandomLinkSplit(
        # Link-level split train (80%), validate (10%), and test edges (10%)
        num_val=0.1,
        num_test=0.1,

        # Of training edges, use 70% for message passing (edge_label_index)
        # and 30% for supervision (edge_index)
        disjoint_train_ratio=0.3,

        # Generate fixed negative edges for evaluation with a ratio of 2-1.
        # Negative edges during training will be generated on-the-fly.
        neg_sampling_ratio=MODEL_CFG['NEGATIVE_SAMPLING_RATIO'],
        add_negative_train_samples=True,

        edge_types=('palmprint', 'has_host', 'taxon'),
        rev_edge_types=('taxon', 'rev_has_host', 'palmprint'),
    )
train_data, val_data, test_data = transform(data)

# train_data, val_data, test_data = pyg_queries.split_data(data)
# train_loader = pyg_queries.get_train_loader(train_data)
# model = pyg_queries.get_model(data)
# pyg_queries.train(model, train_loader)
val_loader = pyg_queries.get_val_loader(val_data)
print(len(val_loader))
print(val_loader)

pyg_queries.eval(model, val_loader)

HeteroData(
  [1mpalmprint[0m={ x=[107546, 1] },
  [1mtaxon[0m={ x=[250187, 1] },
  [1m(palmprint, has_host, taxon)[0m={ edge_index=[2, 135877] },
  [1m(palmprint, has_sotu, palmprint)[0m={
    edge_index=[2, 112324],
    edge_label=[112324]
  },
  [1m(taxon, has_parent, taxon)[0m={
    edge_index=[2, 496412],
    edge_label=[496412]
  },
  [1m(taxon, rev_has_host, palmprint)[0m={ edge_index=[2, 135877] }
)
Sampled validation mini-batch:
HeteroData(
  [1mpalmprint[0m={
    x=[1490, 1],
    n_id=[1490]
  },
  [1mtaxon[0m={
    x=[601, 1],
    n_id=[601]
  },
  [1m(palmprint, has_host, taxon)[0m={
    edge_index=[2, 1],
    edge_label=[384],
    edge_label_index=[2, 384],
    e_id=[1],
    input_id=[384]
  },
  [1m(palmprint, has_sotu, palmprint)[0m={
    edge_index=[2, 1322],
    edge_label=[1322],
    e_id=[1322]
  },
  [1m(taxon, has_parent, taxon)[0m={
    edge_index=[2, 683],
    edge_label=[683],
    e_id=[683]
  },
  [1m(taxon, rev_has_host, palmprint)[0m={

0.7471867965349512

In [8]:
print(val_data['palmprint'].x.shape)

torch.Size([107546, 1])
