In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from typing import Tuple

In [3]:
class fp_CNN_Encoder(nn.Module):
    
    def __init__(self, fp_dim = 2048, hidden_channels = (64, 128), embed_dim = 256, proj_dim = 120, use_projection = True, batchnorm_safe = True):
        super().__init__()
        c1, c2 = hidden_channels

        # convolution stack
        self.conv = nn.Sequential(
            nn.Conv1d(in_channels = 1, out_channels = c1, kernel_size = 5, padding = 2),
            nn.BatchNorm1d(num_features = c1),
            nn.ReLU(inplace = True),
            nn.Conv1d(in_channels = c1, out_channels = c2, kernel_size = 5, padding = 2),
            nn.BatchNorm1d(num_features = c2),
            nn.ReLU(inplace = True),
            nn.AdaptiveMaxPool1d(1), # collapse length to 1
            )

        # encoder head
        self.fc = nn.Linear(in_features = c2, out_features = embed_dim)

        # projection head
        self.use_projection = use_projection
        self.batchnorm_safe = batchnorm_safe
        if self.use_projection:
            if self.batchnorm_safe:
                # LayerNorm works with batch_size=1
                norm_layer = nn.LayerNorm(embed_dim)
            else:
                # BatchNorm1d is better if you always train with batch_size > 1
                norm_layer = nn.BatchNorm1d(embed_dim)

            self.proj = nn.Sequential(
                nn.Linear(embed_dim, embed_dim),
                nn.ReLU(inplace=True),
                norm_layer,
                nn.Linear(embed_dim, proj_dim),
            )

    def forward(self, x):
        # x: [B, fp_dim] or [B, 1, fp_dim]
        if x.dim() == 2:
            x = x.unsqueeze(1) # add channel dim, [B, 1, fp_dim]

        h = self.conv(x).squeeze(-1) # [B, c2, 1] -> [B, c2]
        g = F.normalize(self.fc(h), dim = -1) # [B, embed_dim], normalized embedding

        if self.use_projection:
            z = F.normalize(self.proj(g), dim = -1)
            return g, z
        else:
            return g


In [4]:
class NPZFingerprints(Dataset):
    """
    Dataset for loading precomputed fingerprints from a .npz file.
    """
    def __init__(self, npz_path: str, dtype = torch.float32, normalize = False):
        z = np.load(npz_path, mmap_mode='r')
        self.fps = z["fps"]
        self.labels = z["labels"]
        self.N, self.D = self.fps.shape
        self.dtype = dtype
        self.normalize = normalize
        if normalize:
            # compute per-feature mean/std if requested
            arr = np.asarray(self.fps, dtype=np.float32)
            self.mean = arr.mean(axis=0)
            self.std = arr.std(axis=0) + 1e-8 # avoid div-by-zero

    def __len__(self) -> int:
        return self.N

    def __getitem__(self, idx: int):
        x = np.asarray(self.fps[idx], dtype=np.float32)
        if self.normalize:
            x = (x - self.mean) / self.std
        y = int(self.labels[idx])
        return torch.as_tensor(x, dtype=self.dtype), torch.as_tensor(y, dtype=torch.long)

