In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import pickle
from tqdm.auto import tqdm

In [3]:
def pad_with_random_terms(terms, max_terms, all_terms_array, existing_terms_set=None):
    """Optimized padding with random GO terms using rejection sampling."""
    if len(terms) >= max_terms:
        return terms[:max_terms]
    
    # Calculate how many more terms we need
    num_needed = max_terms - len(terms)
    
    # Use pre-computed set if available, otherwise create it
    if existing_terms_set is None:
        existing_terms_set = set(terms)
    
    # For small num_needed relative to available terms, use rejection sampling
    # This avoids creating the full boolean mask
    if num_needed < len(all_terms_array) * 0.1:  # If we need < 10% of total terms
        random_terms = []
        max_attempts = num_needed * 10  # Prevent infinite loop
        attempts = 0
        
        while len(random_terms) < num_needed and attempts < max_attempts:
            # Sample with replacement first (fast)
            batch_size = min(num_needed * 2, 1000)  # Sample in batches
            candidates = np.random.choice(all_terms_array, size=batch_size, replace=True)
            
            # Filter out existing terms
            for candidate in candidates:
                if candidate not in existing_terms_set:
                    random_terms.append(candidate)
                    existing_terms_set.add(candidate)
                    if len(random_terms) >= num_needed:
                        break
            
            attempts += batch_size
        
        # If rejection sampling didn't get enough, fall back to full filtering
        if len(random_terms) < num_needed:
            available_mask = np.isin(all_terms_array, list(existing_terms_set), invert=True)
            available_terms = all_terms_array[available_mask]
            remaining_needed = num_needed - len(random_terms)
            
            if len(available_terms) >= remaining_needed:
                additional = np.random.choice(available_terms, size=remaining_needed, replace=False).tolist()
            else:
                additional = np.random.choice(available_terms, size=remaining_needed, replace=True).tolist()
            
            random_terms.extend(additional)
    else:
        # For large num_needed, use vectorized filtering (more efficient for bulk)
        available_mask = np.isin(all_terms_array, list(existing_terms_set), invert=True)
        available_terms = all_terms_array[available_mask]
        
        # Randomly sample
        if len(available_terms) >= num_needed:
            random_terms = np.random.choice(available_terms, size=num_needed, replace=False).tolist()
        else:
            # If not enough unique terms, sample with replacement
            random_terms = np.random.choice(available_terms, size=num_needed, replace=True).tolist()
    
    return terms + random_terms[:num_needed]

