## Run GCN Model

In [1]:
%load_ext autoreload

%autoreload 2

In [None]:
import torch
import numpy as np

In [None]:
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")
    mps=False

else:
    mps_device = torch.device("mps")
    mps=True

    print("MPS is available and enabled on device: {}".format(mps_device))

Load X-chromosome data

In [None]:
from module4_gcn.dataset import GPS_GCN_Dataset
from gcn_model_classes import *
from torch.utils.data import DataLoader

windowed_data = "X_sampled_data_31.npy"
labels = "X_sampled_labels_31.npy"

# Initialize the Dataset
data = GPS_GCN_Dataset(windowed_data, labels, device=mps_device)

train_size = int(0.8 * len(data))
test_size = len(data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(data, [train_size, test_size])

# Initialize the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=28, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=28, shuffle=True)


Train/test split

Instantiate model components

Train model

In [None]:
import torch.optim as optim
#criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [None]:
from sklearn.utils import class_weight

np_labels = np.load(labels)
class_weights=class_weight.compute_class_weight('balanced', classes=np.unique(np_labels), y=np_labels)
class_weights=torch.tensor(class_weights, dtype=torch.float)

criterion = nn.CrossEntropyLoss(weight=class_weights)
print(class_weights)

In [None]:
for epoch in range(5):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        #inputs = inputs.to(mps_device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 700 == 5:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {loss:.3f}')

    print("got to eval")
    model.eval()
    correct = 0
    running_loss = 0.0
    with torch.no_grad():
        for i, data in enumerate(test_dataloader, 0):
            inputs, labels = data
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
    print(f'[{epoch + 1}] loss: {running_loss:.3f}, accuracy: {correct / len(test_dataset):.3f}') 
    model.train()

In [None]:
from sklearn.metrics import roc_auc_score, roc_curve, auc
import matplotlib.pyplot as plt

model.eval()
y_true = []
y_score = []
with torch.no_grad():
    for i, data in enumerate(test_dataloader, 0):
        inputs, labels = data
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        y_true.extend(labels.cpu().numpy())
        y_score.extend(outputs.data.cpu().numpy())

y_true = np.array(y_true)
y_score = np.array(y_score)

fpr, tpr, _ = roc_curve((y_true == 1), (y_score[:,1]))
roc_auc = auc(fpr, tpr)
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = {:.2f})'.format(roc_auc))
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc="lower right")
plt.title('ROC Curve for CLAMP Binding Site Prediction')
plt.show()

Confusion Matrix

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

y_pred = np.argmax(y_score, axis=1)
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['No Binding', 'Binding'], )
disp.plot()
disp.ax_.set_title('Confusion Matrix for CLAMP Binding Site Prediction')
