In [21]:
import math
import random
from collections import defaultdict, Counter

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from Bio import SeqIO

device = (
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

In [22]:
###################################
# 1. FASTA Parsing and Filtering
###################################

def parse_fasta_with_labels(fasta_file):
    """
    Parses a FASTA file where each header line is assumed to be:
        >ClassLabel
        DNASEQUENCE
    Returns:
        list of tuples (label, sequence)
    """
    data = []
    for record in SeqIO.parse(fasta_file, "fasta"):
        header = record.description.strip()
        sequence = str(record.seq).upper()
        label = header.split()[0]
        data.append((label, sequence))
    return data


def create_train_test_split(raw_data):
    """
    Given a list of (label, sequence) pairs, 
    pick 1 random sample per class for test, 
    and the rest for train.
    """
    label_to_samples = defaultdict(list)
    for label, seq in raw_data:
        label_to_samples[label].append(seq)
    
    train_data = []
    test_data = []
    
    for label, seqs in label_to_samples.items():
        random.shuffle(seqs)
        # pick one for test
        test_seq = seqs[0]
        # remainder for train
        train_seqs = seqs[1:]
        
        test_data.append((label, test_seq))
        for s in train_seqs:
            train_data.append((label, s))
    
    return train_data, test_data

###################################
# 2. K-mer Processing
###################################

def generate_kmers(sequence, k=6):
    """
    Generates overlapping K-mers of length k from a DNA sequence.
    """
    kmers = []
    for i in range(len(sequence) - k + 1):
        kmers.append(sequence[i:i + k])
    return kmers

def build_kmer_vocab(dataset, k=6):
    """
    dataset is a list of (label, seq) pairs
    returns a dict mapping each K-mer to an integer index
    """
    kmer_set = set()
    for _, seq in dataset:
        kmers = generate_kmers(seq, k)
        kmer_set.update(kmers)
    
    vocab = {"<UNK>": 0}
    for i, kmer in enumerate(sorted(kmer_set), start=1):
        vocab[kmer] = i
    return vocab

def encode_sequence(sequence, vocab, k=6):
    """
    Convert a DNA sequence to a list of token indices based on K-mer vocab.
    """
    kmers = generate_kmers(sequence, k)
    encoded = [vocab.get(kmer, vocab["<UNK>"]) for kmer in kmers]
    return encoded

# genera more than x samples

In [23]:
def filter_classes(raw_data, min_count=10):
    """
    Keep only classes that have > min_count samples.
    Discard classes with <= min_count samples.
    """
    label_counts = Counter([label for (label, _) in raw_data])
    filtered_data = [
        (label, seq) 
        for (label, seq) in raw_data
        if label_counts[label] >= min_count
    ]
    return filtered_data

# train test split

In [24]:
###################################
# 3. Create "Paired" Data
###################################
"""
For demonstration, we'll pair each forward sequence with its reverse.
In practice, you might have two distinct sequences or two feature sets.
"""

def reverse_complement(seq):
    """
    For demonstration, we'll just reverse the sequence.
    (If you want a true 'reverse complement,' you'd also map A->T, C->G, etc.)
    We'll keep it simple here: just reversed.
    """
    return seq[::-1]

def create_paired_data(data_list):
    """
    For each (label, seq) in data_list, produce
    (label, fwd_seq, rev_seq).
    """
    paired = []
    for label, seq in data_list:
        rev_seq = reverse_complement(seq)
        paired.append((label, seq, rev_seq))
    return paired

###################################
# 4. PyTorch Dataset for Two Inputs
###################################

class TwoFastaKmerDataset(Dataset):
    """
    Each item: (encoded_seq_fwd, encoded_seq_rev, label_idx)
    """
    def __init__(self, paired_data, vocab, k=6):
        """
        paired_data: list of (label, fwd_seq, rev_seq)
        vocab: dict for k-mers
        """
        super().__init__()
        self.vocab = vocab
        self.k = k
        
        # Gather labels
        labels = sorted(set(item[0] for item in paired_data))
        self.label2idx = {lbl: i for i, lbl in enumerate(labels)}
        
        self.encoded_data = []
        for label, fwd_seq, rev_seq in paired_data:
            x1 = encode_sequence(fwd_seq, self.vocab, k=self.k)
            x2 = encode_sequence(rev_seq, self.vocab, k=self.k)
            y = self.label2idx[label]
            self.encoded_data.append((x1, x2, y))
    
    def __len__(self):
        return len(self.encoded_data)
    
    def __getitem__(self, idx):
        return self.encoded_data[idx]  # (fwd, rev, label_idx)
    
    def get_vocab_size(self):
        return len(self.vocab)
    
    def get_num_classes(self):
        return len(self.label2idx)

# 2 transformer fusion

In [25]:
###################################
# 5. Collate Function for Two Inputs
###################################

def collate_fn_two(batch):
    """
    batch: list of (seq_fwd, seq_rev, label)
    We'll pad BOTH forward and reverse sequences.
    """
    seqs_fwd, seqs_rev, labels = zip(*batch)
    
    # Convert to tensors
    seq_fwd_tensors = [torch.tensor(s, dtype=torch.long) for s in seqs_fwd]
    seq_rev_tensors = [torch.tensor(s, dtype=torch.long) for s in seqs_rev]
    
    # Pad forward
    padded_fwd = pad_sequence(seq_fwd_tensors, batch_first=True, padding_value=0)
    # Pad reverse
    padded_rev = pad_sequence(seq_rev_tensors, batch_first=True, padding_value=0)
    
    labels_tensors = torch.tensor(labels, dtype=torch.long)
    return padded_fwd, padded_rev, labels_tensors

###################################
# 6. Modified Two-Transformer Fusion Model
###################################

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                             (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # shape: [1, max_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]


class TransformerEncoderBlock(nn.Module):
    """
    A helper class for a transformer encoder block with positional encoding and mean pooling.
    """
    def __init__(self, vocab_size, d_model=128, nhead=8, num_layers=2,
                 dim_feedforward=512, dropout=0.1, max_len=5000, pooling='mean'):
        super().__init__()
        self.d_model = d_model
        self.pooling = pooling

        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_encoder = PositionalEncoding(d_model, max_len=max_len)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, x):
        """
        x: [batch_size, seq_len]
        returns a single pooled vector: [batch_size, d_model]
        """
        # Embed
        embedded = self.embedding(x) * math.sqrt(self.d_model)  # [B, seq_len, d_model]
        # Positional encode
        encoded = self.pos_encoder(embedded)
        # Transformer encoder
        out = self.transformer_encoder(encoded)  # [B, seq_len, d_model]

        # Pooling
        if self.pooling == 'mean':
            pooled = out.mean(dim=1)  # [B, d_model]
        else:
            pooled = out.mean(dim=1)  # Default to mean pooling
        return pooled


class TwoTransformerFusionFeatureExtractor(nn.Module):
    """
    Two separate Transformer encoders -> fused representation.
    Outputs features for AdaBoost.
    """
    def __init__(self, vocab_size, d_model=128, nhead=8, num_layers=2, 
                 dim_feedforward=512, dropout=0.1, max_len=5000, pooling='mean'):
        super().__init__()
        
        # Transformer #1
        self.transformer1 = TransformerEncoderBlock(
            vocab_size=vocab_size,
            d_model=d_model,
            nhead=nhead,
            num_layers=num_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            max_len=max_len,
            pooling=pooling
        )
        # Transformer #2
        self.transformer2 = TransformerEncoderBlock(
            vocab_size=vocab_size,
            d_model=d_model,
            nhead=nhead,
            num_layers=num_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            max_len=max_len,
            pooling=pooling
        )
        
    def forward(self, x1, x2):
        """
        x1: [batch_size, seq_len1]
        x2: [batch_size, seq_len2]
        returns: [batch_size, 2*d_model] fused features
        """
        f1 = self.transformer1(x1)  # [B, d_model]
        f2 = self.transformer2(x2)  # [B, d_model]
        
        # Fuse
        fused = torch.cat([f1, f2], dim=1)  # [B, 2*d_model]
        return fused

###################################
# 7. AdaBoost Implementation in PyTorch
###################################

class AdaBoostClassifierTorch(nn.Module):
    def __init__(self, num_features, num_classes, n_estimators=50):
        super().__init__()
        self.n_estimators = n_estimators
        self.models = nn.ModuleList([nn.Linear(num_features, num_classes) for _ in range(n_estimators)])
        self.alphas = torch.zeros(n_estimators)

    def forward(self, x, stage=None):
        """
        x: [batch_size, num_features]
        stage: If specified, only use the first `stage` estimators.
        """
        if stage is None:
            stage = self.n_estimators
        
        logits = 0
        for i in range(stage):
            logits += self.alphas[i] * self.models[i](x)
        return logits

    def fit(self, x_train, y_train, epochs=10, lr=0.1):
        """
        Fit AdaBoost model.
        x_train: [num_samples, num_features]
        y_train: [num_samples]
        """
        n_samples = x_train.size(0)
        sample_weights = torch.ones(n_samples, device=x_train.device) / n_samples

        optimizer = torch.optim.Adam(self.parameters(), lr=lr)

        for t in range(self.n_estimators):
            # Train individual model
            for _ in range(epochs):
                logits = self.models[t](x_train)
                loss = torch.nn.CrossEntropyLoss(reduction='none')(logits, y_train)
                weighted_loss = (loss * sample_weights).mean()

                optimizer.zero_grad()
                weighted_loss.backward()
                optimizer.step()

            # Evaluate and compute error
            with torch.no_grad():
                preds = torch.argmax(self.models[t](x_train), dim=1)
                incorrect = preds != y_train
                error = (sample_weights * incorrect.float()).sum() / sample_weights.sum()

                # Avoid division by zero
                if error == 0:
                    error = 1e-10
                elif error >= 0.5:
                    continue  # Skip this weak learner if error is too high

                # Compute alpha and update weights
                alpha = 0.5 * torch.log((1 - error) / error)
                self.alphas[t] = alpha

                # Update sample weights
                sample_weights *= torch.exp(alpha * incorrect.float())
                sample_weights /= sample_weights.sum()


In [26]:
###################################
# 7. Putting It All Together
###################################

# 7A) Load & prepare data
fasta_file = "data2/fungi_ITS_cleaned.fasta"

raw_data = parse_fasta_with_labels(fasta_file)
raw_data = filter_classes(raw_data, min_count=10)
train_data, test_data = create_train_test_split(raw_data)

# Now create "paired" forward+reverse
paired_train = create_paired_data(train_data)  # (label, fwd_seq, rev_seq)
paired_test  = create_paired_data(test_data)

# Build vocab from entire set
combined_paired = paired_train + paired_test
# We'll flatten them into (label, seq) to build vocab
tmp_data = []
for (lbl, fwd, rev) in combined_paired:
    tmp_data.append((lbl, fwd))
    tmp_data.append((lbl, rev))

k = 6
vocab = build_kmer_vocab(tmp_data, k=k)

# Make datasets
train_dataset = TwoFastaKmerDataset(paired_train, vocab, k=k)
test_dataset  = TwoFastaKmerDataset(paired_test,  vocab, k=k)

batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                          shuffle=True, collate_fn=collate_fn_two)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, 
                          shuffle=False, collate_fn=collate_fn_two)