def load_data(data_paths, max_terms=256, aspect=None):
    
    seq_2_terms_df = data_paths['seq_2_terms_df']
    train_terms_df = data_paths['train_terms_df']
    plm_features_path = data_paths['plm_features_path']
    go_embeds_paths = data_paths['go_embeds_paths']

    seq_2_terms = pd.read_parquet(seq_2_terms_df, engine='fastparquet')
    train_terms = pd.read_csv(train_terms_df, sep='\t')

    term_to_aspect = train_terms.groupby('term')['aspect'].first().to_dict()
        
    with open(go_embeds_paths, 'rb') as f:
        data = pickle.load(f)
        embeddings_dict = data['embeddings']
        go_ids = data['go_ids']

    # Filter to keep only terms from a specific aspect if aspect is provided
    print('filtering by aspect:', aspect)
    if aspect is not None:
        seq_2_terms['terms_predicted'] = seq_2_terms['terms_predicted'].apply(
            lambda terms: [t for t in terms if term_to_aspect.get(t) == aspect]
        )
        seq_2_terms['terms_true'] = seq_2_terms['terms_true'].apply(
            lambda terms: [t for t in terms if term_to_aspect.get(t) == aspect]
        )
        # Remove rows where terms_predicted or terms_true is now empty
        seq_2_terms = seq_2_terms[seq_2_terms['terms_predicted'].apply(len) > 0]
        seq_2_terms = seq_2_terms[seq_2_terms['terms_true'].apply(len) > 0]

        # Pad terms with random terms from the same aspect
        print(f"Padding terms_predicted with random terms from aspect {aspect}...")
        # Get all terms from this aspect that have embeddings
        aspect_terms = [term for term, asp in term_to_aspect.items() if asp == aspect and term in embeddings_dict]
        all_aspect_terms = np.array(aspect_terms)
        
        tqdm.pandas(desc=f"Padding with random {aspect} terms")
        seq_2_terms['terms_predicted'] = seq_2_terms['terms_predicted'].progress_apply(
            lambda terms: pad_with_random_terms(terms, max_terms, all_aspect_terms)
        )
        
        # Verify padding
        term_lengths_after = seq_2_terms['terms_predicted'].apply(len)
        print(f"After padding - Min: {term_lengths_after.min()}, Max: {term_lengths_after.max()}, Mean: {term_lengths_after.mean():.2f}")

    

    plm_embeds_dict = np.load(plm_features_path, allow_pickle=True).item()

    # term_lengths = seq_2_terms['terms_predicted'].apply(len)

    print("filtering sequences by term lengths")
   #currently only using sequences with 256 terms, need to change later 
    # seq_2_terms = seq_2_terms[term_lengths == max_terms]

    features_ids = list(plm_embeds_dict.keys())
    train_ids =  pd.DataFrame(features_ids, columns=['qseqid'])
    seq_2_terms = seq_2_terms.merge(train_ids, on='qseqid', how='inner')    

    prot_2_pmid_path = data_paths['prot_2_pmid_path']
    pmid_2_embed_path = data_paths['pmid_2_embed_path']
    prot_2_pmid = np.load(prot_2_pmid_path, allow_pickle=True).item()
    pmid_2_embed = np.load(pmid_2_embed_path, allow_pickle=True).item()

    out = {'seq_2_terms': seq_2_terms,
           'plm_embeds': plm_embeds_dict,
           'prot_2_pmid': prot_2_pmid,
           'pmid_2_embed': pmid_2_embed,
           'go_embeds': embeddings_dict,
           }
    
    return out

In [4]:
data_paths = {
    'seq_2_terms_df':           '/mnt/d/ML/Kaggle/CAFA6-new/data_packet1/seq_2_terms.parquet',
    'train_terms_df':           '/mnt/d/ML/Kaggle/CAFA6/cafa-6-protein-function-prediction/Train/train_terms.tsv',
    "plm_features_path":        '/mnt/d/ML/Kaggle/CAFA6-new/data_packet1/plm_features.npy',
    'prot_2_pmid_path':         '/mnt/d/ML/Kaggle/CAFA6-new/data_packet1/prot_2_pmid.npy',
    "pmid_2_embed_path":        '/mnt/d/ML/Kaggle/CAFA6-new/data_packet1/pmid_2_embed.npy',
    'go_embeds_paths':          '/mnt/d/ML/Kaggle/CAFA6-new/data_packet1/go_embeddings.pkl'
}   

data = load_data(data_paths, max_terms=64, aspect='F')
data['seq_2_terms'].shape

filtering by aspect: F
Padding terms_predicted with random terms from aspect F...


Padding with random F terms:   0%|          | 0/57960 [00:00<?, ?it/s]

After padding - Min: 64, Max: 64, Mean: 64.00
filtering sequences by term lengths


(57960, 3)

In [8]:
# Check the distribution of terms_predicted lengths after padding
term_lengths = data['seq_2_terms']['terms_predicted'].apply(len)
print(f"Min terms: {term_lengths.min()}")
print(f"Max terms: {term_lengths.max()}")
print(f"Mean terms: {term_lengths.mean():.2f}")
print(f"Median terms: {term_lengths.median():.2f}")
print(f"No. empty predicted terms: {(term_lengths == 0).sum()}")
print(f"\nSample row with padded terms:")
print(f"Number of predicted terms: {len(data['seq_2_terms'].iloc[0]['terms_predicted'])}")

Min terms: 64
Max terms: 64
Mean terms: 64.00
Median terms: 64.00
No. empty predicted terms: 0

Sample row with padded terms:
Number of predicted terms: 64


## Dataset and Dataloader

