# Imports and Device Setup
Import math, random, collections, torch, torch.nn, torch.optim, torch.nn.functional, DataLoader, pad_sequence, Bio.SeqIO, and matplotlib.pyplot. Set the device based on cuda/mps availability.

In [None]:
import math
import random
from collections import defaultdict, Counter

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F  # For KL divergence in distillation
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from Bio import SeqIO
import matplotlib.pyplot as plt  # For plotting

# Set device
device = (
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

# Data Preparation Functions
Define functions: parse_fasta_with_labels, create_train_test_split, generate_kmers, build_kmer_vocab, encode_sequence, filter_classes, reverse_complement, create_paired_data, and the TwoFastaKmerDataset class with collate_fn_two.

In [None]:
# Data Preparation Functions

def parse_fasta_with_labels(fasta_file):
    data = []
    for record in SeqIO.parse(fasta_file, "fasta"):
        header = record.description.strip()
        sequence = str(record.seq).upper()
        label = header.split()[0]
        data.append((label, sequence))
    return data

def create_train_test_split(raw_data):
    label_to_samples = defaultdict(list)
    for label, seq in raw_data:
        label_to_samples[label].append(seq)
    train_data = []
    test_data = []
    for label, seqs in label_to_samples.items():
        random.shuffle(seqs)
        test_seq = seqs[0]
        train_seqs = seqs[1:]
        test_data.append((label, test_seq))
        for s in train_seqs:
            train_data.append((label, s))
    return train_data, test_data

def generate_kmers(sequence, k=6):
    kmers = []
    for i in range(len(sequence) - k + 1):
        kmers.append(sequence[i:i+k])
    return kmers

def build_kmer_vocab(dataset, k=6):
    kmer_set = set()
    for _, seq in dataset:
        kmers = generate_kmers(seq, k)
        kmer_set.update(kmers)
    vocab = {"<UNK>": 0}
    for i, kmer in enumerate(sorted(kmer_set), start=1):
        vocab[kmer] = i
    return vocab

def encode_sequence(sequence, vocab, k=6):
    kmers = generate_kmers(sequence, k)
    encoded = [vocab.get(kmer, vocab["<UNK>"]) for kmer in kmers]
    return encoded

def filter_classes(raw_data, min_count=5):
    label_counts = Counter([label for (label, _) in raw_data])
    filtered_data = [(label, seq) for (label, seq) in raw_data if label_counts[label] >= min_count]
    return filtered_data

def reverse_complement(seq):
    # For simplicity, we just reverse the sequence.
    return seq[::-1]

def create_paired_data(data_list):
    paired = []
    for label, seq in data_list:
        rev_seq = reverse_complement(seq)
        paired.append((label, seq, rev_seq))
    return paired

class TwoFastaKmerDataset(Dataset):
    """
    Each item: (encoded_seq_fwd, encoded_seq_rev, label_idx)
    """
    def __init__(self, paired_data, vocab, k=6):
        super().__init__()
        self.vocab = vocab
        self.k = k
        labels = sorted(set(item[0] for item in paired_data))
        self.label2idx = {lbl: i for i, lbl in enumerate(labels)}
        self.encoded_data = []
        for label, fwd_seq, rev_seq in paired_data:
            x1 = encode_sequence(fwd_seq, self.vocab, k=self.k)
            x2 = encode_sequence(rev_seq, self.vocab, k=self.k)
            y = self.label2idx[label]
            self.encoded_data.append((x1, x2, y))
    
    def __len__(self):
        return len(self.encoded_data)
    
    def __getitem__(self, idx):
        return self.encoded_data[idx]
    
    def get_vocab_size(self):
        return len(self.vocab)
    
    def get_num_classes(self):
        return len(self.label2idx)

def collate_fn_two(batch):
    seqs_fwd, seqs_rev, labels = zip(*batch)
    seqs_fwd_tensors = [torch.tensor(s, dtype=torch.long) for s in seqs_fwd]
    seqs_rev_tensors = [torch.tensor(s, dtype=torch.long) for s in seqs_rev]
    padded_fwd = pad_sequence(seqs_fwd_tensors, batch_first=True, padding_value=0)
    padded_rev = pad_sequence(seqs_rev_tensors, batch_first=True, padding_value=0)
    labels_tensors = torch.tensor(labels, dtype=torch.long)
    return padded_fwd, padded_rev, labels_tensors

# Model Architecture
Define ViTDeepSEAEncoder and TwoViTDeepSEAFusionDNAClassifierWithFC classes to build the two-branch network architecture.

In [None]:
# Model Architecture

class ViTDeepSEAEncoder(nn.Module):
    """
    k-mer embedding, DeepSEA-style conv block, projection,
    positional embedding, Transformer encoding, and mean pooling.
    """
    def __init__(
        self,
        vocab_size,
        embed_dim=128,
        d_model=256,
        num_conv_filters=(320, 480, 960),
        conv_kernel_sizes=(8, 8, 8),
        pool_kernel_sizes=(4, 4),
        num_transformer_layers=2,
        nhead=8,
        dropout=0.2,
        max_seq_len=1000
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.deepsea_conv = nn.Sequential(
            nn.Conv1d(in_channels=embed_dim, out_channels=num_conv_filters[0], kernel_size=conv_kernel_sizes[0]),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=pool_kernel_sizes[0]),
            nn.Conv1d(in_channels=num_conv_filters[0], out_channels=num_conv_filters[1], kernel_size=conv_kernel_sizes[1]),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=pool_kernel_sizes[1]),
            nn.Conv1d(in_channels=num_conv_filters[1], out_channels=num_conv_filters[2], kernel_size=conv_kernel_sizes[2]),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.bn = nn.BatchNorm1d(num_conv_filters[2])
        self.proj = nn.Linear(num_conv_filters[2], d_model)
        self.max_tokens = 150
        self.pos_embedding = nn.Parameter(torch.zeros(1, self.max_tokens, d_model))
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_transformer_layers)
        self.final_norm = nn.LayerNorm(d_model)
        
    def forward(self, x):
        x = self.embedding(x)            # [B, seq_len, embed_dim]
        x = x.transpose(1, 2)            # [B, embed_dim, seq_len]
        x = self.deepsea_conv(x)         # [B, num_conv_filters[-1], L_out]
        x = self.bn(x)
        x = x.transpose(1, 2)            # [B, L_out, num_conv_filters[-1]]
        x = self.proj(x)                 # [B, L_out, d_model]
        B, L, _ = x.size()
        pos_embed = self.pos_embedding[:, :L, :]
        x = x + pos_embed
        x = self.transformer_encoder(x)
        x = self.final_norm(x)
        x = x.mean(dim=1)
        return x

