# Starting

In [1]:
import sys
sys.path.append('../')

In [53]:
from typing import Any, Optional, List, Dict, Tuple, Union
import torch
from torch import Tensor
from torch_geometric.loader import DataLoader
import torch.nn.functional as F
import numpy as np
import tqdm

In [3]:
from helpers.utils import set_seed
from helpers.encoders import PosEncoder
from helpers.metrics import LossesAndMetrics
from helpers.model import ModelType
from helpers.classes import GumbelArgs, EnvArgs, ActionNetArgs, ActivationType
from helpers.classes import Pool
from helpers.dataset_classes.dataset import DatasetBySplit

from models.CoGNN import CoGNN




## Configurations

In [4]:
seed = 0
dataset_string = 'roman_empire'
act_dim = 16 
env_dim = 64
act_num_layers = 1
env_num_layers = 3
act_model_type = 'MEAN_GNN'
env_model_type = 'MEAN_GNN'
pos_enc = PosEncoder.NONE

# Set the seed
set_seed(seed=seed)

In [5]:
from helpers.dataset_classes.dataset import DataSet

## load datasets

In [6]:
pre_dataset = DataSet.from_string(dataset_string)

# download the dataset at this step
dataset = pre_dataset.load(seed=seed, pos_enc=pos_enc)
dataset

[Data(x=[22662, 300], edge_index=[2, 65854], y=[22662], train_mask=[22662, 10], val_mask=[22662, 10], test_mask=[22662, 10])]

### Things relate to datasets

In [7]:
fold = None

In [8]:
metric_type = pre_dataset.get_metric_type()
decimal = pre_dataset.num_after_decimal()
task_loss = metric_type.get_task_loss()
out_dim = metric_type.get_out_dim(dataset=dataset)
gin_mlp_func = pre_dataset.gin_mlp_func()
env_act_type = pre_dataset.env_activation_type()
folds = pre_dataset.get_folds(fold)
metric_type, decimal, task_loss, out_dim, gin_mlp_func, env_act_type, folds

