In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
import matplotlib.pyplot as plt

# ============================================================================
# 1. DATA PREPARATION
# ============================================================================

class ProteinDataProcessor:
    """Handles encoding of amino acid sequences and structure labels"""
    
    def __init__(self):
        # 20 standard amino acids
        self.amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
        self.aa_to_idx = {aa: idx for idx, aa in enumerate(self.amino_acids)}
        
        # 3-state secondary structure: H (helix), E (sheet), C (coil)
        self.structures = 'HEC'
        self.struct_to_idx = {s: idx for idx, s in enumerate(self.structures)}
        
    def encode_sequence(self, sequence):
        """One-hot encode amino acid sequence"""
        # Shape: (sequence_length, 20)
        encoded = np.zeros((len(sequence), len(self.amino_acids)))
        for i, aa in enumerate(sequence):
            if aa in self.aa_to_idx:
                encoded[i, self.aa_to_idx[aa]] = 1
        return encoded
    
    def encode_structure(self, structure):
        """One-hot encode secondary structure labels"""
        # Shape: (sequence_length, 3)
        encoded = np.zeros((len(structure), len(self.structures)))
        for i, s in enumerate(structure):
            if s in self.struct_to_idx:
                encoded[i, self.struct_to_idx[s]] = 1
        return encoded
    
    def decode_structure(self, encoded):
        """Convert one-hot encoded structure back to string"""
        indices = np.argmax(encoded, axis=-1)
        return ''.join([self.structures[idx] for idx in indices])
    
    def prepare_dataset(self, sequences, structures, max_length=None):
        """Prepare padded dataset for training"""
        if max_length is None:
            max_length = max(len(seq) for seq in sequences)
        
        X = np.zeros((len(sequences), max_length, len(self.amino_acids)))
        y = np.zeros((len(sequences), max_length, len(self.structures)))
        masks = np.zeros((len(sequences), max_length))
        
        for i, (seq, struct) in enumerate(zip(sequences, structures)):
            seq_len = min(len(seq), max_length)
            X[i, :seq_len] = self.encode_sequence(seq[:seq_len])
            y[i, :seq_len] = self.encode_structure(struct[:seq_len])
            masks[i, :seq_len] = 1
        
        return X, y, masks

# ============================================================================
# 2. MODEL ARCHITECTURE
# ============================================================================