class TwoViTDeepSEAFusionDNAClassifierWithFC(nn.Module):
    """
    Two-branch model: processes forward and reverse sequences,
    concatenates features, and applies a fully-connected layer.
    """
    def __init__(
        self,
        vocab_size,
        num_classes,
        embed_dim=128,
        d_model=256,
        num_conv_filters=(320, 480, 960),
        conv_kernel_sizes=(8, 8, 8),
        pool_kernel_sizes=(4, 4),
        num_transformer_layers=2,
        nhead=8,
        dropout=0.2,
        max_seq_len=1000
    ):
        super().__init__()
        self.vit_branch1 = ViTDeepSEAEncoder(
            vocab_size=vocab_size,
            embed_dim=embed_dim,
            d_model=d_model,
            num_conv_filters=num_conv_filters,
            conv_kernel_sizes=conv_kernel_sizes,
            pool_kernel_sizes=pool_kernel_sizes,
            num_transformer_layers=num_transformer_layers,
            nhead=nhead,
            dropout=dropout,
            max_seq_len=max_seq_len
        )
        self.vit_branch2 = ViTDeepSEAEncoder(
            vocab_size=vocab_size,
            embed_dim=embed_dim,
            d_model=d_model,
            num_conv_filters=num_conv_filters,
            conv_kernel_sizes=conv_kernel_sizes,
            pool_kernel_sizes=pool_kernel_sizes,
            num_transformer_layers=num_transformer_layers,
            nhead=nhead,
            dropout=dropout,
            max_seq_len=max_seq_len
        )
        self.fc = nn.Linear(2 * d_model, num_classes)
    
    def forward(self, x1, x2):
        f1 = self.vit_branch1(x1)
        f2 = self.vit_branch2(x2)
        fused = torch.cat([f1, f2], dim=1)
        logits = self.fc(fused)
        return logits

