In [2]:
# Imports and config for the notebook

## Notebook config
import sys
if '../' not in sys.path:
    sys.path.append("../")
%load_ext dotenv
%reload_ext dotenv
%dotenv

import collections
import os
import csv

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import graphistry

from datasources.neo4j import gds
from queries import utils

  from .autonotebook import tqdm as notebook_tqdm


### Embedding testing

In [2]:
# Configs

# graphistry.register(
#     api=3,
#     username=os.getenv('GRAPHISTRY_USERNAME'),
#     password=os.getenv('GRAPHISTRY_PASSWORD'),
# )

RANDOM_SEED = 42
EMBEDDINGS_DIR = '/mnt/embeddings/'

print(gds)

<graphdatascience.graph_data_science.GraphDataScience object at 0x7f4a12cb81d0>


TODO:


- Create Tissue projection (random generate initial embeddings)
- Create Taxon projection (use rank as initial feature?)

- Create heterogenous projection from dataset?
- run memory estimates for hashgnn and fastRP using hetero projection and 


- HashGNN_homogenous, HashGNN_heterogenous
- FastRP_homogenous, FastRP_heterogenous

In [13]:
TAXON_PROJECTION_NAME = 'taxon-graph'
TISSUE_PROJECTION_NAME = 'tissue-graph'
HOMOGENOUS_PROJECTION_NAME = 'homogenous-graph'
HETERO_PROJECTION_NAME = 'hetero-graph'


In [14]:

def get_homogenous_projection():
    projection = gds.graph.project(
        graph_name=HOMOGENOUS_PROJECTION_NAME,
        node_spec=[
            'Taxon',
            'Tissue',
            'SOTU',
        ],
        relationship_spec={
            'HAS_PARENT': {'orientation': 'UNDIRECTED'},
            'SEQUENCE_ALIGNMENT': {'orientation': 'UNDIRECTED'},
        },
    )
    return projection


# TODO: Use dataset projection
def get_heterogenous_projection():
    projection = gds.graph.project(
        graph_name=HETERO_PROJECTION_NAME,
        node_spec=[
            'Taxon',
            'Tissue',
            'SOTU',
        ],
        relationship_spec={
            'HAS_PARENT': {'orientation': 'UNDIRECTED'},
            'SEQUENCE_ALIGNMENT': {'orientation': 'UNDIRECTED'},
        },
    )
    return projection


def get_taxon_projection():
    projection = gds.graph.project(
        graph_name=TAXON_PROJECTION_NAME,
        node_spec=[
            'Taxon',
        ],
        relationship_spec={
            'HAS_PARENT': {'orientation': 'UNDIRECTED'},
        },
    )
    return projection

def get_tissue_projection():
    projection = gds.graph.project(
        graph_name=TISSUE_PROJECTION_NAME,
        node_spec=['Tissue'],
        relationship_spec={
            'HAS_PARENT': {'orientation': 'UNDIRECTED'},
        },
    )
    return projection


In [15]:
projection = get_homogenous_projection()

In [18]:
# https://neo4j.com/docs/graph-data-science/current/machine-learning/node-embeddings/hashgnn/#algorithms-embeddings-hashgnn-syntax
# https://github.com/neo4j/graph-data-science-client/blob/main/examples/heterogeneous-node-classification-with-hashgnn.ipynb#L18

# one may try to set embeddingDensity to 128, 256, 512, or roughly 25%-50% of the embedding dimension, i.e. the number of binary features.

gds.hashgnn.stream.estimate(
    G=gds.graph.get(HOMOGENOUS_PROJECTION_NAME),
    nodeLabels=['Taxon', 'Tissue', 'SOTU'],
    relationshipTypes=['HAS_PARENT', 'SEQUENCE_ALIGNMENT'],
    randomSeed=RANDOM_SEED,
    generateFeatures={
        'dimension': 512, # dimension of the embedding vector
        'densityLevel': 2, # number of initial values equalling 1
    },
    iterations=10, # maximum number of hops
    embeddingDensity=256,
    neighborInfluence=1.0,
)

