Simple pipeline and plyground for analizing temporal graphs by using DMGI

In [1]:
import sys
import logging
import torch
sys.path.append("../")
from temporal_graphs.src.data.mock_data import get_temporal_mock_graph
from temporal_graphs.src.temporal_graph.build_graph import create_torch_temporal_graph_from_df
from temporal_graphs.src.models.dmgi_trainer import DMGITrainer

logging.basicConfig(level="INFO")
prefix = "dmgi"

# Data

In [2]:
graph = get_temporal_mock_graph()
graph = create_torch_temporal_graph_from_df(graph, save=False, path_prefix="../")

In [3]:
graph

HeteroData(
  [1mnode[0m={ x=[6, 4] },
  [1m(node, in_date_group_0, node)[0m={ edge_index=[2, 29] },
  [1m(node, in_date_group_1, node)[0m={ edge_index=[2, 15] },
  [1m(node, in_date_group_2, node)[0m={ edge_index=[2, 15] },
  [1m(node, in_date_group_3, node)[0m={ edge_index=[2, 17] },
  [1m(node, in_date_group_4, node)[0m={ edge_index=[2, 24] }
)

# Train

In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
device = torch.device("cpu")
device

device(type='cpu')

In [5]:
model = DMGITrainer(data=graph, out_channels=10, conv_name="GCNConv", normalize_features=False, device=device)

In [6]:
model.train(epochs=100, learning_rate=0.005, weight_decay=0.00005, print_every_n_epoch=10)

INFO:temporal_graphs.src.models.dmgi_trainer:Epoch: 001, Loss: 7.238
INFO:temporal_graphs.src.models.dmgi_trainer:Epoch: 011, Loss: 6.914
INFO:temporal_graphs.src.models.dmgi_trainer:Epoch: 021, Loss: 6.988
INFO:temporal_graphs.src.models.dmgi_trainer:Epoch: 031, Loss: 6.806
INFO:temporal_graphs.src.models.dmgi_trainer:Epoch: 041, Loss: 7.028
INFO:temporal_graphs.src.models.dmgi_trainer:Epoch: 051, Loss: 6.883
INFO:temporal_graphs.src.models.dmgi_trainer:Epoch: 061, Loss: 6.820
INFO:temporal_graphs.src.models.dmgi_trainer:Epoch: 071, Loss: 6.844
INFO:temporal_graphs.src.models.dmgi_trainer:Epoch: 081, Loss: 6.854
INFO:temporal_graphs.src.models.dmgi_trainer:Epoch: 091, Loss: 6.869


In [7]:
model.save(path=f"../temporal_graphs/models/{prefix}_model.pt")

In [8]:
embeddings = model.get_embeddings()

In [9]:
torch.save(embeddings, f"../temporal_graphs/models/{prefix}_embeddings.pt")

In [10]:
embeddings = embeddings.cpu().detach().numpy().astype("double")