# Helper Functions for Distillation
Create get_overlapping_indices to match teacher and student label indices, and distillation_loss to compute KL divergence on overlapping classes.

In [None]:
# Helper Functions for Distillation

def get_overlapping_indices(teacher_label2idx, student_label2idx):
    """
    Returns two lists of indices corresponding to overlapping classes.
    """
    teacher_indices = []
    student_indices = []
    for label, t_idx in teacher_label2idx.items():
        if label in student_label2idx:
            teacher_indices.append(t_idx)
            student_indices.append(student_label2idx[label])
    return teacher_indices, student_indices

def distillation_loss(student_logits, teacher_logits, student_overlap, teacher_overlap, T):
    """
    Computes KL divergence over the overlapping classes only.
    """
    s_overlap = student_logits[:, student_overlap]  # [B, num_overlap]
    t_overlap = teacher_logits[:, teacher_overlap]  # [B, num_overlap]
    return F.kl_div(F.log_softmax(s_overlap / T, dim=1),
                    F.softmax(t_overlap / T, dim=1),
                    reduction="batchmean") * (T * T)

# Build Vocabulary and Dataset
Load the FASTA file, generate raw_data using parse_fasta_with_labels, build the kmer vocabulary, and prepare dataset splits for training and testing.

In [None]:
# Build Vocabulary and Dataset

# Load the FASTA file
fasta_file = "data2/fungi_ITS_cleaned.fasta"
raw_data = parse_fasta_with_labels(fasta_file)

# Build the kmer vocabulary
vocab = build_kmer_vocab(raw_data, k=6)

# Prepare dataset splits for training and testing
train_data, test_data = create_train_test_split(raw_data)

# Create paired data for training and testing
paired_train_data = create_paired_data(train_data)
paired_test_data = create_paired_data(test_data)

# Create datasets
train_dataset = TwoFastaKmerDataset(paired_train_data, vocab, k=6)
test_dataset = TwoFastaKmerDataset(paired_test_data, vocab, k=6)

# Print dataset information
print("Number of training samples:", len(train_dataset))
print("Number of testing samples:", len(test_dataset))
print("Vocabulary size:", train_dataset.get_vocab_size())
print("Number of classes:", train_dataset.get_num_classes())

# Stage 1: Student 10 Training
Train Student 10 on the dataset with min_count >= 10, including early stopping, LR scheduling, and saving the best model state. This model will serve as the teacher for subsequent stages.

In [None]:
# Stage 1: Student 10 Training

# Filter classes with min_count >= 10
student10_data = filter_classes(raw_data, min_count=10)

# Create train-test split
student10_train_data, student10_test_data = create_train_test_split(student10_data)

# Create paired data
student10_paired_train = create_paired_data(student10_train_data)
student10_paired_test = create_paired_data(student10_test_data)

# Create datasets
student10_dataset = TwoFastaKmerDataset(student10_paired_train, vocab, k=6)
student10_test_dataset = TwoFastaKmerDataset(student10_paired_test, vocab, k=6)

# Print class information
print("Student 10 classes (min_count>=10):")
for cls in student10_dataset.label2idx:
    print(cls)
print("Number of Student 10 classes:", student10_dataset.get_num_classes())

# Create data loaders
student10_train_loader = DataLoader(student10_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn_two)
student10_test_loader = DataLoader(student10_test_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn_two)

# Model parameters
num_classes_student10 = student10_dataset.get_num_classes()
vocab_size = student10_dataset.get_vocab_size()

# Initialize model
student10_model = TwoViTDeepSEAFusionDNAClassifierWithFC(
    vocab_size=vocab_size,
    num_classes=num_classes_student10,
    embed_dim=128,
    d_model=256,
    num_conv_filters=(320, 480, 960),
    conv_kernel_sizes=(8, 8, 8),
    pool_kernel_sizes=(4, 4),
    num_transformer_layers=2,
    nhead=8,
    dropout=0.2,
    max_seq_len=1000
).to(device)

