In [1]:
from wandb_utils import RunInfo, retrieve_best_model_artifact, load_state_dict
import wandb
import pandas as pd
from datetime import datetime

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
from typing import List, Optional


def make_df(
    runs: List[RunInfo],
    since: Optional[datetime] = None,
) -> pd.DataFrame:
    run_data = pd.DataFrame(
        {
            "id": [info.run.id for info in runs],
            "name": [info.run.name for info in runs],
            "split": [info.split_type for info in runs],
            "split_seed": [info.split_seed for info in runs],
            "created_at": [info.run.created_at for info in runs],
            "tags": [info.tags for info in runs],
            "model": [run.model_name for run in runs],
            "run_info": runs
        }
    )
    run_data.created_at = run_data.created_at.apply(datetime.fromisoformat)
    if since is not None:
        run_data = run_data[run_data.created_at >= since]
    return run_data

## Retrieve test preds

In [8]:
runs = RunInfo.fetch(since=datetime(2023, 6, 25))
run_data = make_df(runs)
run_data.head()

Unnamed: 0,id,name,split,split_seed,created_at,tags,model,run_info
0,2nwzsp8f,comic-sun-615,random,4,2023-06-28 16:20:09,[transformer],Covalent Transformer,comic-sun-615(2nwzsp8f)
1,uizeqr17,autumn-grass-614,random,3,2023-06-28 15:26:39,[transformer],Covalent Transformer,autumn-grass-614(uizeqr17)
2,2q7acs57,sage-paper-613,scaffold,4,2023-06-28 14:37:23,[transformer],Covalent Transformer,sage-paper-613(2q7acs57)
3,1xjzm16g,apricot-cherry-612,scaffold,3,2023-06-28 14:09:00,[transformer],Covalent Transformer,apricot-cherry-612(1xjzm16g)
4,3iw1oqo3,eager-waterfall-611,random,2,2023-06-28 13:58:17,[transformer],Covalent Transformer,eager-waterfall-611(3iw1oqo3)


In [9]:
run_data["test/corr"] = run_data["run_info"].apply(lambda r: r.run.summary.get("test/corr", None))
run_data

Unnamed: 0,id,name,split,split_seed,created_at,tags,model,run_info,test/corr
0,2nwzsp8f,comic-sun-615,random,4,2023-06-28 16:20:09,[transformer],Covalent Transformer,comic-sun-615(2nwzsp8f),0.737658
1,uizeqr17,autumn-grass-614,random,3,2023-06-28 15:26:39,[transformer],Covalent Transformer,autumn-grass-614(uizeqr17),0.748805
2,2q7acs57,sage-paper-613,scaffold,4,2023-06-28 14:37:23,[transformer],Covalent Transformer,sage-paper-613(2q7acs57),0.585272
3,1xjzm16g,apricot-cherry-612,scaffold,3,2023-06-28 14:09:00,[transformer],Covalent Transformer,apricot-cherry-612(1xjzm16g),0.579323
4,3iw1oqo3,eager-waterfall-611,random,2,2023-06-28 13:58:17,[transformer],Covalent Transformer,eager-waterfall-611(3iw1oqo3),0.754974
5,1m12unh2,fine-paper-610,pocket,4,2023-06-28 13:50:22,[transformer],Covalent Transformer,fine-paper-610(1m12unh2),0.196236
6,sccvrwh8,balmy-feather-609,pocket,3,2023-06-28 13:05:00,[transformer],Covalent Transformer,balmy-feather-609(sccvrwh8),0.219494
7,ln37plen,giddy-universe-608,scaffold,2,2023-06-28 13:04:52,[transformer],Covalent Transformer,giddy-universe-608(ln37plen),0.591955
8,3ic69qjl,misunderstood-microwave-607,random,1,2023-06-28 12:52:07,[transformer],Covalent Transformer,misunderstood-microwave-607(3ic69qjl),0.748559
9,1ogolv3j,misunderstood-rain-606,pocket,2,2023-06-28 12:29:20,[transformer],Covalent Transformer,misunderstood-rain-606(1ogolv3j),0.418222


In [10]:
run_data = run_data.dropna()

In [11]:
artifact = retrieve_best_model_artifact(run_data.iloc[15].run_info.run)

In [13]:
ckpt = load_state_dict(artifact)

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [19]:
for x in ckpt["state_dict"].keys():
    print(x)

edge_bias
atomic_num_embedding.weight
lin_atom_features.weight
lin_atom_features.bias
attention_blocks.0.attention.normalizer
attention_blocks.0.attention.lin_query.weight
attention_blocks.0.attention.lin_key_value.weight
attention_blocks.0.attention.lin_bias.weight
attention_blocks.0.attention.lin_out.weight
attention_blocks.0.ln1.weight
attention_blocks.0.ln1.bias
attention_blocks.0.ff.0.weight
attention_blocks.0.ff.0.bias
attention_blocks.0.ff.2.weight
attention_blocks.0.ff.2.bias
attention_blocks.0.ln2.weight
attention_blocks.0.ln2.bias
attention_blocks.0.ff_edge.0.weight
attention_blocks.0.ff_edge.0.bias
attention_blocks.0.ff_edge.2.weight
attention_blocks.0.ff_edge.2.bias
attention_blocks.0.ln3.weight
attention_blocks.0.ln3.bias
attention_blocks.1.attention.normalizer
attention_blocks.1.attention.lin_query.weight
attention_blocks.1.attention.lin_key_value.weight
attention_blocks.1.attention.lin_bias.weight
attention_blocks.1.attention.lin_out.weight
attention_blocks.1.ln1.weight


