In [1]:
import os 
from Bio import SeqIO 
import numpy as np 
import random
from sklearn.model_selection import train_test_split 
from pathlib import Path
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Flatten, Dense, Dropout


In [2]:
#1 Collecting ARG Sequences (Positive Samples)
def load_fasta_sequences(filepath):
    sequences = []
    ids = []
    print(f"Loading sequences from: {filepath}")
    try:
        for record in SeqIO.parse(filepath, "fasta"):
            sequences.append(str(record.seq).upper()) #All sequences in uppercase
            ids.append(record.id)
        print(f"Successfully loaded {len(sequences)} sequences.")
    except FileNotFoundError:
        print(f"Error: File not found at {filepath}.")
    except Exception as e:
        print(f"An error occurred while parsing {filepath}: {e}")
    return ids, sequences

#File path
arg_fasta_path = "C:/Users/sspat/OneDrive/Desktop/card-data/nucleotide_fasta_protein_homolog_model.fasta"
arg_ids, arg_sequences = load_fasta_sequences(arg_fasta_path)

if not arg_sequences:
    print("No ARG sequences loaded.")
    avg_arg_len = 1000
    std_arg_len = 500
else:
    print(f"First 3 ARG IDs: {arg_ids[:3]}")
    print(f"Length of first ARG sequence: {len(arg_sequences[0])}")
    print(f"First 50 bases of first ARG sequence: {arg_sequences[0][:50]}...")
    arg_lengths = [len(s) for s in arg_sequences]
    avg_arg_len = np.mean(arg_lengths)
    std_arg_len = np.std(arg_lengths)
    print(f"Approximate Average ARG Length: {avg_arg_len:.2f} bp")
    print(f"Approximate Std Dev ARG Length: {std_arg_len:.2f} bp")

Loading sequences from: C:/Users/sspat/OneDrive/Desktop/card-data/nucleotide_fasta_protein_homolog_model.fasta
Successfully loaded 6052 sequences.
First 3 ARG IDs: ['gb|GQ343019.1|+|132-1023|ARO:3002999|CblA-1', 'gb|HQ845196.1|+|0-861|ARO:3001109|SHV-52', 'gb|AF028812.1|+|392-887|ARO:3002867|dfrF']
Length of first ARG sequence: 891
First 50 bases of first ARG sequence: ATGAAAGCATATTTCATCGCCATACTTACCTTATTCACTTGTATAGCTAC...
Approximate Average ARG Length: 963.81 bp
Approximate Std Dev ARG Length: 303.91 bp


In [3]:
#2 Preparing Non ARG
def generate_non_arg_fragments(genome_root_dir, num_target_fragments, avg_len, std_len):
    all_genomic_dna = []
    
    # Collect all genomic sequences
    for fna_file in Path(genome_root_dir).rglob('*_genomic.fna'):
        try:
            for record in SeqIO.parse(fna_file, "fasta"):
                all_genomic_dna.append(str(record.seq).upper())
        except:
            pass 

    if not all_genomic_dna: return [] # Return empty if no DNA loaded

    non_arg_fragments = []
    max_overall_len = max(len(s) for s in all_genomic_dna) if all_genomic_dna else 3000

    # Generate fragments
    while len(non_arg_fragments) < num_target_fragments:
        selected_seq = random.choice(all_genomic_dna)
        
        # Sample length and clamp
        frag_len = max(100, min(int(np.random.normal(loc=avg_len, scale=std_len)), 3000, len(selected_seq)))

        if len(selected_seq) < frag_len: continue # Skip if sequence too short
        start_pos = random.randint(0, len(selected_seq) - frag_len)
        fragment = selected_seq[start_pos : start_pos + frag_len]
        
        # Simple N filter
        if 'N' not in fragment or fragment.count('N') / len(fragment) <= 0.05:
            non_arg_fragments.append(fragment)

    print(f"Generated {len(non_arg_fragments)} non-ARG fragments.")
    return non_arg_fragments


