In [33]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F


In [34]:
import os
from Bio import SeqIO

def encode_and_save_fasta(fasta_file, data_file):
    # Define the mapping for encoding
    nucleotide_mapping = {'A': '1', 'C': '2', 'G': '3', 'T': '4'}

    # Ensure the directory exists
    os.makedirs(os.path.dirname(data_file), exist_ok=True)

    # Extract sequences, encode, and save to the file
    with open(data_file, 'w') as f:
        for record in SeqIO.parse(fasta_file, "fasta"):
            label = record.description  # Get the label from the header
            sequence = str(record.seq).upper()  # Ensure the sequence is uppercase

            # Encode the sequence using the nucleotide mapping
            encoded_sequence = ''.join([nucleotide_mapping.get(nuc, '0') for nuc in sequence])

            # Write to the file in the format: class label encoded_data
            f.write(f"{label} {encoded_sequence}\n")

# File paths
fasta_file = "data2/fungi_ITS_cleaned.fasta"
data_file = "data2/encoded_data.txt"

# Call the function
encode_and_save_fasta(fasta_file, data_file)


# genera more than x samples

In [35]:
import os
from collections import Counter
def filter_classes_with_more_than_5_samples(encoded_data_file, filtered_data_file):
    # Step 1: Count occurrences of each class
    class_counts = Counter()

    # Read the encoded data and count class occurrences
    with open(encoded_data_file, 'r') as f:
        lines = f.readlines()
        for line in lines:
            class_label = line.split()[0]  # Assume class label is the first part of each line
            class_counts[class_label] += 1

    # Step 2: Filter classes with more than 5 samples
    valid_classes = {cls for cls, count in class_counts.items() if count >= 10}

    # Step 3: Save filtered sequences to another file
    with open(filtered_data_file, 'w') as filtered_f:
        for line in lines:
            class_label = line.split()[0]
            # Write to the filtered file only if the class has more than 5 samples
            if class_label in valid_classes:
                filtered_f.write(line)


encoded_data_file = "data2/encoded_data.txt"
filtered_data_file = "data2/filtered_encoded_data.txt"
filter_classes_with_more_than_5_samples(encoded_data_file, filtered_data_file)

# train test split

In [36]:
import os
import random
from collections import defaultdict

def train_test_split(filtered_data_file, train_file, test_file):
    # Dictionary to hold samples for each class
    class_samples = defaultdict(list)

    # Step 1: Read the filtered data file and organize samples by class
    with open(filtered_data_file, 'r') as f:
        lines = f.readlines()
        for line in lines:
            class_label = line.split()[0]
            class_samples[class_label].append(line)

    # Step 2: Split into train and test
    train_samples = []
    test_samples = []

    for class_label, samples in class_samples.items():
        if len(samples) > 1:  # Only take a sample for test if there are multiple samples
            test_sample = random.choice(samples)
            test_samples.append(test_sample)
            # Add the remaining samples to the train set
            train_samples.extend([sample for sample in samples if sample != test_sample])
        else:
            # If only one sample, add it to the train set
            train_samples.extend(samples)

    # Step 3: Save train and test samples to respective files
    with open(train_file, 'w') as train_f, open(test_file, 'w') as test_f:
        train_f.writelines(train_samples)
        test_f.writelines(test_samples)

# File paths
filtered_data_file = "data2/filtered_encoded_data.txt"
train_file = "data2/train_data.txt"
test_file = "data2/test_data.txt"

# Perform train-test split
train_test_split(filtered_data_file, train_file, test_file)


