In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import pickle
import obonet

In [2]:
import sys
def print_size(data):
    print(f'{sys.getsizeof(data) / (1024 * 1024):.2f} MB')

In [3]:
def pad_terms_with_neighbors(terms, go_graph, go_embeds, max_terms=256):
    """
    Pad a list of GO terms with neighboring terms from the GO graph.
    
    Args:
        terms: List of GO term IDs
        go_graph: networkx graph of GO ontology
        go_embeds: Dictionary of GO embeddings (to check if term has embedding)
        max_terms: Maximum number of terms to return
    
    Returns:
        List of GO terms padded to max_terms
    """
    padded_terms = list(terms)
    
    if len(padded_terms) >= max_terms:
        return padded_terms[:max_terms]
    
    # Use a set for faster lookup
    terms_set = set(padded_terms)
    
    # Collect neighbors from existing terms
    candidates = set()
    for term in padded_terms:
        if term in go_graph:
            # Get parents (is_a relationships)
            if 'is_a' in go_graph.nodes[term]:
                parents = go_graph.nodes[term].get('is_a', [])
                if isinstance(parents, str):
                    parents = [parents]
                candidates.update(parents)
            
            # Get successors (children) and predecessors (parents)
            candidates.update(go_graph.successors(term))
            candidates.update(go_graph.predecessors(term))
    
    # Remove terms already in the list
    candidates = candidates - terms_set
    
    # Filter candidates to only those with embeddings
    candidates = [c for c in candidates if c in go_embeds]
    
    # Add candidates until we reach max_terms
    for candidate in candidates:
        if len(padded_terms) >= max_terms:
            break
        padded_terms.append(candidate)
        terms_set.add(candidate)
    
    # If still not enough, try neighbors of neighbors
    if len(padded_terms) < max_terms:
        second_level_candidates = set()
        for term in candidates[:100]:  # Limit to avoid too much computation
            if term in go_graph:
                if 'is_a' in go_graph.nodes[term]:
                    parents = go_graph.nodes[term].get('is_a', [])
                    if isinstance(parents, str):
                        parents = [parents]
                    second_level_candidates.update(parents)
                second_level_candidates.update(go_graph.successors(term))
                second_level_candidates.update(go_graph.predecessors(term))
        
        second_level_candidates = second_level_candidates - terms_set
        second_level_candidates = [c for c in second_level_candidates if c in go_embeds]
        
        for candidate in second_level_candidates:
            if len(padded_terms) >= max_terms:
                break
            padded_terms.append(candidate)
    
    return padded_terms

def pad_dataframe_terms(seq_2_terms_df, go_graph, go_embeds, max_terms=256):
    """
    Pad the terms_predicted column in the dataframe in-place using GO graph neighbors.
    
    Args:
        seq_2_terms_df: DataFrame with 'terms_predicted' column
        go_graph: networkx graph of GO ontology
        go_embeds: Dictionary of GO embeddings
        max_terms: Maximum number of terms per row
    """
    print("Padding terms_predicted with GO graph neighbors...")
    seq_2_terms_df['terms_predicted'] = seq_2_terms_df['terms_predicted'].apply(
        lambda terms: pad_terms_with_neighbors(terms, go_graph, go_embeds, max_terms)
    )
    print(f"Padding complete. Average terms per row: {seq_2_terms_df['terms_predicted'].apply(len).mean():.2f}")
    return seq_2_terms_df

