In [2]:
import torch

In [3]:
# Load the CLN025 dataset
import matplotlib.pyplot as plt
import numpy as np
import os
from pathlib import Path

# Load the CLN025 dataset
dataset_path = "/home/shpark/prj-mlcv/lib/DESRES/DESRES-Trajectory_CLN025-0-protein/CLN025-0-cad.pt"
print(f"Loading dataset from: {dataset_path}")

# Load the dataset
dataset = torch.load(dataset_path, map_location="cpu")
print(f"Dataset type: {type(dataset)}")
print(f"Dataset keys: {dataset.keys() if isinstance(dataset, dict) else 'Not a dict'}")

# Examine the structure
if isinstance(dataset, dict):
    for key, value in dataset.items():
        if isinstance(value, torch.Tensor):
            print(f"{key}: shape={value.shape}, dtype={value.dtype}")
        else:
            print(f"{key}: type={type(value)}")
elif isinstance(dataset, torch.Tensor):
    print(f"Dataset shape: {dataset.shape}, dtype: {dataset.dtype}")
else:
    print(f"Dataset structure: {type(dataset)}")


Loading dataset from: /home/shpark/prj-mlcv/lib/DESRES/DESRES-Trajectory_CLN025-0-protein/CLN025-0-cad.pt
Dataset type: <class 'torch.Tensor'>
Dataset keys: Not a dict
Dataset shape: torch.Size([534743, 45]), dtype: torch.float32


In [4]:
# Find and list all available jit models for CLN025
model_dir = Path("/home/shpark/prj-mlcv/lib/bioemu/opes/model")
baseline_dir = model_dir / "_baseline_"

# Find all CLN025 jit models
cln025_jit_models = []

# Check baseline models
if baseline_dir.exists():
    baseline_models = list(baseline_dir.glob("*CLN025-jit.pt"))
    cln025_jit_models.extend(baseline_models)
    print("Baseline models:")
    for model in baseline_models:
        print(f"  {model.name}")

# Check your models  
your_models = list(model_dir.glob("*CLN025*jit.pt"))
your_models = [m for m in your_models if m.parent == model_dir]  # exclude baseline dir
cln025_jit_models.extend(your_models)

print("\nYour models:")
for model in your_models:
    print(f"  {model.name}")

print(f"\nTotal CLN025 jit models found: {len(cln025_jit_models)}")
for model in cln025_jit_models:
    print(f"  {model}")


Baseline models:
  tica-CLN025-jit.pt
  tda-CLN025-jit.pt
  vde-CLN025-jit.pt
  tae-CLN025-jit.pt

Your models:
  0816_171833-CLN025-jit.pt
  0816_171833-CLN025-f32-jit.pt

Total CLN025 jit models found: 6
  /home/shpark/prj-mlcv/lib/bioemu/opes/model/_baseline_/tica-CLN025-jit.pt
  /home/shpark/prj-mlcv/lib/bioemu/opes/model/_baseline_/tda-CLN025-jit.pt
  /home/shpark/prj-mlcv/lib/bioemu/opes/model/_baseline_/vde-CLN025-jit.pt
  /home/shpark/prj-mlcv/lib/bioemu/opes/model/_baseline_/tae-CLN025-jit.pt
  /home/shpark/prj-mlcv/lib/bioemu/opes/model/0816_171833-CLN025-jit.pt
  /home/shpark/prj-mlcv/lib/bioemu/opes/model/0816_171833-CLN025-f32-jit.pt


In [5]:
# Load all models and compute outputs
model_outputs = {}
model_info = {}

# Prepare input data - check what format the models expect
print("Dataset structure analysis:")
if isinstance(dataset, dict):
    # Try common keys that might contain the coordinate data
    input_data = None
    for key in ['pos', 'positions', 'coords', 'coordinates', 'cad']:
        if key in dataset:
            input_data = dataset[key]
            print(f"Using '{key}' as input data: shape={input_data.shape}")
            break
    
    if input_data is None:
        # Use the first tensor we find
        for key, value in dataset.items():
            if isinstance(value, torch.Tensor):
                input_data = value
                print(f"Using '{key}' as input data: shape={input_data.shape}")
                break