In [25]:
class EmbeddingsDataset(Dataset):
    """Dataset that yields raw embeddings; tokenization is done in collate_fn for batching."""
    def __init__(self, 
                 data, 
                 max_go_embeds = 256,  
                 oversample_indices=None
                ):
        
        self.data = data
        self.max_go_embeds = max_go_embeds
        self.oversample_indices = oversample_indices if oversample_indices is not None else list(range(len(self.data['seq_2_terms'])))
        self.mask_embed = np.zeros(next(iter(self.data['go_embeds'].values())).shape, dtype=np.float32)
        self.plm_dim  = self.data['plm_embeds'][next(iter(self.data['plm_embeds']))].shape[0]
        self.blm_dim  = self.data['pmid_2_embed'][next(iter(self.data['pmid_2_embed']))].shape[0]

        print(f"PLM dim: {self.plm_dim}, BLM dim: {self.blm_dim}")
        #ensure len of predicted go terms is less than max_go_embeds
        #self.data['seq_2_terms'] = self.data['seq_2_terms'][self


    def __len__(self):
        return len(self.oversample_indices)         

    def __getitem__(self, idx):
        sample_idx = self.oversample_indices[idx]

        row = self.data['seq_2_terms'].iloc[sample_idx]
        qseqid = row['qseqid']

        plm_embed = self.data['plm_embeds'][qseqid]

        true_terms_set = set(row['terms_true'])
        predicted_terms = row['terms_predicted']
        
        # Filter terms that have embeddings (should be all of them after padding)
        # valid_terms = [term for term in predicted_terms if term in self.data['go_embeds']]
        valid_terms = predicted_terms
        # Vectorized operations using list comprehensions
        go_embeds = np.array([self.data['go_embeds'].get(term, self.mask_embed) for term in valid_terms])
        label = np.array([term in true_terms_set for term in valid_terms], dtype=np.float32)
        
        pmid_list = list(self.data['prot_2_pmid'].get(qseqid, []))

        #skip None embeddings and collect valid ones
        valid_blm_embeds = [self.data['pmid_2_embed'].get(pmid) for pmid in pmid_list if self.data['pmid_2_embed'].get(pmid) is not None]
        
        # Create blm_embeds array - if no valid embeddings, create empty array with correct shape
        if len(valid_blm_embeds) > 0:
            blm_embeds = np.vstack(valid_blm_embeds)
        else:
            # Create empty array with shape [0, blm_dim]
            blm_embeds = np.zeros((0, self.blm_dim), dtype=np.float32)
        

        return {
            'entryID'   :       qseqid,
            'plm_embed' :       plm_embed,
            'blm_embeds':       blm_embeds,
            'go_embed'  :       go_embeds,
            'label'     :       label,
            'predicted_terms':  valid_terms,
            'true_terms':       row['terms_true']
        }

In [31]:
dataset = EmbeddingsDataset(data, max_go_embeds=64)
sample = dataset[100]

PLM dim: 1280, BLM dim: 768