def prepare_data(data_paths, max_terms=256, aspect=None):
    
    knn_terms_df = data_paths['knn_terms_df']
    train_terms_df = data_paths['train_terms_df']
    features_embeds_path = data_paths['features_embeds_path']
    features_ids_path = data_paths['features_ids_path']

    go_embeds_paths = data_paths['go_embeds_paths']

    seq_2_terms = pd.read_parquet(knn_terms_df, engine='fastparquet')
    train_terms = pd.read_csv(train_terms_df, sep='\t')

    term_to_aspect = train_terms.groupby('term')['aspect'].first().to_dict()

    go_graph = obonet.read_obo(data_paths['go_obo_path'])
        
    with open(go_embeds_paths, 'rb') as f:
        data = pickle.load(f)
        embeddings_dict = data['embeddings']
        go_ids = data['go_ids']

    # Filter to keep only terms from a specific aspect if aspect is provided
    if aspect is not None:
        seq_2_terms['terms_predicted'] = seq_2_terms['terms_predicted'].apply(
            lambda terms: [t for t in terms if term_to_aspect.get(t) == aspect]
        )
        seq_2_terms['terms_true'] = seq_2_terms['terms_true'].apply(
            lambda terms: [t for t in terms if term_to_aspect.get(t) == aspect]
        )
        # Remove rows where terms_predicted or terms_true is now empty
        seq_2_terms = seq_2_terms[seq_2_terms['terms_predicted'].apply(len) > 0]
        seq_2_terms = seq_2_terms[seq_2_terms['terms_true'].apply(len) > 0]


    features_embeds = np.load(features_embeds_path, allow_pickle=True)
    features_ids = np.load(features_ids_path, allow_pickle=True)

    features_embeds_dict = {feat_id: embed for feat_id, embed in zip(features_ids, features_embeds)}

    # Pad terms_predicted in the dataframe with GO graph neighbors
    seq_2_terms = pad_dataframe_terms(seq_2_terms, go_graph, embeddings_dict, max_terms=max_terms)
    #remove seq_2_terms row for which len(predicted_terms == 0)

    term_lengths = seq_2_terms['terms_predicted'].apply(len)

    #currently only using sequences with 256 terms, need to change later 
    seq_2_terms = seq_2_terms[term_lengths == max_terms]

    train_ids =  pd.DataFrame(features_ids, columns=['qseqid'])
    seq_2_terms = seq_2_terms.merge(train_ids, on='qseqid', how='inner')    

    out = {'seq_2_terms': seq_2_terms,
           'train_terms': train_terms,
           'features_embeds': features_embeds_dict,
           'go_embeds': embeddings_dict,
           'go_graph': go_graph
           }
    return out
    

In [4]:
data_paths = {
    'knn_terms_df':         '/mnt/d/ML/Kaggle/CAFA6-new/uniprot/diamond_knn_predictions.parquet',
    'train_terms_df':       '/mnt/d/ML/Kaggle/CAFA6/cafa-6-protein-function-prediction/Train/train_terms.tsv',
    'go_obo_path':          '/mnt/d/ML/Kaggle/CAFA6/cafa-6-protein-function-prediction/Train/go-basic.obo',
    'features_embeds_path': '/mnt/d/ML/Kaggle/CAFA6-new/Dataset/esm_t33_650M/train_embeds.npy',
    'features_ids_path':    '/mnt/d/ML/Kaggle/CAFA6-new/Dataset/esm_t33_650M/train_ids.npy',
    'go_embeds_paths':      '/mnt/d/ML/Kaggle/CAFA6-new/uniprot/go_embeddings.pkl'
}
data = prepare_data(data_paths, max_terms=192, aspect='P')


Padding terms_predicted with GO graph neighbors...
Padding complete. Average terms per row: 185.86


In [None]:
# C max term 42
# F max terms 34
# P max terms 188

In [5]:
data['seq_2_terms'].shape

(50404, 3)

In [5]:
import networkx as nx

num_weak = nx.number_weakly_connected_components(data['go_graph'])
print(num_weak)

3


In [8]:
embed1 = data['go_embeds']['GO:0000001'] #P
embed2 = data['go_embeds']['GO:0000006'] #F

embed3 = data['go_embeds']['GO:0000785'] #C
print(np.dot(embed1, embed2))
print(np.dot(embed2, embed3))
print(np.dot(embed1, embed3))

-0.6220611
0.67442715
-0.42725486


In [6]:
# Check the distribution of terms_predicted lengths after padding
term_lengths = data['seq_2_terms']['terms_predicted'].apply(len)
print(f"Min terms: {term_lengths.min()}")
print(f"Max terms: {term_lengths.max()}")
print(f"Mean terms: {term_lengths.mean():.2f}")
print(f"Median terms: {term_lengths.median():.2f}")
print(f"No. empty predicted terms: {(term_lengths == 0).sum()}")
print(f"\nSample row with padded terms:")
print(f"Number of predicted terms: {len(data['seq_2_terms'].iloc[0]['terms_predicted'])}")

Min terms: 192
Max terms: 192
Mean terms: 192.00
Median terms: 192.00
No. empty predicted terms: 0

Sample row with padded terms:
Number of predicted terms: 192


In [7]:
# Check the distribution of terms_predicted lengths after padding
term_lengths = data['seq_2_terms']['terms_true'].apply(len)
print(f"Min terms: {term_lengths.min()}")
print(f"Max terms: {term_lengths.max()}")
print(f"Mean terms: {term_lengths.mean():.2f}")
print(f"Median terms: {term_lengths.median():.2f}")
print(f"No. empty predicted terms: {(term_lengths == 0).sum()}")
print(f"\nSample row with padded terms:")
print(f"Number of predicted terms: {len(data['seq_2_terms'].iloc[0]['terms_predicted'])}")

