In [1]:
positive_fasta = r'C:\Rishabh\IISER\Semesters\Semester 7\Computational Functional Genomics\Project\FASTA Files\pos_sample.fa'
negative_fasta = r'C:\Rishabh\IISER\Semesters\Semester 7\Computational Functional Genomics\Project\FASTA Files\neg_sample.fa'

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from Bio import SeqIO
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
from scipy import interpolate

# DNA sequence to one-hot encoding
def seq_to_onehot(seq, seq_len=1000):
    nuc_dict = {'A': 0, 'T': 1, 'G': 2, 'C': 3}
    onehot = np.zeros((4, seq_len))
    seq = seq.upper()[:seq_len]
    seq = seq.ljust(seq_len, 'N')
    
    for i, nuc in enumerate(seq):
        if nuc in nuc_dict:
            onehot[nuc_dict[nuc], i] = 1
    return onehot

class PromoterDataset(Dataset):
    def __init__(self, pos_fasta, neg_fasta, seq_len=1000):
        self.seq_len = seq_len
        self.sequences = []
        self.labels = []
        
        # Load positive sequences
        for record in SeqIO.parse(pos_fasta, "fasta"):
            self.sequences.append(str(record.seq))
            self.labels.append(1)
        
        # Load negative sequences
        for record in SeqIO.parse(neg_fasta, "fasta"):
            self.sequences.append(str(record.seq))
            self.labels.append(0)
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        seq = self.sequences[idx]
        label = self.labels[idx]
        onehot = seq_to_onehot(seq, self.seq_len)
        return torch.FloatTensor(onehot), torch.tensor(label, dtype=torch.long)

class PromoterCNN(nn.Module):
    def __init__(self, seq_len=100):
        super(PromoterCNN, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv1d(4, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(4)
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(4)
        )
        self.conv3 = nn.Sequential(
            nn.Conv1d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2)
        )

        # Compute reduced sequence length
        reduced_seq_len = seq_len // (4 * 4 * 2)
        self.fc1 = nn.Linear(reduced_seq_len * 64, 32)
        self.fc2 = nn.Linear(32, 2)
        self.dropout = nn.Dropout(0.7)


    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


def train_one_fold(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=20):
    best_val_acc = 0
    best_state_dict = None
    best_epoch = 0
    best_roc_data = None
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        val_probs = []
        val_labels_all = []
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                probabilities = torch.softmax(outputs, dim=1)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
                
                val_probs.extend(probabilities[:, 1].cpu().numpy())
                val_labels_all.extend(labels.cpu().numpy())
        
        val_acc = 100. * val_correct / val_total
        
        # Calculate ROC for this epoch
        fpr, tpr, _ = roc_curve(val_labels_all, val_probs)
        roc_auc = auc(fpr, tpr)
        
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {train_loss/len(train_loader):.4f}, '
              f'Train Acc: {100.*train_correct/train_total:.2f}%')
        print(f'Val Loss: {val_loss/len(val_loader):.4f}, '
              f'Val Acc: {val_acc:.2f}%, Val AUC: {roc_auc:.4f}\n')
        
        # Save best model for this fold
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_state_dict = model.state_dict().copy()
            best_epoch = epoch
            best_roc_data = (fpr, tpr, roc_auc)
    
    return best_state_dict, best_val_acc, best_epoch, best_roc_data

