In [None]:
import os

import pandas as pd
import pytorch_lightning as pl

import clip_graph as cg

In [None]:
os.chdir(os.path.expanduser('~/github/congrat'))

In [None]:
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

In [None]:
pl.seed_everything(2969591811)

In [None]:
stats = {}

# Twitter

## Causal

In [None]:
dm = cg.utils.datamodule_from_yaml('configs/eval-datasets/twitter-small/causal.yaml')['dm']

# these don't differ between masked and causal; we only need to do it once
stats['twitter-small'] = {
    'num_edges': dm._get_edgelist().shape[0],
    'num_texts': len(dm.dataset.text),
    'num_nodes': dm.dataset.graph_data.node_ids.shape[0],
}

In [None]:
next(iter(dm.val_dataloader()))

In [None]:
del dm

## Masked

In [None]:
dm = cg.utils.datamodule_from_yaml('configs/eval-datasets/twitter-small/masked.yaml')['dm']

In [None]:
print(next(iter(dm.val_dataloader())))

In [None]:
del dm

# Pubmed

## Causal

In [None]:
dm = cg.utils.datamodule_from_yaml('configs/eval-datasets/pubmed/causal.yaml')['dm']

# these don't differ between masked and causal; we only need to do it once
stats['pubmed'] = {
    'num_edges': dm._get_edgelist().shape[0],
    'num_texts': len(dm.dataset.text),
    'num_nodes': dm.dataset.graph_data.node_ids.shape[0],
}

In [None]:
print(next(iter(dm.val_dataloader())))

In [None]:
del dm

## Masked

In [None]:
dm = cg.utils.datamodule_from_yaml('configs/eval-datasets/pubmed/masked.yaml')['dm']

In [None]:
print(next(iter(dm.val_dataloader())))

In [None]:
del dm

# TRex

## Causal

In [None]:
dm = cg.utils.datamodule_from_yaml('configs/eval-datasets/trex/causal.yaml')['dm']

# these don't differ between masked and causal; we only need to do it once
stats['trex'] = {
    'num_edges': dm._get_edgelist().shape[0],
    'num_texts': len(dm.dataset.text),
    'num_nodes': dm.dataset.graph_data.node_ids.shape[0],
}

In [None]:
print(next(iter(dm.val_dataloader())))

In [None]:
del dm

## Masked

In [None]:
dm = cg.utils.datamodule_from_yaml('configs/eval-datasets/trex/masked.yaml')['dm']

In [None]:
print(next(iter(dm.val_dataloader())))

In [None]:
del dm

# Write out dataset stats

To copy into the LaTeX doc

In [None]:
tmp = pd.DataFrame(pd.Series({
    d + '_' + k : stats[d][k]
    for d in stats
    for k in stats[d]
}), columns=['value'])

tmp.index.name = 'datavar'
tmp = tmp.reset_index()

tmp.to_csv('data/dataset-stats.csv', index=False)

In [None]:
!cat data/dataset-stats.csv