In [1]:
import math
import random
from collections import defaultdict, Counter

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from Bio import SeqIO


In [2]:
###############################
# Parsing and K-mer functions
###############################

def parse_fasta_with_labels(fasta_file):
    """
    Parses a FASTA file where each header line is assumed to be:
        >ClassLabel
        DNASEQUENCE
    Returns:
        list of tuples (label, sequence)
    """
    data = []
    for record in SeqIO.parse(fasta_file, "fasta"):
        header = record.description.strip()  # e.g. "Cortinarius"
        sequence = str(record.seq).upper()   # DNA sequence in uppercase
        label = header.split()[0]           # If the header is just ">Cortinarius", label = "Cortinarius"
        data.append((label, sequence))
    return data


def generate_kmers(sequence, k=6):
    """
    Generates overlapping K-mers of length k from a DNA sequence.
    """
    kmers = []
    for i in range(len(sequence) - k + 1):
        kmers.append(sequence[i:i + k])
    return kmers

def build_kmer_vocab(dataset, k=6):
    """
    dataset is a list of (label, seq) pairs
    returns:
      - vocab: dict mapping K-mer to index
    """
    kmer_set = set()
    for _, seq in dataset:
        kmers = generate_kmers(seq, k)
        kmer_set.update(kmers)
    
    # Create a vocabulary where each K-mer gets a unique integer ID
    # Let's reserve a special token for unknown words: <UNK>
    vocab = {"<UNK>": 0}
    for i, kmer in enumerate(sorted(kmer_set), start=1):
        vocab[kmer] = i
    
    return vocab

def encode_sequence(sequence, vocab, k=6):
    kmers = generate_kmers(sequence, k)
    encoded = [vocab[kmer] if kmer in vocab else vocab["<UNK>"]
               for kmer in kmers]
    return encoded






# genera more than x samples

In [3]:
def filter_classes(raw_data, min_count=10):
    """
    Keep only classes that have > min_count samples.
    Discard classes with <= min_count samples.
    """
    label_counts = Counter([label for (label, _) in raw_data])
    filtered_data = [
        (label, seq) 
        for (label, seq) in raw_data
        if label_counts[label] >= min_count
    ]
    return filtered_data

# train test split

In [4]:
###############################
# Dataset class
###############################

class FastaKmerDataset(Dataset):
    def __init__(self, data_list, vocab, k=6):
        """
        data_list is a list of (label, seq) pairs
        vocab is a dict mapping K-mers to indices
        """
        super().__init__()
        
        self.vocab = vocab
        self.k = k
        
        # Gather all labels in a set
        labels = sorted(set([item[0] for item in data_list]))
        self.label2idx = {label: i for i, label in enumerate(labels)}
        
        # Pre-encode data
        self.encoded_data = []
        for label, seq in data_list:
            encoded_seq = encode_sequence(seq, self.vocab, k=self.k)
            numeric_label = self.label2idx[label]
            self.encoded_data.append((encoded_seq, numeric_label))
    
    def __len__(self):
        return len(self.encoded_data)
    
    def __getitem__(self, idx):
        return self.encoded_data[idx]
    
    def get_vocab_size(self):
        return len(self.vocab)
    
    def get_num_classes(self):
        return len(self.label2idx)
    
    def idx2label(self, idx):
        # Reverse lookup if needed
        for lbl, lbl_idx in self.label2idx.items():
            if lbl_idx == idx:
                return lbl
        return None


###############################
# Collate function for padding
###############################

from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    """
    batch: list of (encoded_seq, label)
    We'll pad the sequences to the same length and return Tensors
    """
    # Unpack
    sequences, labels = zip(*batch)
    
    # Convert each list of ints into a torch.Tensor
    seq_tensors = [torch.tensor(seq, dtype=torch.long) for seq in sequences]
    
    # Pad them to the same length
    padded_seqs = pad_sequence(seq_tensors, batch_first=True, padding_value=0)
    
    # Convert labels to tensor
    label_tensor = torch.tensor(labels, dtype=torch.long)
    
    return padded_seqs, label_tensor


#################################
# Split data into train and test
#################################

def create_train_test_split(raw_data):
    """
    Given a list of (label, sequence) pairs, 
    pick 1 random sample per class for test, 
    and the rest for train.
    
    Returns:
      train_data, test_data
    """
    # Group data by label
    label_to_samples = defaultdict(list)
    for label, seq in raw_data:
        label_to_samples[label].append(seq)
    
    train_data = []
    test_data = []
    
    for label, seqs in label_to_samples.items():
        # Shuffle sequences
        random.shuffle(seqs)
        # Pick one for test
        test_seq = seqs[0]
        # Remainder for train
        train_seqs = seqs[1:]
        
        test_data.append((label, test_seq))
        for s in train_seqs:
            train_data.append((label, s))
    
    return train_data, test_data

# make pytorch dataset

In [5]:
# 1) Parse FASTA
fasta_file = "data2/fungi_ITS_cleaned.fasta"
raw_data = parse_fasta_with_labels(fasta_file)

# 2) Filter out classes that do NOT have >10 samples
raw_data = filter_classes(raw_data, min_count=5)

# 3) Create the train/test split
train_data, test_data = create_train_test_split(raw_data)

# 4) Build vocab (either from just train or from all)
#    Here we build from train+test to minimize OOV in test.
combined_data = train_data + test_data
vocab = build_kmer_vocab(combined_data, k=6)

# 5) Create dataset objects
train_dataset = FastaKmerDataset(train_data, vocab=vocab, k=6)
test_dataset  = FastaKmerDataset(test_data,  vocab=vocab, k=6)

