In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import os
from itertools import product
import wandb
import numpy as np
import pandas as pd

In [None]:
wandb_user = 'sisaman'
wandb_project = 'GAP'

In [None]:
class RunFactory:
    def __init__(self, entity, project, check_existing=True):
        self.project = project
        self.check_existing = check_existing

        if check_existing:
            api = wandb.Api()
            runs = api.runs(f"{entity}/{project}", per_page=2000)
            config_list = []
            for run in runs:
                config_list.append({k: v for k,v in run.config.items() if not k.startswith('_')})

            self.runs_df = pd.DataFrame.from_dict(config_list)
            self.runs_df['epsilon'] = self.runs_df['epsilon'].astype(float)
        
    
    def build(self, subcommand, **params):      
        for key, value in params.items():
            if not (isinstance(value, list) or isinstance(value, tuple)):
                params[key] = (value,)
        
        cmd_list = []
        configs = self.product_dict(params)

        for config in configs:
            if not self.check_existing or len(self.find_runs(config)) == 0:
                config['project'] = self.project
                options = ' '.join([f' --{param} {value} ' for param, value in config.items()])
                command = f'python train.py {subcommand} {options}'
                command = ' '.join(command.split())
                cmd_list.append(command)

        return cmd_list

    def find_runs(self, config):
        test_df = self.runs_df.loc[np.all([self.runs_df[k] == v for k, v in config.items()], axis=0), :]
        return test_df

    @staticmethod
    def product_dict(params):
        keys = params.keys()
        vals = params.values()
        for instance in product(*vals):
            yield dict(zip(keys, instance))


# Experiments

In [None]:
run_factory = RunFactory(entity=wandb_user, project=wandb_project, check_existing=True)

# DEFAULT PARAMS
dataset=['facebook', 'reddit', 'amazon']
epsilon={
    'edge': {
        'standard': list(range(1,10,2)),
        'extended': [0.2, 0.4, 0.6, 0.8] + list(range(1,10,2)),
    },
    'node': {
        'standard': list(range(5,30,5)),
        'extended': [1, 2, 3, 4] + list(range(5,30,5)),
    }
}
hops=[1,2,3,4,5]
max_degree = {
    'edge': -1,
    'node': {
        'facebook': {
            'standard': [10,20,50,100,200],
            'extended': [10,20,50,100,200],
        },
        'reddit': {
            'standard': [50,100,200,300,400],
            'extended': [50,100,200,300,400],
        },
        'amazon': {
            'standard': [10,20,50,100,200],
            'extended': [10,20,50,100,200],
        },
    }
}
hidden_dim=[16]
encoder_layers=2
pre_layers=1
post_layers=1
combine='cat'
activation='selu'
dropout=0
batch_norm=True
optimizer='adam'
learning_rate=0.01,
weight_decay=0,
pre_epochs = {
    'edge': 100,
    'node': 10,
}
epochs = {
    'edge': 100,
    'node': 10,
}
batch_size = {
    'edge': -1,
    'node': {
        'facebook': 256,
        'reddit':   2048,
        'amazon':   4096,
    }
}
max_grad_norm=1
repeats=10
logger='wandb'

cmd = []

cmd += run_factory.build(
    subcommand='gap',
    name='GAP-INF',
    dataset=dataset,
    dp_level='edge',
    epsilon=np.inf,
    hops=hops,
    max_degree=max_degree['edge'],
    hidden_dim=hidden_dim,
    encoder_layers=encoder_layers,
    pre_layers=pre_layers,
    post_layers=post_layers,
    combine=combine,
    activation=activation,
    dropout=dropout,
    batch_norm=batch_norm,
    optimizer=optimizer,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    pre_epochs=pre_epochs['edge'],
    epochs=epochs['edge'],
    batch_size=batch_size['edge'],
    max_grad_norm=max_grad_norm,
    repeats=repeats,
    logger=logger
)

cmd += run_factory.build(
    subcommand='gap',
    name='GAP-EDP',
    dataset=dataset,
    dp_level='edge',
    epsilon=epsilon['edge']['extended'],
    hops=hops,
    max_degree=max_degree['edge'],
    hidden_dim=hidden_dim,
    encoder_layers=encoder_layers,
    pre_layers=pre_layers,
    post_layers=post_layers,
    combine=combine,
    activation=activation,
    dropout=dropout,
    batch_norm=batch_norm,
    optimizer=optimizer,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    pre_epochs=pre_epochs['edge'],
    epochs=epochs['edge'],
    batch_size=batch_size['edge'],
    max_grad_norm=max_grad_norm,
    repeats=repeats,
    logger=logger
)

# GAP-EDP W/O EM
cmd += run_factory.build(
    subcommand='gap',
    name='GAP-EDP',
    dataset=dataset,
    dp_level='edge',
    epsilon=epsilon['edge']['standard'],
    hops=hops,
    max_degree=max_degree['edge'],
    hidden_dim=hidden_dim,
    encoder_layers=0,
    pre_layers=pre_layers,
    post_layers=post_layers,
    combine=combine,
    activation=activation,
    dropout=dropout,
    batch_norm=batch_norm,
    optimizer=optimizer,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    pre_epochs=0,
    epochs=epochs['edge'],
    batch_size=batch_size['edge'],
    max_grad_norm=max_grad_norm,
    repeats=repeats,
    logger=logger
)