num_classes = train_dataset.get_num_classes()
vocab_size = train_dataset.get_vocab_size()


# 7B) Create model
feature_extractor = TwoTransformerFusionFeatureExtractor(
    vocab_size=vocab_size,
    d_model=128,
    nhead=8,
    num_layers=2,
    dim_feedforward=512,
    dropout=0.1,
    max_len=5000,
    pooling='mean'
).to(device)

adaboost = AdaBoostClassifierTorch(
    num_features=256,  # 2 * d_model (fused features)
    num_classes=num_classes,
    n_estimators=50
).to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(feature_extractor.parameters(), lr=1e-3)

###################################
# 7C) Prepare Train and Test Data
###################################

def extract_features_for_adaboost(model, data_loader):
    model.eval()
    features = []
    labels = []
    with torch.no_grad():
        for fwd, rev, lbl in data_loader:
            fwd, rev = fwd.to(device), rev.to(device)
            feats = model(fwd, rev)  # Extract features
            features.append(feats)
            labels.append(lbl)
    return torch.cat(features), torch.cat(labels)

# Extract features for AdaBoost
print("Extracting train features...")
x_train, y_train = extract_features_for_adaboost(feature_extractor, train_loader)
print("Extracting test features...")
x_test, y_test = extract_features_for_adaboost(feature_extractor, test_loader)

