In [62]:
import math
import random
from collections import defaultdict, Counter

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from Bio import SeqIO

device = (
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

In [63]:
###################################
# 1. FASTA Parsing and Filtering
###################################

def parse_fasta_with_labels(fasta_file):
    """
    Parses a FASTA file where each header line is assumed to be:
        >ClassLabel
        DNASEQUENCE
    Returns:
        list of tuples (label, sequence)
    """
    data = []
    for record in SeqIO.parse(fasta_file, "fasta"):
        header = record.description.strip()
        sequence = str(record.seq).upper()
        label = header.split()[0]
        data.append((label, sequence))
    return data


def create_train_test_split(raw_data):
    """
    Given a list of (label, sequence) pairs, 
    pick 1 random sample per class for test, 
    and the rest for train.
    """
    label_to_samples = defaultdict(list)
    for label, seq in raw_data:
        label_to_samples[label].append(seq)
    
    train_data = []
    test_data = []
    
    for label, seqs in label_to_samples.items():
        random.shuffle(seqs)
        # pick one for test
        test_seq = seqs[0]
        # remainder for train
        train_seqs = seqs[1:]
        
        test_data.append((label, test_seq))
        for s in train_seqs:
            train_data.append((label, s))
    
    return train_data, test_data

###################################
# 2. K-mer Processing
###################################

def generate_kmers(sequence, k=6):
    """
    Generates overlapping K-mers of length k from a DNA sequence.
    """
    kmers = []
    for i in range(len(sequence) - k + 1):
        kmers.append(sequence[i:i + k])
    return kmers

def build_kmer_vocab(dataset, k=6):
    """
    dataset is a list of (label, seq) pairs
    returns a dict mapping each K-mer to an integer index
    """
    kmer_set = set()
    for _, seq in dataset:
        kmers = generate_kmers(seq, k)
        kmer_set.update(kmers)
    
    vocab = {"<UNK>": 0}
    for i, kmer in enumerate(sorted(kmer_set), start=1):
        vocab[kmer] = i
    return vocab

def encode_sequence(sequence, vocab, k=6):
    """
    Convert a DNA sequence to a list of token indices based on K-mer vocab.
    """
    kmers = generate_kmers(sequence, k)
    encoded = [vocab.get(kmer, vocab["<UNK>"]) for kmer in kmers]
    return encoded

# genera more than x samples

In [64]:
def filter_classes(raw_data, min_count=10):
    """
    Keep only classes that have > min_count samples.
    Discard classes with <= min_count samples.
    """
    label_counts = Counter([label for (label, _) in raw_data])
    filtered_data = [
        (label, seq) 
        for (label, seq) in raw_data
        if label_counts[label] >= min_count
    ]
    return filtered_data

# train test split

In [65]:
###################################
# 3. Create "Paired" Data
###################################
"""
For demonstration, we'll pair each forward sequence with its reverse.
In practice, you might have two distinct sequences or two feature sets.
"""

def reverse_complement(seq):
    """
    For demonstration, we'll just reverse the sequence.
    (If you want a true 'reverse complement,' you'd also map A->T, C->G, etc.)
    We'll keep it simple here: just reversed.
    """
    return seq[::-1]

def create_paired_data(data_list):
    """
    For each (label, seq) in data_list, produce
    (label, fwd_seq, rev_seq).
    """
    paired = []
    for label, seq in data_list:
        rev_seq = reverse_complement(seq)
        paired.append((label, seq, rev_seq))
    return paired

###################################
# 4. PyTorch Dataset for Two Inputs
###################################

class TwoFastaKmerDataset(Dataset):
    """
    Each item: (encoded_seq_fwd, encoded_seq_rev, label_idx)
    """
    def __init__(self, paired_data, vocab, k=6):
        """
        paired_data: list of (label, fwd_seq, rev_seq)
        vocab: dict for k-mers
        """
        super().__init__()
        self.vocab = vocab
        self.k = k
        
        # Gather labels
        labels = sorted(set(item[0] for item in paired_data))
        self.label2idx = {lbl: i for i, lbl in enumerate(labels)}
        
        self.encoded_data = []
        for label, fwd_seq, rev_seq in paired_data:
            x1 = encode_sequence(fwd_seq, self.vocab, k=self.k)
            x2 = encode_sequence(rev_seq, self.vocab, k=self.k)
            y = self.label2idx[label]
            self.encoded_data.append((x1, x2, y))
    
    def __len__(self):
        return len(self.encoded_data)
    
    def __getitem__(self, idx):
        return self.encoded_data[idx]  # (fwd, rev, label_idx)
    
    def get_vocab_size(self):
        return len(self.vocab)
    
    def get_num_classes(self):
        return len(self.label2idx)

# 2 transformer fusion

In [None]:
###################################
# 5. Collate Function for Two Inputs
###################################

def collate_fn_two(batch):
    """
    batch: list of (seq_fwd, seq_rev, label)
    We'll pad BOTH forward and reverse sequences.
    """
    seqs_fwd, seqs_rev, labels = zip(*batch)
    
    # Convert to tensors
    seq_fwd_tensors = [torch.tensor(s, dtype=torch.long) for s in seqs_fwd]
    seq_rev_tensors = [torch.tensor(s, dtype=torch.long) for s in seqs_rev]
    
    # Pad forward
    padded_fwd = pad_sequence(seq_fwd_tensors, batch_first=True, padding_value=0)
    # Pad reverse
    padded_rev = pad_sequence(seq_rev_tensors, batch_first=True, padding_value=0)
    
    labels_tensors = torch.tensor(labels, dtype=torch.long)
    return padded_fwd, padded_rev, labels_tensors

###################################
# 6. Modified Two-Transformer Fusion Model
###################################

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                             (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # shape: [1, max_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]


class TransformerEncoderBlock(nn.Module):
    """
    A helper class for a transformer encoder block with positional encoding and mean pooling.
    """
    def __init__(self, vocab_size, d_model=128, nhead=8, num_layers=2,
                 dim_feedforward=512, dropout=0.1, max_len=5000, pooling='mean'):
        super().__init__()
        self.d_model = d_model
        self.pooling = pooling

        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_encoder = PositionalEncoding(d_model, max_len=max_len)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, x):
        """
        x: [batch_size, seq_len]
        returns a single pooled vector: [batch_size, d_model]
        """
        # Embed
        embedded = self.embedding(x) * math.sqrt(self.d_model)  # [B, seq_len, d_model]
        # Positional encode
        encoded = self.pos_encoder(embedded)
        # Transformer encoder
        out = self.transformer_encoder(encoded)  # [B, seq_len, d_model]

        # Pooling
        if self.pooling == 'mean':
            pooled = out.mean(dim=1)  # [B, d_model]
        else:
            pooled = out.mean(dim=1)  # Default to mean pooling
        return pooled


class TwoTransformerFusionFeatureExtractor(nn.Module):
    """
    Two separate Transformer encoders -> fused representation.
    Outputs features for AdaBoost.
    """
    def __init__(self, vocab_size, d_model=128, nhead=8, num_layers=2, 
                 dim_feedforward=512, dropout=0.1, max_len=5000, pooling='mean'):
        super().__init__()
        
        # Transformer #1
        self.transformer1 = TransformerEncoderBlock(
            vocab_size=vocab_size,
            d_model=d_model,
            nhead=nhead,
            num_layers=num_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            max_len=max_len,
            pooling=pooling
        )
        # Transformer #2
        self.transformer2 = TransformerEncoderBlock(
            vocab_size=vocab_size,
            d_model=d_model,
            nhead=nhead,
            num_layers=num_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            max_len=max_len,
            pooling=pooling
        )
        
    def forward(self, x1, x2):
        """
        x1: [batch_size, seq_len1]
        x2: [batch_size, seq_len2]
        returns: [batch_size, 2*d_model] fused features
        """
        f1 = self.transformer1(x1)  # [B, d_model]
        f2 = self.transformer2(x2)  # [B, d_model]
        
        # Fuse
        fused = torch.cat([f1, f2], dim=1)  # [B, 2*d_model]
        return fused

###################################
# 7. Soft decision tree
###################################

class SoftDecisionTree(nn.Module):
    """
    A differentiable (soft) decision tree.
    Given an input vector x (of dimension input_dim), the tree computes a series
    of decisions (via a sigmoid function) to route examples probabilistically 
    to one of 2^depth leaves. Each leaf contains learnable logits for the classes.
    """
    def __init__(self, input_dim, num_classes, depth=3):
        super().__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.depth = depth
        self.num_leaves = 2 ** depth
        
        # This layer computes one decision per tree level.
        self.decision_layer = nn.Linear(input_dim, depth)
        
        # Each leaf has a learnable logit vector (for num_classes)
        self.leaf_logits = nn.Parameter(torch.randn(self.num_leaves, num_classes))
        
        # Precompute the routing matrix: shape [num_leaves, depth]
        # Each row is the binary representation (0: left, 1: right) of the leaf index.
        decision_matrix = []
        for i in range(self.num_leaves):
            bin_repr = [(i >> (depth - 1 - j)) & 1 for j in range(depth)]
            decision_matrix.append(bin_repr)
        decision_matrix = torch.tensor(decision_matrix, dtype=torch.float32)
        self.register_buffer('decision_matrix', decision_matrix)
    
    def forward(self, x):
        """
        x: [B, input_dim]
        returns: [B, num_classes] logits computed as the weighted sum of leaf logits.
        """
        B = x.size(0)
        decisions = torch.sigmoid(self.decision_layer(x))  # [B, depth]
        # Expand for broadcasting: [B, 1, depth]
        decisions_exp = decisions.unsqueeze(1)
        # Expand decision matrix: [1, num_leaves, depth]
        decision_matrix = self.decision_matrix.unsqueeze(0)
        # Compute routing probabilities at each level:
        # For a given level, use d if decision_matrix==1 else (1-d)
        routing_probs = decisions_exp * decision_matrix + (1 - decisions_exp) * (1 - decision_matrix)
        # The probability of reaching each leaf is the product over the tree depth.
        leaf_probs = torch.prod(routing_probs, dim=2)  # [B, num_leaves]
        # Weighted sum of leaf logits gives the final output.
        output = leaf_probs @ self.leaf_logits  # [B, num_classes]
        return output

###################################
# 8. Bagging Tree Classifier 
###################################
class BaggingTreeClassifier(nn.Module):
    def __init__(self, input_dim, num_classes, num_trees=10, tree_depth=3):
        super().__init__()
        self.trees = nn.ModuleList([
            SoftDecisionTree(input_dim, num_classes, depth=tree_depth)
            for _ in range(num_trees)
        ])
    
    def forward(self, x):
        # x: [B, input_dim]
        outputs = [tree(x) for tree in self.trees]  # list of [B, num_classes]
        outputs = torch.stack(outputs, dim=0)  # [num_trees, B, num_classes]
        outputs = outputs.mean(dim=0)  # [B, num_classes]
        return outputs


###################################
# 9. Two-Transformer Fusion with Tree Classifier
###################################
class TwoTransformerFusionDNAClassifierWithTree(nn.Module):
    def __init__(self, vocab_size, num_classes, d_model=128, nhead=8, num_layers=2, 
                 dim_feedforward=512, dropout=0.1, max_len=5000, pooling='mean',
                 num_trees=10, tree_depth=3):
        super().__init__()
        # Two independent transformer blocks for forward and reverse sequences.
        self.transformer1 = TransformerEncoderBlock(
            vocab_size=vocab_size,
            d_model=d_model,
            nhead=nhead,
            num_layers=num_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            max_len=max_len,
            pooling=pooling
        )
        self.transformer2 = TransformerEncoderBlock(
            vocab_size=vocab_size,
            d_model=d_model,
            nhead=nhead,
            num_layers=num_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            max_len=max_len,
            pooling=pooling
        )
        # Instead of the original FC head, we use a bagging tree classifier.
        self.bagging_tree = BaggingTreeClassifier(
            input_dim=2 * d_model, 
            num_classes=num_classes,
            num_trees=num_trees,
            tree_depth=tree_depth
        )
    
    def forward(self, x1, x2):
        # x1: [B, seq_len1], x2: [B, seq_len2]
        f1 = self.transformer1(x1)  # [B, d_model]
        f2 = self.transformer2(x2)  # [B, d_model]
        fused = torch.cat([f1, f2], dim=1)  # [B, 2*d_model]
        logits = self.bagging_tree(fused)   # [B, num_classes]
        return logits



In [67]:
###################################
# 10. Putting It All Together
###################################

fasta_file = "data2/fungi_ITS_cleaned.fasta"

raw_data = parse_fasta_with_labels(fasta_file)
raw_data = filter_classes(raw_data, min_count=10)
train_data, test_data = create_train_test_split(raw_data)

# Now create "paired" forward+reverse sequences
paired_train = create_paired_data(train_data)  # (label, fwd_seq, rev_seq)
paired_test  = create_paired_data(test_data)

# Build vocab from entire set (using both forward and reverse)
combined_paired = paired_train + paired_test
tmp_data = []
for (lbl, fwd, rev) in combined_paired:
    tmp_data.append((lbl, fwd))
    tmp_data.append((lbl, rev))

k = 6
vocab = build_kmer_vocab(tmp_data, k=k)

# Create datasets
train_dataset = TwoFastaKmerDataset(paired_train, vocab, k=k)
test_dataset  = TwoFastaKmerDataset(paired_test,  vocab, k=k)

batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                          shuffle=True, collate_fn=collate_fn_two)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, 
                          shuffle=False, collate_fn=collate_fn_two)