cmd += run_factory.build(
    subcommand='gap',
    name='MLP',
    dataset=dataset,
    dp_level='edge',
    epsilon=0,
    hops=0,
    max_degree=max_degree['edge'],
    hidden_dim=hidden_dim,
    encoder_layers=0,
    pre_layers=encoder_layers,
    post_layers=post_layers,
    combine=combine,
    activation=activation,
    dropout=dropout,
    batch_norm=batch_norm,
    optimizer=optimizer,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    pre_epochs=0,
    epochs=epochs['edge'],
    batch_size=batch_size['edge'],
    max_grad_norm=max_grad_norm,
    repeats=repeats,
    logger=logger
)

cmd += run_factory.build(
    subcommand='sage',
    name='SAGE-EDP',
    dataset=dataset,
    dp_level='edge',
    epsilon=epsilon['edge']['extended'],
    max_degree=max_degree['edge'],
    hidden_dim=hidden_dim,
    encoder_layers=encoder_layers,
    mp_layers=hops,
    post_layers=post_layers,
    activation=activation,
    dropout=dropout,
    batch_norm=batch_norm,
    optimizer=optimizer,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    epochs=epochs['edge'],
    batch_size=batch_size['edge'],
    max_grad_norm=max_grad_norm,
    repeats=repeats,
    logger=logger
)

cmd += run_factory.build(
    subcommand='sage',
    name='SAGE-INF',
    dataset=dataset,
    dp_level='edge',
    epsilon=np.inf,
    max_degree=max_degree['edge'],
    hidden_dim=hidden_dim,
    encoder_layers=encoder_layers,
    mp_layers=hops,
    post_layers=post_layers,
    activation=activation,
    dropout=dropout,
    batch_norm=batch_norm,
    optimizer=optimizer,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    epochs=epochs['edge'],
    batch_size=batch_size['edge'],
    max_grad_norm=max_grad_norm,
    repeats=repeats,
    logger=logger
)

for dataset_name in dataset:
    cmd += run_factory.build(
        subcommand='gap',
        name='GAP-NDP',
        dataset=dataset_name,
        dp_level='node',
        epsilon=epsilon['node']['extended'],
        hops=hops,
        max_degree=max_degree['node'][dataset_name]['extended'],
        hidden_dim=hidden_dim,
        encoder_layers=encoder_layers,
        pre_layers=pre_layers,
        post_layers=post_layers,
        combine=combine,
        activation=activation,
        dropout=dropout,
        batch_norm=False,
        optimizer=optimizer,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        pre_epochs=pre_epochs['node'],
        epochs=epochs['node'],
        batch_size=batch_size['node'][dataset_name],
        max_grad_norm=max_grad_norm,
        repeats=repeats,
        logger=logger
    )
    
    # GAP-NDP W/O EM
    cmd += run_factory.build(
        subcommand='gap',
        name='GAP-NDP',
        dataset=dataset_name,
        dp_level='node',
        epsilon=epsilon['node']['standard'],
        hops=hops,
        max_degree=max_degree['node'][dataset_name]['standard'],
        hidden_dim=hidden_dim,
        encoder_layers=0,
        pre_layers=pre_layers,
        post_layers=post_layers,
        combine=combine,
        activation=activation,
        dropout=dropout,
        batch_norm=False,
        optimizer=optimizer,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        pre_epochs=0,
        epochs=epochs['node'],
        batch_size=batch_size['node'][dataset_name],
        max_grad_norm=max_grad_norm,
        repeats=repeats,
        logger=logger
    )
    
    cmd += run_factory.build(
        subcommand='sage',
        name='SAGE-NDP',
        dataset=dataset_name,
        dp_level='node',
        epsilon=epsilon['node']['extended'],
        max_degree=max_degree['node'][dataset_name]['standard'],
        hidden_dim=hidden_dim,
        encoder_layers=encoder_layers,
        mp_layers=1,
        post_layers=post_layers,
        activation=activation,
        dropout=dropout,
        batch_norm=False,
        optimizer=optimizer,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        epochs=epochs['node'],
        batch_size=batch_size['node'][dataset_name],
        max_grad_norm=max_grad_norm,
        repeats=repeats,
        logger=logger
    )

    cmd += run_factory.build(
        subcommand='gap',
        name='MLP-DP',
        dataset=dataset_name,
        dp_level='node',
        epsilon=epsilon['node']['extended'],
        hops=0,
        max_degree=-1,
        hidden_dim=hidden_dim,
        encoder_layers=0,
        pre_layers=encoder_layers,
        post_layers=post_layers,
        combine=combine,
        activation=activation,
        dropout=dropout,
        batch_norm=False,
        optimizer=optimizer,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        pre_epochs=0,
        epochs=epochs['node'],
        batch_size=batch_size['node'][dataset_name],
        max_grad_norm=max_grad_norm,
        repeats=repeats,
        logger=logger
    )


# shuffle(cmd)

filename = 'jobs/experiments.sh'
os.makedirs('jobs', exist_ok=True)
with open(filename, 'w') as file:
    for item in cmd:
        print(item, file=file)

print('new:        ', len(cmd))