## Load a model!

In [None]:
import time
import os
import numpy as np

import torch, torch_geometric.transforms as T, torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler

from torch_geometric.loader import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint

import matplotlib.pyplot as plt

from sklearn.metrics import (
    roc_auc_score,
    roc_curve,
    auc,
    average_precision_score,
    f1_score,
    accuracy_score,
    precision_score,
    recall_score,
)

import wandb
import pickle

from cancernet.arch import GCNNet, GATNet, InteractionNet
from cancernet.util import ProgressBar, InMemoryLogger, get_roc
from cancernet import PnetDataSet

In [None]:
base_data_string="/mnt/home/cpedersen/ceph/Data/data"

dataset = PnetDataSet(
    root=os.path.join(base_data_string, "prostate"),
    name="prostate_graph_humanbase",
    # files={'graph_file': "global.geneSymbol.gz"},
    edge_tol=0.5,
    pre_transform=T.Compose(
        [T.GCNNorm(add_self_loops=False), T.ToSparseTensor(remove_edge_index=False)]),)

splits_root = os.path.join(base_data_string, "prostate", "splits")
dataset.split_index_by_file(
    train_fp=os.path.join(splits_root, "training_set_0.csv"),
    valid_fp=os.path.join(splits_root, "validation_set.csv"),
    test_fp=os.path.join(splits_root, "test_set.csv"))

pl.seed_everything(42, workers=True)

n_epochs = 10
batch_size = 10


test_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    sampler=SubsetRandomSampler(dataset.test_idx),
    generator=torch.Generator().manual_seed(43),
)


In [None]:
## Our best GCN
## https://wandb.ai/cancer-net/hyperparam_sweeps_May/runs/bv8vmf70/
model_gcn = GCNNet(dims=[3, 256, 256, 256], lr=0.0061575)
model_gcn.load_state_dict(torch.load("/mnt/home/cpedersen/ceph/cancer-net-models/hyperparam_sweeps/GCN_no_early/wandb/run-20230518_010814-bv8vmf70/files/model_weights.pt"))

## Our best GAT
## https://wandb.ai/cancer-net/hyperparam_sweeps_May/runs/qf3s2ze6
model_gat = GATNet(dims=[3, 32,32, 32], lr=0.0001017)
model_gat.load_state_dict(torch.load("/mnt/home/cpedersen/ceph/cancer-net-models/hyperparam_sweeps/GAT_no_early/wandb/run-20230518_001508-qf3s2ze6/files/model_weights.pt"))

## Our best MetaLayer
## https://wandb.ai/cancer-net/hyperparam_sweeps_May/runs/kg827ae8
model_metalayer = InteractionNet(layers=5,
                        hidden=64,
                        lr=0.000068689)
model_metalayer.load_state_dict(torch.load("/mnt/home/cpedersen/ceph/cancer-net-models/hyperparam_sweeps/MetaLayer_no_early/wandb/run-20230518_001513-kg827ae8/files/model_weights.pt"))

In [None]:
test_data=next(iter(test_loader))

In [None]:
model_gcn(test_data)

In [None]:
model_gat(test_data)

In [None]:
model_metalayer(test_data)

In [None]:
fpr_test, tpr_test, test_auc, ys, outs = get_roc(model_metalayer, test_loader)

In [None]:
fig, ax = plt.subplots()
ax.plot(fpr_test, tpr_test, lw=2, label="test (area = %0.3f)" % test_auc)
ax.plot([0, 1], [0, 1], color="black", lw=1, linestyle="--")
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.set_title("Receiver operating characteristic")
ax.legend(loc="lower right", frameon=False)