In [23]:
means = ckpt["state_dict"]["dist_embedding.0.means"]

In [27]:
(means[1:] - means[:-1]).clamp_(max=0)

tensor([-0.0162,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0244,  0.0000,
         0.0000, -0.0145,  0.0000,  0.0000, -0.0027,  0.0000, -0.0270,  0.0000,
        -0.0088,  0.0000, -0.0112,  0.0000,  0.0000,  0.0000, -0.0139,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
        -0.0143,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000, -0.0144,  0.0000,  0.0000,  0.0000,  0.0000,
        -0.0006,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
        -0.0121,  0.0000,  0.0000,  0.0000, -0.0004,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0029,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0002,  0.0000,
         0.0000,  0.0000, -0.0040,  0.0000,  0.0000,  0.0000,  0.0000, -0.0007,
         0.0000,  0.0000,  0.0000,  0.00

In [6]:
def get_pred_artifacts(run_info):
    artifacts = {artifact.name: artifact for artifact in run_info.run.logged_artifacts() if "predictions" in artifact.name}
    return artifacts

In [7]:
run_data["artifacts"] = run_data["run_info"].apply(get_pred_artifacts)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  run_data["artifacts"] = run_data["run_info"].apply(get_pred_artifacts)


In [8]:
run_data = run_data[run_data["artifacts"] != dict()]

In [9]:
run_data["artifact_paths"] = run_data["artifacts"].apply(lambda ad: list(ad.values())[0].download())

[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m: 

In [10]:
import json
from pathlib import Path

dfs = []
for _, eval_item in run_data.iterrows():
    preds_raw = json.loads(Path(eval_item["artifact_paths"]).joinpath("predictions.table.json").read_text())
    preds = pd.DataFrame(data=preds_raw["data"], columns=preds_raw["columns"])
    preds["ident"] = preds["ident"].astype(int)
    preds["model"] = eval_item.model
    preds["split"] = eval_item.split
    preds["split_seed"] = eval_item.split_seed
    dfs.append(preds)

In [11]:
all_results = pd.concat(dfs)
all_results

Unnamed: 0,pred,ident,model,split,split_seed
0,6.861015,13841248,Covalent Transformer,random,4
1,8.251204,17641064,Covalent Transformer,random,4
2,5.887022,17650892,Covalent Transformer,random,4
3,7.638943,17767038,Covalent Transformer,random,4
4,8.907273,6202640,Covalent Transformer,random,4
...,...,...,...,...,...
4401,8.404881,23312098,Transformer,scaffold,0
4402,7.110132,23314854,Transformer,scaffold,0
4403,7.757973,23314980,Transformer,scaffold,0
4404,8.256553,23314980,Transformer,scaffold,0


In [12]:
from kinodata.data.dataset import KinodataDocked
dataset = KinodataDocked()
source = dataset.df
source["ident"] = source["ident"].astype(int)

  from .autonotebook import tqdm as notebook_tqdm


Reading data frame..
Checking for missing pocket mol2 files...


100%|██████████| 2439/2439 [00:00<00:00, 12268.97it/s]


Adding pocket sequences from cached file /Users/joschka/projects/kinodata-docked-rescore/data/raw/pocket_sequences.csv.


In [14]:
eval_data = pd.merge(
    all_results, source[["ident", "activities.standard_value"]], on="ident"
)
eval_data["target"] = eval_data["activities.standard_value"].astype(float)
eval_data

Unnamed: 0,pred,ident,model,split,split_seed,activities.standard_value,target
0,6.861015,13841248,Covalent Transformer,random,4,6.136677,6.136677
1,6.836371,13841248,Covalent Transformer,random,3,6.136677,6.136677
2,6.633666,13841248,Transformer,random,4,6.136677,6.136677
3,6.653460,13841248,Transformer,random,3,6.136677,6.136677
4,8.251204,17641064,Covalent Transformer,random,4,7.995679,7.995679
...,...,...,...,...,...,...,...
108045,7.833258,15152531,Transformer,random,0,9.619789,9.619789
108046,5.179553,20638952,Covalent Transformer,random,0,6.173925,6.173925
108047,5.913575,20638952,Transformer,random,0,6.173925,6.173925
108048,8.231318,19175820,Covalent Transformer,random,0,8.045757,8.045757


In [15]:

eval_data.to_csv("~/projects/kinodata-docked-rescore/eval_data/transformer_eval_data.csv", index=False)