In [1]:
from wandb_utils import RunInfo
import pandas as pd
from datetime import datetime

In [2]:
from typing import List, Optional


def make_df(
    runs: List[RunInfo],
    since: Optional[datetime] = None,
) -> pd.DataFrame:
    run_data = pd.DataFrame(
        {
            "id": [info.run.id for info in runs],
            "name": [info.run.name for info in runs],
            "split": [info.split_type for info in runs],
            "split_seed": [info.split_seed for info in runs],
            "created_at": [info.run.created_at for info in runs],
            "tags": [info.tags for info in runs],
            "model": [run.model_name for run in runs],
            "run_info": runs
        }
    )
    run_data.created_at = run_data.created_at.apply(datetime.fromisoformat)
    if since is not None:
        run_data = run_data[run_data.created_at >= since]
    return run_data

## Retrieve EGNN (R) test preds

In [3]:
runs = RunInfo.fetch_all()
run_data = make_df(runs)
run_data.head()

Unnamed: 0,id,name,split,split_seed,created_at,tags,model,run_info
0,epq0akyo,lucky-leaf-467,random,4,2023-06-06 02:15:08,[],EGNN (R/12),lucky-leaf-467(epq0akyo)
1,3dr19mv8,drawn-dawn-466,scaffold,4,2023-06-06 01:13:06,[],EGNN (R/12),drawn-dawn-466(3dr19mv8)
2,3m4ecbqg,summer-dew-465,pocket,4,2023-06-06 00:32:50,[],EGNN (R/12),summer-dew-465(3m4ecbqg)
3,2lhqwwwf,vivid-lake-464,random,3,2023-06-05 22:29:38,[],EGNN (R/12),vivid-lake-464(2lhqwwwf)
4,q9cykxhv,whole-sponge-463,scaffold,3,2023-06-05 21:17:05,[],EGNN (R/12),whole-sponge-463(q9cykxhv)


In [4]:
run_data["model"].unique()

array(['EGNN (R/12)', 'DTI', '', 'EGNN ', 'GIN'], dtype=object)

In [16]:
run_data = run_data[run_data["model"] == "EGNN (R)"]
run_data

Unnamed: 0,id,name,split,split_seed,created_at,tags,model,run_info
0,2u85y7s3,super-brook-433,random,3,2023-05-23 06:48:10,[],EGNN (R),super-brook-433(2u85y7s3)
1,2pe97uwz,unique-smoke-432,random,4,2023-05-23 05:57:21,[],EGNN (R),unique-smoke-432(2pe97uwz)
2,124gtbho,icy-yogurt-431,scaffold,3,2023-05-23 05:11:20,[],EGNN (R),icy-yogurt-431(124gtbho)
3,2njmyw4w,usual-armadillo-430,scaffold,4,2023-05-23 04:47:07,[],EGNN (R),usual-armadillo-430(2njmyw4w)
4,1fs5h4xb,volcanic-donkey-429,pocket,3,2023-05-23 04:41:52,[],EGNN (R),volcanic-donkey-429(1fs5h4xb)
5,14624p3e,generous-pine-428,pocket,4,2023-05-23 04:25:24,[],EGNN (R),generous-pine-428(14624p3e)
6,2zjdpvzk,youthful-violet-427,random,3,2023-05-23 03:16:56,[],EGNN (R),youthful-violet-427(2zjdpvzk)
7,1l9dlxbh,spring-aardvark-426,random,2,2023-05-23 02:53:10,[],EGNN (R),spring-aardvark-426(1l9dlxbh)
8,2m1oikkp,wise-resonance-425,scaffold,3,2023-05-23 02:10:56,[],EGNN (R),wise-resonance-425(2m1oikkp)
9,1itngdmc,good-dawn-424,pocket,3,2023-05-23 01:51:06,[],EGNN (R),good-dawn-424(1itngdmc)


In [35]:
run_data["run_info"].iloc[10].run.summary["test/corr"]

0.4085184931755066

In [19]:
def get_pre_artifacts(run_info):
    artifacts = {artifact.name: artifact for artifact in run_info.run.logged_artifacts() if "predictions" in artifact.name}
    return artifacts

In [20]:
run_data["artifacts"] = run_data["run_info"].apply(get_pre_artifacts)

In [21]:
run_data = run_data[run_data["artifacts"] != dict()]

In [22]:
run_data["artifact_paths"] = run_data["artifacts"].apply(lambda ad: list(ad.values())[0].download())

[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m: 

In [23]:
import json
from pathlib import Path

dfs = []
for _, eval_item in run_data.iterrows():
    preds_raw = json.loads(Path(eval_item["artifact_paths"]).joinpath("predictions.table.json").read_text())
    preds = pd.DataFrame(data=preds_raw["data"], columns=preds_raw["columns"])
    preds["ident"] = preds["ident"].astype(int)
    preds["model"] = eval_item.model
    preds["split"] = eval_item.split
    preds["split_seed"] = eval_item.split_seed
    dfs.append(preds)

In [26]:
all_results = pd.concat(dfs)
all_results

Unnamed: 0,pred,ident,model,split,split_seed
0,6.306590,13841248,EGNN (R),random,4
1,7.810676,17641064,EGNN (R),random,4
2,5.449872,17650892,EGNN (R),random,4
3,8.143995,17767038,EGNN (R),random,4
4,7.652162,6202640,EGNN (R),random,4
...,...,...,...,...,...
3982,8.562795,22419498,EGNN (R),pocket,0
3983,8.466667,22419500,EGNN (R),pocket,0
3984,8.461862,22419500,EGNN (R),pocket,0
3985,8.457900,22419502,EGNN (R),pocket,0


In [27]:
from kinodata.data.dataset import KinodataDocked
dataset = KinodataDocked()
source = dataset.df
source["ident"] = source["ident"].astype(int)

  from .autonotebook import tqdm as notebook_tqdm


Reading data frame..
Checking for missing pocket mol2 files...


100%|██████████| 2439/2439 [00:00<00:00, 34493.03it/s]


In [28]:
eval_data = pd.merge(
    all_results, source[["ident", "activities.standard_value"]], on="ident"
)
eval_data["target"] = eval_data["activities.standard_value"].astype(float)
eval_data.to_csv("~/projects/kinodata-docked-rescore/eval_data/egnnr_eval_data.csv", index=False)