In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from typing import Tuple

In [3]:
class fp_CNN_Encoder(nn.Module):
    
    def __init__(self, fp_dim = 2048, hidden_channels = (64, 128), embed_dim = 256, proj_dim = 120, use_projection = True, batchnorm_safe = True):
        super().__init__()
        c1, c2 = hidden_channels

        # convolution stack
        self.conv = nn.Sequential(
            nn.Conv1d(in_channels = 1, out_channels = c1, kernel_size = 5, padding = 2),
            nn.BatchNorm1d(num_features = c1),
            nn.ReLU(inplace = True),
            nn.Conv1d(in_channels = c1, out_channels = c2, kernel_size = 5, padding = 2),
            nn.BatchNorm1d(num_features = c2),
            nn.ReLU(inplace = True),
            nn.AdaptiveMaxPool1d(1), # collapse length to 1
            )

        # encoder head
        self.fc = nn.Linear(in_features = c2, out_features = embed_dim)

        # projection head
        self.use_projection = use_projection
        self.batchnorm_safe = batchnorm_safe
        if self.use_projection:
            if self.batchnorm_safe:
                # LayerNorm works with batch_size=1
                norm_layer = nn.LayerNorm(embed_dim)
            else:
                # BatchNorm1d is better if you always train with batch_size > 1
                norm_layer = nn.BatchNorm1d(embed_dim)

            self.proj = nn.Sequential(
                nn.Linear(embed_dim, embed_dim),
                nn.ReLU(inplace=True),
                norm_layer,
                nn.Linear(embed_dim, proj_dim),
            )

    def forward(self, x):
        # x: [B, fp_dim] or [B, 1, fp_dim]
        if x.dim() == 2:
            x = x.unsqueeze(1) # add channel dim, [B, 1, fp_dim]

        h = self.conv(x).squeeze(-1) # [B, c2, 1] -> [B, c2]
        g = F.normalize(self.fc(h), dim = -1) # [B, embed_dim], normalized embedding

        if self.use_projection:
            z = F.normalize(self.proj(g), dim = -1)
            return g, z
        else:
            return g


In [4]:
class NPZFingerprints(Dataset):
    """
    Dataset for loading precomputed fingerprints from a .npz file.
    """
    def __init__(self, npz_path: str, dtype = torch.float32, normalize = False):
        z = np.load(npz_path, mmap_mode='r')
        self.fps = z["fps"]
        self.labels = z["labels"]
        self.N, self.D = self.fps.shape
        self.dtype = dtype
        self.normalize = normalize
        if normalize:
            # compute per-feature mean/std if requested
            arr = np.asarray(self.fps, dtype=np.float32)
            self.mean = arr.mean(axis=0)
            self.std = arr.std(axis=0) + 1e-8 # avoid div-by-zero

    def __len__(self) -> int:
        return self.N

    def __getitem__(self, idx: int):
        x = np.asarray(self.fps[idx], dtype=np.float32)
        if self.normalize:
            x = (x - self.mean) / self.std
        y = int(self.labels[idx])
        return torch.as_tensor(x, dtype=self.dtype), torch.as_tensor(y, dtype=torch.long)