# 2 trans fusion, run 10 times

In [None]:
import math
import random
from collections import defaultdict, Counter

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from Bio import SeqIO

###################################
# 1. FASTA Parsing and Filtering
###################################

def parse_fasta_with_labels(fasta_file):
    """
    Parses a FASTA file where each header line is assumed to be:
        >ClassLabel
        DNASEQUENCE
    Returns:
        list of tuples (label, sequence)
    """
    data = []
    for record in SeqIO.parse(fasta_file, "fasta"):
        header = record.description.strip()
        sequence = str(record.seq).upper()
        label = header.split()[0]
        data.append((label, sequence))
    return data


def create_train_test_split(raw_data):
    """
    Given a list of (label, sequence) pairs, 
    pick 1 random sample per class for test, 
    and the rest for train.
    """
    label_to_samples = defaultdict(list)
    for label, seq in raw_data:
        label_to_samples[label].append(seq)
    
    train_data = []
    test_data = []
    
    for label, seqs in label_to_samples.items():
        random.shuffle(seqs)
        # pick one for test
        test_seq = seqs[0]
        # remainder for train
        train_seqs = seqs[1:]
        
        test_data.append((label, test_seq))
        for s in train_seqs:
            train_data.append((label, s))
    
    return train_data, test_data

###################################
# 2. K-mer Processing
###################################

def generate_kmers(sequence, k=6):
    """
    Generates overlapping K-mers of length k from a DNA sequence.
    """
    kmers = []
    for i in range(len(sequence) - k + 1):
        kmers.append(sequence[i:i + k])
    return kmers

def build_kmer_vocab(dataset, k=6):
    """
    dataset is a list of (label, seq) pairs.
    Returns a dict mapping each K-mer to an integer index.
    """
    kmer_set = set()
    for _, seq in dataset:
        kmers = generate_kmers(seq, k)
        kmer_set.update(kmers)
    
    vocab = {"<UNK>": 0}
    for i, kmer in enumerate(sorted(kmer_set), start=1):
        vocab[kmer] = i
    return vocab

def encode_sequence(sequence, vocab, k=6):
    """
    Convert a DNA sequence to a list of token indices based on K-mer vocab.
    """
    kmers = generate_kmers(sequence, k)
    encoded = [vocab.get(kmer, vocab["<UNK>"]) for kmer in kmers]
    return encoded

def filter_classes(raw_data, min_count=10):
    """
    Keep only classes that have >= min_count samples.
    Discard classes with fewer samples.
    """
    label_counts = Counter([label for (label, _) in raw_data])
    filtered_data = [
        (label, seq) 
        for (label, seq) in raw_data
        if label_counts[label] >= min_count
    ]
    return filtered_data

###################################
# 3. Create "Paired" Data
###################################
"""
For demonstration, we'll pair each forward sequence with its reverse.
"""

def reverse_complement(seq):
    """
    For demonstration, simply reverse the sequence.
    (A true reverse complement would also swap nucleotides.)
    """
    return seq[::-1]

def create_paired_data(data_list):
    """
    For each (label, seq) in data_list, produce
    (label, fwd_seq, rev_seq).
    """
    paired = []
    for label, seq in data_list:
        rev_seq = reverse_complement(seq)
        paired.append((label, seq, rev_seq))
    return paired

###################################
# 4. PyTorch Dataset for Two Inputs
###################################

class TwoFastaKmerDataset(Dataset):
    """
    Each item: (encoded_seq_fwd, encoded_seq_rev, label_idx)
    """
    def __init__(self, paired_data, vocab, k=6):
        """
        paired_data: list of (label, fwd_seq, rev_seq)
        vocab: dict for k-mers
        """
        super().__init__()
        self.vocab = vocab
        self.k = k
        
        # Gather labels and create mapping
        labels = sorted(set(item[0] for item in paired_data))
        self.label2idx = {lbl: i for i, lbl in enumerate(labels)}
        
        self.encoded_data = []
        for label, fwd_seq, rev_seq in paired_data:
            x1 = encode_sequence(fwd_seq, self.vocab, k=self.k)
            x2 = encode_sequence(rev_seq, self.vocab, k=self.k)
            y = self.label2idx[label]
            self.encoded_data.append((x1, x2, y))
    
    def __len__(self):
        return len(self.encoded_data)
    
    def __getitem__(self, idx):
        return self.encoded_data[idx]  # (fwd, rev, label_idx)
    
    def get_vocab_size(self):
        return len(self.vocab)
    
    def get_num_classes(self):
        return len(self.label2idx)