# Training parameters
student10_epochs = 100
optimizer_student10 = optim.AdamW(student10_model.parameters(), lr=1e-4, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
scheduler_student10 = optim.lr_scheduler.CosineAnnealingLR(optimizer_student10, T_max=student10_epochs)

# Early stopping parameters
best_student10_acc = 0.0
best_student10_state = None
patience = 10
patience_counter = 0

# List to store test accuracy per epoch
teacher_test_acc_list = []

# Training loop
print("Starting Student 10 training (min_count>=10) with early stopping and LR scheduling...")
for epoch in range(1, student10_epochs + 1):
    student10_model.train()
    total_loss = 0.0
    for fwd, rev, labels in student10_train_loader:
        fwd, rev, labels = fwd.to(device), rev.to(device), labels.to(device)
        optimizer_student10.zero_grad()
        logits = student10_model(fwd, rev)
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(student10_model.parameters(), max_norm=1.0)
        optimizer_student10.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(student10_train_loader)
    
    # Evaluate Student 10 accuracy
    student10_model.eval()
    train_correct = 0
    train_total = 0
    for fwd, rev, labels in student10_train_loader:
        fwd, rev, labels = fwd.to(device), rev.to(device), labels.to(device)
        preds = torch.argmax(student10_model(fwd, rev), dim=1)
        train_correct += (preds == labels).sum().item()
        train_total += labels.size(0)
    train_acc = 100.0 * train_correct / train_total
    
    test_correct = 0
    test_total = 0
    for fwd, rev, labels in student10_test_loader:
        fwd, rev, labels = fwd.to(device), rev.to(device), labels.to(device)
        preds = torch.argmax(student10_model(fwd, rev), dim=1)
        test_correct += (preds == labels).sum().item()
        test_total += labels.size(0)
    test_acc = 100.0 * test_correct / test_total
    teacher_test_acc_list.append(test_acc)
    
    scheduler_student10.step(test_acc)
    
    print(f"[Student 10] Epoch {epoch}/{student10_epochs} | Loss: {avg_loss:.4f} | Train Acc: {train_acc:.2f}% | Test Acc: {test_acc:.2f}%")
    
    if test_acc > best_student10_acc:
        best_student10_acc = test_acc
        best_student10_state = student10_model.state_dict()
        patience_counter = 0
    else:
        patience_counter += 1
    if patience_counter >= patience:
        print(f"Early stopping triggered at epoch {epoch}.")
        break

# Load the best model state
student10_model.load_state_dict(best_student10_state)
for param in student10_model.parameters():
    param.requires_grad = False
student10_model.eval()
print(f"Best Student 10 Accuracy: {best_student10_acc:.2f}%")

# Stage 2: Student 8 Training with Distillation
Filter the dataset for classes with min_count >= 8; create paired data; train Student 8 using the teacher logits from Student 10 combined with cross-entropy, applying KD with given temperature and alpha hyperparameters.

In [None]:
# Stage 2: Student 8 Training with Distillation

# Filter classes with min_count >= 8
student8_data = filter_classes(raw_data, min_count=8)

# Create train-test split
student8_train_data, student8_test_data = create_train_test_split(student8_data)

# Create paired data
student8_paired_train = create_paired_data(student8_train_data)
student8_paired_test = create_paired_data(student8_test_data)

# Create datasets
student8_dataset = TwoFastaKmerDataset(student8_paired_train, vocab, k=6)
student8_test_dataset = TwoFastaKmerDataset(student8_paired_test, vocab, k=6)

# Print class information
print("\nStudent 8 classes (min_count>=8):")
for cls in student8_dataset.label2idx:
    print(cls)
print("Number of Student 8 classes:", student8_dataset.get_num_classes())

# Create data loaders
student8_train_loader = DataLoader(student8_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn_two)
student8_test_loader = DataLoader(student8_test_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn_two)

# Model parameters
num_classes_student8 = student8_dataset.get_num_classes()

# Initialize model
student8_model = TwoViTDeepSEAFusionDNAClassifierWithFC(
    vocab_size=vocab_size,
    num_classes=num_classes_student8,
    embed_dim=128,
    d_model=256,
    num_conv_filters=(320, 480, 960),
    conv_kernel_sizes=(8, 8, 8),
    pool_kernel_sizes=(4, 4),
    num_transformer_layers=2,
    nhead=8,
    dropout=0.2,
    max_seq_len=1000
).to(device)

# Training parameters
optimizer_student8 = optim.AdamW(student8_model.parameters(), lr=1e-4, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

# Distillation hyperparameters
temperature = 4.0
alpha = 0.5

# Get overlapping indices between Student 10 (teacher) and Student 8 (student)
teacher_overlap, student_overlap = get_overlapping_indices(student10_dataset.label2idx, student8_dataset.label2idx)
print("\nOverlapping classes for KD between Student 10 and Student 8:")
for label in student10_dataset.label2idx:
    if label in student8_dataset.label2idx:
        print(label)
print("Number of overlapping classes:", len(teacher_overlap))

# List to store test accuracy per epoch
student8_test_acc_list = []

# Training loop
student8_epochs = 100
best_student8_acc = 0.0
best_student8_state = None

print("\nStarting Student 8 training with distillation from Student 10...")
for epoch in range(1, student8_epochs + 1):
    student8_model.train()
    total_loss = 0.0
    for fwd, rev, labels in student8_train_loader:
        fwd, rev, labels = fwd.to(device), rev.to(device), labels.to(device)
        optimizer_student8.zero_grad()
        student_logits = student8_model(fwd, rev)
        with torch.no_grad():
            teacher_logits = student10_model(fwd, rev)
        ce_loss = criterion(student_logits, labels)
        kd_loss = distillation_loss(student_logits, teacher_logits, student_overlap, teacher_overlap, temperature)
        loss = alpha * kd_loss + (1 - alpha) * ce_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(student8_model.parameters(), max_norm=1.0)
        optimizer_student8.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(student8_train_loader)
    
    # Evaluate Student 8 accuracy
    student8_model.eval()
    correct = 0
    total = 0
    for fwd, rev, labels in student8_test_loader:
        fwd, rev, labels = fwd.to(device), rev.to(device), labels.to(device)
        preds = torch.argmax(student8_model(fwd, rev), dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    test_acc = 100.0 * correct / total
    student8_test_acc_list.append(test_acc)
    
    print(f"[Student 8] Epoch {epoch}/{student8_epochs} | Loss: {avg_loss:.4f} | Test Acc: {test_acc:.2f}%")
    
    if test_acc > best_student8_acc:
        best_student8_acc = test_acc
        best_student8_state = student8_model.state_dict()

# Load the best model state
student8_model.load_state_dict(best_student8_state)
for param in student8_model.parameters():
    param.requires_grad = False
student8_model.eval()
print(f"\nHighest Student 8 Test Accuracy: {best_student8_acc:.2f}%")

# Stage 3: Student 7 Training with Distillation
Filter the dataset for min_count >= 7, prepare paired data, and train Student 7 with KD from Student 8, including the computation of overlapping indices.

In [None]:
# Stage 3: Student 7 Training with Distillation

# Filter classes with min_count >= 7
student7_data = filter_classes(raw_data, min_count=7)

# Create train-test split
student7_train_data, student7_test_data = create_train_test_split(student7_data)

# Create paired data
student7_paired_train = create_paired_data(student7_train_data)
student7_paired_test = create_paired_data(student7_test_data)

# Create datasets
student7_dataset = TwoFastaKmerDataset(student7_paired_train, vocab, k=6)
student7_test_dataset = TwoFastaKmerDataset(student7_paired_test, vocab, k=6)

# Print class information
print("\nStudent 7 classes (min_count>=7):")
for cls in student7_dataset.label2idx:
    print(cls)
print("Number of Student 7 classes:", student7_dataset.get_num_classes())

# Create data loaders
student7_train_loader = DataLoader(student7_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn_two)
student7_test_loader = DataLoader(student7_test_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn_two)

# Model parameters
num_classes_student7 = student7_dataset.get_num_classes()

# Initialize model
student7_model = TwoViTDeepSEAFusionDNAClassifierWithFC(
    vocab_size=vocab_size,
    num_classes=num_classes_student7,
    embed_dim=128,
    d_model=256,
    num_conv_filters=(320, 480, 960),
    conv_kernel_sizes=(8, 8, 8),
    pool_kernel_sizes=(4, 4),
    num_transformer_layers=2,
    nhead=8,
    dropout=0.2,
    max_seq_len=1000
).to(device)

# Training parameters
optimizer_student7 = optim.AdamW(student7_model.parameters(), lr=1e-4, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

# Get overlapping indices between Student 8 (teacher) and Student 7 (student)
teacher_overlap_7, student_overlap_7 = get_overlapping_indices(student8_dataset.label2idx, student7_dataset.label2idx)
print("\nOverlapping classes for KD between Student 8 and Student 7:")
for label in student8_dataset.label2idx:
    if label in student7_dataset.label2idx:
        print(label)
print("Number of overlapping classes:", len(teacher_overlap_7))

# List to store test accuracy per epoch
student7_test_acc_list = []

# Training loop
student7_epochs = 100
best_student7_acc = 0.0
best_student7_state = None

print("\nStarting Student 7 training with distillation from Student 8...")
for epoch in range(1, student7_epochs + 1):
    student7_model.train()
    total_loss = 0.0
    for fwd, rev, labels in student7_train_loader:
        fwd, rev, labels = fwd.to(device), rev.to(device), labels.to(device)
        optimizer_student7.zero_grad()
        student_logits = student7_model(fwd, rev)
        with torch.no_grad():
            teacher_logits = student8_model(fwd, rev)
        ce_loss = criterion(student_logits, labels)
        kd_loss = distillation_loss(student_logits, teacher_logits, student_overlap_7, teacher_overlap_7, temperature)
        loss = alpha * kd_loss + (1 - alpha) * ce_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(student7_model.parameters(), max_norm=1.0)
        optimizer_student7.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(student7_train_loader)
    
    # Evaluate Student 7 accuracy
    student7_model.eval()
    correct = 0
    total = 0
    for fwd, rev, labels in student7_test_loader:
        fwd, rev, labels = fwd.to(device), rev.to(device), labels.to(device)
        preds = torch.argmax(student7_model(fwd, rev), dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    test_acc = 100.0 * correct / total
    student7_test_acc_list.append(test_acc)
    
    print(f"[Student 7] Epoch {epoch}/{student7_epochs} | Loss: {avg_loss:.4f} | Test Acc: {test_acc:.2f}%")
    
    if test_acc > best_student7_acc:
        best_student7_acc = test_acc
        best_student7_state = student7_model.state_dict()

# Load the best model state
student7_model.load_state_dict(best_student7_state)
for param in student7_model.parameters():
    param.requires_grad = False
student7_model.eval()
print(f"\nHighest Student 7 Test Accuracy: {best_student7_acc:.2f}%")

# Stage 4: Student 6 Training with Distillation
Using min_count >= 6, prepare the dataset and train Student 6 with distillation from Student 7. Implement similar KD loss and cross-entropy loss fusion.

In [None]:
# Stage 4: Student 6 Training with Distillation

# Filter classes with min_count >= 6
student6_data = filter_classes(raw_data, min_count=6)

# Create train-test split
student6_train_data, student6_test_data = create_train_test_split(student6_data)

# Create paired data
student6_paired_train = create_paired_data(student6_train_data)
student6_paired_test = create_paired_data(student6_test_data)

# Create datasets
student6_dataset = TwoFastaKmerDataset(student6_paired_train, vocab, k=6)
student6_test_dataset = TwoFastaKmerDataset(student6_paired_test, vocab, k=6)

# Print class information
print("\nStudent 6 classes (min_count>=6):")
for cls in student6_dataset.label2idx:
    print(cls)
print("Number of Student 6 classes:", student6_dataset.get_num_classes())

# Create data loaders
student6_train_loader = DataLoader(student6_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn_two)
student6_test_loader = DataLoader(student6_test_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn_two)

# Model parameters
num_classes_student6 = student6_dataset.get_num_classes()

# Initialize model
student6_model = TwoViTDeepSEAFusionDNAClassifierWithFC(
    vocab_size=vocab_size,
    num_classes=num_classes_student6,
    embed_dim=128,
    d_model=256,
    num_conv_filters=(320, 480, 960),
    conv_kernel_sizes=(8, 8, 8),
    pool_kernel_sizes=(4, 4),
    num_transformer_layers=2,
    nhead=8,
    dropout=0.2,
    max_seq_len=1000
).to(device)

# Training parameters
optimizer_student6 = optim.AdamW(student6_model.parameters(), lr=1e-4, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

# Get overlapping indices between Student 7 (teacher) and Student 6 (student)
teacher_overlap_6, student_overlap_6 = get_overlapping_indices(student7_dataset.label2idx, student6_dataset.label2idx)
print("\nOverlapping classes for KD between Student 7 and Student 6:")
for label in student7_dataset.label2idx:
    if label in student6_dataset.label2idx:
        print(label)
print("Number of overlapping classes:", len(teacher_overlap_6))

# List to store test accuracy per epoch
student6_test_acc_list = []

# Training loop
student6_epochs = 100
best_student6_acc = 0.0
best_student6_state = None

print("\nStarting Student 6 training with distillation from Student 7...")
for epoch in range(1, student6_epochs + 1):
    student6_model.train()
    total_loss = 0.0
    for fwd, rev, labels in student6_train_loader:
        fwd, rev, labels = fwd.to(device), rev.to(device), labels.to(device)
        optimizer_student6.zero_grad()
        student_logits = student6_model(fwd, rev)
        with torch.no_grad():
            teacher_logits = student7_model(fwd, rev)
        ce_loss = criterion(student_logits, labels)
        kd_loss = distillation_loss(student_logits, teacher_logits, student_overlap_6, teacher_overlap_6, temperature)
        loss = alpha * kd_loss + (1 - alpha) * ce_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(student6_model.parameters(), max_norm=1.0)
        optimizer_student6.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(student6_train_loader)
    
    # Evaluate Student 6 accuracy
    student6_model.eval()
    correct = 0
    total = 0
    for fwd, rev, labels in student6_test_loader:
        fwd, rev, labels = fwd.to(device), rev.to(device), labels.to(device)
        preds = torch.argmax(student6_model(fwd, rev), dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    test_acc = 100.0 * correct / total
    student6_test_acc_list.append(test_acc)
    
    print(f"[Student 6] Epoch {epoch}/{student6_epochs} | Loss: {avg_loss:.4f} | Test Acc: {test_acc:.2f}%")
    
    if test_acc > best_student6_acc:
        best_student6_acc = test_acc
        best_student6_state = student6_model.state_dict()

# Load the best model state
student6_model.load_state_dict(best_student6_state)
for param in student6_model.parameters():
    param.requires_grad = False
student6_model.eval()
print(f"\nHighest Student 6 Test Accuracy: {best_student6_acc:.2f}%")

# Stage 5: Improved Student 5 Training with Enhanced KD
For min_count >= 5, improve Student 5 training by tuning distillation hyperparameters (e.g., increasing the temperature, adjusting alpha to better balance KD and CE losses), adding additional training epochs or fine-tuning optimal learning rates, and potentially enhancing the model architecture (e.g., additional dropout adjustments or a modified final classification layer) to achieve >85% accuracy.

In [None]:
# Stage 5: Improved Student 5 Training with Enhanced KD

# Filter classes with min_count >= 5
student5_data = filter_classes(raw_data, min_count=5)

# Create train-test split
student5_train_data, student5_test_data = create_train_test_split(student5_data)

# Create paired data
student5_paired_train = create_paired_data(student5_train_data)
student5_paired_test = create_paired_data(student5_test_data)

# Create datasets
student5_dataset = TwoFastaKmerDataset(student5_paired_train, vocab, k=6)
student5_test_dataset = TwoFastaKmerDataset(student5_paired_test, vocab, k=6)

# Print class information
print("\nStudent 5 classes (min_count>=5):")
for cls in student5_dataset.label2idx:
    print(cls)
print("Number of Student 5 classes:", student5_dataset.get_num_classes())

# Create data loaders
student5_train_loader = DataLoader(student5_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn_two)
student5_test_loader = DataLoader(student5_test_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn_two)

# Model parameters
num_classes_student5 = student5_dataset.get_num_classes()

# Initialize model with enhanced architecture
student5_model = TwoViTDeepSEAFusionDNAClassifierWithFC(
    vocab_size=vocab_size,
    num_classes=num_classes_student5,
    embed_dim=128,
    d_model=256,
    num_conv_filters=(320, 480, 960),
    conv_kernel_sizes=(8, 8, 8),
    pool_kernel_sizes=(4, 4),
    num_transformer_layers=3,  # Increased number of transformer layers
    nhead=8,
    dropout=0.3,  # Increased dropout
    max_seq_len=1000
).to(device)

# Training parameters
optimizer_student5 = optim.AdamW(student5_model.parameters(), lr=5e-5, weight_decay=1e-4)  # Adjusted learning rate
criterion = nn.CrossEntropyLoss()

# Distillation hyperparameters
temperature = 6.0  # Increased temperature
alpha = 0.7  # Adjusted alpha for better balance

# Get overlapping indices between Student 6 (teacher) and Student 5 (student)
teacher_overlap_5, student_overlap_5 = get_overlapping_indices(student6_dataset.label2idx, student5_dataset.label2idx)
print("\nOverlapping classes for KD between Student 6 and Student 5:")
for label in student6_dataset.label2idx:
    if label in student5_dataset.label2idx:
        print(label)
print("Number of overlapping classes:", len(teacher_overlap_5))

# List to store test accuracy per epoch
student5_acc_list = []

# Training loop
student5_epochs = 150  # Increased number of epochs
best_student5_acc = 0.0
best_student5_state = None

print("\nStarting Student 5 training with enhanced KD from Student 6...")
for epoch in range(1, student5_epochs + 1):
    student5_model.train()
    total_loss = 0.0
    for fwd, rev, labels in student5_train_loader:
        fwd, rev, labels = fwd.to(device), rev.to(device), labels.to(device)
        optimizer_student5.zero_grad()
        student_logits = student5_model(fwd, rev)
        with torch.no_grad():
            teacher_logits = student6_model(fwd, rev)
        ce_loss = criterion(student_logits, labels)
        kd_loss = distillation_loss(student_logits, teacher_logits, student_overlap_5, teacher_overlap_5, temperature)
        loss = alpha * kd_loss + (1 - alpha) * ce_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(student5_model.parameters(), max_norm=1.0)
        optimizer_student5.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(student5_train_loader)
    
    # Evaluate Student 5 accuracy
    student5_model.eval()
    correct = 0
    total = 0
    for fwd, rev, labels in student5_test_loader:
        fwd, rev, labels = fwd.to(device), rev.to(device), labels.to(device)
        preds = torch.argmax(student5_model(fwd, rev), dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    test_acc = 100.0 * correct / total
    student5_acc_list.append(test_acc)
    
    print(f"[Student 5] Epoch {epoch}/{student5_epochs} | Loss: {avg_loss:.4f} | Test Acc: {test_acc:.2f}%")
    
    if test_acc > best_student5_acc:
        best_student5_acc = test_acc
        best_student5_state = student5_model.state_dict()

# Load the best model state
student5_model.load_state_dict(best_student5_state)
for param in student5_model.parameters():
    param.requires_grad = False
student5_model.eval()
print(f"\nHighest Student 5 Test Accuracy: {best_student5_acc:.2f}% at final epoch")

# Plot Accuracy Curves
Plot and compare test accuracy curves for Student 10 through Student 5 using matplotlib to visualize improvements and overall performance.

In [None]:
# Plot Accuracy Curves

plt.figure(figsize=(10, 6))
plt.plot(range(1, len(teacher_test_acc_list) + 1), teacher_test_acc_list, label="Student 10 Test Accuracy", marker="o")
plt.plot(range(1, len(student8_test_acc_list) + 1), student8_test_acc_list, label="Student 8 Test Accuracy", marker="o")
plt.plot(range(1, len(student7_test_acc_list) + 1), student7_test_acc_list, label="Student 7 Test Accuracy", marker="o")
plt.plot(range(1, len(student6_test_acc_list) + 1), student6_test_acc_list, label="Student 6 Test Accuracy", marker="o")
plt.plot(range(1, len(student5_acc_list) + 1), student5_acc_list, label="Student 5 Test Accuracy", marker="o")
plt.xlabel("Epoch")
plt.ylabel("Test Accuracy (%)")
plt.title("Test Accuracy Curves for All Stages")
plt.legend()
plt.show()