def build_1d_cnn_model(input_shape, num_classes=3):
    """
    Build 1D CNN model for secondary structure prediction
    
    Architecture:
    - Multiple 1D convolutional layers to capture local patterns
    - Batch normalization for stable training
    - Dropout for regularization
    - Final dense layer for classification at each position
    """
    
    inputs = keras.Input(shape=input_shape)
    
    # First convolutional block
    x = layers.Conv1D(filters=128, kernel_size=7, padding='same', activation='relu')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    
    # Second convolutional block
    x = layers.Conv1D(filters=128, kernel_size=7, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    
    # Third convolutional block (capture longer-range patterns)
    x = layers.Conv1D(filters=256, kernel_size=5, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    
    # Fourth convolutional block
    x = layers.Conv1D(filters=256, kernel_size=5, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    
    # Output layer: predict structure at each position
    outputs = layers.Conv1D(filters=num_classes, kernel_size=1, activation='softmax')(x)
    
    model = keras.Model(inputs=inputs, outputs=outputs, name='1D_CNN_SecStruct')
    
    return model

# ============================================================================
# 3. GENERATE SAMPLE DATA (Replace with your real dataset)
# ============================================================================

def generate_sample_data(n_samples=1000, seq_length=50):
    """Generate synthetic data for demonstration"""
    np.random.seed(42)
    
    amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
    structures = 'HEC'
    
    sequences = []
    struct_labels = []
    
    for _ in range(n_samples):
        # Random sequence
        seq = ''.join(np.random.choice(list(amino_acids), size=seq_length))
        
        # Simplified structure assignment (not realistic, just for demo)
        # In reality, you'd get this from PDB with DSSP annotations
        struct = []
        for aa in seq:
            if aa in 'AVILM':  # Hydrophobic -> more likely helix
                struct.append(np.random.choice(['H', 'E', 'C'], p=[0.5, 0.2, 0.3]))
            elif aa in 'KRE':  # Charged -> more likely coil
                struct.append(np.random.choice(['H', 'E', 'C'], p=[0.2, 0.2, 0.6]))
            else:
                struct.append(np.random.choice(['H', 'E', 'C'], p=[0.33, 0.33, 0.34]))
        
        sequences.append(seq)
        struct_labels.append(''.join(struct))
    
    return sequences, struct_labels

# ============================================================================
# 4. TRAINING PIPELINE
# ============================================================================

def train_model():
    """Complete training pipeline"""
    
    print("=" * 60)
    print("1D CNN SECONDARY STRUCTURE PREDICTION")
    print("=" * 60)
    
    # Generate sample data (replace with real data loading)
    print("\n1. Loading data...")
    sequences, structures = generate_sample_data(n_samples=1000, seq_length=50)
    print(f"   Loaded {len(sequences)} sequences")
    print(f"   Example sequence: {sequences[0][:30]}...")
    print(f"   Example structure: {structures[0][:30]}...")
    
    # Prepare data
    print("\n2. Preparing dataset...")
    processor = ProteinDataProcessor()
    X, y, masks = processor.prepare_dataset(sequences, structures)
    print(f"   X shape: {X.shape}")
    print(f"   y shape: {y.shape}")
    
    # Split data
    X_train, X_test, y_train, y_test, mask_train, mask_test = train_test_split(
        X, y, masks, test_size=0.2, random_state=42
    )
    print(f"   Training samples: {len(X_train)}")
    print(f"   Test samples: {len(X_test)}")
    
    # Build model
    print("\n3. Building model...")
    model = build_1d_cnn_model(input_shape=(X.shape[1], X.shape[2]))
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=0.001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    print(model.summary())
    
    # Train model
    print("\n4. Training model...")
    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=20,
        batch_size=32,
        verbose=1
    )
    
    # Evaluate
    print("\n5. Evaluating model...")
    y_pred = model.predict(X_test)
    
    # Calculate per-residue accuracy (considering only non-padded positions)
    y_true_flat = []
    y_pred_flat = []
    
    for i in range(len(y_test)):
        mask = mask_test[i] == 1
        y_true_flat.extend(np.argmax(y_test[i][mask], axis=-1))
        y_pred_flat.extend(np.argmax(y_pred[i][mask], axis=-1))
    
    accuracy = accuracy_score(y_true_flat, y_pred_flat)
    print(f"\n   Overall accuracy: {accuracy:.4f}")
    
    # Classification report
    target_names = ['Helix (H)', 'Sheet (E)', 'Coil (C)']
    print("\n   Classification Report:")
    print(classification_report(y_true_flat, y_pred_flat, target_names=target_names))
    
    # Plot training history
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')
    
    plt.subplot(1, 2, 2)
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Training and Validation Accuracy')
    
    plt.tight_layout()
    plt.savefig('training_history.png')
    print("\n   Training plots saved to 'training_history.png'")
    
    # Example predictions
    print("\n6. Example predictions:")
    for i in range(min(3, len(X_test))):
        pred = model.predict(X_test[i:i+1], verbose=0)
        pred_struct = processor.decode_structure(pred[0])
        true_struct = processor.decode_structure(y_test[i])
        
        # Get original sequence
        seq_idx = i
        mask = mask_test[i] == 1
        seq_decoded = ''.join([processor.amino_acids[np.argmax(X_test[i][j])] 
                               for j in range(sum(mask))])
        
        print(f"\n   Example {i+1}:")
        print(f"   Sequence:  {seq_decoded[:40]}...")
        print(f"   True:      {true_struct[:40]}...")
        print(f"   Predicted: {pred_struct[:40]}...")
    
    return model, processor, history

# ============================================================================
# 5. RUN TRAINING
# ============================================================================

if __name__ == "__main__":
    model, processor, history = train_model()
    
    print("\n" + "=" * 60)
    print("Training complete!")
    print("=" * 60)
    print("\nNext steps:")
    print("- Replace generate_sample_data() with real PDB/DSSP data")
    print("- Experiment with different architectures (kernel sizes, filters)")
    print("- Try data augmentation or class weighting")
    print("- Add bidirectional context with dilated convolutions")
    print("- Compare with BiLSTM or Transformer models")