In [1]:
import os

os.chdir("/home/robsyc/Desktop/thesis/MB-VAE-DTI")



In [2]:
import h5torch

from mb_vae_dti.processing import load_h5torch_DTI

04:19:00 - INFO - Old pandas version detected. Patching DataFrame.map to DataFrame.applymap


In [3]:
file_path = "data/processed/data.h5torch"
with h5torch.File(file_path, "r") as f:
    for key in f.keys():
        print(key)
        print(f[key].keys())

0
<KeysViewHDF5 ['Drug_ID', 'Drug_InChIKey', 'Drug_SMILES']>
1
<KeysViewHDF5 ['Target_AA', 'Target_DNA', 'Target_Gene_name', 'Target_ID', 'Target_RefSeq_ID', 'Target_UniProt_ID']>
central
<KeysViewHDF5 ['data', 'indices']>
unstructured
<KeysViewHDF5 ['Y_KIBA', 'Y_pKd', 'Y_pKi', 'in_BindingDB_Kd', 'in_BindingDB_Ki', 'in_DAVIS', 'in_KIBA', 'in_Metz', 'split_cold', 'split_rand']>


In [4]:
with h5torch.File(file_path, "r") as f:
    print(f["1/Target_fp"][100][:20])
    print(f["1/Target_fp"][100].shape)

KeyError: "Unable to synchronously open object (object 'Target_fp' doesn't exist)"

In [5]:
test_davis_metz = load_h5torch_DTI(
    setting="split_cold",
    split="test",
    datasets=["in_DAVIS", "in_Metz", "in_BindingDB_Kd"]
)
test_davis_metz[153]

Using boolean mask for mapping (7145 indices)
Verified alignment: all unstructured data has 396469 elements


