In [1]:
import pandas as pd

# Replace with your actual CSV file path
external_df = pd.read_csv("3_E.bugandensis_ISS_2DRDKit_scaled.csv")


In [2]:
# If your CSV has labels
external_X = external_df.drop("Values", axis=1)
external_y = external_df["Values"]

In [None]:
import torch
import torch.nn as nn

class ImprovedMolecularNN(nn.Module):
    def __init__(self, input_dim):
        super(ImprovedMolecularNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256, 128)
        self.bn3 = nn.BatchNorm1d(128)
        self.fc4 = nn.Linear(128, 64)
        self.bn4 = nn.BatchNorm1d(64)
        self.fc5 = nn.Linear(64, 1)

        self.leaky_relu = nn.LeakyReLU(0.1)
        self.dropout = nn.Dropout(0.4)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.leaky_relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        x = self.leaky_relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        x = self.leaky_relu(self.bn3(self.fc3(x)))
        x = self.dropout(x)
        x = self.leaky_relu(self.bn4(self.fc4(x)))
        x = self.sigmoid(self.fc5(x))
        return x

input_dim = 140  # The model was trained with 140 features

# Load the model using torch.load
model = ImprovedMolecularNN(input_dim)  # Now input_dim is defined
model.load_state_dict(torch.load("4_2D rdkit best nn_model.pth"), strict=False) # strict=False allows loading even with missing keys

In [4]:
# Assuming 'external_X' is your external data as a Pandas DataFrame
external_X_tensor = torch.tensor(external_X.values, dtype=torch.float32) # Convert to tensor

model.eval()  # Set the model to evaluation mode
with torch.no_grad():  # Disable gradient calculations during inference
    predictions = model(external_X_tensor)  # Get predictions

# Convert predictions to numpy and apply threshold (if needed for binary classification)
predictions = predictions.numpy()
predictions = (predictions >= 0.5).astype(int) # Optional thresholding for binary classification

In [None]:
#predictions = model.predict(external_X)

# If labels are available, evaluate performance
from sklearn.metrics import accuracy_score, classification_report

print("Accuracy:", accuracy_score(external_y, predictions))
print(classification_report(external_y, predictions))


In [None]:
# prompt: calculate sensitivity

from sklearn.metrics import confusion_matrix

# Calculate the confusion matrix
tn, fp, fn, tp = confusion_matrix(external_y, predictions).ravel()

# Calculate sensitivity (True Positive Rate)
sensitivity = tp / (tp + fn)

print("Sensitivity:", sensitivity)


In [None]:
# prompt: calculate secificity

# Calculate specificity (True Negative Rate)
specificity = tn / (tn + fp)

print("Specificity:", specificity)

In [None]:
# prompt: calculate accuracy

# Assuming 'external_y' are your true labels as a Pandas Series or NumPy array
# Assuming 'predictions' are your model's predictions (after thresholding if binary)

print("Accuracy:", accuracy_score(external_y, predictions))
print(classification_report(external_y, predictions))

# Calculate the confusion matrix
tn, fp, fn, tp = confusion_matrix(external_y, predictions).ravel()

# Calculate sensitivity (True Positive Rate)
sensitivity = tp / (tp + fn)

print("Sensitivity:", sensitivity)

# Calculate specificity (True Negative Rate)
specificity = tn / (tn + fp)

print("Specificity:", specificity)

In [None]:
# prompt: roc auc curve
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

# Assuming you have y_test_np and y_pred_test from your code
fpr, tpr, thresholds = roc_curve(external_y, predictions)
roc_auc = auc(fpr, tpr)

plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")
plt.show()