###################################
# 5. Collate Function for Two Inputs
###################################

def collate_fn_two(batch):
    """
    batch: list of (seq_fwd, seq_rev, label)
    Pads both forward and reverse sequences.
    """
    seqs_fwd, seqs_rev, labels = zip(*batch)
    
    seq_fwd_tensors = [torch.tensor(s, dtype=torch.long) for s in seqs_fwd]
    seq_rev_tensors = [torch.tensor(s, dtype=torch.long) for s in seqs_rev]
    
    padded_fwd = pad_sequence(seq_fwd_tensors, batch_first=True, padding_value=0)
    padded_rev = pad_sequence(seq_rev_tensors, batch_first=True, padding_value=0)
    
    labels_tensors = torch.tensor(labels, dtype=torch.long)
    return padded_fwd, padded_rev, labels_tensors

###################################
# 6. Two-Transformer Fusion Model
###################################

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)  # shape: [1, max_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]

class TransformerEncoderBlock(nn.Module):
    """
    Embeds tokens, adds positional encoding, runs TransformerEncoder,
    and pools the output (mean pooling by default).
    """
    def __init__(self, vocab_size, d_model=128, nhead=8, num_layers=2, 
                 dim_feedforward=512, dropout=0.1, max_len=5000, pooling='mean'):
        super().__init__()
        
        self.d_model = d_model
        self.pooling = pooling
        
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_encoder = PositionalEncoding(d_model, max_len=max_len)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
    def forward(self, x):
        # x: [batch_size, seq_len]
        embedded = self.embedding(x) * math.sqrt(self.d_model)
        encoded = self.pos_encoder(embedded)
        out = self.transformer_encoder(encoded)
        
        if self.pooling == 'mean':
            pooled = out.mean(dim=1)
        else:
            pooled = out.mean(dim=1)
        return pooled

class TwoTransformerFusionDNAClassifier(nn.Module):
    """
    Two separate Transformer encoders -> fused representation -> classification.
    """
    def __init__(self, vocab_size, num_classes, d_model=128, nhead=8, num_layers=2, 
                 dim_feedforward=512, dropout=0.1, max_len=5000, pooling='mean'):
        super().__init__()
        
        self.transformer1 = TransformerEncoderBlock(
            vocab_size=vocab_size,
            d_model=d_model,
            nhead=nhead,
            num_layers=num_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            max_len=max_len,
            pooling=pooling
        )
        self.transformer2 = TransformerEncoderBlock(
            vocab_size=vocab_size,
            d_model=d_model,
            nhead=nhead,
            num_layers=num_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            max_len=max_len,
            pooling=pooling
        )
        
        fused_dim = 2 * d_model
        
        self.classifier = nn.Sequential(
            nn.Linear(fused_dim, fused_dim),
            nn.ReLU(),
            nn.Linear(fused_dim, num_classes)
        )
    
    def forward(self, x1, x2):
        f1 = self.transformer1(x1)
        f2 = self.transformer2(x2)
        fused = torch.cat([f1, f2], dim=1)
        logits = self.classifier(fused)
        return logits

###################################
# 7. Putting It All Together
###################################

# 7A) Load & prepare raw data
fasta_file = "data2/fungi_ITS_cleaned.fasta"
raw_data = parse_fasta_with_labels(fasta_file)
raw_data = filter_classes(raw_data, min_count=10)

# Build vocabulary from the entire raw data (both forward and its reverse)
tmp_data = []
for (lbl, seq) in raw_data:
    tmp_data.append((lbl, seq))
    tmp_data.append((lbl, reverse_complement(seq)))

k = 6
vocab = build_kmer_vocab(tmp_data, k=k)

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"

###################################
# 8. Run the Experiment 10 Times
###################################

num_runs = 10
epochs = 100
best_accuracies = []