In [32]:
def collate_with_blm_projection(batch, blm_projection_layer, tokenizer, device=None, dtype=torch.float32, num_plm_tokens=32, num_blm_tokens=32):
    """
    Custom collate function that tokenizes PLM features to 32 tokens and 
    projects BLM features to 32 tokens, stacking them to get 64 total tokens.
    
    Args:
        batch: List of samples from the dataset
        blm_projection_layer: nn.Linear layer to project BLM features from their dim to model_dim (512)
        tokenizer: Tokenizer to apply to PLM features (projects to 32 tokens) - REQUIRED
        device: Device to move tensors to (cuda or cpu)
        dtype: Target dtype for tensors (torch.float32, torch.float16, or torch.bfloat16)
        num_plm_tokens: Number of PLM tokens (default: 32)
        num_blm_tokens: Number of BLM tokens (default: 32)
    
    Returns:
        Dictionary with:
            - entryID: List of entry IDs
            - features: Stacked PLM + BLM tokens [batch, 64, model_dim]
            - mask: Attention mask [batch, 64]
            - go_embed: GO embeddings [batch, num_terms, go_embed_dim]
            - label: Labels [batch, num_terms]
            - predicted_terms: List of predicted terms
            - true_terms: List of true terms
    """
    batch_size = len(batch)
    model_dim = blm_projection_layer.out_features
    
    # Process PLM embeddings - stack and convert to tensors
    plm_embeds = torch.stack([torch.from_numpy(item['plm_embed']) for item in batch])
    plm_embeds = plm_embeds.to(dtype=dtype)
    
    if device is not None:
        plm_embeds = plm_embeds.to(device)
    
    # Tokenize PLM embeddings to fixed number of tokens (32)
    plm_tokens = tokenizer(plm_embeds)  # [batch, 32, model_dim]
    
    # Process BLM embeddings - project to model_dim, cap at 32, and pad to 32
    blm_embeds_padded = torch.zeros(batch_size, num_blm_tokens, model_dim, dtype=dtype)
    blm_attention_mask = torch.zeros(batch_size, num_blm_tokens, dtype=torch.bool)
    
    if device is not None:
        blm_embeds_padded = blm_embeds_padded.to(device)
        blm_attention_mask = blm_attention_mask.to(device)
    
    for i, item in enumerate(batch):
        blm = torch.from_numpy(item['blm_embeds']).to(dtype=dtype)  # [num_tokens, blm_dim]
        
        # Cap at num_blm_tokens (32)
        actual_tokens = min(blm.shape[0], num_blm_tokens)
        if actual_tokens > 0:
            blm = blm[:actual_tokens]
            
            if device is not None:
                blm = blm.to(device)
            
            # Project to model_dim (512)
            with torch.no_grad():
                blm_projected = blm_projection_layer(blm)  # [actual_tokens, model_dim]
            
            # Place in padded tensor
            blm_embeds_padded[i, :actual_tokens] = blm_projected
            blm_attention_mask[i, :actual_tokens] = True
    
    # Stack PLM tokens (32) and BLM tokens (32) to get 64 tokens total
    features = torch.cat([plm_tokens, blm_embeds_padded], dim=1)  # [batch, 64, model_dim]
    
    # Create combined attention mask (PLM tokens are always valid, BLM may be padded)
    plm_attention_mask = torch.ones(batch_size, num_plm_tokens, dtype=torch.bool)
    if device is not None:
        plm_attention_mask = plm_attention_mask.to(device)
    
    mask = torch.cat([plm_attention_mask, blm_attention_mask], dim=1)  # [batch, 64]
    
    # Process GO embeddings and labels
    go_embed = torch.stack([torch.from_numpy(item['go_embed']) for item in batch])
    go_embed = go_embed.to(dtype=dtype)
    
    label = torch.stack([torch.from_numpy(item['label']) for item in batch])
    label = label.to(dtype=dtype)
    
    if device is not None:
        go_embed = go_embed.to(device)
        label = label.to(device)
    
    return {
        'entryID': [item['entryID'] for item in batch],
        'features': features,
        'mask': mask,
        'go_embed': go_embed,
        'label': label,
        'predicted_terms': [item['predicted_terms'] for item in batch],
        'true_terms': [item['true_terms'] for item in batch]
    }


