In [1]:
import functions
import matplotlib.pyplot as plt
import pandas as pd
import torch

In [2]:
DATASET_PATH = "/path/to/LamaH-CE"
CHECKPOINT_PATH = "/path/to/checkpoint"

In [3]:
adj_df = pd.read_csv(f"{DATASET_PATH}/processed/adjacency_399_True.csv")
weight_cols = adj_df[["dist_hdn", "elev_diff", "strm_slope"]].values
stream_length = torch.tensor(weight_cols[:, 0], dtype=torch.float)
elevation_difference = torch.tensor(weight_cols[:, 1], dtype=torch.float)
average_slope = torch.tensor(weight_cols[:, 2], dtype=torch.float)

In [4]:
for edge_orientation in ["downstream", "upstream", "bidirectional"]:
    for architecture in ["ResGCN", "GCNII"]:
        print(architecture, edge_orientation)
        corrmats = []
        for fold in range(3):
            chkpt = torch.load(f"{CHECKPOINT_PATH}/{architecture}_{edge_orientation}_learned_{fold}.run")
            learned_weights = chkpt["history"]["best_model_params"]["edge_weights"].nan_to_num().cpu()
            corrmats.append(torch.corrcoef(torch.stack([learned_weights, stream_length, elevation_difference, average_slope]))) 
        print("correlation matrix mean:")
        print(torch.stack(corrmats).mean(dim=0))
        print("correlation matrix std:")
        print(torch.stack(corrmats).std(dim=0))
        print()

ResGCN downstream
correlation matrix mean:
tensor([[ 1.0000, -0.3748, -0.1484,  0.0748],
        [-0.3748,  1.0000,  0.3284, -0.1774],
        [-0.1484,  0.3284,  1.0000,  0.5832],
        [ 0.0748, -0.1774,  0.5832,  1.0000]])
correlation matrix std:
tensor([[6.8826e-08, 1.2207e-02, 5.8971e-03, 7.4733e-03],
        [1.2207e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [5.8971e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [7.4733e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00]])

GCNII downstream
correlation matrix mean:
tensor([[ 1.0000, -0.2848, -0.2142, -0.0340],
        [-0.2848,  1.0000,  0.3284, -0.1774],
        [-0.2142,  0.3284,  1.0000,  0.5832],
        [-0.0340, -0.1774,  0.5832,  1.0000]])
correlation matrix std:
tensor([[6.8826e-08, 1.3545e-02, 1.3289e-02, 1.7768e-02],
        [1.3545e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.3289e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.7768e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00]])

ResGCN upstream
correlation m

In [5]:
weight_stats_df = pd.DataFrame()
for edge_orientation in ["downstream", "upstream", "bidirectional"]:
    for architecture in ["ResGCN", "GCNII"]:
        print(architecture, edge_orientation)
        stats = []
        descriptors = ["mean", "std", "min", "25%", "median", "75%", "max"]
        for fold in range(3):
            chkpt = torch.load(f"{CHECKPOINT_PATH}/{architecture}_{edge_orientation}_learned_{fold}.run")
            learned_weights = chkpt["history"]["best_model_params"]["edge_weights"].nan_to_num().cpu().clamp(min=0)
            stats.append([learned_weights.mean(), 
                          learned_weights.std(), 
                          learned_weights.min(),
                          learned_weights.quantile(0.25),
                          learned_weights.median(),
                          learned_weights.quantile(0.75),
                          learned_weights.max()])
        stats = torch.tensor(stats)
        for i, descriptor in enumerate(descriptors):
            stat_mean = stats[:, i].mean()
            stat_std = stats[:, i].std()
            weight_stats_df.loc[descriptor, f"{edge_orientation}_{architecture}"] = f"{stat_mean:.3f} ± {stat_std:.3f}"
weight_stats_df

ResGCN downstream
GCNII downstream
ResGCN upstream
GCNII upstream
ResGCN bidirectional
GCNII bidirectional


Unnamed: 0,downstream_ResGCN,downstream_GCNII,upstream_ResGCN,upstream_GCNII,bidirectional_ResGCN,bidirectional_GCNII
mean,0.462 ± 0.082,0.263 ± 0.072,0.670 ± 0.039,0.627 ± 0.056,0.789 ± 0.044,0.618 ± 0.062
std,0.322 ± 0.013,0.281 ± 0.038,0.375 ± 0.004,0.369 ± 0.013,0.329 ± 0.008,0.361 ± 0.005
min,0.000 ± 0.000,0.000 ± 0.000,0.000 ± 0.000,0.000 ± 0.000,0.061 ± 0.045,0.000 ± 0.000
25%,0.191 ± 0.086,0.033 ± 0.033,0.382 ± 0.049,0.345 ± 0.066,0.556 ± 0.039,0.342 ± 0.079
median,0.416 ± 0.084,0.158 ± 0.091,0.708 ± 0.064,0.628 ± 0.075,0.802 ± 0.049,0.592 ± 0.066
75%,0.689 ± 0.102,0.434 ± 0.092,0.959 ± 0.036,0.894 ± 0.078,1.032 ± 0.040,0.901 ± 0.075
max,1.313 ± 0.089,1.376 ± 0.094,1.471 ± 0.052,1.609 ± 0.102,1.565 ± 0.037,1.668 ± 0.083
