In [1]:
import logging
import os
import typing
from collections.abc import Callable
from pathlib import Path
from typing import Literal

import hydra
import numpy as np
import torch
import yaml
from huggingface_hub import hf_hub_download
from torch_geometric.data.batch import Batch
from tqdm import tqdm


from bioemu.models import DiGConditionalScoreModel
from bioemu.sde_lib import SDE


In [2]:
ckpt_path = "../model/checkpoint.ckpt"
cfg_path = "../model/config.yaml"

with open(cfg_path) as f:
	model_config = yaml.safe_load(f)

In [3]:
model_state = torch.load(ckpt_path, map_location="cpu", weights_only=True)
score_model: DiGConditionalScoreModel = hydra.utils.instantiate(model_config["score_model"])
score_model.load_pretrained(model_state)
sdes: dict[str, SDE] = hydra.utils.instantiate(model_config["sdes"])

Loading pretrained model from {'model_nn.step_emb.dummy': tensor([]), 'model_nn.x1d_proj.0.weight': tensor([ 1.5789,  1.6632,  0.7087,  2.0426,  1.8947,  1.6089,  1.1531,  1.5849,
         1.3797,  1.2903,  1.6233,  1.4974,  1.4384,  1.6689,  1.4800,  1.9022,
         1.4978,  1.7501,  1.9123,  1.1528,  1.3177,  1.5758,  1.4895,  1.9559,
         0.9697,  1.5946,  1.3306,  1.3673,  2.2045,  1.3551,  1.7627,  2.2268,
         1.6442,  1.5997,  1.6451,  1.7478,  1.6228,  1.4556,  1.6747,  1.8330,
         1.2021,  1.5726,  1.5630,  0.9163,  0.4207,  1.6263,  1.4297,  1.5666,
         0.5863,  1.1037,  0.1450,  1.2744,  1.8637,  1.3461,  2.0432,  1.8054,
         2.0542,  1.8250,  1.6202,  1.9141,  1.5349,  1.5607,  0.9694,  1.4380,
         1.6122,  1.3108,  1.3651,  1.4898,  1.6006,  1.8075,  1.4324,  0.8528,
         1.6853,  2.0325,  1.6123,  0.6026,  1.8759,  0.5754,  1.7199,  1.9134,
         1.8842,  1.3549,  1.7653,  0.3734,  0.4770,  1.6190,  1.7670,  1.2457,
         1.3492,  0.

In [14]:
score_model.model_nn.st_module

StructureModuleControl(
  (encoder): SAEncoder(
    (layers): ModuleList(
      (0-7): 8 x SAEncoderLayer(
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): SAAttention(
          (scalar_query): Linear(in_features=512, out_features=512, bias=False)
          (scalar_key): Linear(in_features=512, out_features=512, bias=False)
          (scalar_value): Linear(in_features=512, out_features=512, bias=False)
          (pair_bias): Linear(in_features=256, out_features=32, bias=False)
          (point_query): Linear(in_features=512, out_features=384, bias=False)
          (point_key): Linear(in_features=512, out_features=384, bias=False)
          (point_value): Linear(in_features=512, out_features=768, bias=False)
          (pair_value): Linear(in_features=256, out_features=512, bias=False)
          (fc_out): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (norm2): LayerNorm((512

In [None]:
from bioemu.structure_module import SAEncoderLayer

test_score_model = score_model

for i in range(0, 8):
	test_score_model.model_nn.st_module.encoder.add_module(f"control_encoder_{i}", SAEncoderLayer(
        d_model=512,
        d_pair=256,
        n_head=32,
        dim_feedforward=1024,
        dropout=0.1,
    ))

In [19]:
test_score_model.model_nn.st_module.encoder._modules

{'layers': ModuleList(
   (0-7): 8 x SAEncoderLayer(
     (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
     (attn): SAAttention(
       (scalar_query): Linear(in_features=512, out_features=512, bias=False)
       (scalar_key): Linear(in_features=512, out_features=512, bias=False)
       (scalar_value): Linear(in_features=512, out_features=512, bias=False)
       (pair_bias): Linear(in_features=256, out_features=32, bias=False)
       (point_query): Linear(in_features=512, out_features=384, bias=False)
       (point_key): Linear(in_features=512, out_features=384, bias=False)
       (point_value): Linear(in_features=512, out_features=768, bias=False)
       (pair_value): Linear(in_features=256, out_features=512, bias=False)
       (fc_out): Linear(in_features=2048, out_features=512, bias=True)
       (dropout): Dropout(p=0.1, inplace=False)
     )
     (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
     (ffn): FeedForward(
       (ff): Sequential(
      