In [37]:
class SequenceDataset(Dataset):
    def __init__(self, data_file):
        """
        Each line in data_file is like:
            label  3411134123...
        Where `label` is a string (like 'Trichoderma') and the rest is a sequence of digits.

        This class:
          - Maps each unique label to a numeric index
          - Converts each sequence of digits into a list of integers
        """
        self.samples = []
        self.label_mapping = {}
        label_counter = 0
        
        with open(data_file, 'r') as f:
            lines = f.readlines()
            for line in lines:
                line = line.strip()
                if not line:
                    continue  # skip empty lines

                # Split once on whitespace:
                parts = line.split(maxsplit=1)
                if len(parts) < 2:
                    # If there's a malformed line that doesn't have both label and sequence
                    continue

                label_str, sequence_str = parts[0], parts[1]

                # Map label to numeric index
                if label_str not in self.label_mapping:
                    self.label_mapping[label_str] = label_counter
                    label_counter += 1

                numeric_label = self.label_mapping[label_str]
                # Convert each character in the sequence string to an integer
                sequence = [int(x) for x in sequence_str]

                self.samples.append((sequence, numeric_label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sequence, label = self.samples[idx]
        # Return torch tensors
        # The sequence is variable-length, so we handle actual padding in a collate_fn
        return torch.tensor(sequence, dtype=torch.float32), torch.tensor(label, dtype=torch.long)


# ---------------------------------------------------------------------
# 2) Collate Function: Pads variable-length sequences
# ---------------------------------------------------------------------
def pad_collate(batch):
    """
    Expects a batch of (sequence, label) pairs.
    Finds the longest sequence, pads all others to that length with zeros.
    Returns (padded_sequences, labels).
    """
    sequences, labels = zip(*batch)
    max_length = max(len(seq) for seq in sequences)

    padded_sequences = []
    for seq in sequences:
        # Pad up to max_length
        length_diff = max_length - len(seq)
        if length_diff > 0:
            seq = torch.cat([seq, torch.zeros(length_diff)])
        padded_sequences.append(seq)

    # Stack into one tensor of shape (batch, max_length)
    padded_sequences = torch.stack(padded_sequences, dim=0)
    labels = torch.tensor(labels, dtype=torch.long)
    return padded_sequences, labels


# ---------------------------------------------------------------------
# 3) Example: Create Datasets & DataLoaders
# ---------------------------------------------------------------------
train_file = "data2/train_data.txt"
test_file  = "data2/test_data.txt"

train_dataset = SequenceDataset(train_file)
test_dataset  = SequenceDataset(test_file)

# RRCNN-LSTM

In [41]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim



# IMPORTANT: If you have 69 distinct labels, label indices will be [0..68].
# So we must ensure num_classes == len(train_dataset.label_mapping)
num_classes = len(train_dataset.label_mapping)
print("Number of training samples:", len(train_dataset))
print("Number of distinct classes:", num_classes)
print("Label mapping:", train_dataset.label_mapping)

batch_size = 16
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=pad_collate
)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=pad_collate
)


# ---------------------------------------------------------------------
# 4) Define the RRCNN + LSTM for 1D Sequences
# ---------------------------------------------------------------------

class RRCNNBlock1D(nn.Module):
    """
    A single Recurrent Residual block for 1D sequences:
      (Conv1d -> ReLU -> Conv1d -> ReLU), repeated `num_recurrent` times,
      each time adding a residual connection.
    """
    def __init__(self, channels, kernel_size=3, padding=1, num_recurrent=2):
        super().__init__()
        self.num_recurrent = num_recurrent
        self.conv = nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # x shape: (batch, channels, length)
        out = x
        for _ in range(self.num_recurrent):
            residual = out
            out = self.conv(out)   # conv1
            out = self.relu(out)
            out = self.conv(out)   # conv2
            out = self.relu(out)
            out = out + residual   # residual connection
        return out