if __name__ == "__main__":
    avg_arg_len = 963.81 
    std_arg_len = 303.91
    num_arg_sequences = 6052 
    
    #Path to Non ARG sequences folder
    genomic_data_root = "C:/Users/sspat/OneDrive/Desktop/Data combined/Data_combined/" 
    
    non_arg_sequences = generate_non_arg_fragments(
        genome_root_dir=genomic_data_root,
        num_target_fragments=num_arg_sequences,
        avg_len=avg_arg_len,
        std_len=std_arg_len
    )

    print(f"First 10 generated non-ARGs (lengths): {[len(s) for s in non_arg_sequences[:10]]}")

Generated 6052 non-ARG fragments.
First 10 generated non-ARGs (lengths): [1316, 865, 1021, 223, 1190, 557, 398, 951, 735, 1014]


In [4]:
#3 Data splitting for CNN
# Combine sequences and labels
all_sequences = arg_sequences + non_arg_sequences
all_labels = [1] * len(arg_sequences) + [0] * len(non_arg_sequences)

# Fixed MAX_SEQUENCE_LENGTH
MAX_SEQUENCE_LENGTH = 2000 

#3-way split: 70% Train, 15% Validation, 15% Test
X_temp, X_test, y_temp, y_test = train_test_split(
    all_sequences, all_labels, test_size=0.15, random_state=42, stratify=all_labels
)
X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=(0.15 / (1 - 0.15)), random_state=42, stratify=y_temp
)

print("\nDataset splitting complete.")


Dataset splitting complete.


In [5]:
#4 Encoding
# Create a mapping from nucleotide characters to their one-hot vectors
DNA_TO_ONEHOT = {
    'A': [1, 0, 0, 0],
    'T': [0, 1, 0, 0],
    'C': [0, 0, 1, 0],
    'G': [0, 0, 0, 1],
    'N': [0, 0, 0, 0]
}

def one_hot_encode_sequence(sequence, max_len, dna_to_onehot_map):
    # Initialize an array of zeros with the target shape (max_len, 4 for ATCG)
    encoded_seq = np.zeros((max_len, len(dna_to_onehot_map['A'])), dtype=np.int8)
    
    # Iterate up to max_len or the actual sequence length, whichever is smaller
    for i, char in enumerate(sequence[:max_len]):
        if char in dna_to_onehot_map:
            encoded_seq[i] = dna_to_onehot_map[char]
        else:
            # Handle unexpected characters by treating them as 'N' (all zeros)
            encoded_seq[i] = dna_to_onehot_map['N']
            
    return encoded_seq

# Apply one-hot encoding
print(f"Encoding sequences to a fixed length of {MAX_SEQUENCE_LENGTH} bp")
X_train_encoded = np.array([one_hot_encode_sequence(s, MAX_SEQUENCE_LENGTH, DNA_TO_ONEHOT) for s in X_train])
X_val_encoded = np.array([one_hot_encode_sequence(s, MAX_SEQUENCE_LENGTH, DNA_TO_ONEHOT) for s in X_val])
X_test_encoded = np.array([one_hot_encode_sequence(s, MAX_SEQUENCE_LENGTH, DNA_TO_ONEHOT) for s in X_test])

# Convert labels to NumPy arrays
y_train_np = np.array(y_train)
y_val_np = np.array(y_val)
y_test_np = np.array(y_test)

print(f"Shape of X_train_encoded: {X_train_encoded.shape}")
print(f"Shape of X_val_encoded: {X_val_encoded.shape}")
print(f"Shape of X_test_encoded: {X_test_encoded.shape}")
print(f"Shape of y_train_np: {y_train_np.shape}")

Encoding sequences to a fixed length of 2000 bp
Shape of X_train_encoded: (8472, 2000, 4)
Shape of X_val_encoded: (1816, 2000, 4)
Shape of X_test_encoded: (1816, 2000, 4)
Shape of y_train_np: (8472,)


In [6]:
#5 CNN
# Model Definition
model = Sequential([
    Conv1D(128, 8, activation='relu', input_shape=(MAX_SEQUENCE_LENGTH, 4)),
    MaxPooling1D(2),
    Conv1D(64, 8, activation='relu'),
    MaxPooling1D(2),
    Flatten(),
    Dense(128, activation='relu'),
    Dropout(0.5),
    Dense(1, activation='sigmoid')
])