In [5]:
class SupConLoss(nn.Module):
    """
    Supervised Contrastive Learning Loss (Khosla et al., 2020)
    Operates on normalized embeddings; no augmentations needed.
    All samples sharing the same label in a batch are considered positives.
    """
    def __init__(self, temperature: float = 0.1):
        super().__init__()
        self.tau = temperature

    def forward(self, z: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        z: [B, d] normalized projections
        labels: [B] int labels
        """
        B = z.size(0)
        sim = z @ z.t() / self.tau  # cosine sims since z normalized

        # masks
        eye = torch.eye(B, dtype=torch.bool, device=z.device)
        labels = labels.view(-1, 1)
        pos_mask = (labels == labels.t()) & (~eye)   # same-class pairs
        all_mask = ~eye                              # all except self

        # log prob over all others
        logits = sim
        denom = torch.logsumexp(logits.masked_fill(~all_mask, -1e9), dim=1, keepdim=True)
        log_prob = logits - denom

        # average over positives per anchor
        pos_log_prob = (pos_mask * log_prob).sum(1) / (pos_mask.sum(1) + 1e-9)
        loss = -pos_log_prob.mean()
        return loss


In [6]:
def make_weighted_sampler(labels_np: np.ndarray):
    """
    Balances batches by inverse class frequency
    """
    counts = np.array([np.sum(labels_np == 0), np.sum(labels_np == 1)], dtype=np.float64)
    w_per_class = 1.0 / (counts + 1e-12)
    weights = np.array([w_per_class[y] for y in labels_np], dtype=np.float64)
    return WeightedRandomSampler(weights=torch.from_numpy(weights).double(),
                                 num_samples=len(labels_np), replacement=True)

In [7]:
def seed_all(seed: int):
    import random
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

In [8]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score

@torch.no_grad()
def evaluate_val(val_loader, encoder, device):
    encoder.eval()
    all_g, all_y = [], []
    
    for xb, yb in val_loader:
        xb = xb.to(device, non_blocking=True)
        _, _ = encoder(xb) if encoder.use_projection else encoder(xb)
        g = encoder(xb)[0] if encoder.use_projection else encoder(xb)  # [B, embed_dim]
        all_g.append(g.cpu())
        all_y.append(yb)

    X = torch.cat(all_g, dim=0).numpy()
    y = torch.cat(all_y, dim=0).numpy()

    # simple linear probe
    clf = LogisticRegression(max_iter=1000, class_weight="balanced")
    clf.fit(X, y)
    y_pred = clf.predict_proba(X)[:,1]

    auprc = average_precision_score(y, y_pred)
    return auprc


In [None]:
def train(args):
    seed_all(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else "cpu")

    train_ds = NPZFingerprints(args.train_npz, dtype=torch.float32, normalize=args.normalize)
    val_loader = None
    if args.val_npz and os.path.exists(args.val_npz):
        val_ds = NPZFingerprints(args.val_npz, dtype=torch.float32, normalize=args.normalize)
        val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)

    # weighted sample to ensure PKSs appear regularly in batches
    labels_np = train_ds.labels.astype(np.int64)
    sampler = make_weighted_sampler(labels_np) if args.balance else None
    train_loader = DataLoader(train_ds,
                              batch_size=args.batch_size,
                              sampler=sampler,
                              shuffle=(sampler is None),
                              num_workers=args.num_workers,
                              pin_memory=True,
                              drop_last=True)

    # model & loss
    encoder = fp_CNN_Encoder(fp_dim = args.fp_dim,
                             hidden_channels = (args.c1, args.c2),
                             embed_dim = args.embed_dim,
                             proj_dim = args.proj_dim,
                             use_projection = args.use_projection,
                             batchnorm_safe = args.batchnorm_safe).to(device)

    criterion = SupConLoss(temperature=args.temperature).to(device)
    optimizer = torch.optim.AdamW(encoder.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == 'cuda'))

    os.makedirs(args.out_dir, exist_ok=True)
    best_monitor = -1.0
    patience = 0

    # training loop
    for epoch in range(1, args.epochs + 1):
        encoder.train()
        epoch_loss, steps = 0.0, 0
        t0 = time.time()

        for xb, yb in train_loader:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)

            with torch.cuda.amp.autocast(enabled=(device.type == 'cuda')):
                # no augmentations
                _, z = encoder(xb) 
                loss = criterion(z, yb)

            optimizer.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()

            if args.grad_clip > 0:
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(encoder.parameters(), args.grad_clip)
            scaler.step(optimizer); scaler.update()

            epoch_loss += loss.item(); steps += 1

        epoch_loss /= max(1, steps)
        msg = f"[Epoch {epoch:03d}] train_supcon={epoch_loss:.4f} time={time.time()-t0:.1f}s"

        if val_loader is not None:
            auprc = evaluate_val(val_loader, encoder, device)
            msg += f" val_AUPRC={auprc:.4f}"
        print(msg)

    torch.save({"encoder": encoder.state_dict(), "args": vars(args)},
               os.path.join(args.out_dir, "last_encoder.pt"))
    print(f"Saved to {args.out_dir}")