In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import random
import copy

# --------------------------
# --- FIX SEEDS & DEVICE ---
# --------------------------
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------------------------
# --- DATASET ---
# --------------------------
class MotifDatasetEmbed8(Dataset):
    def __init__(self, num_samples=5000, seq_len=100, embed_dim=8, motif_len=3, num_motifs=5, seed=SEED):
        super().__init__()
        self.num_samples = num_samples
        self.seq_len = seq_len
        self.embed_dim = embed_dim
        self.motif_len = motif_len
        self.rng = np.random.default_rng(seed)

        # Create several motif variants
        self.motifs = [self.rng.normal(loc=1.0, scale=0.5, size=(motif_len, embed_dim)) for _ in range(num_motifs)]
        self.pos_encoding = self.get_positional_encoding(seq_len, embed_dim)

        self.data, self.labels = self._generate_dataset()

    def get_positional_encoding(self, seq_len, embed_dim):
        pe = np.zeros((seq_len, embed_dim))
        for pos in range(seq_len):
            for i in range(0, embed_dim, 2):
                pe[pos, i] = np.sin(pos / (10000 ** ((2 * i)/embed_dim)))
                if i + 1 < embed_dim:
                    pe[pos, i+1] = np.cos(pos / (10000 ** ((2 * i)/embed_dim)))
        return torch.tensor(pe, dtype=torch.float32)

    def _generate_dataset(self):
        data, labels = [], []
        for _ in range(self.num_samples):
            seq = self.rng.normal(loc=0.0, scale=1.0, size=(self.seq_len, self.embed_dim))

            if random.random() < 0.5:
                motif = random.choice(self.motifs)
                start = random.randint(0, self.seq_len - self.motif_len)
                seq[start:start+self.motif_len] = motif + self.rng.normal(0, 0.1, motif.shape)
                label = 1
            else:
                label = 0

            seq += self.pos_encoding.numpy()
            data.append(seq)
            labels.append(label)

        return torch.tensor(np.stack(data), dtype=torch.float32), torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# --------------------------
# --- DATALOADERS ---
# --------------------------
train_size, val_size, test_size = 5000, 1000, 1000
full_dataset = MotifDatasetEmbed8(num_samples=train_size+val_size, seed=SEED)
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
test_dataset = MotifDatasetEmbed8(num_samples=test_size, seed=SEED+1)  # different seed for test set

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# --------------------------
# --- MODELS ---
# --------------------------
class MLPClassifierBinary(nn.Module):
    def __init__(self, seq_len=100, embed_dim=8, hidden_dim=64, seed=SEED):
        super().__init__()
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)

        self.flatten = nn.Flatten()
        self.fc = nn.Sequential(
            nn.Linear(seq_len * embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        return self.fc(self.flatten(x))

class TransformerClassifierBinary(nn.Module):
    def __init__(self, embed_dim=8, num_heads=1, num_layers=1, seq_len=100, seed=SEED):
        super().__init__()
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)

        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(embed_dim, 1)

    def forward(self, x):
        x = self.transformer(x)
        x = x.mean(dim=1)
        return self.fc(x)

# --------------------------
# --- TRAINING FUNCTION ---
# --------------------------
def train(model, train_loader, val_loader, test_loader, epochs=10):
    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters())

    best_val_acc = 0.0
    best_threshold = 0.5
    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(1, epochs + 1):
        # Train
        model.train()
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device).float().unsqueeze(1)
            optimizer.zero_grad()
            loss = criterion(model(X_batch), y_batch)
            loss.backward()
            optimizer.step()

        # Validate
        model.eval()
        val_outputs, val_labels = [], []
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device).float().unsqueeze(1)
                val_outputs.append(model(X_batch))
                val_labels.append(y_batch)
        val_outputs = torch.cat(val_outputs)
        val_labels = torch.cat(val_labels)

        # Threshold search
        thresholds = torch.linspace(0.1, 0.9, steps=81)
        val_accs = [( (torch.sigmoid(val_outputs) > t).long() == val_labels.long() ).float().mean().item() for t in thresholds]
        best_idx = torch.tensor(val_accs).argmax()
        best_epoch_acc = val_accs[best_idx]
        best_epoch_threshold = thresholds[best_idx].item()

        if best_epoch_acc > best_val_acc:
            best_val_acc = best_epoch_acc
            best_threshold = best_epoch_threshold
            best_model_wts = copy.deepcopy(model.state_dict())

        print(f"Epoch {epoch}: Val Accuracy = {best_epoch_acc:.3f}, Best Threshold = {best_epoch_threshold:.2f}")

    # Test
    model.load_state_dict(best_model_wts)
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device).float().unsqueeze(1)
            preds = (torch.sigmoid(model(X_batch)) > best_threshold).long()
            correct += (preds.squeeze() == y_batch.squeeze().long()).sum().item()
            total += y_batch.size(0)
    test_acc = correct / total
    print(f"\nTest Accuracy = {test_acc:.3f} using threshold {best_threshold:.2f}\n")

# --------------------------
# --- RUN TRAINING ---
# --------------------------
mlp = MLPClassifierBinary(seq_len=100, embed_dim=8, hidden_dim=64)
transformer = TransformerClassifierBinary(embed_dim=8, num_heads=1, num_layers=1, seq_len=100)

print("Training MLP:")
train(mlp, train_loader, val_loader, test_loader, epochs=10)

print("Training Transformer:")
train(transformer, train_loader, val_loader, test_loader, epochs=10)