num_classes = train_dataset.get_num_classes()
vocab_size = train_dataset.get_vocab_size()

print(f"Number of classes: {num_classes}")
print(f"Number of training samples: {len(train_dataset)}")

device = "cuda" if torch.cuda.is_available() else "cpu"

# 7B) Create model using the new bagging tree classifier head
model = TwoTransformerFusionDNAClassifierWithTree(
    vocab_size=vocab_size,
    num_classes=num_classes,
    d_model=128,
    nhead=8,
    num_layers=2,#attention block
    dim_feedforward=512,
    dropout=0.1,
    max_len=5000,
    pooling='mean',
    num_trees=10,    # Adjust number of trees as desired
    tree_depth=3     # Adjust tree depth as desired
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

###################################
# 11. Train & Evaluate
###################################
def evaluate_accuracy(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for fwd, rev, labels in data_loader:
            fwd, rev, labels = fwd.to(device), rev.to(device), labels.to(device)
            logits = model(fwd, rev)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return 100.0 * correct / total

epochs = 100
for epoch in range(1, epochs + 1):
    model.train()
    total_loss = 0.0
    
    for fwd, rev, labels in train_loader:
        fwd, rev, labels = fwd.to(device), rev.to(device), labels.to(device)
        
        optimizer.zero_grad()
        logits = model(fwd, rev)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)
    train_acc = evaluate_accuracy(model, train_loader)
    test_acc  = evaluate_accuracy(model, test_loader)
    
    print(f"Epoch {epoch}/{epochs} | Loss: {avg_loss:.4f} | "
          f"Train Acc: {train_acc:.2f}% | Test Acc: {test_acc:.2f}%")



Number of classes: 81
Number of training samples: 1612
Epoch 1/100 | Loss: 4.2793 | Train Acc: 16.50% | Test Acc: 2.47%
Epoch 2/100 | Loss: 4.0325 | Train Acc: 23.76% | Test Acc: 6.17%
Epoch 3/100 | Loss: 3.8730 | Train Acc: 24.13% | Test Acc: 6.17%
Epoch 4/100 | Loss: 3.7293 | Train Acc: 24.57% | Test Acc: 8.64%
Epoch 5/100 | Loss: 3.6161 | Train Acc: 29.16% | Test Acc: 8.64%
Epoch 6/100 | Loss: 3.5066 | Train Acc: 32.38% | Test Acc: 13.58%
Epoch 7/100 | Loss: 3.4201 | Train Acc: 32.44% | Test Acc: 11.11%
Epoch 8/100 | Loss: 3.3104 | Train Acc: 32.94% | Test Acc: 13.58%
Epoch 9/100 | Loss: 3.2256 | Train Acc: 36.04% | Test Acc: 17.28%
Epoch 10/100 | Loss: 3.1419 | Train Acc: 36.85% | Test Acc: 14.81%
Epoch 11/100 | Loss: 3.0644 | Train Acc: 35.24% | Test Acc: 16.05%
Epoch 12/100 | Loss: 2.9809 | Train Acc: 40.26% | Test Acc: 17.28%
Epoch 13/100 | Loss: 2.9158 | Train Acc: 41.00% | Test Acc: 19.75%
Epoch 14/100 | Loss: 2.8503 | Train Acc: 42.31% | Test Acc: 22.22%
Epoch 15/100 | Loss: 