In [None]:
pip install gym
pip install stable-baselines3[extra]
pip install shimmy>=2.0

In [None]:
import gym
print(f"Gym version: {gym.__version__}")
import shimmy
print(f"Shimmy version: {shimmy.__version__}")

from sklearn.metrics import classification_report, roc_auc_score, roc_curve, confusion_matrix, matthews_corrcoef
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Load datasets
train_data = pd.read_csv("train_Boarderline_smote_B_data.csv")
test_data = pd.read_csv("test_B_data.csv")
eval_data = pd.read_csv("external_eval_B_data.csv")

# Inspect the datasets
print(f"Train Data Shape: {train_data.shape}")
print(f"Test Data Shape: {test_data.shape}")
print(f"Evaluation Data Shape: {eval_data.shape}")

In [None]:
#Encode the Epitope Sequences:
from sklearn.preprocessing import LabelEncoder

# Combine all sequences from train, test, and eval for consistent encoding
all_sequences = pd.concat([train_data.iloc[:, 0], test_data.iloc[:, 0], eval_data.iloc[:, 0]])

# Initialize and fit LabelEncoder
label_encoder = LabelEncoder()
label_encoder.fit(all_sequences)

# Transform sequences in all datasets
X_train_seq = label_encoder.transform(train_data.iloc[:, 0]).reshape(-1, 1)
X_test_seq = label_encoder.transform(test_data.iloc[:, 0]).reshape(-1, 1)
X_eval_seq = label_encoder.transform(eval_data.iloc[:, 0]).reshape(-1, 1)

In [None]:
#Combine Encoded Sequences with Numeric Features:
import numpy as np

# Extract numeric features (columns 3 onward)
X_train_numeric = train_data.iloc[:, 2:].values
X_test_numeric = test_data.iloc[:, 2:].values
X_eval_numeric = eval_data.iloc[:, 2:].values

# Combine sequence encodings and numeric features
X_train = np.hstack((X_train_seq, X_train_numeric))
X_test = np.hstack((X_test_seq, X_test_numeric))
X_eval = np.hstack((X_eval_seq, X_eval_numeric))

# Labels (binary classification)
y_train = train_data.iloc[:, 1].values
y_test = test_data.iloc[:, 1].values
y_eval = eval_data.iloc[:, 1].values

print(f"X_train Shape: {X_train.shape}, y_train Shape: {y_train.shape}")
print(f"X_test Shape: {X_test.shape}, y_test Shape: {y_test.shape}")
print(f"X_eval Shape: {X_eval.shape}, y_eval Shape: {y_eval.shape}")

In [None]:
# Reshape for Conv1D (samples, timesteps, features)
X_train = X_train.reshape(-1, 767, 1)
X_test = X_test.reshape(-1, 767, 1)
X_eval = X_eval.reshape(-1, 767, 1)

In [None]:
import numpy as np
from gym import Env
from gym.spaces import Discrete, Box

class EpitopesEnv(Env):
    def __init__(self, features, labels):
        super().__init__()
        self.features = features  # Features of epitopes (X)
        self.labels = labels      # Labels of epitopes (y)
        self.n_samples = features.shape[0]
        self.current_idx = 0      # Pointer to the current sample

        # Action space: Classify as 0 (negative) or 1 (positive)
        self.action_space = Discrete(2)

        # Observation space: Feature vector of epitope
        self.observation_space = Box(
            low=np.min(features), high=np.max(features), shape=(features.shape[1],), dtype=np.float32
        )

        self.seed_val = None  # To store the seed value

    def reset(self):
        # Reset the environment
        self.current_idx = 0
        return self.features[self.current_idx].flatten()  # Flatten to 1D

    def step(self, action):
        # Reward for correct classification
        reward = 1 if action == self.labels[self.current_idx] else -1
        
        # Move to the next sample
        self.current_idx += 1
        
        # Check if the dataset is exhausted
        done = self.current_idx >= self.n_samples
        
        # Next state
        next_state = self.features[self.current_idx].flatten() if not done else None
        
        return next_state, reward, done, {}

    def render(self):
        pass

    def seed(self, seed=None):
        """Set the seed for reproducibility."""
        self.seed_val = seed
        np.random.seed(seed)

In [None]:
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env
# Wrap the environment
env = make_vec_env(lambda: EpitopesEnv(X_train, y_train), n_envs=1)

In [None]:
model = DQN(
    'MlpPolicy', 
    env, 
    verbose=1, 
    learning_rate=0.0001,  # Fine-grained updates
    gamma=0.95,  # Balances immediate vs. future rewards
    batch_size=64,  # Stabilizes updates
    exploration_fraction=0.2,  # Standard exploration
    exploration_final_eps=0.01,  # Reduced exploration at the end
    target_update_interval=1000,  # Frequent updates for responsiveness
    train_freq=(1, 'step'),  # Update after every step
    gradient_steps=1,  # Backpropagation steps per update
    policy_kwargs={"net_arch": [256, 128, 128]}  # Deeper neural network architecture
)