for run in range(num_runs):
    print(f"\n=== Run {run+1}/{num_runs} ===")
    # Create a new train-test split for each run
    train_data, test_data = create_train_test_split(raw_data)
    paired_train = create_paired_data(train_data)
    paired_test  = create_paired_data(test_data)
    
    train_dataset = TwoFastaKmerDataset(paired_train, vocab, k=k)
    test_dataset  = TwoFastaKmerDataset(paired_test,  vocab, k=k)
    
    batch_size = 8
    train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                              shuffle=True, collate_fn=collate_fn_two)
    test_loader  = DataLoader(test_dataset,  batch_size=batch_size, 
                              shuffle=False, collate_fn=collate_fn_two)
    
    num_classes = train_dataset.get_num_classes()
    vocab_size = train_dataset.get_vocab_size()
    
    # 7B) Create the model for this run
    model = TwoTransformerFusionDNAClassifier(
        vocab_size=vocab_size,
        num_classes=num_classes,
        d_model=512,
        nhead=8,
        num_layers=2,
        dim_feedforward=512,
        dropout=0.1,
        max_len=5000,
        pooling='mean'
    ).to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    # Function to evaluate accuracy
    def evaluate_accuracy(model, data_loader):
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for fwd, rev, labels in data_loader:
                fwd, rev, labels = fwd.to(device), rev.to(device), labels.to(device)
                logits = model(fwd, rev)
                preds = torch.argmax(logits, dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        return 100.0 * correct / total
    
    best_run_acc = 0.0
    # Training loop for current run
    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        
        for fwd, rev, labels in train_loader:
            fwd, rev, labels = fwd.to(device), rev.to(device), labels.to(device)
            optimizer.zero_grad()
            logits = model(fwd, rev)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader)
        test_acc  = evaluate_accuracy(model, test_loader)
        
        if test_acc > best_run_acc:
            best_run_acc = test_acc
        
        print(f"Epoch {epoch}/{epochs} | Loss: {avg_loss:.4f} | Test Acc: {test_acc:.2f}%")
    
    best_accuracies.append(best_run_acc)
    print(f"Run {run+1} Best Test Accuracy: {best_run_acc:.2f}%")

avg_accuracy = sum(best_accuracies) / len(best_accuracies)
print(f"\nAverage Highest Test Accuracy over {num_runs} runs: {avg_accuracy:.2f}%")


# 1 transformer to fc, run 10 times

In [None]:
import math
import random
from collections import defaultdict, Counter

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from Bio import SeqIO

###################################
# 1. FASTA Parsing and Filtering
###################################

def parse_fasta_with_labels(fasta_file):
    """
    Parses a FASTA file where each header line is assumed to be:
        >ClassLabel
        DNASEQUENCE
    Returns:
        list of tuples (label, sequence)
    """
    data = []
    for record in SeqIO.parse(fasta_file, "fasta"):
        header = record.description.strip()
        sequence = str(record.seq).upper()
        label = header.split()[0]
        data.append((label, sequence))
    return data


def create_train_test_split(raw_data):
    """
    Given a list of (label, sequence) pairs, 
    pick 1 random sample per class for test, 
    and the rest for train.
    """
    label_to_samples = defaultdict(list)
    for label, seq in raw_data:
        label_to_samples[label].append(seq)
    
    train_data = []
    test_data = []
    
    for label, seqs in label_to_samples.items():
        random.shuffle(seqs)
        # pick one for test
        test_seq = seqs[0]
        # remainder for train
        train_seqs = seqs[1:]
        
        test_data.append((label, test_seq))
        for s in train_seqs:
            train_data.append((label, s))
    
    return train_data, test_data

###################################
# 2. K-mer Processing
###################################

def generate_kmers(sequence, k=6):
    """
    Generates overlapping K-mers of length k from a DNA sequence.
    """
    kmers = []
    for i in range(len(sequence) - k + 1):
        kmers.append(sequence[i:i + k])
    return kmers

def build_kmer_vocab(dataset, k=6):
    """
    dataset is a list of (label, seq) pairs.
    Returns a dict mapping each K-mer to an integer index.
    """
    kmer_set = set()
    for _, seq in dataset:
        kmers = generate_kmers(seq, k)
        kmer_set.update(kmers)
    
    vocab = {"<UNK>": 0}
    for i, kmer in enumerate(sorted(kmer_set), start=1):
        vocab[kmer] = i
    return vocab