# Prepare AdaBoost optimizer for weak learners
adaboost_optimizers = [
    torch.optim.Adam(adaboost.models[i].parameters(), lr=1e-3)
    for i in range(adaboost.n_estimators)
]


###################################
# 8. Train & Evaluate
###################################

def evaluate_accuracy(feature_extractor, adaboost, data_loader):
    """
    Evaluate accuracy with AdaBoost.
    """
    feature_extractor.eval()
    adaboost.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for fwd, rev, labels in data_loader:
            fwd, rev, labels = fwd.to(device), rev.to(device), labels.to(device)
            features = feature_extractor(fwd, rev)  # Extract features
            logits = adaboost(features)  # Use AdaBoost to predict
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return 100.0 * correct / total

# Training loop
epochs = 30
for epoch in range(1, epochs + 1):
    feature_extractor.train()
    adaboost.train()
    total_loss = 0.0

    for fwd, rev, labels in train_loader:
        fwd, rev, labels = fwd.to(device), rev.to(device), labels.to(device)

        # Extract features
        features = feature_extractor(fwd, rev)  # Features from Transformer

        # Train AdaBoost weak learners
        for t in range(adaboost.n_estimators):
            # Train each weak learner
            logits = adaboost.models[t](features)  # Weak learner predictions
            loss = criterion(logits, labels)

            adaboost_optimizers[t].zero_grad()
            loss.backward()
            adaboost_optimizers[t].step()

            # Update AdaBoost weights and alpha
            with torch.no_grad():
                preds = torch.argmax(logits, dim=1)
                incorrect = (preds != labels).float()
                error = (incorrect * adaboost.alphas[t]).sum() / labels.size(0)

                if error == 0:
                    error = 1e-10
                elif error >= 0.5:
                    continue  # Skip weak learners with high error

                # Update alpha
                adaboost.alphas[t] = 0.5 * torch.log((1 - error) / error)

        # Calculate total loss for logging
        total_loss += loss.item()

    # Calculate average loss and evaluate accuracy
    avg_loss = total_loss / len(train_loader)
    train_acc = evaluate_accuracy(feature_extractor, adaboost, train_loader)
    test_acc = evaluate_accuracy(feature_extractor, adaboost, test_loader)

    # Logging
    print(f"Epoch {epoch}/{epochs} | Loss: {avg_loss:.4f} | "
          f"Train Acc: {train_acc:.2f}% | Test Acc: {test_acc:.2f}%")



Extracting train features...


KeyboardInterrupt: 