# Assessment of the network embeddings

In this notebook we aim to assess the network embeddings on a node classification task. The goal is to ensure that the embeddings are generic enough to infer functional labels given by gene set annotations for each node.


Before we do that however, we will first 

---

## 0. Environmental setup

In [47]:
import pandas as pd
import numpy as np
import sys
import torch
from torch import nn
from torch.optim import Adam
import time
import copy

sys.path.append("../../../")

from src.utils.basic.io import get_genesets_from_gmt_file
from src.models.clf import SimpleClassifier
from sklearn.model_selection import train_test_split
from src.utils.torch.general import get_device

device = "cpu"

%load_ext nb_black

The nb_black extension is already loaded. To reload it, use:
  %reload_ext nb_black


<IPython.core.display.Javascript object>

---

## 1. Read in data

In [104]:
hallmark_geneset_dict = get_genesets_from_gmt_file(
    "../../../data/other/h.all.v7.4.symbols.gmt"
)

<IPython.core.display.Javascript object>

In [263]:
node_embs = pd.read_csv(
    "../../../data/ppi/embedding/node_embeddings_cv_wo_hallmark.csv", index_col=0
)

<IPython.core.display.Javascript object>

In [264]:
labels = np.zeros([len(node_embs), len(hallmark_geneset_dict)])

for i in range(len(labels)):
    for j in range(len(hallmark_geneset_dict)):
        if list(node_embs.index)[i] in list(hallmark_geneset_dict.values())[j]:
            labels[i, j] = 1
labels = pd.DataFrame(
    labels, index=node_embs.index, columns=list(hallmark_geneset_dict.keys())
)

<IPython.core.display.Javascript object>

In [265]:
labels = labels.drop(columns=np.array(labels.columns)[labels.sum(axis=0) < 20])
labels = labels.iloc[:, 0]

<IPython.core.display.Javascript object>

In [266]:
labels

CDCA2     0.0
APPBP2    0.0
TK1       0.0
MMP2      0.0
PRKACA    0.0
         ... 
ATF4      0.0
CENPA     0.0
PLCB2     0.0
CDK2      0.0
TUBB6     0.0
Name: HALLMARK_MITOTIC_SPINDLE, Length: 244, dtype: float64

<IPython.core.display.Javascript object>

In [267]:
from sklearn.preprocessing import StandardScaler

node_embs = StandardScaler().fit_transform(node_embs)

idc = list(range(len(node_embs)))
train_val_idc, test_idc = train_test_split(idc, test_size=0.2)
train_idc, val_idc = train_test_split(train_val_idc, test_size=0.1 / 0.8)

train_embs, train_labels = np.array(node_embs)[train_idc], labels.iloc[train_idc]
val_embs, val_labels = np.array(node_embs)[val_idc], labels.iloc[val_idc]
test_embs, test_labels = np.array(node_embs)[test_idc], labels.iloc[test_idc]

<IPython.core.display.Javascript object>

In [268]:
import torch.utils.data as data_utils

train_tensors = data_utils.TensorDataset(
    torch.FloatTensor(np.array(train_embs)), torch.FloatTensor(np.array(train_labels))
)

train_loader = data_utils.DataLoader(train_tensors, batch_size=32, shuffle=True)

val_tensors = data_utils.TensorDataset(
    torch.FloatTensor(np.array(val_embs)), torch.FloatTensor(np.array(val_labels))
)

val_loader = data_utils.DataLoader(val_tensors, batch_size=32, shuffle=False)

test_tensors = data_utils.TensorDataset(
    torch.FloatTensor(np.array(test_embs)), torch.FloatTensor(np.array(test_labels))
)

test_loader = data_utils.DataLoader(test_tensors, batch_size=32, shuffle=False)

data_loaders_dict = {"train": train_loader, "val": val_loader, "test": test_loader}
data_loaders_dict["train"]

<torch.utils.data.dataloader.DataLoader at 0x7f8576a09dc0>

<IPython.core.display.Javascript object>

## 2. Node classification

In [289]:
model = nn.Sequential(
    nn.Linear(128, 128),
    nn.PReLU(),
    nn.BatchNorm1d(128),
    nn.PReLU(),
    nn.BatchNorm1d(128),
    nn.PReLU(),
    nn.BatchNorm1d(128),
    nn.Linear(128, 1),
    nn.Sigmoid(),
)

<IPython.core.display.Javascript object>

In [290]:
optimizer = Adam(model.parameters(), lr=1e-4)
criterion = nn.BCELoss()