elif isinstance(dataset, torch.Tensor):
    input_data = dataset
    print(f"Using dataset directly: shape={input_data.shape}")

print(f"\nInput data shape: {input_data.shape}")
print(f"Input data dtype: {input_data.dtype}")
print(f"Input data range: [{input_data.min():.4f}, {input_data.max():.4f}]")

# Take a subset for faster testing (first 1000 frames if available)
if len(input_data) > 1000:
    input_subset = input_data[:1000]
    print(f"Using subset for testing: {input_subset.shape}")
else:
    input_subset = input_data
    print(f"Using full dataset: {input_subset.shape}")


Dataset structure analysis:
Using dataset directly: shape=torch.Size([534743, 45])

Input data shape: torch.Size([534743, 45])
Input data dtype: torch.float32
Input data range: [0.3508, 3.0170]
Using subset for testing: torch.Size([1000, 45])


In [None]:
# Import additional dependencies for compute_cv_values
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

def compute_cv_values(
    mlcv_model,
    cad_torch,
    model_type,
    reference_cad=None,
    batch_size=10000,
):
    """Compute CV values from the model with optional sign flipping using batch processing."""
    dataset = TensorDataset(cad_torch)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    print(f"Computing CV values in batches of {batch_size}...")
    
    with torch.no_grad():
        sample_batch = next(iter(dataloader))[0]
        sample_output = mlcv_model(sample_batch)
        output_dim = sample_output.shape[1]
    cv_batches = torch.zeros((len(cad_torch), output_dim)).to(cad_torch.device)
    
    with torch.no_grad():
        for batch_idx, (batch_data,) in enumerate(tqdm(
            dataloader,
            desc="Computing CV values",
            total=len(dataloader),
            leave=False,
        )):
            batch_cv = mlcv_model(batch_data)
            start_idx = batch_idx * batch_size
            end_idx = start_idx + batch_cv.shape[0]  # Handle last batch size correctly
            cv_batches[start_idx:end_idx] = batch_cv
    
    cv = cv_batches.detach().cpu().numpy()
    MLCV_DIM = cv.shape[1]
    
    print(f"CV computation complete. Shape: {cv.shape}")
    
    if model_type == "mlcv":
        # Normalize CV values for MLCV
        cv_normalized = np.zeros_like(cv)
        
        for cv_dim in range(MLCV_DIM):
            cv_dim_val = cv[:, cv_dim]
            cv_range_min, cv_range_max = cv_dim_val.min(), cv_dim_val.max()
            cv_range_mean = (cv_range_min + cv_range_max) / 2.0
            cv_range = (cv_range_max - cv_range_min) / 2.0
            cv_normalized[:, cv_dim] = (cv_dim_val - cv_range_mean) / cv_range
        
        cv = cv_normalized
        
    # Additional sign flipping based on reference structure
    if reference_cad is not None:
        with torch.no_grad():
            ref_cv = mlcv_model(torch.from_numpy(reference_cad).to(cad_torch.device))
            ref_cv = ref_cv.detach().cpu().numpy()
        
        # Normalize reference CV the same way
        if model_type == "mlcv":
            ref_cv_normalized = np.zeros_like(ref_cv)
            for cv_dim in range(MLCV_DIM):
                ref_cv_dim_val = ref_cv[:, cv_dim]
                cv_dim_val = cv[:, cv_dim]
                cv_range_min, cv_range_max = cv_dim_val.min(), cv_dim_val.max()
                cv_range_mean = (cv_range_min + cv_range_max) / 2.0
                cv_range = (cv_range_max - cv_range_min) / 2.0
                ref_cv_normalized[:, cv_dim] = (ref_cv_dim_val - cv_range_mean) / cv_range
            ref_cv = ref_cv_normalized
        
        # Flip signs to ensure reference CV is positive
        for cv_dim in range(MLCV_DIM):
            if ref_cv[0, cv_dim] < 0:
                cv[:, cv_dim] = -cv[:, cv_dim]
                print(f"Flipped sign for CV dimension {cv_dim} to ensure positive reference value")
    
    return cv

print("✓ compute_cv_values function defined")