Min terms: 1
Max terms: 188
Mean terms: 4.57
Median terms: 3.00
No. empty predicted terms: 0

Sample row with padded terms:
Number of predicted terms: 192


In [9]:
row = data['seq_2_terms'].iloc[0]

feature_embed = data['features_embeds'][row['qseqid']]

go_embeds = np.array([data['go_embeds'][go_id] for go_id in row['terms_predicted']])
go_embeds.shape

(256, 512)

In [10]:
class EmbeddingsDataset(Dataset):
    """Dataset that yields raw embeddings; tokenization is done in collate_fn for batching."""
    def __init__(self, 
                 data, 
                 max_go_embeds = 256,  
                 oversample_indices=None
                ):
        
        self.data = data
        self.max_go_embeds = max_go_embeds
        self.oversample_indices = oversample_indices if oversample_indices is not None else list(range(len(self.data['seq_2_terms'])))
        self.mask_embed = np.zeros(next(iter(self.data['go_embeds'].values())).shape, dtype=np.float32)
        #ensure len of predicted go terms is less than max_go_embeds
        #self.data['seq_2_terms'] = self.data['seq_2_terms'][self


    def __len__(self):
        return len(self.oversample_indices)         

    def __getitem__(self, idx):
        sample_idx = self.oversample_indices[idx]

        row = self.data['seq_2_terms'].iloc[sample_idx]
        qseqid = row['qseqid']

        feature_embed = self.data['features_embeds'][qseqid]

        true_terms_set = set(row['terms_true'])
        predicted_terms = row['terms_predicted']
        
        # Filter terms that have embeddings (should be all of them after padding)
        # valid_terms = [term for term in predicted_terms if term in self.data['go_embeds']]
        valid_terms = predicted_terms
        # Vectorized operations using list comprehensions
        go_embeds = np.array([self.data['go_embeds'].get(term, self.mask_embed) for term in valid_terms])
        label = np.array([term in true_terms_set for term in valid_terms], dtype=np.float32)
        
        return {
            'entryID'   : qseqid,
            'feature'   : feature_embed,
            'go_embed'  : go_embeds,
            'label'     : label,
            'predicted_terms': valid_terms,
            'true_terms': row['terms_true']
        }


In [11]:
dataset = EmbeddingsDataset(data)

aa = dataset[1200]

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [19]:
data['features_embeds']['A0A023FBW4']

array([ 0.07045657, -0.00120772,  0.02064155, ..., -0.09901655,
       -0.0983785 ,  0.11195755], shape=(1280,))

In [23]:
def generate_orthogonal_vectors_qr(dimension, num_vectors=None):
    """
    Generates a set of orthogonal vectors using QR factorization.
    
    Args:
        dimension (int): The dimensionality of the vectors.
        num_vectors (int, optional): The number of vectors to generate. 
                                     Defaults to 'dimension' for a full basis.
    Returns:
        numpy.ndarray: A matrix where each column is an orthogonal vector.
    """
    if num_vectors is None:
        num_vectors = dimension
        
    # Generate a random matrix
    A = np.random.rand(dimension, num_vectors)
    
    # Perform QR factorization
    # The 'Q' matrix contains the orthogonal columns
    Q, R = np.linalg.qr(A)
    
    # Return the first 'num_vectors' columns of Q
    return Q[:, :num_vectors]




In [30]:
vectors = generate_orthogonal_vectors_qr(1280 ) #, num_vectors=512)

# Verification (dot product of any two distinct columns should be near zero):
print("\nDot product of the first two vectors:", np.dot(vectors[:, 0], vectors[:, 1]))
print(vectors.shape)


Dot product of the first two vectors: -1.734723475976807e-18
(1280, 1280)


In [29]:
vectors.shape

(1280, 512)

In [31]:
import torch
import torch.nn as nn