# --------------------------------------------------
# 3) RRCNN for 1D data
#    (Stack multiple RRCNN blocks + optional global pooling)
# --------------------------------------------------
class RRCNN1D(nn.Module):
    def __init__(self, 
                 in_channels=4, 
                 hidden_channels=16, 
                 num_blocks=2, 
                 num_recurrent=6, 
                 kernel_size=3, 
                 dropout_p=0.2,
                 use_global_pool=True):
        """
        :param in_channels: e.g. 4 if you one-hot encode A,C,G,T as separate channels.
                            But if you're just passing integer-coded [0..3], 
                            you usually set in_channels=1 and embed first.
        :param hidden_channels: the number of feature maps in the CNN
        :param num_blocks: how many RRCNNBlock1D to stack
        :param num_recurrent: how many recurrent steps per block
        :param kernel_size, padding: typical convolution params
        :param use_global_pool: if True, do global average pool over sequence length
        """
        super().__init__()
        self.use_global_pool = use_global_pool
        self.dropout = nn.Dropout(p=dropout_p)
        
        # Entry convolution to go from in_channels -> hidden_channels
        self.entry_conv = nn.Conv1d(in_channels, hidden_channels, kernel_size=kernel_size, padding=1)
        self.relu = nn.ReLU(inplace=True)
        
        # Stack multiple RRCNN blocks
        blocks = []
        for _ in range(num_blocks):
            blocks.append(RRCNNBlock1D(hidden_channels, kernel_size=kernel_size, padding=1, num_recurrent=num_recurrent))
        self.blocks = nn.Sequential(*blocks)

        # Exit convolution (optional)
        self.exit_conv = nn.Conv1d(hidden_channels, hidden_channels, kernel_size=kernel_size, padding=1)

    def forward(self, x):
        """
        x shape: (batch, channels, length)
        """
        out = self.entry_conv(x)
        out = self.relu(out)
        out = self.blocks(out)
        out = self.exit_conv(out)
        if self.use_global_pool:
            # global average pooling over length dimension => (batch, hidden_channels)
            out = out.mean(dim=-1)
        return out

# --------------------------------------------------
# 4) RRCNN + LSTM Model for DNA classification
# --------------------------------------------------
class RRCNN_LSTM(nn.Module):
    """
    RRCNN + LSTM + Dense classification layer
    """
    def __init__(self,
                 vocab_size=4,            # e.g. A,C,G,T
                 embed_dim=4,             # dimension to embed each nucleotide
                 hidden_channels=16,
                 rrcnn_blocks=4,
                 rrcnn_recurrent=2,
                 kernel_size=3,
                 lstm_hidden_dim=64,
                 lstm_layers=1,
                 num_classes=2,
                 use_global_pool=True):
        super().__init__()

        # If your data is integer-coded [0..3], embed first => shape (batch, embed_dim, length)
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)

        # RRCNN for feature extraction in 1D
        # in_channels=embed_dim because the embedding dimension is your "channel" dimension
        self.rrcnn = RRCNN1D(
            in_channels=embed_dim,
            hidden_channels=hidden_channels,
            num_blocks=rrcnn_blocks,
            num_recurrent=rrcnn_recurrent,
            kernel_size=kernel_size,
            use_global_pool=use_global_pool
        )
        
        # LSTM: input_size = hidden_channels if use_global_pool else hidden_channels * ...
        self.lstm = nn.LSTM(
            input_size=hidden_channels,
            hidden_size=lstm_hidden_dim,
            num_layers=lstm_layers,
            batch_first=True
        )

        # Final classification layer
        self.fc = nn.Linear(lstm_hidden_dim, num_classes)

    def forward(self, x):
        # x: shape (batch, seq_len) of integer-coded nucleotides
        # 1) Embedding => shape (batch, seq_len, embed_dim)
        x = self.embedding(x)  # => (batch, seq_len, embed_dim)

        # 2) Transpose to (batch, embed_dim, seq_len) for Conv1D
        x = x.transpose(1, 2)  # => (batch, embed_dim, seq_len)

        # 3) RRCNN => if use_global_pool=True => (batch, hidden_channels)
        #             else => (batch, hidden_channels, seq_len)
        feats = self.rrcnn(x)
        
        # 4) If used global_pool, feats is (batch, hidden_channels). 
        #    We treat each entire sequence as one "time-step" => (batch, 1, hidden_channels).
        feats = feats.unsqueeze(1)  # => (batch, 1, hidden_channels)

        # 5) LSTM => output shape (batch, 1, lstm_hidden_dim)
        lstm_out, (h_n, c_n) = self.lstm(feats)
        
        # 6) Take last time-step => (batch, lstm_hidden_dim)
        last_out = lstm_out[:, -1, :]

        # 7) Final classifier => (batch, num_classes)
        logits = self.fc(last_out)
        return logits