<IPython.core.display.Javascript object>

In [314]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs=100):
    since = time.time()

    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch, num_epochs - 1))
        print("-" * 10)

        # Each epoch has a training and validation phase
        for phase in ["train", "val"]:
            if phase == "train":
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.type(torch.FloatTensor).to(device)
                labels = labels.type(torch.FloatTensor).to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == "train"):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    loss = criterion(outputs.view(-1, 1), labels.view(-1, 1))
                    preds = (outputs > 0.5).float()

                    # backward + optimize only if in training phase
                if phase == "train":
                    loss.backward()
                    optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(torch.eq(preds.view(-1), labels.view(-1)))

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / (len(dataloaders[phase].dataset))

            print("{} Loss: {:.6f} Acc: {:.6f}".format(phase, epoch_loss, epoch_acc))

            # deep copy the model if it has the best val accurary
            if phase == "val" and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == "val":
                val_acc_history.append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print(
        "Training complete in {:.0f}m {:.0f}s".format(
            time_elapsed // 60, time_elapsed % 60
        )
    )
    print("Best val Acc: {:4f}".format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

<IPython.core.display.Javascript object>

In [315]:
fitted_model, val_acc_history = train_model(
    model,
    criterion=criterion,
    optimizer=optimizer,
    dataloaders=data_loaders_dict,
    num_epochs=1000,
)

Epoch 0/999
----------
train Loss: 0.001967 Acc: 1.000000
val Loss: 1.457953 Acc: 0.880000

Epoch 1/999
----------
train Loss: 0.000742 Acc: 1.000000
val Loss: 1.439485 Acc: 0.880000

Epoch 2/999
----------
train Loss: 0.000763 Acc: 1.000000
val Loss: 1.449684 Acc: 0.880000

Epoch 3/999
----------
train Loss: 0.001503 Acc: 1.000000
val Loss: 1.450108 Acc: 0.880000

Epoch 4/999
----------
train Loss: 0.001626 Acc: 1.000000
val Loss: 1.456382 Acc: 0.880000

Epoch 5/999
----------
train Loss: 0.001098 Acc: 1.000000
val Loss: 1.400019 Acc: 0.880000

Epoch 6/999
----------
train Loss: 0.000822 Acc: 1.000000
val Loss: 1.410700 Acc: 0.880000

Epoch 7/999
----------
train Loss: 0.000967 Acc: 1.000000
val Loss: 1.423033 Acc: 0.880000

Epoch 8/999
----------
train Loss: 0.000744 Acc: 1.000000
val Loss: 1.443026 Acc: 0.880000

Epoch 9/999
----------
train Loss: 0.000855 Acc: 1.000000
val Loss: 1.467290 Acc: 0.880000

Epoch 10/999
----------
train Loss: 0.000442 Acc: 1.000000
val Loss: 1.435770 Ac

train Loss: 0.004496 Acc: 1.000000
val Loss: 1.582960 Acc: 0.880000

Epoch 97/999
----------
train Loss: 0.004016 Acc: 1.000000
val Loss: 1.616880 Acc: 0.880000

Epoch 98/999
----------
train Loss: 0.004584 Acc: 1.000000
val Loss: 1.593695 Acc: 0.880000

Epoch 99/999
----------
train Loss: 0.002995 Acc: 1.000000
val Loss: 1.544488 Acc: 0.880000

Epoch 100/999
----------
train Loss: 0.002014 Acc: 1.000000
val Loss: 1.517076 Acc: 0.880000

Epoch 101/999
----------
train Loss: 0.002585 Acc: 1.000000
val Loss: 1.526346 Acc: 0.880000

Epoch 102/999
----------
train Loss: 0.002189 Acc: 1.000000
val Loss: 1.542157 Acc: 0.880000

Epoch 103/999
----------
train Loss: 0.002883 Acc: 1.000000
val Loss: 1.533118 Acc: 0.880000

Epoch 104/999
----------
train Loss: 0.001769 Acc: 1.000000
val Loss: 1.589845 Acc: 0.880000

Epoch 105/999
----------
train Loss: 0.001511 Acc: 1.000000
val Loss: 1.617213 Acc: 0.880000

Epoch 106/999
----------
train Loss: 0.004698 Acc: 1.000000
val Loss: 1.564020 Acc: 0.88

train Loss: 0.000577 Acc: 1.000000
val Loss: 1.410667 Acc: 0.880000

Epoch 196/999
----------
train Loss: 0.000626 Acc: 1.000000
val Loss: 1.421411 Acc: 0.880000

Epoch 197/999
----------
train Loss: 0.000616 Acc: 1.000000
val Loss: 1.448570 Acc: 0.880000

Epoch 198/999
----------
train Loss: 0.000796 Acc: 1.000000
val Loss: 1.435245 Acc: 0.880000

Epoch 199/999
----------
train Loss: 0.001923 Acc: 1.000000
val Loss: 1.428959 Acc: 0.880000

Epoch 200/999
----------
train Loss: 0.001066 Acc: 1.000000
val Loss: 1.489615 Acc: 0.880000

Epoch 201/999
----------
train Loss: 0.001478 Acc: 1.000000
val Loss: 1.480366 Acc: 0.880000

Epoch 202/999
----------
train Loss: 0.000416 Acc: 1.000000
val Loss: 1.492555 Acc: 0.880000

Epoch 203/999
----------
train Loss: 0.001206 Acc: 1.000000
val Loss: 1.423190 Acc: 0.880000

Epoch 204/999
----------
train Loss: 0.001518 Acc: 1.000000
val Loss: 1.423571 Acc: 0.880000

Epoch 205/999
----------
train Loss: 0.000684 Acc: 1.000000
val Loss: 1.404611 Acc: 0

train Loss: 0.000619 Acc: 1.000000
val Loss: 1.481055 Acc: 0.880000

Epoch 288/999
----------
train Loss: 0.000545 Acc: 1.000000
val Loss: 1.483057 Acc: 0.880000

Epoch 289/999
----------
train Loss: 0.001285 Acc: 1.000000
val Loss: 1.503872 Acc: 0.880000

Epoch 290/999
----------
train Loss: 0.000649 Acc: 1.000000
val Loss: 1.499937 Acc: 0.880000

Epoch 291/999
----------
train Loss: 0.000773 Acc: 1.000000
val Loss: 1.499946 Acc: 0.880000

Epoch 292/999
----------
train Loss: 0.000323 Acc: 1.000000
val Loss: 1.479052 Acc: 0.880000

Epoch 293/999
----------
train Loss: 0.000398 Acc: 1.000000
val Loss: 1.474544 Acc: 0.880000

Epoch 294/999
----------
train Loss: 0.001304 Acc: 1.000000
val Loss: 1.519891 Acc: 0.880000

Epoch 295/999
----------
train Loss: 0.000559 Acc: 1.000000
val Loss: 1.531210 Acc: 0.880000

Epoch 296/999
----------
train Loss: 0.000351 Acc: 1.000000
val Loss: 1.528803 Acc: 0.880000

Epoch 297/999
----------
train Loss: 0.001357 Acc: 1.000000
val Loss: 1.498666 Acc: 0

train Loss: 0.000520 Acc: 1.000000
val Loss: 1.571983 Acc: 0.840000

Epoch 382/999
----------
train Loss: 0.000384 Acc: 1.000000
val Loss: 1.558242 Acc: 0.880000

Epoch 383/999
----------
train Loss: 0.000618 Acc: 1.000000
val Loss: 1.583712 Acc: 0.880000

Epoch 384/999
----------
train Loss: 0.000665 Acc: 1.000000
val Loss: 1.574021 Acc: 0.880000

Epoch 385/999
----------
train Loss: 0.017536 Acc: 0.994118
val Loss: 1.593503 Acc: 0.880000

Epoch 386/999
----------
train Loss: 0.001612 Acc: 1.000000
val Loss: 1.610782 Acc: 0.880000

Epoch 387/999
----------
train Loss: 0.003476 Acc: 1.000000
val Loss: 1.593156 Acc: 0.880000

Epoch 388/999
----------
train Loss: 0.020741 Acc: 0.994118
val Loss: 1.567635 Acc: 0.880000

Epoch 389/999
----------
train Loss: 0.011570 Acc: 0.994118
val Loss: 1.534309 Acc: 0.840000

Epoch 390/999
----------
train Loss: 0.000623 Acc: 1.000000
val Loss: 1.472774 Acc: 0.840000

Epoch 391/999
----------
train Loss: 0.002792 Acc: 1.000000
val Loss: 1.468933 Acc: 0

train Loss: 0.002431 Acc: 1.000000
val Loss: 1.511719 Acc: 0.880000

Epoch 473/999
----------
train Loss: 0.000753 Acc: 1.000000
val Loss: 1.517156 Acc: 0.880000

Epoch 474/999
----------
train Loss: 0.001171 Acc: 1.000000
val Loss: 1.528342 Acc: 0.880000

Epoch 475/999
----------
train Loss: 0.000633 Acc: 1.000000
val Loss: 1.515751 Acc: 0.880000

Epoch 476/999
----------
train Loss: 0.000540 Acc: 1.000000
val Loss: 1.548121 Acc: 0.880000

Epoch 477/999
----------
train Loss: 0.001054 Acc: 1.000000
val Loss: 1.541395 Acc: 0.880000

Epoch 478/999
----------
train Loss: 0.001605 Acc: 1.000000
val Loss: 1.550546 Acc: 0.880000

Epoch 479/999
----------
train Loss: 0.000294 Acc: 1.000000
val Loss: 1.555252 Acc: 0.880000

Epoch 480/999
----------
train Loss: 0.000740 Acc: 1.000000
val Loss: 1.565702 Acc: 0.880000

Epoch 481/999
----------
train Loss: 0.000986 Acc: 1.000000
val Loss: 1.537940 Acc: 0.880000

Epoch 482/999
----------
train Loss: 0.000860 Acc: 1.000000
val Loss: 1.537805 Acc: 0


Epoch 564/999
----------
train Loss: 0.000790 Acc: 1.000000
val Loss: 1.493876 Acc: 0.880000

Epoch 565/999
----------
train Loss: 0.000184 Acc: 1.000000
val Loss: 1.471865 Acc: 0.880000

Epoch 566/999
----------
train Loss: 0.000401 Acc: 1.000000
val Loss: 1.511454 Acc: 0.880000

Epoch 567/999
----------
train Loss: 0.000426 Acc: 1.000000
val Loss: 1.516929 Acc: 0.880000

Epoch 568/999
----------
train Loss: 0.000418 Acc: 1.000000
val Loss: 1.499048 Acc: 0.880000

Epoch 569/999
----------
train Loss: 0.000418 Acc: 1.000000
val Loss: 1.529859 Acc: 0.880000

Epoch 570/999
----------
train Loss: 0.000311 Acc: 1.000000
val Loss: 1.519728 Acc: 0.880000

Epoch 571/999
----------
train Loss: 0.000345 Acc: 1.000000
val Loss: 1.492429 Acc: 0.880000

Epoch 572/999
----------
train Loss: 0.000303 Acc: 1.000000
val Loss: 1.486974 Acc: 0.880000

Epoch 573/999
----------
train Loss: 0.000314 Acc: 1.000000
val Loss: 1.467608 Acc: 0.880000

Epoch 574/999
----------
train Loss: 0.001452 Acc: 1.000000

train Loss: 0.000262 Acc: 1.000000
val Loss: 1.603756 Acc: 0.880000

Epoch 652/999
----------
train Loss: 0.000487 Acc: 1.000000
val Loss: 1.628365 Acc: 0.880000

Epoch 653/999
----------
train Loss: 0.000632 Acc: 1.000000
val Loss: 1.620713 Acc: 0.880000

Epoch 654/999
----------
train Loss: 0.000555 Acc: 1.000000
val Loss: 1.630669 Acc: 0.880000

Epoch 655/999
----------
train Loss: 0.000472 Acc: 1.000000
val Loss: 1.651703 Acc: 0.880000

Epoch 656/999
----------
train Loss: 0.000416 Acc: 1.000000
val Loss: 1.655453 Acc: 0.880000

Epoch 657/999
----------
train Loss: 0.000497 Acc: 1.000000
val Loss: 1.644998 Acc: 0.880000

Epoch 658/999
----------
train Loss: 0.000247 Acc: 1.000000
val Loss: 1.636913 Acc: 0.880000

Epoch 659/999
----------
train Loss: 0.000476 Acc: 1.000000
val Loss: 1.655404 Acc: 0.880000

Epoch 660/999
----------
train Loss: 0.000326 Acc: 1.000000
val Loss: 1.629518 Acc: 0.880000

Epoch 661/999
----------
train Loss: 0.000282 Acc: 1.000000
val Loss: 1.641334 Acc: 0

train Loss: 0.001271 Acc: 1.000000
val Loss: 1.473825 Acc: 0.880000

Epoch 745/999
----------
train Loss: 0.000528 Acc: 1.000000
val Loss: 1.478336 Acc: 0.880000

Epoch 746/999
----------
train Loss: 0.000479 Acc: 1.000000
val Loss: 1.438427 Acc: 0.880000

Epoch 747/999
----------
train Loss: 0.000690 Acc: 1.000000
val Loss: 1.433169 Acc: 0.880000

Epoch 748/999
----------
train Loss: 0.000977 Acc: 1.000000
val Loss: 1.452889 Acc: 0.880000

Epoch 749/999
----------
train Loss: 0.000521 Acc: 1.000000
val Loss: 1.453907 Acc: 0.880000

Epoch 750/999
----------
train Loss: 0.000474 Acc: 1.000000
val Loss: 1.484091 Acc: 0.880000

Epoch 751/999
----------
train Loss: 0.001072 Acc: 1.000000
val Loss: 1.502903 Acc: 0.880000

Epoch 752/999
----------
train Loss: 0.001071 Acc: 1.000000
val Loss: 1.496955 Acc: 0.880000

Epoch 753/999
----------
train Loss: 0.000800 Acc: 1.000000
val Loss: 1.507867 Acc: 0.880000

Epoch 754/999
----------
train Loss: 0.000940 Acc: 1.000000
val Loss: 1.485207 Acc: 0

train Loss: 0.000425 Acc: 1.000000
val Loss: 1.494421 Acc: 0.880000

Epoch 844/999
----------
train Loss: 0.000652 Acc: 1.000000
val Loss: 1.505683 Acc: 0.880000

Epoch 845/999
----------
train Loss: 0.000770 Acc: 1.000000
val Loss: 1.525344 Acc: 0.880000

Epoch 846/999
----------
train Loss: 0.000681 Acc: 1.000000
val Loss: 1.543884 Acc: 0.880000

Epoch 847/999
----------
train Loss: 0.001145 Acc: 1.000000
val Loss: 1.563303 Acc: 0.880000

Epoch 848/999
----------
train Loss: 0.000360 Acc: 1.000000
val Loss: 1.575787 Acc: 0.880000

Epoch 849/999
----------
train Loss: 0.001078 Acc: 1.000000
val Loss: 1.550033 Acc: 0.880000

Epoch 850/999
----------
train Loss: 0.000737 Acc: 1.000000
val Loss: 1.543394 Acc: 0.880000

Epoch 851/999
----------
train Loss: 0.000891 Acc: 1.000000
val Loss: 1.581605 Acc: 0.880000

Epoch 852/999
----------
train Loss: 0.001283 Acc: 1.000000
val Loss: 1.576413 Acc: 0.880000

Epoch 853/999
----------
train Loss: 0.000552 Acc: 1.000000
val Loss: 1.572345 Acc: 0

train Loss: 0.000590 Acc: 1.000000
val Loss: 1.581045 Acc: 0.880000

Epoch 944/999
----------
train Loss: 0.000560 Acc: 1.000000
val Loss: 1.605367 Acc: 0.880000

Epoch 945/999
----------
train Loss: 0.001796 Acc: 1.000000
val Loss: 1.620489 Acc: 0.880000

Epoch 946/999
----------
train Loss: 0.001384 Acc: 1.000000
val Loss: 1.627234 Acc: 0.880000

Epoch 947/999
----------
train Loss: 0.003192 Acc: 1.000000
val Loss: 1.591491 Acc: 0.880000

Epoch 948/999
----------
train Loss: 0.000648 Acc: 1.000000
val Loss: 1.596256 Acc: 0.880000

Epoch 949/999
----------
train Loss: 0.005644 Acc: 0.994118
val Loss: 1.576827 Acc: 0.880000

Epoch 950/999
----------
train Loss: 0.003315 Acc: 1.000000
val Loss: 1.581892 Acc: 0.880000

Epoch 951/999
----------
train Loss: 0.000605 Acc: 1.000000
val Loss: 1.592136 Acc: 0.880000

Epoch 952/999
----------
train Loss: 0.001119 Acc: 1.000000
val Loss: 1.583298 Acc: 0.880000

Epoch 953/999
----------
train Loss: 0.004624 Acc: 1.000000
val Loss: 1.581889 Acc: 0

<IPython.core.display.Javascript object>

In [325]:
(fitted_model(torch.FloatTensor(np.array(test_embs))) > 0.5).view(-1) * 1

tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
        0])

<IPython.core.display.Javascript object>

In [326]:
torch.LongTensor(np.array(test_labels))

tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
        0])

<IPython.core.display.Javascript object>