# Import Packages

In [1]:
import os
import yaml

import numpy as np
import pandas as pd

import torch
import torch.optim as optim
import torch.utils.data as data

import dgl
import dgl.nn as dglnn
import dgl.function as fn
import dgl.nn.functional as F

from tqdm import tqdm

In [2]:
torch.set_num_threads(8)

In [3]:
data_root = '../datasets'

In [4]:
def save_config(data, config_path):
    with open(config_path, 'w') as f:
        data = yaml.dump(data, f)

def load_config(config_path):
    with open(config_path, 'r') as f:
        data = yaml.safe_load(f)
    
    return data

# Functions for Node Embeddings

In [5]:
def get_node_embeddings(graph, embedding_dim, device):
    model = dglnn.DeepWalk(
        g=graph.cpu(),
        emb_dim=embedding_dim, 
        walk_length=40, 
        window_size=5, 
        negative_size=1, 
        fast_neg=False, 
        sparse=False, 
    )

    batch_size = 256
    num_workers = 8
    dataloader = data.DataLoader(
        dataset=torch.arange(graph.num_nodes()), 
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=True, 
        collate_fn=model.sample,
        drop_last=True,
    )
    learning_rate = 1e-3
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    model.train()
    model = model.to(device)

    num_epochs = 16
    for epoch in range(1, num_epochs + 1):
        pbar = tqdm(dataloader, leave=False)
        pbar.set_description(f'epoch = {epoch}')
        for batch in pbar:
            batch = batch.to(device)
            loss = model(batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            pbar.set_postfix_str(f'loss = {loss.item():.4f}')

    node_embeds = model.node_embed.weight.detach()
    return node_embeds

# Prepare Graph Augmentation

In [6]:
dataset_names = [
    # 'tolokers-tab',
    # 'questions-tab',
    # 'city-reviews',
    # 'browser-games',
    # 'hm-categories',
    # 'web-fraud',
    # 'city-roads-M',
    # 'city-roads-L',
    # 'avazu-devices',
    # 'hm-prices',
    # 'web-traffic'
]

In [7]:
dataset_name_to_embedding_dim = {
    'tolokers-tab': 64,
    'questions-tab': 64,
    'city-reviews': 128,
    'browser-games': 128,
    'hm-categories': 96,
    'web-fraud': 256,
    'city-roads-M': 64,
    'city-roads-L': 128,
    'avazu-devices': 96,
    'hm-prices': 96,
    'web-traffic': 256,
}

In [8]:
gpu_index = 1
device = torch.device(f'cuda:{gpu_index}')

In [None]:
for dataset_name in dataset_names:
    print(dataset_name)
    dataset_path = f"{data_root}/{dataset_name}"

    edge_list = pd.read_csv(f"{dataset_path}/edgelist.csv").values
    graph = dgl.graph(tuple(torch.tensor(indices) for indices in edge_list.T))
    graph = dgl.to_bidirected(graph)

    embedding_dim = dataset_name_to_embedding_dim[dataset_name]
    node_embeds = get_node_embeddings(graph, embedding_dim, device)
    
    node_embeddings_path = f'{dataset_path}/node_embeddings.npz'
    np.savez_compressed(node_embeddings_path, node_embeds=node_embeds.cpu().numpy())

    node_embeddings_info = {'embedding_dim': embedding_dim}
    node_embeddings_info_path = f'{dataset_path}/node_embeddings_info.yaml'
    save_config(node_embeddings_info, node_embeddings_info_path)