class PrefetchLoaderWithBLM:
    """
    Prefetch loader that loads batches asynchronously to GPU for faster training.
    Handles PLM tokenization (32 tokens), BLM projection (32 tokens), stacking to 64 tokens.
    """
    def __init__(self, dataloader, device, blm_projection_layer, tokenizer, num_plm_tokens=32, num_blm_tokens=32, max_prefetch=1):
        self.dataloader = dataloader
        self.device = device
        self.blm_projection_layer = blm_projection_layer
        self.tokenizer = tokenizer
        self.num_plm_tokens = num_plm_tokens
        self.num_blm_tokens = num_blm_tokens
        self.max_prefetch = max_prefetch
        self.stream = torch.cuda.Stream() if device.type == 'cuda' else None
        
        # Move projection layer to device
        self.blm_projection_layer = self.blm_projection_layer.to(device)
        
        # Move tokenizer to device
        self.tokenizer = self.tokenizer.to(device)
        
    def __iter__(self):
        if self.stream is not None:
            return self._cuda_iter()
        else:
            return self._cpu_iter()
    
    def _cpu_iter(self):
        """Iterator without prefetching for CPU."""
        for batch in self.dataloader:
            # Apply collate with projection and tokenization
            batch = collate_with_blm_projection(
                batch, 
                self.blm_projection_layer,
                tokenizer=self.tokenizer,
                device=self.device, 
                dtype=next(self.blm_projection_layer.parameters()).dtype,
                num_plm_tokens=self.num_plm_tokens,
                num_blm_tokens=self.num_blm_tokens
            )
            yield batch
    
    def _cuda_iter(self):
        """Iterator with CUDA stream prefetching."""
        loader_iter = iter(self.dataloader)
        
        # Preload first batch
        try:
            with torch.cuda.stream(self.stream):
                next_batch = next(loader_iter)
                next_batch = self._process_batch(next_batch)
        except StopIteration:
            return
        
        while True:
            # Wait for the prefetch stream to finish
            torch.cuda.current_stream().wait_stream(self.stream)
            batch = next_batch
            
            # Record stream for tensors
            if isinstance(batch, dict):
                for k, v in batch.items():
                    if isinstance(v, torch.Tensor):
                        v.record_stream(torch.cuda.current_stream())
            
            # Start loading next batch in background
            try:
                with torch.cuda.stream(self.stream):
                    next_batch = next(loader_iter)
                    next_batch = self._process_batch(next_batch)
            except StopIteration:
                yield batch
                del batch
                break
                
            yield batch
            del batch
    
    def _process_batch(self, batch):
        """Process batch with PLM tokenization, BLM projection, and move to device."""
        return collate_with_blm_projection(
            batch,
            self.blm_projection_layer,
            tokenizer=self.tokenizer,
            device=self.device,
            dtype=next(self.blm_projection_layer.parameters()).dtype,
            num_plm_tokens=self.num_plm_tokens,
            num_blm_tokens=self.num_blm_tokens
        )
    
    def __len__(self):
        return len(self.dataloader)

In [33]:
from Utils.tokenizer import EmbedTokenizer

tokenizer = EmbedTokenizer(1280, 512, 32)

In [38]:
# Example usage:
# Initialize the BLM projection layer and tokenizer
model_dim = 512  # Target dimension
blm_dim = dataset.blm_dim  # Get from dataset
plm_dim = dataset.plm_dim  # Get from dataset

# Create projection layer to project BLM features to model_dim
blm_projection = torch.nn.Linear(blm_dim, model_dim)

# Create tokenizer for PLM features (must tokenize to 32 tokens)
# tokenizer should already be defined above (e.g., EmbedTokenizer(1280, 512, 32))
# tokenizer = EmbedTokenizer(plm_dim, model_dim, 32)

# Simple collate function that just returns the batch as-is (no stacking)
def simple_collate(batch):
    return batch

# Create dataloader with simple collate (we'll handle the actual collation in PrefetchLoader)
train_dataloader = DataLoader(
    dataset, 
    batch_size=32, 
    shuffle=True,
    num_workers=0,  # Use 0 for single process, or multi-worker compatible
    pin_memory=True if torch.cuda.is_available() else False,
    collate_fn=simple_collate  # Don't stack yet - let PrefetchLoader handle it
)

# Wrap with PrefetchLoader that handles PLM tokenization (32) + BLM projection (32) = 64 tokens
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = PrefetchLoaderWithBLM(
    train_dataloader,
    device=device,
    blm_projection_layer=blm_projection,
    tokenizer=tokenizer,  # REQUIRED: tokenizer for PLM features (32 tokens)
    num_plm_tokens=32,    # Fixed: 32 PLM tokens
    num_blm_tokens=32     # Fixed: 32 BLM tokens (padded if less)
)

# Now you can iterate through the loader
# for batch in train_loader:
#     features = batch['features']  # [batch, 64, 512] - 32 PLM + 32 BLM tokens
#     mask = batch['mask']          # [batch, 64] - combined attention mask
#     go_embed = batch['go_embed']  # [batch, num_terms, go_dim]
#     label = batch['label']        # [batch, num_terms]