# 6) Create DataLoaders
batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                            shuffle=True, collate_fn=collate_fn)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, 
                            shuffle=False, collate_fn=collate_fn)

print(f"Train size: {len(train_dataset)}, Test size: {len(test_dataset)}")
print(f"Vocab size: {train_dataset.get_vocab_size()}")
print(f"Num classes: {train_dataset.get_num_classes()}")

# Example: examine one batch of training data
for batch_seqs, batch_labels in train_loader:
    print("Train batch sequences shape:", batch_seqs.shape)
    print("Train batch labels:", batch_labels)
    break

# Example: examine one batch of test data
for batch_seqs, batch_labels in test_loader:
    print("Test batch sequences shape:", batch_seqs.shape)
    print("Test batch labels:", batch_labels)
    break

Train size: 2500, Test size: 247
Vocab size: 6017
Num classes: 247
Train batch sequences shape: torch.Size([4, 1324])
Train batch labels: tensor([196,  43,  70,  29])
Test batch sequences shape: torch.Size([4, 745])
Test batch labels: tensor([ 50, 150,  11,  97])


# transformer based model

In [6]:
###################################
# 5. Transformer Model Definition
###################################

class PositionalEncoding(nn.Module):
    """
    Standard sine/cosine positional encoding from the original Transformer paper.
    """
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                             (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)  # shape: [1, max_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        x: [batch_size, seq_len, d_model]
        returns x + positional_enc up to seq_len
        """
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]

class TransformerDNAClassifier(nn.Module):
    """
    Simple Transformer-based classifier for DNA sequences.
    """
    def __init__(self, 
                 vocab_size, 
                 num_classes, 
                 d_model=128, 
                 nhead=8, 
                 num_layers=2, 
                 dim_feedforward=512, 
                 dropout=0.1,
                 max_len=5000,
                 pooling='mean'):
        super().__init__()
        self.d_model = d_model
        self.pooling = pooling

        # Embedding
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)

        # Positional Encoding
        self.pos_encoder = PositionalEncoding(d_model, max_len=max_len)

        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )

        # Classifier
        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, x):
        """
        x: [batch_size, seq_len]
        returns logits: [batch_size, num_classes]
        """
        embedded = self.embedding(x)                # [batch, seq_len, d_model]
        embedded = embedded * math.sqrt(self.d_model)
        encoded = self.pos_encoder(embedded)        # [batch, seq_len, d_model]
        transformer_out = self.transformer_encoder(encoded)
        
        # Pool across sequence dimension
        if self.pooling == 'mean':
            pooled = transformer_out.mean(dim=1)    # [batch, d_model]
        else:
            pooled = transformer_out[:, 0, :]       # [CLS]-like approach

        logits = self.classifier(pooled)            # [batch, num_classes]
        return logits


###################################
# 6. Putting It All Together
###################################

# Replace this with the path to your FASTA file
fasta_file = "data2/fungi_ITS_cleaned.fasta"

# Step A: Parse and Filter
raw_data = parse_fasta_with_labels(fasta_file)
raw_data = filter_classes(raw_data, min_count=10)
train_data, test_data = create_train_test_split(raw_data)

# Step B: Build vocab from train + test combined (to avoid OOV in test)
combined_data = train_data + test_data
k = 6
vocab = build_kmer_vocab(combined_data, k=k)

# Step C: Create Datasets
train_dataset = FastaKmerDataset(train_data, vocab, k=k)
test_dataset  = FastaKmerDataset(test_data,  vocab, k=k)

# Step D: Create DataLoaders
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                          shuffle=True, collate_fn=collate_fn)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, 
                          shuffle=False, collate_fn=collate_fn)

# Step E: Define Model, Loss, Optimizer
device = "cuda" if torch.cuda.is_available() else "cpu"

num_classes = train_dataset.get_num_classes()
vocab_size = train_dataset.get_vocab_size()

model = TransformerDNAClassifier(
    vocab_size=vocab_size,
    num_classes=num_classes,
    d_model=128,
    nhead=8,
    num_layers=2,
    dim_feedforward=512,
    dropout=0.1,
    max_len=5000,
    pooling='mean'
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

###################################
# 7. Training Loop
#    (Printing Train & Test Acc. each epoch)
###################################

def evaluate_accuracy(model, data_loader, device):
    """
    Evaluate the classification accuracy of the model on a given DataLoader.
    """
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_seqs, batch_labels in data_loader:
            batch_seqs = batch_seqs.to(device)
            batch_labels = batch_labels.to(device)

            logits = model(batch_seqs)
            preds = torch.argmax(logits, dim=1)

            correct += (preds == batch_labels).sum().item()
            total += batch_labels.size(0)
    return 100.0 * correct / total

epochs = 20
for epoch in range(1, epochs + 1):
    # TRAIN
    model.train()
    total_loss = 0.0
    for batch_seqs, batch_labels in train_loader:
        batch_seqs = batch_seqs.to(device)
        batch_labels = batch_labels.to(device)

        optimizer.zero_grad()
        logits = model(batch_seqs)
        loss = criterion(logits, batch_labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)

    # EVALUATE on TRAIN and TEST for accuracy
    train_acc = evaluate_accuracy(model, train_loader, device)
    test_acc  = evaluate_accuracy(model, test_loader,  device)

    print(f"Epoch {epoch}/{epochs} | Loss: {avg_loss:.4f} | "
          f"Train Acc: {train_acc:.2f}% | Test Acc: {test_acc:.2f}%")

  from .autonotebook import tqdm as notebook_tqdm


Epoch 1/20 | Loss: 3.3945 | Train Acc: 35.30% | Test Acc: 9.88%


KeyboardInterrupt: 