# Import Essentials:

In [None]:
from PUGNN import SW, HomoDataset, HomoDataReader, BoostedDataLoader, Trainer
import numpy as np

# Specify the `output` and `input` directories:

In [None]:
input_directory  = "/eos/user/m/mjalalva/Run1/June4/"
output_directory = './'

# Define a model and some metrics:

In [None]:
from torch_geometric.nn import GraphConv, MLP, global_add_pool, GATv2Conv, LayerNorm, global_mean_pool
import torch
import torch.nn.functional as F


class PUModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_features, GNN=GraphConv):
        super(PUModel, self).__init__()
        torch.manual_seed(12345)
        self.num_features = num_features
        
        self.conv1 = GNN(in_channels, hidden_channels, edge_dim=1, add_self_loops=False)
        self.norm1 = LayerNorm(hidden_channels)
        self.mlp1  = MLP([hidden_channels + num_features, 2*hidden_channels, 2*hidden_channels], norm='layer_norm')
        
        self.conv2 = GNN(hidden_channels, hidden_channels, edge_dim=1, add_self_loops=False)
        self.norm2 = LayerNorm(hidden_channels)
        self.mlp2  = MLP([3*hidden_channels + num_features, 2*hidden_channels, hidden_channels, hidden_channels//2, out_channels], norm='layer_norm')
        
    def forward(self, data):
        x, adj, features, batch = data.x, data.adj_t, torch.reshape(data.features, (-1, self.num_features)), data.batch
        
        # 1. Obtain node embeddings 
        x = self.conv1(x, adj)
        x = self.norm1(x)
        x = x.relu()
        
        g = self.mlp1(torch.cat([global_mean_pool(x, batch), features], dim=1))
    
        x = self.conv2(x, adj)
        x = self.norm2(x)
        x = x.relu()

        g = self.mlp2(torch.cat([global_mean_pool(x, batch), g, features], dim=1))
        
        return g
    
class Bias(torch.nn.Module):
    def __init__(self):
        super().__init__()
        return
    
    def forward(self, input, target):
        return (input-target).mean()
    
class MAPE(torch.nn.Module):
    def __init__(self):
        super().__init__()
        return
        
    def forward(self, input, target):
        return (abs(input-target)/target).mean()

In [None]:
model = PUModel(in_channels=12, hidden_channels=150, num_features=7, out_channels=1, GNN=GATv2Conv)

# Make an Instance of `SW` Class with Your `DataSet` and a `DataLoader` Class:

In [None]:
software1 = SW(input_directory, output_directory, name='pu-test')
sample_metadata = dict(zip(
    map(str, list(range(20, 81))), # Range on PU values
    np.ones(61) * 10               # Number of events at a given PU
))

software1.set_dataset(HomoDataset,      # Your Dataset Type
                      sample_metadata, 
                      HomoDataReader()  # A Data Reader Function
                     )

software1.set_loader(BoostedDataLoader, # Your Dataloader Type
                     loading_workers=4,
                     batch_size=4,
                     num_workers=16
                    )

# Enter into the `TrainerScope` of Your `SW`:

In [None]:
with software1.trainer_scope(Trainer) as pu_trainer:
    pu_trainer.set(model)
    res = pu_trainer.train(
        max_epochs=50, optimizer=torch.optim.RAdam,
        optimizer_args=dict(lr=5e-3),
        loss_fn=torch.nn.L1Loss,
        metrics=[MAPE(), Bias()], select_topk=5,
        lr_scheduler=torch.optim.lr_scheduler.MultiStepLR,
        lr_scheduler_args=dict(milestones=[7, 15, 25, 35], gamma=0.06),
    )

# After Training, Let's Enter into the `AnalyzerScope`:

In [None]:
eval_model = PUModel(in_channels=12, hidden_channels=150, num_features=7, out_channels=1, GNN=GATv2Conv)

with software1.analyzer_scope() as pu_analyzer:
    pu_analyzer(eval_model, state_dicts, torch.nn.L1Loss())
    
    best_model = pu_analyzer.model
    res_plot   = pu_analyzer.residual_plot()
    metrics    = pu_analyzer.apply_metrics([Bias(), MAPE(), MAE(), MSE()])
    dist_plots = pu_analyzer.distribution_plots()
    LLERes     = pu_analyzer.rangeLLE(30,60)
    
    no_verts   = pu_analyzer.extract_feature(0, 7)
    comparing  = pu_analyzer.compare(
        # Here, you can add the outputs as follows:
        # `model2` = (model2_summary.y, model2_summary.yhat)
        # `model3` = ...
        NV = (pu_analyzer.y, nv)
    )

## Distribution plots

In [None]:
dist_plots.heatmap

In [None]:
dist_plots.histogram

In [None]:
dist_plots.kdeplot

## Residual plot

In [None]:
res_plot

## Compare the model with another models

In [None]:
comparing.plot   # Visualize the comparasion

In [None]:
comparing.R2     # R2 factor of model itself and with other models if they exist

## Log-likelihood Estimation through a Given Range of $<PU>=L\sigma$

In [None]:
LLERes.plot  # The plot of the result

In [None]:
LLERes.estimated_pu       # Estimated <PU> by model

In [None]:
LLERes.true_pu            # Expected or real <PU>

In [None]:
LLERes.lower_bond_error   # Lower-bond error of estimation

In [None]:
LLERes.upper_bond_error   # Upper-bond error of estimation