# Train the model
model.learn(total_timesteps=100000)  # Extended training

# Evaluate on test data
correct_test = 0
for i in range(len(X_test)):
    obs = X_test[i].flatten()
    action, _ = model.predict(obs)
    if action == y_test[i]:
        correct_test += 1

accuracy_test = correct_test / len(X_test)
print(f"Test Accuracy: {accuracy_test * 100:.2f}%")

# Evaluate on validation data
correct_eval = 0
for i in range(len(X_eval)):
    obs = X_eval[i].flatten()
    action, _ = model.predict(obs)
    if action == y_eval[i]:
        correct_eval += 1

accuracy_eval = correct_eval / len(X_eval)
print(f"Validation Accuracy: {accuracy_eval * 100:.2f}%")

In [None]:
# Save the trained model
model.save("dqn_epitope_bcell_classifier")
print("DQN model saved as 'dqn_epitope_bcell_classifier'")

In [None]:
from sklearn.metrics import confusion_matrix, f1_score

# Generate predictions for test and validation sets
y_test_pred = []
for i in range(len(X_test)):
    obs = X_test[i].flatten()
    action, _ = model.predict(obs)
    y_test_pred.append(action)

y_eval_pred = []
for i in range(len(X_eval)):
    obs = X_eval[i].flatten()
    action, _ = model.predict(obs)
    y_eval_pred.append(action)

y_test_pred = np.array(y_test_pred)
y_eval_pred = np.array(y_eval_pred)

In [None]:
# Confusion matrix for Test Data
tn_test, fp_test, fn_test, tp_test = confusion_matrix(y_test, y_test_pred).ravel()

# Confusion matrix for Validation Data
tn_eval, fp_eval, fn_eval, tp_eval = confusion_matrix(y_eval, y_eval_pred).ravel()

In [None]:
# F1 Scores
f1_test = f1_score(y_test, y_test_pred)
f1_eval = f1_score(y_eval, y_eval_pred)

# Sensitivity (Recall)
sensitivity_test = tp_test / (tp_test + fn_test)
sensitivity_eval = tp_eval / (tp_eval + fn_eval)

# Specificity
specificity_test = tn_test / (tn_test + fp_test)
specificity_eval = tn_eval / (tn_eval + fp_eval)

# Print the results
print(f"Test F1 Score: {f1_test:.2f}")
print(f"External Validation F1 Score: {f1_eval:.2f}")
print(f"Test Sensitivity: {sensitivity_test:.2f}")
print(f"Test Specificity: {specificity_test:.2f}")
print(f"External Validation Sensitivity: {sensitivity_eval:.2f}")
print(f"External Validation Specificity: {specificity_eval:.2f}")

In [None]:
# Classification reports
print("\nTest Data - Classification Report:\n", classification_report(y_test, y_test_pred))
print("\nValidation Data - Classification Report:\n", classification_report(y_eval, y_eval_pred))

# Confusion Matrices
print("\nTest Data - Confusion Matrix:\n", confusion_matrix(y_test, y_test_pred))
print("\nValidation Data - Confusion Matrix:\n", confusion_matrix(y_eval, y_eval_pred))

In [None]:
# Check what q_values contains
obs = X_test[0].flatten()
q_values = model.predict(obs)
print("Q-values type:", type(q_values))
print("Q-values content:", q_values)

In [None]:
# Map predictions to binary labels (0 or 1)
y_test_pred = [model.predict(X_test[i].flatten())[0] for i in range(len(X_test))]
y_eval_pred = [model.predict(X_eval[i].flatten())[0] for i in range(len(X_eval))]

# Compute ROC-AUC using binary predictions
from sklearn.metrics import roc_auc_score

roc_auc_test = roc_auc_score(y_test, y_test_pred)
roc_auc_eval = roc_auc_score(y_eval, y_eval_pred)

print(f"Test ROC-AUC: {roc_auc_test:.2f}")
print(f"Validation ROC-AUC: {roc_auc_eval:.2f}")

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import auc

# Extract true positives, false positives, true negatives, and false negatives from the confusion matrices
tp_test, fn_test = 49, 8795
fp_test, tn_test = 357, 74448

tp_eval, fn_eval = 44, 8800
fp_eval, tn_eval = 383, 74422

# Calculate TPR (Sensitivity) and FPR for test and validation
tpr_test = tp_test / (tp_test + fn_test)
fpr_test = fp_test / (fp_test + tn_test)

tpr_eval = tp_eval / (tp_eval + fn_eval)
fpr_eval = fp_eval / (fp_eval + tn_eval)

# Simulate ROC Curves
fpr_test_sim = [0, fpr_test, 1]
tpr_test_sim = [0, tpr_test, 1]

fpr_eval_sim = [0, fpr_eval, 1]
tpr_eval_sim = [0, tpr_eval, 1]