# ---------------------------------------------------------------------
# 5) Instantiate Model, Optimizer, Criterion
# ---------------------------------------------------------------------
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
model = RRCNN1D_LSTM(
    in_channels=1,
    hidden_channels=16,
    rrcnn_blocks=2,
    rrcnn_recurrent=2,
    kernel_size=3,
    lstm_hidden_dim=64,
    lstm_layers=1,
    num_classes=num_classes,    # CRITICAL: matches number of distinct labels
    use_global_pool=True
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3)


# ---------------------------------------------------------------------
# 6) Training & Testing Loop
# ---------------------------------------------------------------------
num_epochs = 20
for epoch in range(num_epochs):
    # -- Training --
    model.train()
    total_train_loss = 0.0
    total_correct_train = 0
    total_train_samples = 0

    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        
        # Forward pass
        outputs = model(data)  # (batch, num_classes)
        loss = criterion(outputs, labels)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Stats
        batch_size = labels.size(0)
        total_train_loss += loss.item() * batch_size
        preds = outputs.argmax(dim=1)
        total_correct_train += (preds == labels).sum().item()
        total_train_samples += batch_size

    avg_train_loss = total_train_loss / total_train_samples
    train_accuracy = 100.0 * total_correct_train / total_train_samples

    # -- Testing --
    model.eval()
    total_test_loss = 0.0
    total_correct_test = 0
    total_test_samples = 0
    
    with torch.no_grad():
        for batch_idx, (data, labels) in enumerate(test_loader):
            data, labels = data.to(device), labels.to(device)

            outputs = model(data)
            loss = criterion(outputs, labels)

            batch_size = labels.size(0)
            total_test_loss += loss.item() * batch_size
            preds = outputs.argmax(dim=1)
            total_correct_test += (preds == labels).sum().item()
            total_test_samples += batch_size

    avg_test_loss = total_test_loss / total_test_samples
    test_accuracy = 100.0 * total_correct_test / total_test_samples

    print(f"Epoch [{epoch+1}/{num_epochs}] "
          f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.2f}% | "
          f"Test Loss: {avg_test_loss:.4f}, Test Acc: {test_accuracy:.2f}%")


Number of training samples: 1612
Number of distinct classes: 81
Label mapping: {'Cortinarius': 0, 'Aspergillus': 1, 'Inocybe': 2, 'Trichoderma': 3, 'Talaromyces': 4, 'Amanita': 5, 'Entoloma': 6, 'Orbilia': 7, 'Russula': 8, 'Lactarius': 9, 'Elsinoe': 10, 'Phyllosticta': 11, 'Mucor': 12, 'Candida': 13, 'Apiospora': 14, 'Exophiala': 15, 'Marasmius': 16, 'Hypoxylon': 17, 'Ogataea': 18, 'Tuber': 19, 'Pluteus': 20, 'Scolecobasidium': 21, 'Lactifluus': 22, 'Metschnikowia': 23, 'Leucoagaricus': 24, 'Gymnopus': 25, 'Xylodon': 26, 'Cladophialophora': 27, 'Tomentella': 28, 'Otidea': 29, 'Kazachstania': 30, 'Verrucaria': 31, 'Lipomyces': 32, 'Hygrophorus': 33, 'Geastrum': 34, 'Pseudosperma': 35, 'Boletus': 36, 'Cyberlindnera': 37, 'Absidia': 38, 'Sugiyamaella': 39, 'Wickerhamiella': 40, 'Mortierella': 41, 'Arthroderma': 42, 'Suhomyces': 43, 'Fomitiporia': 44, 'Tremella': 45, 'Xylaria': 46, 'Starmerella': 47, 'Trechispora': 48, 'Cyphellophora': 49, 'Mycena': 50, 'Wickerhamomyces': 51, 'Pichia': 52,