requiredMemory                                [1338 MiB ... 12865 MiB]
treeView             Memory Estimation: [1338 MiB ... 12865 MiB]\n|...
mapView              {'memoryUsage': '[1338 MiB ... 12865 MiB]', 'n...
bytesMin                                                    1403659432
bytesMax                                                   13490131432
nodeCount                                                      3021618
relationshipCount                                             57702540
heapPercentageMin                                                  0.1
heapPercentageMax                                                  0.6
Name: 0, dtype: object

In [20]:
filename = 'HashGNN_homogenous.csv'
df = gds.hashgnn.stream(
    G=gds.graph.get(HOMOGENOUS_PROJECTION_NAME),
    nodeLabels=['Taxon', 'Tissue', 'SOTU'],
    relationshipTypes=['HAS_PARENT', 'SEQUENCE_ALIGNMENT'],
    randomSeed=RANDOM_SEED,
    generateFeatures={
        'dimension': 512, # dimension of the embedding vector
        'densityLevel': 2, # number of initial values equalling 1
    },
    iterations=10, # maximum number of hops
    embeddingDensity=256,
    neighborInfluence=1.0,
)

df.to_csv(EMBEDDING_DIR + filename, index=False)

HashGNN:  48%|████▊     | 47.58/100 [1:53:11<2:18:41, 158.75s/%]

In [None]:
gds.fastrp.stream.estimate(
    G=gds.graph.get(HOMOGENOUS_PROJECTION_NAME),
    nodeLabels=['Taxon', 'Tissue', 'SOTU'],
    relationshipTypes=['HAS_PARENT', 'SEQUENCE_ALIGNMENT'],
    randomSeed=RANDOM_SEED,
    embeddingDimension=256,
    relationshipWeightProperty="weight",
)

In [None]:
filename = 'FastRP_homogenous.csv'
df = gds.fastrp.stream(
    G=gds.graph.get(HOMOGENOUS_PROJECTION_NAME),
    nodeLabels=['Taxon', 'Tissue', 'SOTU'],
    relationshipTypes=['HAS_PARENT', 'SEQUENCE_ALIGNMENT'],
    randomSeed=RANDOM_SEED,
    embeddingDimension=256,
    relationshipWeightProperty='weight',
)
df.to_csv(EMBEDDING_DIR + filename, index=False)

In [2]:
from queries import feature_queries
from config.base import (
    DIR_CFG,
    MODEL_CFG,
    DATASET_CFG,
)

dir_name = '/mnt/graphdata/query_cache/neo4j/'


dataset_cfg = DATASET_CFG
nodes = feature_queries.get_all_node_features(
    dir_name=dir_name,
    dataset_cfg=dataset_cfg,
)

relationships = feature_queries.get_all_relationship_features(
    dir_name=dir_name,
    dataset_cfg=dataset_cfg,
)

undirected_relationship_types = list(
    map(
        (lambda cfg: cfg['TYPES'][0]),
        dataset_cfg['REL_TYPES']
    )
)

Reading local file:  /mnt/graphdata/query_cache/neo4j/sotu_nodes.csv
Reading local file:  /mnt/graphdata/query_cache/neo4j/taxon_nodes.csv
Reading local file:  /mnt/graphdata/query_cache/neo4j/tissue_nodes.csv
Reading local file:  /mnt/graphdata/query_cache/neo4j/sotu_has_host_stat_edges.csv
Reading local file:  /mnt/graphdata/query_cache/neo4j/taxon_has_parent_edges.csv
Reading local file:  /mnt/graphdata/query_cache/neo4j/sotu_sequence_alignment_edges.csv
Reading local file:  /mnt/graphdata/query_cache/neo4j/sotu_has_inferred_taxon_edges.csv
Reading local file:  /mnt/graphdata/query_cache/neo4j/sotu_has_tissue_metadata_edges.csv


In [6]:
print(nodes.shape)
print(relationships.shape)

model_cfg = MODEL_CFG
sampling_ratio = 1

graph_name = \
        f"{model_cfg['PROJECTION_NAME']}_{sampling_ratio}"

print(graph_name)

(3021618, 3)
(24096965, 4)
incl-best-1_1


### ML testing

In [11]:
from queries import pyg_queries, utils
from config.base import MODEL_CFG, DATASET_CFG

from queries.feature_queries import (
    IdentityEncoder,
    ListEncoder,
    load_edge_tensor,
    load_node_tensor,
)
from queries.utils import read_ddf_from_disk
from models.models_v3 import Model
from config.base import (
    DIR_CFG,
    MODEL_CFG,
    DATASET_CFG,
)

import numpy as np
import pandas as pd
from sklearn.metrics import (
    average_precision_score,
    roc_auc_score,
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
)
import torch
from torch_geometric import seed_everything
from torch_geometric.data import HeteroData
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.utils import to_networkx
import torch_geometric.transforms as T
import torch.nn.functional as F


seed_everything(MODEL_CFG['RANDOM_SEED'])

class RandomValueEncoder(object):
    def __init__(self, dim=1):
        self.dim = dim

    def __call__(self, df):
        return torch.rand(len(df), self.dim)

In [1]:
def create_pyg_graph(
    sampling_rate=MODEL_CFG['SAMPLING_RATIO'],
    dataset_cfg=DATASET_CFG,
):
    data = HeteroData()
    mappings = {}
    dir_name = f"{DIR_CFG['DATASETS_DIR']}{sampling_rate}"
    node_file_paths = list(
        map(
            (lambda cfg: cfg['FILE_NAME']),
            dataset_cfg['NODE_TYPES']
        )
    )
    rel_file_paths = list(
        map(
            (lambda cfg: cfg['FILE_NAME']),
            dataset_cfg['REL_TYPES']
        )
    )

    if 'taxon_nodes.csv' in node_file_paths:
        taxon_x, taxon_mapping = load_node_tensor(
            filename=f'{dir_name}/taxon_nodes.csv',
            index_col='appId',
            encoders={
                # 'rankEncoded': IdentityEncoder(
                #     dtype=torch.long, is_tensor=True),
                'features': RandomValueEncoder(),
                # 'FastRP_embedding': ListEncoder(),
            }
        )
        data['taxon'].x = taxon_x
        mappings['taxon'] = taxon_mapping

    if 'sotu_nodes.csv' in node_file_paths:
        sotu_x, sotu_mapping = load_node_tensor(
            filename=f'{dir_name}/sotu_nodes.csv',
            index_col='appId',
            encoders={
                # 'centroidEncoded': IdentityEncoder(
                #   dtype=torch.long, is_tensor=True),
                'features': RandomValueEncoder(),
                # 'FastRP_embedding': ListEncoder(),
            }
        )
        data['sotu'].x = sotu_x #torch.arange(0, len(sotu_mapping))
        mappings['sotu'] = sotu_mapping

    # if 'tissue_nodes.csv' in node_file_paths:
    #     tissue_x, tissue_mapping = load_node_tensor(
    #         filename=f'{dir_name}/tissue_nodes.csv',
    #         index_col='appId',
    #         encoders={
    #             # 'centroidEncoded': IdentityEncoder(
    #             #   dtype=torch.long, is_tensor=True),
    #             'features': ListEncoder(),
    #             # 'FastRP_embedding': ListEncoder(),
    #         }
    #     )
    #     data['tissue'].x = tissue_x # torch.arange(0, len(tissue_mapping))
    #     mappings['tissue'] = tissue_mapping

    if 'sotu_has_host_stat_edges.csv' in rel_file_paths:
        edge_index, edge_label = load_edge_tensor(
            filename=f'{dir_name}/sotu_has_host_stat_edges.csv',
            src_index_col='sourceAppId',
            src_mapping=sotu_mapping,
            dst_index_col='targetAppId',
            dst_mapping=taxon_mapping,
            # encoders={
            #     'weight': IdentityEncoder(dtype=torch.float, is_tensor=True),
            #     'weight': BinaryEncoder(dtype=torch.long),
            # },
        )
        # edge_label = torch.div(edge_label, 100)
        data['sotu', 'has_host', 'taxon'].edge_index = edge_index
        data['sotu', 'has_host', 'taxon'].edge_label = edge_label

    # if 'taxon_has_parent_edges.csv' in rel_file_paths:
    #     edge_index, edge_label = load_edge_tensor(
    #         filename=f'{dir_name}/taxon_has_parent_edges.csv',
    #         src_index_col='sourceAppId',
    #         src_mapping=taxon_mapping,
    #         dst_index_col='targetAppId',
    #         dst_mapping=taxon_mapping,
    #         encoders={
    #             'weight': IdentityEncoder(dtype=torch.float, is_tensor=True)
    #         },
    #     )
    #     data['taxon', 'has_parent', 'taxon'].edge_index = edge_index
    #     data['taxon', 'has_parent', 'taxon'].edge_label = edge_label

    # if 'tissue_has_parent_edges.csv' in rel_file_paths:
    #     edge_index, edge_label = load_edge_tensor(
    #         filename=f'{dir_name}/tissue_has_parent_edges.csv',
    #         src_index_col='sourceAppId',
    #         src_mapping=tissue_mapping,
    #         dst_index_col='targetAppId',
    #         dst_mapping=tissue_mapping,
    #         encoders={
    #             'weight': IdentityEncoder(dtype=torch.float, is_tensor=True)
    #         },
    #     )
    #     data['tissue', 'has_parent', 'tissue'].edge_index = edge_index
    #     data['tissue', 'has_parent', 'tissue'].edge_label = edge_label

    # if 'sotu_sequence_alignment_edges.csv' in rel_file_paths:
    #     edge_index, edge_label = load_edge_tensor(
    #         filename=f'{dir_name}/sotu_sequence_alignment_edges.csv',
    #         src_index_col='sourceAppId',
    #         src_mapping=sotu_mapping,
    #         dst_index_col='targetAppId',
    #         dst_mapping=sotu_mapping,
    #         encoders={
    #             'weight': IdentityEncoder(dtype=torch.float, is_tensor=True)
    #         },
    #     )
    #     data['sotu', 'sequence_alignment', 'sotu'].edge_index = edge_index
    #     data['sotu', 'sequence_alignment', 'sotu'].edge_label = edge_label

    # if 'sotu_has_inferred_taxon_edges.csv' in rel_file_paths:
    #     edge_index, edge_label = load_edge_tensor(
    #         filename=f'{dir_name}/sotu_has_inferred_taxon_edges.csv',
    #         src_index_col='sourceAppId',
    #         src_mapping=sotu_mapping,
    #         dst_index_col='targetAppId',
    #         dst_mapping=taxon_mapping,
    #         encoders={
    #             'weight': IdentityEncoder(dtype=torch.float, is_tensor=True)
    #         },
    #     )
    #     data['sotu', 'has_inferred_taxon', 'taxon'].edge_index = edge_index
    #     data['sotu', 'has_inferred_taxon', 'taxon'].edge_label = edge_label

    node_types, edge_types = data.metadata()
    data = T.ToUndirected()(data)
    # if not ('taxon', 'rev_has_host', 'sotu') in edge_types:
    #     data = T.ToUndirected()(data)
        # Remove "reverse" label. (redundant if using link loader)
        # del data['taxon', 'rev_has_host', 'sotu'].edge_label
    return data, mappings



def split_data(data):
    num_test = (1 - MODEL_CFG['TRAIN_FRACTION']) * MODEL_CFG['TEST_FRACTION']
    num_val = 1 - MODEL_CFG['TRAIN_FRACTION'] - num_test
    # labels = data[('sotu', 'has_host', 'taxon')]['edge_label']
    # print(labels)
    # print(torch.min(labels))
    # print(torch.max(labels))

    transform = T.RandomLinkSplit(
        # Link-level split train (80%), validate (10%), and test edges (10%)
        # num_val=num_val,
        num_val=0.1,
        # num_test=num_test,
        num_test=0.1,
        # Of training edges, use 70% for message passing (edge_index)
        # and 30% for supervision (edge_label_index)
        disjoint_train_ratio=0.3,
        # Generate fixed negative edges for evaluation with a ratio of 2-1.
        # Negative edges during training will be generated on-the-fly.
        neg_sampling_ratio=2.0,
        add_negative_train_samples=False,
        # is_undirected=True,
        edge_types=('sotu', 'has_host', 'taxon'),
        rev_edge_types=('taxon', 'rev_has_host', 'sotu'),
    )
    train_data, val_data, test_data = transform(data)
    return train_data, val_data, test_data




def get_train_loader(data_obj, batch_size=128):
    # Define mini-batch loaders
    edge_label_index = data_obj[(
        'sotu', 'has_host', 'taxon')].edge_label_index

    edge_label = data_obj[(
        'sotu', 'has_host', 'taxon')].edge_label

    train_loader = LinkNeighborLoader(
        data=data_obj,
        # In the first hop, we sample at most 20 neighbors.
        # In the second hop, we sample at most 10 neighbors.
        num_neighbors=[20, 10],
        neg_sampling_ratio=3.0, #MODEL_CFG['NEGATIVE_SAMPLING_RATIO']
        # neg_sampling='binary',
        # let 'binary' setting handle this
        # edge_label=train_data[('sotu', 'has_host', 'taxon')].edge_label,
        edge_label_index=(('sotu', 'has_host', 'taxon'),
                          edge_label_index),
        edge_label=edge_label,
        batch_size=batch_size,
        shuffle=True,
        # num_workers=4,
    )
    return train_loader


def get_val_loader(val_data, batch_size=128*3):
    # Define the validation seed edges:
    edge_label_index = val_data['sotu', 'has_host', 'taxon'].edge_label_index
    edge_label = val_data['sotu', 'has_host', 'taxon'].edge_label

    val_loader = LinkNeighborLoader(
        data=val_data,
        num_neighbors=[20, 10],
        edge_label_index=(('sotu', 'has_host', 'taxon'),
                          edge_label_index),
        edge_label=edge_label,
        batch_size=batch_size,
        shuffle=False,
        # num_workers=4,
    )
    return val_loader

NameError: name 'MODEL_CFG' is not defined

In [35]:
data, mappings = create_pyg_graph()
train_data, val_data, test_data = split_data(data)

train_loader = get_train_loader(train_data)
val_loader = get_val_loader(val_data)
test_loader = get_val_loader(train_data)


In [36]:
print(train_data["sotu", "has_host", "taxon"].num_edges)
print(val_data["sotu", "has_host", "taxon"].num_edges)
print(val_data["sotu", "has_host", "taxon"].num_edges)

print(train_data["sotu", "has_host", "taxon"].edge_label.min())
print(train_data["sotu", "has_host", "taxon"].edge_label.max())

print(train_data["sotu", "has_host", "taxon"].edge_label.long().bincount().tolist())
print(val_data["sotu", "has_host", "taxon"].edge_label.long().bincount().tolist())


print('----')
sampled_data = next(iter(train_loader))
print(sampled_data)
print(sampled_data["sotu", "has_host", "taxon"].edge_label_index.size(1))
print(sampled_data["sotu", "has_host", "taxon"].edge_label.max())
print(sampled_data["sotu", "has_host", "taxon"].edge_label.min())

412424
589177
589177
tensor(1.)
tensor(1.)
[0, 176753]
[147294, 73647]
----
HeteroData(
  taxon={
    x=[738, 1],
    n_id=[738],
    num_sampled_nodes=[3],
  },
  sotu={
    x=[2842, 1],
    n_id=[2842],
    num_sampled_nodes=[3],
  },
  (sotu, has_host, taxon)={
    edge_index=[2, 2510],
    edge_label=[512],
    edge_label_index=[2, 512],
    e_id=[2510],
    num_sampled_edges=[2],
    input_id=[128],
  },
  (taxon, rev_has_host, sotu)={
    edge_index=[2, 7597],
    e_id=[7597],
    num_sampled_edges=[2],
  }
)
512
tensor(1.)
tensor(0.)


In [15]:
print(data.num_nodes)
# for batch in train_loader:
#     print(batch.validate())
#     # print(batch.edge_index_dict[('taxon', 'has_parent', 'taxon')].max())
#     print(batch.edge_index_dict[('sotu', 'sequence_alignment', 'sotu')].max())
#     print(batch.edge_index_dict[('taxon', 'has_parent', 'taxon')].shape )
# #     print(edge_index.max())

# print(data.num_node_features)
print(data.validate())

3015049
True


In [16]:
import torch
from torch_geometric.nn import SAGEConv, to_hetero
import torch.nn.functional as F

class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x


class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = torch.nn.Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = torch.nn.Linear(hidden_channels, 1)

    def forward(self, z_dict, edge_label_index):
        row, col = edge_label_index
        z = torch.cat([z_dict['sotu'][row], z_dict['taxon'][col]], dim=-1)

        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)