def encode_sequence(sequence, vocab, k=6):
    """
    Convert a DNA sequence to a list of token indices based on K-mer vocab.
    """
    kmers = generate_kmers(sequence, k)
    encoded = [vocab.get(kmer, vocab["<UNK>"]) for kmer in kmers]
    return encoded

def filter_classes(raw_data, min_count=10):
    """
    Keep only classes that have >= min_count samples.
    Discard classes with fewer samples.
    """
    label_counts = Counter([label for (label, _) in raw_data])
    filtered_data = [
        (label, seq) 
        for (label, seq) in raw_data
        if label_counts[label] >= min_count
    ]
    return filtered_data

###################################
# 3. Create "Paired" Data
###################################
"""
For demonstration, we'll pair each forward sequence with its reverse.
"""

def reverse_complement(seq):
    """
    For demonstration, simply reverse the sequence.
    (A true reverse complement would also swap nucleotides.)
    """
    return seq[::-1]

def create_paired_data(data_list):
    """
    For each (label, seq) in data_list, produce
    (label, fwd_seq, rev_seq).
    """
    paired = []
    for label, seq in data_list:
        rev_seq = reverse_complement(seq)
        paired.append((label, seq, rev_seq))
    return paired

###################################
# 4. PyTorch Dataset for Two Inputs
###################################

class TwoFastaKmerDataset(Dataset):
    """
    Each item: (encoded_seq_fwd, encoded_seq_rev, label_idx)
    """
    def __init__(self, paired_data, vocab, k=6):
        """
        paired_data: list of (label, fwd_seq, rev_seq)
        vocab: dict for k-mers
        """
        super().__init__()
        self.vocab = vocab
        self.k = k
        
        # Gather labels and create mapping
        labels = sorted(set(item[0] for item in paired_data))
        self.label2idx = {lbl: i for i, lbl in enumerate(labels)}
        
        self.encoded_data = []
        for label, fwd_seq, rev_seq in paired_data:
            x1 = encode_sequence(fwd_seq, self.vocab, k=self.k)
            x2 = encode_sequence(rev_seq, self.vocab, k=self.k)
            y = self.label2idx[label]
            self.encoded_data.append((x1, x2, y))
    
    def __len__(self):
        return len(self.encoded_data)
    
    def __getitem__(self, idx):
        return self.encoded_data[idx]  # (fwd, rev, label_idx)
    
    def get_vocab_size(self):
        return len(self.vocab)
    
    def get_num_classes(self):
        return len(self.label2idx)

###################################
# 5. Collate Function for Two Inputs
###################################

def collate_fn_two(batch):
    """
    batch: list of (seq_fwd, seq_rev, label)
    Pads both forward and reverse sequences.
    """
    seqs_fwd, seqs_rev, labels = zip(*batch)
    
    seq_fwd_tensors = [torch.tensor(s, dtype=torch.long) for s in seqs_fwd]
    seq_rev_tensors = [torch.tensor(s, dtype=torch.long) for s in seqs_rev]
    
    padded_fwd = pad_sequence(seq_fwd_tensors, batch_first=True, padding_value=0)
    padded_rev = pad_sequence(seq_rev_tensors, batch_first=True, padding_value=0)
    
    labels_tensors = torch.tensor(labels, dtype=torch.long)
    return padded_fwd, padded_rev, labels_tensors

###################################
# 6. Single Transformer Model
###################################

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)  # shape: [1, max_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]

class TransformerEncoderBlock(nn.Module):
    """
    Embeds tokens, adds positional encoding, runs TransformerEncoder,
    and pools the output (mean pooling by default).
    """
    def __init__(self, vocab_size, d_model=128, nhead=8, num_layers=2, 
                 dim_feedforward=512, dropout=0.1, max_len=5000, pooling='mean'):
        super().__init__()
        
        self.d_model = d_model
        self.pooling = pooling
        
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_encoder = PositionalEncoding(d_model, max_len=max_len)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
    def forward(self, x):
        # x: [batch_size, seq_len]
        embedded = self.embedding(x) * math.sqrt(self.d_model)
        encoded = self.pos_encoder(embedded)
        out = self.transformer_encoder(encoded)
        
        if self.pooling == 'mean':
            pooled = out.mean(dim=1)
        else:
            pooled = out.mean(dim=1)
        return pooled