# Compute AUC for test and validation
auc_test = auc(fpr_test_sim, tpr_test_sim)
auc_eval = auc(fpr_eval_sim, tpr_eval_sim)

# Print ROC data for Test
print("Test Data - ROC Curve Values")
print("False Positive Rate (FPR):", fpr_test_sim)
print("True Positive Rate (TPR):", tpr_test_sim)
print("Thresholds: [1, ~midpoint, 0]")  # Simulated thresholds for the three points
print(f"ROC AUC: {auc_test:.2f}\n")

# Print ROC data for Validation
print("Validation Data - ROC Curve Values")
print("False Positive Rate (FPR):", fpr_eval_sim)
print("True Positive Rate (TPR):", tpr_eval_sim)
print("Thresholds: [1, ~midpoint, 0]")  # Simulated thresholds for the three points
print(f"ROC AUC: {auc_eval:.2f}\n")

# Plot ROC Curves
plt.figure(figsize=(14, 6))

# Test ROC Curve
plt.subplot(1, 2, 1)
plt.plot(fpr_test_sim, tpr_test_sim, color='blue', label=f'ROC AUC = {auc_test:.2f}')
plt.plot([0, 1], [0, 1], color='gray', linestyle='--', label='Random Classifier (AUC = 0.50)')
plt.title('Test Data - ROC Curve')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc='lower right')

# Validation ROC Curve
plt.subplot(1, 2, 2)
plt.plot(fpr_eval_sim, tpr_eval_sim, color='green', label=f'ROC AUC = {auc_eval:.2f}')
plt.plot([0, 1], [0, 1], color='gray', linestyle='--', label='Random Classifier (AUC = 0.50)')
plt.title('Validation Data - ROC Curve')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc='lower right')

plt.tight_layout()
plt.show()

plt.savefig("dqn_roc_auc_curve.png", dpi=500)


In [None]:
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_curve, auc, accuracy_score
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import softmax

In [None]:
# Load the saved model
model = DQN.load("dqn_epitope_bcell_classifier")
print("DQN model loaded.")

# Initialize StratifiedKFold for 10 splits
cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)

# Variables to store results
tprs = []
aucs = []
mean_fpr = np.linspace(0, 1, 100)
accuracies = []

# Perform cross-validation
for i, (train_idx, test_idx) in enumerate(cv.split(X_train, y_train)):
    y_true = y_train[test_idx]
    y_prob = []

    for idx in test_idx:
        obs = X_train[idx].reshape(1, -1)  # Reshape to match input format
        # Predict the action and Q-values
        action, q_values = model.predict(obs, deterministic=True)
        
        # Check if q_values is None
        if q_values is None:
            q_values = np.array([1.0, 0.0])  # Default probabilities for fallback
        
        # Normalize Q-values using softmax
        probabilities = softmax(q_values)
        y_prob.append(probabilities[1])  # Probability of class 1

    # Compute accuracy
    y_pred = [1 if prob > 0.5 else 0 for prob in y_prob]
    accuracy = accuracy_score(y_true, y_pred)
    accuracies.append(accuracy)

In [None]:
# Compute ROC curve and AUC
for i, (train_idx, test_idx) in enumerate(cv.split(X_train, y_train)):
    y_true = y_train[test_idx]
    y_prob = []

    for idx in test_idx:
        obs = X_train[idx].reshape(1, -1)  # Reshape to match input format
        # Predict the action and Q-values
        action, q_values = model.predict(obs, deterministic=True)
        
        # Check if q_values is None
        if q_values is None:
            q_values = np.array([1.0, 0.0])  # Default probabilities for fallback
        
        # Normalize Q-values using softmax
        probabilities = softmax(q_values)
        y_prob.append(probabilities[1])  # Probability of class 1

    # Compute ROC curve and AUC
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    tprs.append(np.interp(mean_fpr, fpr, tpr))
    tprs[-1][0] = 0.0
    roc_auc = auc(fpr, tpr)
    aucs.append(roc_auc)

    # Plot each fold's ROC curve
    plt.plot(fpr, tpr, lw=1, alpha=0.3, label=f'Fold {i+1} (AUC = {roc_auc:.2f})')

# Compute mean ROC curve and AUC
mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0
mean_auc = auc(mean_fpr, mean_tpr)

# Plot mean ROC curve
plt.plot(mean_fpr, mean_tpr, color='b', label=f'Mean ROC (AUC = {mean_auc:.2f})', lw=2)
plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('10-Fold Cross-Validation ROC Curve')
plt.legend(loc="lower right")

# Save the ROC curve
plt.savefig("cv_roc_curve_bcell_dqn.png", dpi=500)
plt.show()

In [None]:
# Compute and print mean accuracy
mean_accuracy = np.mean(accuracies)
print(f"Mean Accuracy: {mean_accuracy * 100:.2f}%")