class Model(torch.nn.Module):
    def __init__(self, hidden_channels, data):
        super().__init__()
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.encoder = GNNEncoder(hidden_channels, hidden_channels)
        self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        z_dict = self.encoder(x_dict, edge_index_dict)
        return self.decoder(z_dict, edge_label_index)


def get_model(data):
    model = Model(
        hidden_channels=64,
        data=data,
    )
    return model

In [17]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False


def train(model, train_loader, optimizer, device):
    model.train()
    total_loss = total_examples = 0
    total_neg = 0
    total = 0
    for batch in train_loader:
        labels = batch['sotu', 'has_host', 'taxon'].edge_label
        total_neg += torch.count_nonzero(labels).item()
        total += labels.numel()

        batch = batch.to(device)
        optimizer.zero_grad()

        pred = model(
            batch.x_dict,
            batch.edge_index_dict,
            batch['sotu', 'has_host', 'taxon'].edge_label_index
        )

        target = batch['sotu', 'has_host', 'taxon'].edge_label.float()
        loss = F.binary_cross_entropy_with_logits(pred, target)
        loss.backward()
        optimizer.step()

        total_loss += float(loss)
        total_examples += pred.numel()

    return total_loss / total_examples


@torch.no_grad()
def test(model, loader, device):
    model.eval()

    preds, targets = [], []
    total_examples = total_correct = 0

    for batch in loader:
        batch = batch.to(device)

        pred = model(
            batch.x_dict,
            batch.edge_index_dict,
            batch['sotu', 'has_host', 'taxon'].edge_label_index
        )
        # pred = pred.sigmoid().view(-1).cpu()
        pred = pred.clamp(min=0, max=1)
        pred = (pred>0.5).float()

        target = batch['sotu', 'has_host', 'taxon'].edge_label.float() #.cpu()
        preds.append(pred)
        targets.append(target)

    pred = torch.cat(preds, dim=0).numpy()
    target = torch.cat(targets, dim=0).numpy()

    accuracy = accuracy_score(target, pred)
    # print(f"Accuracy: {accuracy:.4f}")

    # auc_roc = roc_auc_score(target, pred)
    # print(f"Validation AUC-ROC: {auc_roc:.4f}")

    auc_pr = average_precision_score(target, pred)
    # print(f"Validation AUC-PR: {auc_pr:.4f}")

    return accuracy