class SingleTransformerDNAClassifier(nn.Module):
    """
    Single Transformer Encoder that processes a concatenated sequence.
    
    The forward and reverse sequences are concatenated along the sequence dimension,
    then passed through a single Transformer encoder. The pooled output is used for classification.
    """
    def __init__(self, vocab_size, num_classes, d_model=128, nhead=8, num_layers=2, 
                 dim_feedforward=512, dropout=0.1, max_len=5000, pooling='mean'):
        super().__init__()
        
        self.encoder = TransformerEncoderBlock(
            vocab_size=vocab_size,
            d_model=d_model,
            nhead=nhead,
            num_layers=num_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            max_len=max_len,
            pooling=pooling
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, num_classes)
        )
    
    def forward(self, x1, x2):
        # x1: [B, seq_len1]
        # x2: [B, seq_len2]
        # Concatenate along sequence dimension
        x = torch.cat([x1, x2], dim=1)  # [B, seq_len1+seq_len2]
        encoded = self.encoder(x)       # [B, d_model]
        logits = self.classifier(encoded)
        return logits

###################################
# 7. Putting It All Together
###################################

# 7A) Load & prepare raw data
fasta_file = "data2/fungi_ITS_cleaned.fasta"
raw_data = parse_fasta_with_labels(fasta_file)
raw_data = filter_classes(raw_data, min_count=10)

# Build vocabulary from the entire raw data (both forward and its reverse)
tmp_data = []
for (lbl, seq) in raw_data:
    tmp_data.append((lbl, seq))
    tmp_data.append((lbl, reverse_complement(seq)))

k = 6
vocab = build_kmer_vocab(tmp_data, k=k)

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"

###################################
# 8. Run the Experiment 10 Times
###################################

num_runs = 10
epochs = 30
best_accuracies = []

for run in range(num_runs):
    print(f"\n=== Run {run+1}/{num_runs} ===")
    # Create a new train-test split for each run
    train_data, test_data = create_train_test_split(raw_data)
    paired_train = create_paired_data(train_data)
    paired_test  = create_paired_data(test_data)
    
    train_dataset = TwoFastaKmerDataset(paired_train, vocab, k=k)
    test_dataset  = TwoFastaKmerDataset(paired_test,  vocab, k=k)
    
    batch_size = 8
    train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                              shuffle=True, collate_fn=collate_fn_two)
    test_loader  = DataLoader(test_dataset,  batch_size=batch_size, 
                              shuffle=False, collate_fn=collate_fn_two)
    
    num_classes = train_dataset.get_num_classes()
    vocab_size = train_dataset.get_vocab_size()
    
    # 7B) Create the single transformer model for this run
    model = SingleTransformerDNAClassifier(
        vocab_size=vocab_size,
        num_classes=num_classes,
        d_model=512,
        nhead=8,
        num_layers=2,
        dim_feedforward=512,
        dropout=0.1,
        max_len=5000,
        pooling='mean'
    ).to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    # Function to evaluate accuracy
    def evaluate_accuracy(model, data_loader):
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for fwd, rev, labels in data_loader:
                fwd, rev, labels = fwd.to(device), rev.to(device), labels.to(device)
                logits = model(fwd, rev)
                preds = torch.argmax(logits, dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        return 100.0 * correct / total
    
    best_run_acc = 0.0
    # Training loop for current run
    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        
        for fwd, rev, labels in train_loader:
            fwd, rev, labels = fwd.to(device), rev.to(device), labels.to(device)
            optimizer.zero_grad()
            logits = model(fwd, rev)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader)
        test_acc  = evaluate_accuracy(model, test_loader)
        
        if test_acc > best_run_acc:
            best_run_acc = test_acc
        
        print(f"Epoch {epoch}/{epochs} | Loss: {avg_loss:.4f} | Test Acc: {test_acc:.2f}%")
    
    best_accuracies.append(best_run_acc)
    print(f"Run {run+1} Best Test Accuracy: {best_run_acc:.2f}%")

avg_accuracy = sum(best_accuracies) / len(best_accuracies)
print(f"\nAverage Highest Test Accuracy over {num_runs} runs: {avg_accuracy:.2f}%")