In [41]:
batch = next(iter(train_loader))

In [42]:
batch['features'].shape, batch['mask'].shape, batch['go_embed'].shape, batch['label'].shape

(torch.Size([32, 64, 512]),
 torch.Size([32, 64]),
 torch.Size([32, 64, 512]),
 torch.Size([32, 64]))

## -------Analysis----------

In [10]:
# Get all PMID lists from the dictionary
pmid_lists = [pmids for pmids in data['prot_2_pmid'].values()]

# Calculate statistics
num_pmids_per_protein = [len(pmids) for pmids in pmid_lists]

print(f"Number of proteins with PMID annotations: {len(pmid_lists)}")
print(f"\nPMIDs per protein statistics:")
print(f"Mean: {np.mean(num_pmids_per_protein):.2f}")
print(f"Min: {np.min(num_pmids_per_protein)}")
print(f"Max: {np.max(num_pmids_per_protein)}")
print(f"Median: {np.median(num_pmids_per_protein):.2f}")
print(f"Std: {np.std(num_pmids_per_protein):.2f}")

# Additional distribution info
print(f"\nDistribution:")
print(f"Proteins with 0 PMIDs: {sum(1 for n in num_pmids_per_protein if n == 0)}")
print(f"Proteins with 1-5 PMIDs: {sum(1 for n in num_pmids_per_protein if 1 <= n <= 5)}")
print(f"Proteins with 6-10 PMIDs: {sum(1 for n in num_pmids_per_protein if 6 <= n <= 10)}")
print(f"Proteins with 10 -56 PMIDs: {sum(1 for n in num_pmids_per_protein if 10 < n <= 64)}")
print(f"Proteins with >56 PMIDs: {sum(1 for n in num_pmids_per_protein if n > 64)}")

Number of proteins with PMID annotations: 82404

PMIDs per protein statistics:
Mean: 6.97
Min: 0
Max: 224
Median: 5.00
Std: 7.50

Distribution:
Proteins with 0 PMIDs: 589
Proteins with 1-5 PMIDs: 45223
Proteins with 6-10 PMIDs: 22675
Proteins with 10 -64 PMIDs: 13756
Proteins with >64 PMIDs: 161


In [11]:
print(f"Proteins with >32 PMIDs: {sum(1 for n in num_pmids_per_protein if n > 32)}")

Proteins with >32 PMIDs: 1106


## Integration Complete ✅

Successfully integrated the new dataset and dataloader with BLM embeddings into the training pipeline:

### Changes Made:

1. **EmbeddingsDataset.py**:
   - Updated `EmbeddingsDataset` to return both `plm_embed` and `blm_embeds`
   - Added `simple_collate` function to avoid premature stacking
   - Added `collate_with_blm_projection` for PLM tokenization + BLM projection
   - Replaced `PrefetchLoader` with `PrefetchLoaderWithBLM` (32 PLM + 32 BLM = 64 tokens)

2. **train.py**:
   - Updated imports to use new loader classes
   - Created BLM projection layer (`blm_projection`)
   - Modified tokenizer to only handle PLM (32 tokens instead of 64)
   - Updated dataloaders to use `simple_collate` and `PrefetchLoaderWithBLM`

3. **configs_new.json**:
   - Changed `features_embeds_path` → `plm_features_path`
   - Added `prot_2_pmid_path` and `pmid_2_embed_path` for BLM data
   - Removed `features_ids_path` (not needed with .npy dict format)

4. **Dataset/Utils.py**:
   - Updated `load_data` to load PLM, BLM mappings, and PMID embeddings
   - Changed output keys: `features_embeds` → `plm_embeds`
   - Added `prot_2_pmid` and `pmid_2_embed` to output

### Output Format:
- **features**: `[batch, 64, 512]` - 32 PLM tokens + 32 BLM tokens
- **mask**: `[batch, 64]` - Combined attention mask
- **go_embed**: `[batch, num_terms, go_dim]`
- **label**: `[batch, num_terms]`

Ready to train with: `python train.py --config configs_new.json`