{'central': True,
 '0/Drug_ID': 'D000028',
 '0/Drug_InChIKey': 'XZXHXSATPCNXJR-ZIADKAODSA-N',
 '0/Drug_SMILES': 'COC(=O)c1ccc2c(c1)NC(=O)C2=C(Nc1ccc(N(C)C(=O)CN2CCN(C)CC2)cc1)c1ccccc1',
 '1/Target_AA': 'MEQPPAPKSKLKKLSEDSLTKQPEEVFDVLEKLGEGSYGSVFKAIHKESGQVVAIKQVPVESDLQEIIKEISIMQQCDSPYVVKYYGSYFKNTDLWIVMEYCGAGSVSDIIRLRNKTLIEDEIATILKSTLKGLEYLHFMRKIHRDIKAGNILLNTEGHAKLADFGVAGQLTDTMAKRNTVIGTPFWMAPEVIQEIGYNCVADIWSLGITSIEMAEGKPPYADIHPMRAIFMIPTNPPPTFRKPELWSDDFTDFVKKCLVKNPEQRATATQLLQHPFIKNAKPVSILRDLITEAMEIKAKRHEEQQRELEEEEENSDEDELDSHTMVKTSVESVGTMRATSTMSEGAQTMIEHNSTMLESDLGTMVINSEDEEEEDGTMKRNATSPQVQRPSFMDYFDKQDFKNKSHENCNQNMHEPFPMSKNVFPDNWKVPQDGDFDFLKNLSLEELQMRLKALDPMMEREIEELRQRYTAKRQPILDAMDAKKRRQQNF',
 '1/Target_DNA': 'AGTAAACTAAAAAAGCTGAGTGAAGACAGTTTGACTAAGCAGCCTGAAGAAGTTTTTGATGTATTAGAGAAGCTTGGAGAAGGGTCTTATGGAAGTGTATTTAAAGCAATACACAAGGAATCCGGTCAAGTTGTCGCAATTAAACAAGTACCTGTTGAATCAGATCTTCAGGAAATAATCAAAGAAATTTCCATAATGCAGCAATGTGACAGCCCATATGTTGTAAAGTACTATGGCAGTTATTTTAAGAATACAGACCTCTGGATTGTTATGGAGTACTGTGGC

In [6]:
aa = test_davis_metz[153]['1/Target_AA']
aa

'MEQPPAPKSKLKKLSEDSLTKQPEEVFDVLEKLGEGSYGSVFKAIHKESGQVVAIKQVPVESDLQEIIKEISIMQQCDSPYVVKYYGSYFKNTDLWIVMEYCGAGSVSDIIRLRNKTLIEDEIATILKSTLKGLEYLHFMRKIHRDIKAGNILLNTEGHAKLADFGVAGQLTDTMAKRNTVIGTPFWMAPEVIQEIGYNCVADIWSLGITSIEMAEGKPPYADIHPMRAIFMIPTNPPPTFRKPELWSDDFTDFVKKCLVKNPEQRATATQLLQHPFIKNAKPVSILRDLITEAMEIKAKRHEEQQRELEEEEENSDEDELDSHTMVKTSVESVGTMRATSTMSEGAQTMIEHNSTMLESDLGTMVINSEDEEEEDGTMKRNATSPQVQRPSFMDYFDKQDFKNKSHENCNQNMHEPFPMSKNVFPDNWKVPQDGDFDFLKNLSLEELQMRLKALDPMMEREIEELRQRYTAKRQPILDAMDAKKRRQQNF'

In [7]:
from mb_vae_dti.processing.embed_helper import get_target_fingerprint

aa_fp = get_target_fingerprint(aa)
print(aa_fp.shape)
print(aa_fp[:20])

(4170,)
[0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]


In [8]:
import h5torch
import numpy as np
import tqdm
from typing import Callable, Tuple, Optional, List, Union

def add_processed_feature(
    file_path: str,
    entity_path: str,
    process_func: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]],
    feature_name: str,
    batch_size: int = 1,
    dtype_save: str = "float32",
    dtype_load: str = "float32",
    overwrite: bool = False
):
    """
    Generic function to add processed features to h5torch file with batch processing support.
    
    Args:
        file_path (str): Path to the h5torch file
        entity_path (str): Path to entity in h5torch file (e.g. "1/Target_AA")
        process_func (Callable): Function to process entities, can handle batches if batch_size > 1
        feature_name (str): Name for the new feature
        batch_size (int): Number of entities to process at once (1 for no batching)
        dtype_save (str): Data type to save as
        dtype_load (str): Data type to load as
        overwrite (bool): Whether to overwrite existing feature if it exists
    """
    # Parse entity path to get axis
    parts = entity_path.split("/")
    axis = parts[0]
    
    # Open the file in append mode
    f = h5torch.File(file_path, "a")
    
    # Get the axis number if it's a digit
    axis_num = None
    if axis.isdigit():
        axis_num = int(axis)
    
    # Check if feature exists and should be overwritten
    feature_path = f"{axis}/{feature_name}"
    if feature_path in f:
        if overwrite:
            print(f"Found existing {feature_name}, deleting it")
            del f[feature_path]
        else:
            print(f"Feature {feature_name} already exists and overwrite=False. Skipping.")
            f.close()
            return
    
    # Get all entity values
    entities = f[entity_path][:]
    
    # Determine if entities need decoding (from bytes to str)
    needs_decoding = isinstance(entities[0], bytes)
    
    # Process a sample to determine output shape
    sample_entity = entities[0].decode('utf-8') if needs_decoding else entities[0]
    sample_result = process_func(sample_entity)
    result_shape = sample_result.shape if hasattr(sample_result, 'shape') else (1,)
    
    # Create array to store processed results
    n_entities = len(entities)
    output_array = np.zeros((n_entities, *result_shape), dtype=np.float32)
    
    # Adjust batch size if needed
    effective_batch_size = min(batch_size, n_entities)
    
    print(f"Processing {n_entities} entities with batch size {effective_batch_size}, output dimension {result_shape}")
    
    # Process entities in batches
    num_batches = (n_entities + effective_batch_size - 1) // effective_batch_size
    
    for batch_idx in tqdm.tqdm(range(num_batches), desc=f"Processing {feature_name} batches"):
        # Get batch indices
        start_idx = batch_idx * effective_batch_size
        end_idx = min(start_idx + effective_batch_size, n_entities)
        
        # Prepare batch of entities
        batch_entities = entities[start_idx:end_idx]
        
        # Decode if needed
        if needs_decoding:
            batch_entities = [e.decode('utf-8') for e in batch_entities]
        
        # Process batch
        if effective_batch_size == 1:
            # Single item processing
            batch_results = [process_func(batch_entities[0])]
        else:
            # True batch processing
            batch_results = process_func(batch_entities)
            if not isinstance(batch_results, list):
                # If process_func returns a single array for the whole batch
                # Split it into individual results
                batch_results = [batch_results[i] for i in range(len(batch_entities))]
        
        # Store results
        for i, result in enumerate(batch_results):
            idx = start_idx + i
            if idx >= n_entities:
                break
                
            if result_shape == (1,):  # Handle scalar results
                output_array[idx, 0] = result
            else:
                output_array[idx] = result
    
    # Register the processed features
    f.register(
        output_array, 
        mode="N-D", 
        axis=axis_num if axis_num is not None else axis, 
        name=feature_name, 
        dtype_save=dtype_save, 
        dtype_load=dtype_load
    )
    
    # Close the file
    f.close()
    print(f"Successfully added {feature_name} to {file_path}")

In [9]:
from mb_vae_dti.processing.embed_helper import get_target_fingerprint

# Example 1: For a function that can process batches:
def batch_target_fingerprint(sequences):
    """Process multiple sequences at once"""
    if isinstance(sequences, str):
        return get_target_fingerprint(sequences)
    else:
        # Process a batch of sequences
        return [get_target_fingerprint(seq) for seq in sequences]

# Using batch processing (process 32 sequences at a time)
add_processed_feature(
    file_path="data/processed/data.h5torch",
    entity_path="1/Target_AA",
    process_func=batch_target_fingerprint,
    feature_name="Target_fp",
    batch_size=32,
    overwrite=True
)

# Example 2: If you had a drug fingerprint function
from mb_vae_dti.processing.embed_helper import get_drug_fingerprint

add_processed_feature(
    file_path="data/processed/data.h5torch",
    entity_path="0/Drug_SMILES",
    process_func=get_drug_fingerprint,
    feature_name="Drug_fp"
)

Processing 2047 entities with batch size 32, output dimension (4170,)


Processing Target_fp batches: 100%|██████████| 64/64 [00:15<00:00,  4.08it/s]


Successfully added Target_fp to data/processed/data.h5torch
Processing 149962 entities with batch size 1, output dimension (2048,)


Processing Drug_fp batches: 100%|██████████| 149962/149962 [02:12<00:00, 1135.22it/s]


Successfully added Drug_fp to data/processed/data.h5torch