def k_fold_cross_validation(dataset, k_folds=5, num_epochs=20, batch_size=32, learning_rate=0.001):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)
    
    # Store results for each fold
    fold_results = []
    
    # K-fold Cross Validation
    for fold, (train_ids, val_ids) in enumerate(kfold.split(dataset)):
        print(f'\nFOLD {fold+1}/{k_folds}')
        print('-' * 50)
        
        # Data loaders for this fold
        train_loader = DataLoader(dataset, batch_size=batch_size,
                                sampler=SubsetRandomSampler(train_ids))
        val_loader = DataLoader(dataset, batch_size=batch_size,
                              sampler=SubsetRandomSampler(val_ids))
        
        # Initialize a new model for this fold
        model = PromoterCNN().to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        
        # Train the model for this fold
        best_state_dict, best_acc, best_epoch, best_roc = train_one_fold(
            model, train_loader, val_loader, criterion, optimizer, device, num_epochs)
        
        # Store results
        fold_results.append({
            'fold': fold + 1,
            'best_acc': best_acc,
            'best_epoch': best_epoch,
            'best_roc': best_roc,
            'state_dict': best_state_dict
        })
        
        # Save best model for this fold
        torch.save(best_state_dict, f'best_model_fold_{fold+1}.pth')
    
    return fold_results

def plot_fold_roc_curves(fold_results, save_path='roc_curves_k_fold.png'):
    plt.figure(figsize=(10, 8))
    
    # Plot ROC curve for each fold
    mean_tpr = []
    mean_fpr = np.linspace(0, 1, 100)
    
    for result in fold_results:
        fpr, tpr, roc_auc = result['best_roc']
        plt.plot(fpr, tpr, alpha=0.3, 
                label=f'Fold {result["fold"]} (AUC = {roc_auc:.2f})')
        
        # Interpolate TPR values for mean calculation
        interp_tpr = interpolate.interp1d(fpr, tpr)(mean_fpr)
        interp_tpr[0] = 0.0
        mean_tpr.append(interp_tpr)
    
    # Calculate and plot mean ROC curve
    mean_tpr = np.array(mean_tpr)
    mean_tpr_avg = np.mean(mean_tpr, axis=0)
    mean_tpr_std = np.std(mean_tpr, axis=0)
    
    mean_auc = auc(mean_fpr, mean_tpr_avg)
    std_auc = np.std([result['best_roc'][2] for result in fold_results])
    
    plt.plot(mean_fpr, mean_tpr_avg, 'b-',
             label=f'Mean ROC (AUC = {mean_auc:.2f} ± {std_auc:.2f})',
             lw=2)
    
    # Plot standard deviation bands
    plt.fill_between(mean_fpr, mean_tpr_avg - mean_tpr_std,
                    mean_tpr_avg + mean_tpr_std, color='grey', alpha=0.2,
                    label=f'±1 std. dev.')
    
    plt.plot([0, 1], [0, 1], 'k--', label='Random')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curves for K-Fold Cross Validation')
    plt.legend(loc='lower right')
    plt.grid(True)
    plt.savefig(save_path)
    plt.close()

def main():
    # Hyperparameters
    K_FOLDS = 5
    NUM_EPOCHS = 20
    BATCH_SIZE = 16  # Reduced for system compatibility
    LEARNING_RATE = 0.001  # Smaller for finer updates
    SEQUENCE_LENGTH = 100  # Adjusted for average sequence length

    # Load dataset
    dataset = PromoterDataset(positive_fasta, negative_fasta, SEQUENCE_LENGTH)
    
    # Perform k-fold cross validation
    fold_results = k_fold_cross_validation(
        dataset, 
        k_folds=K_FOLDS,
        num_epochs=NUM_EPOCHS,
        batch_size=BATCH_SIZE,
        learning_rate=LEARNING_RATE
    )
    
    # Plot ROC curves
    plot_fold_roc_curves(fold_results)
    
    # Print summary statistics
    aucs = [result['best_roc'][2] for result in fold_results]
    print("\nSummary of k-fold cross validation:")
    print(f"Mean AUC: {np.mean(aucs):.3f} ± {np.std(aucs):.3f}")
    for result in fold_results:
        print(f"\nFold {result['fold']}:")
        print(f"Best Epoch: {result['best_epoch'] + 1}")
        print(f"Best Accuracy: {result['best_acc']:.2f}%")
        print(f"AUC: {result['best_roc'][2]:.3f}")

if __name__ == "__main__":
    main()




FOLD 1/5
--------------------------------------------------


KeyboardInterrupt: 