class EmbedTokenizer(nn.Module):
    def __init__(self, D, d, N, rng=None):  
        super(EmbedTokenizer, self).__init__()
        """
        Parameters
        ----------
        D : int
            Dimension of the embedding space.
        d : int
            Dimension of the token space.
        N : int
            Number of tokens.
        rng : np.random.Generator or None
            Random generator. If None, use default.
        """
        if rng is None:
            rng = np.random.default_rng()

        V = generate_orthogonal_vectors_qr(D)
        K = D
        # Build P as a single tensor of shape (N, d, D) and register as buffer so it's moved with .to()
        P_list = []
        for i in range(N):
            indices = np.arange(D)
            sampled_idx = np.random.choice(indices, size=d, replace=False)
            p = V[:, sampled_idx].T
            P_list.append(p)

        P_np = np.stack(P_list, axis=0).astype(np.float32)  # (N, d, D)
        P_tensor = torch.from_numpy(P_np)
        # register as buffer so it's not a parameter but moves with the module
        self.register_buffer('P_buffer', P_tensor)

    def forward(self, x):

        """
        Parameters
        ----------
        x : Tensor, shape (D)
            Input embeddings.

        Returns
        -------
        tokens : Tensor, shape (batch_size, N, d)
            Token representations.
        """
        # x: (batch_size, D) or (D,) -> ensure batch
        squeeze_output = False
        if x.dim() == 1:
            x = x.unsqueeze(0)
            squeeze_output = True

        # Get P and move to correct device/dtype
        P = self.P_buffer.to(dtype=x.dtype, device=x.device)  # (N, d, D)

        # Vectorized matmul: (batch_size, D) @ (D, N*d) -> (batch_size, N*d) -> reshape
        D = x.shape[1]
        P_2d = P.permute(2, 0, 1).reshape(D, -1)  # (D, N*d)
        tokens = torch.matmul(x, P_2d)  # (batch_size, N*d)
        tokens = tokens.reshape(x.shape[0], P.shape[0], P.shape[1])  # (batch_size, N, d)

        if squeeze_output:
            tokens = tokens.squeeze(0)

        return tokens


In [32]:
tokenizer = EmbedTokenizer(D=1280, d=512, N=128)

In [36]:
tokenizer(torch.tensor(aa['feature'], dtype=torch.float32).to(device).unsqueeze(0)).unsqueeze(0)

tensor([[[[-0.3775,  0.2384,  0.0141,  ...,  0.0190, -0.1658,  0.2868],
          [ 0.2066, -0.4773, -0.2407,  ..., -0.3053,  0.0404, -0.0880],
          [ 0.1087,  0.3423,  0.3232,  ..., -0.2379, -0.0897, -0.1857],
          ...,
          [ 0.1216, -0.1268, -0.1314,  ...,  0.0746, -0.1035,  0.2956],
          [-0.2053, -0.3775,  0.0284,  ...,  0.2384, -0.3817,  0.0690],
          [-0.0076, -0.1326,  0.1210,  ..., -0.3185, -0.3413, -0.1075]]]],
       device='cuda:0')

In [None]:
data['features_embeds'].shape

(1280,)

In [54]:
def collate_tokenize(batch, tokenizer, device=None, dtype=torch.float32):
    """Custom collate function to handle variable-length data.
    
    Args:
        batch: List of samples from the dataset
        tokenizer: Tokenizer to apply to features
        device: Device to move tensors to (cuda or cpu)
        dtype: Target dtype for tensors (torch.float32, torch.float16, or torch.bfloat16)
    """
    features = torch.stack([torch.from_numpy(item['feature']) for item in batch])
    features = features.to(dtype=dtype, device=device) if device else features.to(dtype=dtype)
    features = tokenizer(features)
    
    go_embed = torch.stack([torch.from_numpy(item['go_embed']) for item in batch])
    go_embed = go_embed.to(dtype=dtype, device=device) if device else go_embed.to(dtype=dtype)
    
    label = torch.stack([torch.from_numpy(item['label']) for item in batch])
    label = label.to(dtype=dtype, device=device) if device else label.to(dtype=dtype)
    
    return {
        'entryID': [item['entryID'] for item in batch],
        'feature': features,
        'go_embed': go_embed,
        'label': label,
        'predicted_terms': [item['predicted_terms'] for item in batch],
        'true_terms': [item['true_terms'] for item in batch]
    }

In [55]:
# Options: torch.float32 (default), torch.float16, torch.bfloat16
target_dtype = torch.float32  # Change this to torch.float16 or torch.bfloat16 if needed

dataloader = DataLoader(
    dataset, 
    batch_size=32, 
    shuffle=True, 
    num_workers=0, 
    collate_fn=lambda b: collate_tokenize(b, tokenizer, device, dtype=target_dtype)
)

