In [None]:
import os

import torch

import pytorch_lightning as pl

import clip_graph as cg

In [None]:
os.chdir(os.path.expanduser('~/github/congrat'))

In [None]:
pl.seed_everything(2969591811)

# Load data

In [None]:
datamodules = {  # causal or masked doesn't matter, we're not using text
    'twitter-small': {
        'key': 'tsvd',
        'path': 'configs/eval-datasets/twitter-small/causal.yaml',
    },
    
    'pubmed': {
        'key': 'x',
        'path': 'configs/eval-datasets/pubmed/causal.yaml',
    },
    
    'trex': {
        'key': 'x',
        'path': 'configs/eval-datasets/trex/causal.yaml',
    },
    
    'twitter-small-directed': {
        'key': 'tsvd',
        'path': 'configs/eval-datasets/twitter-small/causal-directed.yaml',
    },
    
    'pubmed-directed': {
        'key': 'x',
        'path': 'configs/eval-datasets/pubmed/causal-directed.yaml',
    },
    
    'trex-directed': {
        'key': 'x',
        'path': 'configs/eval-datasets/trex/causal-directed.yaml',
    },
}

In [None]:
datamodules = {
    k : {
        'key': v['key'],
        'path': v['path'],
        'dm': cg.utils.datamodule_from_yaml(v['path'])['dm'],
    }
    
    for k, v in datamodules.items()
}

# SVD node embeddings

In [None]:
for name, obj in datamodules.items():
    os.makedirs(f'data/svd-initial-vectors/{name}/', exist_ok=True)
    
    torch.save(
        getattr(obj['dm'].dataset.graph_data, obj['key']),
        f'data/svd-initial-vectors/{name}/full.pt',
    )
    
    torch.save(
        getattr(obj['dm'].train_dataset.dataset.graph_data, obj['key']),
        f'data/svd-initial-vectors/{name}/train.pt',
    )
    
    torch.save(
        getattr(obj['dm'].val_dataset.dataset.graph_data, obj['key']),
        f'data/svd-initial-vectors/{name}/val.pt',
    )
    
    torch.save(
        getattr(obj['dm'].test_dataset.dataset.graph_data, obj['key']),
        f'data/svd-initial-vectors/{name}/test.pt',
    )

# Splits

In [None]:
for name, obj in datamodules.items():
    os.makedirs(f'data/split-node-ids/{name}/', exist_ok=True)
    #
    # All data
    #
    
    torch.save(
        obj['dm'].dataset.text_node_ids,
        f'data/split-node-ids/{name}/full-text.pt'
    )
    
    torch.save(
        obj['dm'].dataset.graph_data.node_ids,
        f'data/split-node-ids/{name}/full-graph.pt'
    )
    
    #
    # Train set
    #
    
    torch.save(
        obj['dm'].train_dataset.dataset.text_node_ids,
        f'data/split-node-ids/{name}/train-text.pt'
    )
    
    torch.save(
        obj['dm'].train_dataset.dataset.graph_data.node_ids,
        f'data/split-node-ids/{name}/train-graph.pt'
    )
    
    #
    # Val set
    #
    
    torch.save(
        obj['dm'].val_dataset.dataset.text_node_ids,
        f'data/split-node-ids/{name}/val-text.pt'
    )
    
    torch.save(
        obj['dm'].val_dataset.dataset.graph_data.node_ids,
        f'data/split-node-ids/{name}/val-graph.pt'
    )
    
    #
    # Test set
    #
    
    torch.save(
        obj['dm'].test_dataset.dataset.text_node_ids,
        f'data/split-node-ids/{name}/test-text.pt'
    )
    
    torch.save(
        obj['dm'].test_dataset.dataset.graph_data.node_ids,
        f'data/split-node-ids/{name}/test-graph.pt'
    )