(<MetricType.ACCURACY: 1>,
 2,
 CrossEntropyLoss(),
 18,
 <function helpers.dataset_classes.dataset.DataSet.gin_mlp_func.<locals>.mlp_func(in_channels: int, out_channels: int, bias: bool)>,
 <ActivationType.GELU: 2>,
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [9]:
device = "cpu"
learn_temp = False
temp_model_type = ModelType.LIN
tau0 = 0.5
temp = 0.01 # temperature
gin_mlp_func=gin_mlp_func

gumbel_args = GumbelArgs(
    learn_temp=learn_temp, temp_model_type=temp_model_type, tau0=tau0, temp=temp, gin_mlp_func=gin_mlp_func)
gumbel_args

GumbelArgs(learn_temp=False, temp_model_type=<ModelType.LIN: 3>, tau0=0.5, temp=0.01, gin_mlp_func=<function DataSet.gin_mlp_func.<locals>.mlp_func at 0x1041254c0>)

In [10]:
max_epochs = 300
batch_size = 32
lr = 1e-3
dropout = 0.2
step_size = None
gamma = None
num_warmup_epochs = None


### Environment net arguements

In [11]:
env_model_type = ModelType.MEAN_GNN
env_num_layers = env_num_layers
env_dim = 64
skip = False
batch_norm = False
layer_norm = False
dec_num_layers = 1
pos_enc = PosEncoder.NONE

env_args = EnvArgs(
    model_type=env_model_type, num_layers=env_num_layers,
    env_dim=env_dim, skip=skip, batch_norm=batch_norm,
    layer_norm=layer_norm, dec_num_layers=dec_num_layers,
    pos_enc=pos_enc, metric_type=metric_type, in_dim=dataset[0].x.shape[1],
    out_dim=out_dim, gin_mlp_func=gin_mlp_func,
    dropout=dropout, act_type=env_act_type, dataset_encoders=pre_dataset.get_dataset_encoders())
env_args

EnvArgs(model_type=<ModelType.MEAN_GNN: 5>, num_layers=3, env_dim=64, layer_norm=False, skip=False, batch_norm=False, dropout=0.2, act_type=<ActivationType.GELU: 2>, dec_num_layers=1, pos_enc=<PosEncoder.NONE: 1>, dataset_encoders=<DataSetEncoders.NONE: 1>, metric_type=<MetricType.ACCURACY: 1>, in_dim=300, out_dim=18, gin_mlp_func=<function DataSet.gin_mlp_func.<locals>.mlp_func at 0x1041254c0>)

### AactionNet arguments

In [12]:
action_args = ActionNetArgs(
    model_type=ModelType.MEAN_GNN,
    num_layers=act_num_layers,
    hidden_dim=act_dim,
    dropout=dropout,
    act_type=ActivationType.RELU,
    env_dim=env_dim,
    gin_mlp_func=gin_mlp_func,)

In [13]:
metrics_list = []
edge_ratios_list = []

## Load model and optimizer

In [14]:
model = CoGNN(gumbel_args=gumbel_args, env_args=env_args, action_args=action_args, pool=Pool.NONE).to(
            device=device)

optimizer = pre_dataset.optimizer(model=model, lr=lr, weight_decay=0)

scheduler = pre_dataset.scheduler(
    optimizer=optimizer,
    step_size=step_size,
    gamma=gamma,
    num_warmup_epochs=num_warmup_epochs,
    max_epochs=max_epochs)

model

CoGNN(
  (env_net): ModuleList(
    (0): EncoderLinear(in_features=300, out_features=64, bias=True)
    (1-3): 3 x WeightedGNNConv(64, 64)
    (4): Linear(in_features=64, out_features=18, bias=True)
  )
  (hidden_layer_norm): Identity()
  (dropout): Dropout(p=0.2, inplace=False)
  (in_act_net): ActionNet(
    (net): ModuleList(
      (0): WeightedGNNConv(64, 2)
    )
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (out_act_net): ActionNet(
    (net): ModuleList(
      (0): WeightedGNNConv(64, 2)
    )
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (pooling): BatchIdentity()
)

### Training Helpers

In [15]:
def train_and_test(dataset_by_split, model, optimizer, scheduler, pbar, num_fold: int):
    train_loader = DataLoader(dataset_by_split.train, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(dataset_by_split.val, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset_by_split.test, batch_size=batch_size, shuffle=True)

    best_losses_n_metrics = metric_type.get_worst_losses_n_metrics()
    for epoch in range(max_epochs):
        train_func(train_loader=train_loader, model=model, optimizer=optimizer)
        train_loss, train_metric, _ = test_func(
            loader=train_loader, model=model, split_mask_name="train_mask", calc_edge_ratio=False
        )
        if pre_dataset.is_expressivity():
            val_loss, val_metric = train_loss, train_metric
            test_loss, test_metric = train_loss, train_metric

        else:
            val_loss, val_metric, _ = test_func(
                loader=val_loader, model=model, split_mask_name="val_mask", calc_edge_ratio=False
            )
            test_loss, test_metric, _ = test_func(
                loader=test_loader, model=model, split_mask_name="test_mask", calc_edge_ratio=False
            )
        
        losses_n_metrics = LossesAndMetrics(
            train_loss=train_loss,
            val_loss=val_loss,
            test_loss=test_loss,
            train_metric=train_metric,
            val_metric=val_metric,
            test_metric=test_metric,
        )
        if scheduler is not None:
            scheduler.step(losses_n_metrics.val_metric)

        # best metrics
        if metric_type.src_better_than_other(
            src=losses_n_metrics.val_metric, other=best_losses_n_metrics.val_metric
        ):
            best_losses_n_metrics = losses_n_metrics
        
        log_str = f"Split: {num_fold}, epoch: {epoch}"
        for name in losses_n_metrics._fields:
            log_str += f",{name}={round(getattr(losses_n_metrics, name), decimal)}"
        log_str += f"({round(best_losses_n_metrics.test_metric, decimal)})"
        pbar.set_description(log_str)
        pbar.update(n=1)

    edge_ratios = None
    if pre_dataset.not_synthetic():
        _, _, edge_ratios = test_func(
            loader=test_loader, model=model, split_mask_name="test_mask", calc_edge_ratio=True
        )
    return best_losses_n_metrics, edge_ratios

def train_func(train_loader, model, optimizer):
    model.train()

    for data in train_loader:
        if batch_norm and (data.x.shape[0] == 1 or data.num_graphs == 1):
            continue
        optimizer.zero_grad()
        node_mask = pre_dataset.get_split_mask(
            data=data, batch_size=data.num_graphs, split_mask_name="train_mask"
        ).to(device=device)
        edge_attr = data.edge_attr
        if data.edge_attr is not None:
            edge_attr = edge_attr.to(device=device)
        
        scores, _ = model(
            data.x.to(device=device),
            edge_index=data.edge_index.to(device=device),
            batch=data.batch.to(device=device),
            edge_attr=edge_attr,
            edge_ratio_node_mask=None,
            pestat=pos_enc.get_pe(data=data, device=device),
        )
        train_loss = task_loss(scores[node_mask], data.y.to(device=device)[node_mask])

        train_loss.backward()
        if pre_dataset.clip_grad():
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()

def test_func(loader, model, split_mask_name: str, calc_edge_ratio: bool) -> Tuple[float, Any, Tensor]:
    model.eval()

    total_loss, total_metric, total_edge_ratios = 0, 0, 0
    total_scores = np.empty(shape=(0, model.env_args.out_dim))
    total_y = None
    for data in loader:
        if batch_norm and (data.x.shape[0] == 1 or data.num_graphs == 1):
            continue
        node_mask = pre_dataset.get_split_mask(
            data=data, batch_size=data.num_graphs, split_mask_name=split_mask_name
        ).to(device=device)
        if calc_edge_ratio:
            edge_ratio_node_mask = pre_dataset.get_edge_ratio_node_mask(
                data=data, split_mask_name=split_mask_name
            ).to(device=device)
        else:
            edge_ratio_node_mask = None
        edge_attr = data.edge_attr
        if data.edge_attr is not None:
            edge_attr = edge_attr.to(device=device)
        
        scores, edge_ratios = model(
            data.x.to(device=device),
            edge_index=data.edge_index.to(device=device),
            edge_attr=edge_attr,
            batch=data.batch.to(device=device),
            edge_ratio_node_mask=edge_ratio_node_mask,
            pestat=pos_enc.get_pe(data=data, device=device),
        )

        eval_loss = task_loss(scores, data.y.to(device=device))

        total_scores = np.concatenate((total_scores, scores[node_mask].detach().cpu().numpy()))
        if total_y is None:
            total_y = data.y.to(device=device)[node_mask].detach().cpu().numpy()
        else:
            total_y = np.concatenate((total_y, data.y.to(device=device)[node_mask].detach().cpu().numpy()))
        
        total_loss += eval_loss.item() * data.num_graphs
        total_edge_ratios += edge_ratios * data.num_graphs
    
    metric = metric_type.apply_metric(scores=total_scores, target=total_y)

    loss = total_loss / len(loader.dataset)
    edge_ratios = total_edge_ratios / len(loader.dataset)
    return loss, metric, edge_ratios

In [16]:
def single_fold(
    dataset_by_split, model, optimizer, scheduler, num_fold: int):
    with tqdm.tqdm(total=max_epochs, file=sys.stdout) as pbar:
        best_losses_n_metrics, edge_ratios = train_and_test(
            dataset_by_split=dataset_by_split,
            model=model,
            optimizer=optimizer,
            scheduler=scheduler,
            pbar=pbar,
            num_fold=num_fold,
        )
    return best_losses_n_metrics, edge_ratios

In [17]:
dataset_by_split = pre_dataset.select_fold_and_split(num_fold=2, dataset=dataset)

## Run training

In [18]:
best_losses_n_metrics, edge_ratios = single_fold(
        dataset_by_split=dataset_by_split,
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        num_fold=2,
    )

Split: 2, epoch: 299,train_loss=0.88,val_loss=0.88,test_loss=0.88,train_metric=0.71,val_metric=0.69,test_metric=0.69(0.68): 100%|██████████| 300/300 [02:13<00:00,  2.25it/s]


In [32]:
train_loader = DataLoader(dataset_by_split.train, batch_size=batch_size, shuffle=True)
sample_data = next(iter(train_loader))
sample_data

DataBatch(x=[22662, 300], edge_index=[2, 65854], y=[22662], train_mask=[22662], val_mask=[22662], test_mask=[22662], batch=[22662], ptr=[2])

In [38]:
sample_data.edge_index

tensor([[    0,     0,     1,  ..., 22661, 22661, 22661],
        [    1,     2,     0,  ..., 22653, 22659, 22660]])

In [37]:
sample_data.x.shape, sample_data.edge_index.shape, sample_data.y.shape

(torch.Size([22662, 300]), torch.Size([2, 65854]), torch.Size([22662]))

## Under the hood
> What's in the forward pass

In [44]:
with torch.no_grad():
    x_ = model.hidden_layer_norm(sample_data.x)
x_.shape

torch.Size([22662, 300])

In [47]:
with torch.no_grad():
    x_ = model.env_net[0](x_, pestat=pos_enc.get_pe(data=sample_data, device=device))
x_.shape

torch.Size([22662, 64])

In [51]:
with torch.no_grad():
    in_logits = model.in_act_net(
        x = x_, edge_index=sample_data.edge_index,
        env_edge_attr=None, act_edge_attr=None)

    out_logits = model.out_act_net(
        x = x_, edge_index=sample_data.edge_index,
        env_edge_attr=None, act_edge_attr=None)
in_logits.shape, out_logits.shape

(torch.Size([22662, 2]), torch.Size([22662, 2]))

In [54]:
with torch.no_grad():
    in_probs = F.gumbel_softmax(logits=in_logits, tau=temp, hard=True)
    out_probs = F.gumbel_softmax(logits=out_logits, tau=temp, hard=True)
in_probs, out_probs

(tensor([[1., 0.],
         [0., 1.],
         [0., 1.],
         ...,
         [0., 1.],
         [1., 0.],
         [1., 0.]]),
 tensor([[0., 1.],
         [1., 0.],
         [1., 0.],
         ...,
         [1., 0.],
         [1., 0.],
         [0., 1.]]))

In [57]:
in_probs.sum(-1).max(), in_probs.sum(-1).min()

(tensor(1.), tensor(1.))

In [61]:
u, v = sample_data.edge_index
u, v

(tensor([    0,     0,     1,  ..., 22661, 22661, 22661]),
 tensor([    1,     2,     0,  ..., 22653, 22659, 22660]))

In [64]:
with torch.no_grad():
    u, v = sample_data.edge_index
    edge_in_probs = in_probs[:,0][u]
    edge_out_probs = out_probs[:,0][v]
    edge_weights = edge_in_probs * edge_out_probs


In [67]:
edge_weights.mean()

tensor(0.7359)

In [70]:
model.env_net

ModuleList(
  (0): EncoderLinear(in_features=300, out_features=64, bias=True)
  (1-3): 3 x WeightedGNNConv(64, 64)
  (4): Linear(in_features=64, out_features=18, bias=True)
)

In [20]:
print_str = f"Fold {0}/{len(folds)}"
for name in best_losses_n_metrics._fields:
    print_str += f",{name}={round(getattr(best_losses_n_metrics, name), decimal)}"
print(print_str)
print()
metrics_list.append(best_losses_n_metrics.get_fold_metrics())

if edge_ratios is not None:
    edge_ratios_list.append(edge_ratios)

Fold 0/10,train_loss=0.89,val_loss=0.89,test_loss=0.89,train_metric=0.7,val_metric=0.69,test_metric=0.68



In [23]:
best_losses_n_metrics

LossesAndMetrics(train_loss=0.8871640563011169, val_loss=0.8867784142494202, test_loss=0.8861422538757324, train_metric=0.7033801078796387, val_metric=0.6907325387001038, test_metric=0.6849629282951355)

In [73]:
data_row = dataset[0]

In [76]:
x

NameError: name 'x' is not defined