In [None]:
class PrefetchLoader:
    """
    Prefetch loader that loads batches asynchronously to GPU for faster training.
    Overlaps data transfer with computation by loading the next batch while 
    the model processes the current batch.
    """
    def __init__(self, dataloader, device):
        self.dataloader = dataloader
        self.device = device
        self.stream = torch.cuda.Stream() if device.type == 'cuda' else None
        
    def __iter__(self):
        if self.stream is not None:
            # CUDA prefetching
            return self._cuda_iter()
        else:
            # CPU fallback - no prefetching needed
            return iter(self.dataloader)
    
    def _cuda_iter(self):
        """Iterator with CUDA stream prefetching."""
        loader_iter = iter(self.dataloader)
        
        # Preload first batch
        try:
            with torch.cuda.stream(self.stream):
                next_batch = next(loader_iter)
                next_batch = self._to_device(next_batch)
        except StopIteration:
            return
        
        while True:
            # Wait for the prefetch stream to finish
            torch.cuda.current_stream().wait_stream(self.stream)
            batch = next_batch
            
            # Make sure tensors are ready before yielding
            if isinstance(batch, dict):
                for k, v in batch.items():
                    if isinstance(v, torch.Tensor):
                        v.record_stream(torch.cuda.current_stream())
            
            # Start loading next batch in background
            try:
                with torch.cuda.stream(self.stream):
                    next_batch = next(loader_iter)
                    next_batch = self._to_device(next_batch)
            except StopIteration:
                yield batch
                break
                    
            yield batch
    
    def _to_device(self, batch):
        """Move batch to device (already moved in collate_fn, but ensure it's there)."""
        if isinstance(batch, dict):
            # Batch is already on device from collate_fn, just return it
            return batch
        return batch
    
    def __len__(self):
        return len(self.dataloader)

In [64]:
# Create prefetch loader
prefetch_loader = PrefetchLoader(dataloader, device)

print(f"DataLoader length: {len(prefetch_loader)}")
print(f"Device: {device}")

DataLoader length: 2300
Device: cuda


In [65]:
# Test prefetch loader - compare timing
import time

# Test without prefetch
print("Testing regular dataloader:")
start = time.time()
for i, batch in enumerate(dataloader):
    if i >= 5:  # Just test 5 batches
        break
    # Simulate some processing
    _ = batch['feature'].sum()
regular_time = time.time() - start
print(f"Regular loader time (5 batches): {regular_time:.4f}s")

# Test with prefetch
print("\nTesting prefetch loader:")
start = time.time()
for i, batch in enumerate(prefetch_loader):
    if i >= 5:  # Just test 5 batches
        break
    # Simulate some processing
    _ = batch['feature'].sum()
prefetch_time = time.time() - start
print(f"Prefetch loader time (5 batches): {prefetch_time:.4f}s")
print(f"Speedup: {regular_time/prefetch_time:.2f}x")

Testing regular dataloader:
Regular loader time (5 batches): 0.9385s

Testing prefetch loader:
Prefetch loader time (5 batches): 0.7915s
Speedup: 1.19x


In [66]:
# Example: How to use prefetch loader in a training loop
# 
# # Old way (slower):
# for epoch in range(num_epochs):
#     for batch in dataloader:
#         outputs = model(batch['go_embed'], batch['feature'])
#         loss = criterion(outputs, batch['label'])
#         loss.backward()
#         optimizer.step()
#
# # New way (faster):
# for epoch in range(num_epochs):
#     for batch in prefetch_loader:  # <-- Just replace dataloader with prefetch_loader
#         outputs = model(batch['go_embed'], batch['feature'])
#         loss = criterion(outputs, batch['label'])
#         loss.backward()
#         optimizer.step()

print("Usage: Simply replace 'dataloader' with 'prefetch_loader' in your training loop!")
print("The prefetch loader will automatically overlap data loading with GPU computation.")

Usage: Simply replace 'dataloader' with 'prefetch_loader' in your training loop!
The prefetch loader will automatically overlap data loading with GPU computation.


In [56]:
batch =  next(iter(dataloader))
batch['label']

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')

In [None]:
batch

In [44]:
from Model.Query2Label import Query2Label


model = Query2Label(num_classes = 256,
                    in_dim = 512,
                    nheads = 8,
                    num_encoder_layers = 1,
                    num_decoder_layers = 2,
                    dim_feedforward = 2048,
                    dropout = 0.1,
                    use_positional_encoding = True,
                )


In [47]:
model = model.to(device)

In [57]:
batch['go_embed'].dtype

torch.float32

In [58]:
batch['feature'].dtype

torch.float32

In [60]:
model_output = model(batch['go_embed'], batch['feature'])

In [62]:
model_output.shape

torch.Size([32, 256])