def update_stats(training_stats, epoch_stats):
    if training_stats is None:
        training_stats = {}
        for key in epoch_stats.keys():
            training_stats[key] = []
    for key, val in epoch_stats.items():
        training_stats[key].append(val)
    return training_stats



def train_and_eval_loop(model, train_loader, val_loader, test_loader):
    early_stopper = EarlyStopper(patience=3, min_delta=10)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    training_stats = None

    for epoch in range(1, MODEL_CFG['MAX_EPOCHS']):
        train_loss = train(model, train_loader, optimizer, device)
        train_acc = test(model, test_loader, device)
        val_acc = test(model, val_loader, device)
        epoch_stats = {'train_acc': train_acc, 'val_acc': val_acc,
                       'train_loss': train_loss, 'epoch': epoch}
        training_stats = update_stats(training_stats, epoch_stats)
        if epoch % 10 == 0:
            print(f"Epoch: {epoch:03d}")
            print(f"Train loss: {train_loss:.4f}")
            print(f"Train accuracy: {train_acc:.4f}")
            print(f"Validation accuracy: {val_acc:.4f}")

        if epoch > MODEL_CFG['MIN_EPOCHS'] \
                and early_stopper.early_stop(val_acc):
            break
    return training_stats

In [37]:
model = get_model(data)
stats = train_and_eval_loop(
    model, train_loader, val_loader, test_loader)