# Model Compilation
model.compile('adam', 'binary_crossentropy', metrics=['accuracy'])

# Model Training
model.fit(
    X_train_encoded, y_train_np,
    epochs=50,
    batch_size=32,
    validation_data=(X_val_encoded, y_val_np),
    verbose=0 # Sets verbose to 0 to hide training progress per epoch
)


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


<keras.src.callbacks.history.History at 0x1b5b1f8c8c0>

In [7]:
#6  Model Evaluation
print(f"\nPart 4: Model Evaluation")

# verbose=0 means no detailed output during evaluation, just the final metrics
loss, accuracy = model.evaluate(X_test_encoded, y_test_np, verbose=0)
#Output
print(f"Test Loss: {loss:.4f}")
print(f"Test Accuracy: {accuracy:.4f}")
print("Model Evaluation Complete.")


Part 4: Model Evaluation
Test Loss: 0.1192
Test Accuracy: 0.9829
Model Evaluation Complete.


In [8]:
#7 Saving the model
print("\nSaving")
model_save_path = "C:/Users/sspat/OneDrive/Desktop/ARG Predictor.h5"
model.save(model_save_path)
print(f"Model saved successfully to: {model_save_path}")

#The New DNA sequences must be entered in the variables
new_dna_sequences = X_test[:5]
true_labels = y_test[:5]

#Encoding of the new sequence must be identical to the training of the model: A=[1,0,0,0], T=[0,1,0,0], C=[0,0,1,0], G=[0,0,0,1], N=[0,0,0,0]
new_dna_encoded = np.array([
    one_hot_encode_sequence(s, MAX_SEQUENCE_LENGTH, DNA_TO_ONEHOT) 
    for s in new_dna_sequences
])

#Predictions
predictions = model.predict(new_dna_encoded)

# Converts probabilities to binary classes (0 or 1)
# A common threshold is 0.5: if probability > 0.5, predict 1 (ARG), else 0 (Non-ARG)
predicted_classes = (predictions > 0.5).astype(int)

print(f"\nPredictions on new data:")
for i, (seq, prob, pred_class, true_class) in enumerate(zip(new_dna_sequences, predictions, predicted_classes, true_labels)):
    print(f"Seq {i+1} (truncated): {seq[:50]}...")
    print(f"  Predicted Probability: {prob[0]:.4f}")
    print(f"  Predicted Class: {pred_class[0]} {'(ARG)' if pred_class[0] == 1 else '(Non-ARG)'}")
    print(f"  True Class: {true_class} {'(ARG)' if true_class == 1 else '(Non-ARG)'}\n")

print("Prediction complete.")




Saving
Model saved successfully to: C:/Users/sspat/OneDrive/Desktop/ARG Predictor.h5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 316ms/step

Predictions on new data:
Seq 1 (truncated): ACACGGKTTCGGTCCTCCAGTGCGCTTTCCCGCACCTTCAACCTGGACAT...
  Predicted Probability: 0.0000
  Predicted Class: 0 (Non-ARG)
  True Class: 0 (Non-ARG)

Seq 2 (truncated): ATGTCCGCCACGCTCCACGACACCGCAGCGGATCGTCGGAAGGCCACCCG...
  Predicted Probability: 1.0000
  Predicted Class: 1 (ARG)
  True Class: 1 (ARG)

Seq 3 (truncated): ATGGCTGCAAGAGCGAAAAATGGCGTAATCGGTTGCGGTCCTAACATTCC...
  Predicted Probability: 1.0000
  Predicted Class: 1 (ARG)
  True Class: 1 (ARG)

Seq 4 (truncated): ATGATGAAAAAATCGATATGCTGCGCGCTGCTGCTGACAGCCTCTTTCTC...
  Predicted Probability: 1.0000
  Predicted Class: 1 (ARG)
  True Class: 1 (ARG)

Seq 5 (truncated): CTCGACGCCCCGGCCGCTCGACGGCGTGCCGCCGTTCGTCTGGCACGGCT...
  Predicted Probability: 0.0000
  Predicted Class: 0 (Non-ARG)
  True Class: 0 (Non-ARG)

Prediction complete.