Epoch: 010
Train loss: 0.0000
Train accuracy: 0.9999
Validation accuracy: 0.9998
Epoch: 020
Train loss: 0.0000
Train accuracy: 0.9999
Validation accuracy: 0.9998


KeyboardInterrupt: 

In [32]:
print(stats)

{'train_acc': [0.9007774140752864, 0.935556464811784, 0.25572831423895254, 0.9392389525368249, 0.25736497545008186, 0.938011456628478, 0.9404664484451718, 0.2569558101472995, 0.2561374795417349, 0.9376022913256956, 0.2569558101472995, 0.9400572831423896, 0.25736497545008186, 0.9384206219312602, 0.9449672667757774, 0.25777414075286414, 0.9433306055646481, 0.9373977086743044, 0.2571603927986907, 0.9435351882160393, 0.2569558101472995, 0.9437397708674304, 0.2579787234042553, 0.978518821603928, 0.2608428805237316, 0.9478314238952537, 0.948240589198036, 0.8788870703764321, 0.2839607201309329, 0.9463993453355155, 0.9445581014729951, 0.9652209492635024, 0.26268412438625205, 0.26063829787234044, 0.9453764320785597, 0.2624795417348609, 0.983633387888707, 0.25941080196399346, 0.9449672667757774, 0.2590016366612111, 0.993657937806874, 0.2600245499181669, 0.9883387888707038, 0.9838379705400983, 0.2590016366612111, 0.9429214402618658, 0.2579787234042553, 0.9468085106382979, 0.2608428805237316, 0.95

In [21]:
model = get_model(data)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# weight = torch.bincount(train_data['sotu', 'has_host', 'taxon'].edge_label)
# weight = weight.max() / weight

weight = None

def weighted_mse_loss(pred, target, weight=None):
    weight = 1. if weight is None else weight[target].to(pred.dtype)
    return (weight * (pred - target.to(pred.dtype)).pow(2)).mean()

def train():
    model.train()
    optimizer.zero_grad()
    pred = model(
        train_data.x_dict,
        train_data.edge_index_dict,
        train_data['sotu', 'has_host', 'taxon'].edge_label_index
    )
    target = train_data['sotu', 'has_host', 'taxon'].edge_label
    # loss = weighted_mse_loss(pred, target, weight)
    # loss = F.binary_cross_entropy_with_logits(pred, target)
    loss = F.cross_entropy(pred, target)
    loss.backward()
    optimizer.step()
    return loss 


@torch.no_grad()
def test(data):
    model.eval()
    pred = model(
        data.x_dict,
        data.edge_index_dict,
        data['sotu', 'has_host', 'taxon'].edge_label_index
    )
    pred = pred.clamp(min=0, max=5)
    target = data['sotu', 'has_host', 'taxon'].edge_label.float()
    # rmse = F.mse_loss(pred, target).sqrt()
    auc_pr = average_precision_score(target, pred)
    return float(auc_pr)


for epoch in range(1, 301):
    loss = train()
    train_rmse = test(train_data)
    val_rmse = test(val_data)
    test_rmse = test(test_data)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, '
          f'Val: {val_rmse:.4f}, Test: {test_rmse:.4f}')

Epoch: 001, Loss: 62423.2227, Train: 1.0000, Val: 0.7471, Test: 0.7601
Epoch: 002, Loss: 3526901.5000, Train: 1.0000, Val: 0.8859, Test: 0.8855
Epoch: 003, Loss: 1013483.3125, Train: 1.0000, Val: 0.4918, Test: 0.4982
Epoch: 004, Loss: 895018.1250, Train: 1.0000, Val: 0.8454, Test: 0.8556
Epoch: 005, Loss: 976447.0625, Train: 1.0000, Val: 0.7759, Test: 0.7992
Epoch: 006, Loss: 787798.8750, Train: 1.0000, Val: 0.7489, Test: 0.7837
Epoch: 007, Loss: 451444.2188, Train: 1.0000, Val: 0.7469, Test: 0.7364
Epoch: 008, Loss: 165784.0625, Train: 1.0000, Val: 0.4272, Test: 0.4235
Epoch: 009, Loss: 258998.6406, Train: 1.0000, Val: 0.7814, Test: 0.7692
Epoch: 010, Loss: 137901.4375, Train: 1.0000, Val: 0.8090, Test: 0.7883
Epoch: 011, Loss: 175027.7344, Train: 1.0000, Val: 0.8116, Test: 0.7865


KeyboardInterrupt: 