In [1]:
from wandb_utils import RunInfo
import pandas as pd
from datetime import datetime

In [2]:
from typing import List, Optional


def make_df(
    runs: List[RunInfo],
    since: Optional[datetime] = None,
) -> pd.DataFrame:
    run_data = pd.DataFrame(
        {
            "id": [info.run.id for info in runs],
            "name": [info.run.name for info in runs],
            "split": [info.split_type for info in runs],
            "split_seed": [info.split_seed for info in runs],
            "created_at": [info.run.created_at for info in runs],
            "tags": [info.tags for info in runs],
            "model": [run.model_name for run in runs],
            "run_info": runs
        }
    )
    run_data.created_at = run_data.created_at.apply(datetime.fromisoformat)
    if since is not None:
        run_data = run_data[run_data.created_at >= since]
    return run_data

## Retrieve EGNN (R) test preds

In [3]:
runs = RunInfo.fetch_all()
run_data = make_df(runs)
run_data.head()

Unnamed: 0,id,name,split,split_seed,created_at,tags,model,run_info
0,epq0akyo,lucky-leaf-467,random,4,2023-06-06 02:15:08,[],EGNN (R/12),lucky-leaf-467(epq0akyo)
1,3dr19mv8,drawn-dawn-466,scaffold,4,2023-06-06 01:13:06,[],EGNN (R/12),drawn-dawn-466(3dr19mv8)
2,3m4ecbqg,summer-dew-465,pocket,4,2023-06-06 00:32:50,[],EGNN (R/12),summer-dew-465(3m4ecbqg)
3,2lhqwwwf,vivid-lake-464,random,3,2023-06-05 22:29:38,[],EGNN (R/12),vivid-lake-464(2lhqwwwf)
4,q9cykxhv,whole-sponge-463,scaffold,3,2023-06-05 21:17:05,[],EGNN (R/12),whole-sponge-463(q9cykxhv)


In [4]:
run_data["model"].unique()

array(['EGNN (R/12)', 'DTI', '', 'EGNN ', 'GIN'], dtype=object)

In [5]:
run_data = run_data[run_data["model"] == "EGNN (R/12)"]

In [7]:
run_data["corr"] = run_data["run_info"].apply(lambda r: r.run.summary.get("test/corr"))

In [8]:
run_data = run_data.dropna()

In [15]:
run_data = run_data[run_data["created_at"].apply(lambda d: d.month) >= 6]

In [17]:
run_data.sort_values(by="corr")

Unnamed: 0,id,name,split,split_seed,created_at,tags,model,run_info,corr
5,1imwje0p,vibrant-pyramid-462,pocket,3,2023-06-05 20:54:03,[],EGNN (R/12),vibrant-pyramid-462(1imwje0p),0.170296
2,3m4ecbqg,summer-dew-465,pocket,4,2023-06-06 00:32:50,[],EGNN (R/12),summer-dew-465(3m4ecbqg),0.281284
11,1fece45o,sparkling-terrain-456,pocket,1,2023-06-05 12:31:32,[],EGNN (R/12),sparkling-terrain-456(1fece45o),0.343874
14,1i3alfvt,morning-paper-453,pocket,0,2023-06-05 08:33:09,[],EGNN (R/12),morning-paper-453(1i3alfvt),0.364961
8,1yqui9m5,lilac-gorge-459,pocket,2,2023-06-05 16:58:44,[],EGNN (R/12),lilac-gorge-459(1yqui9m5),0.370401
1,3dr19mv8,drawn-dawn-466,scaffold,4,2023-06-06 01:13:06,[],EGNN (R/12),drawn-dawn-466(3dr19mv8),0.569598
13,24ezo8f9,ethereal-cloud-454,scaffold,0,2023-06-05 08:57:51,[],EGNN (R/12),ethereal-cloud-454(24ezo8f9),0.580239
4,q9cykxhv,whole-sponge-463,scaffold,3,2023-06-05 21:17:05,[],EGNN (R/12),whole-sponge-463(q9cykxhv),0.58899
7,1zflr1yv,resilient-music-460,scaffold,2,2023-06-05 17:54:24,[],EGNN (R/12),resilient-music-460(1zflr1yv),0.598421
10,34h17f8e,curious-frog-457,scaffold,1,2023-06-05 12:59:19,[],EGNN (R/12),curious-frog-457(34h17f8e),0.688779


In [18]:
def get_pre_artifacts(run_info):
    artifacts = {artifact.name: artifact for artifact in run_info.run.logged_artifacts() if "predictions" in artifact.name}
    return artifacts

In [19]:
run_data["artifacts"] = run_data["run_info"].apply(get_pre_artifacts)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  run_data["artifacts"] = run_data["run_info"].apply(get_pre_artifacts)


In [20]:
run_data = run_data[run_data["artifacts"] != dict()]

In [21]:
run_data["artifact_paths"] = run_data["artifacts"].apply(lambda ad: list(ad.values())[0].download())

[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [22]:
import json
from pathlib import Path

dfs = []
for _, eval_item in run_data.iterrows():
    preds_raw = json.loads(Path(eval_item["artifact_paths"]).joinpath("predictions.table.json").read_text())
    preds = pd.DataFrame(data=preds_raw["data"], columns=preds_raw["columns"])
    preds["ident"] = preds["ident"].astype(int)
    preds["model"] = eval_item.model
    preds["split"] = eval_item.split
    preds["split_seed"] = eval_item.split_seed
    dfs.append(preds)

In [23]:
all_results = pd.concat(dfs)
all_results

Unnamed: 0,pred,ident,model,split,split_seed
0,6.700248,13841248,EGNN (R/12),random,4
1,7.670762,17641064,EGNN (R/12),random,4
2,5.971170,17650892,EGNN (R/12),random,4
3,8.267131,17767038,EGNN (R/12),random,4
4,8.771928,6202640,EGNN (R/12),random,4
...,...,...,...,...,...
3982,8.256366,22419498,EGNN (R/12),pocket,0
3983,8.205363,22419500,EGNN (R/12),pocket,0
3984,8.198938,22419500,EGNN (R/12),pocket,0
3985,8.223491,22419502,EGNN (R/12),pocket,0


In [24]:
from kinodata.data.dataset import KinodataDocked
dataset = KinodataDocked()
source = dataset.df
source["ident"] = source["ident"].astype(int)

  from .autonotebook import tqdm as notebook_tqdm


Reading data frame..
Checking for missing pocket mol2 files...


100%|██████████| 2439/2439 [00:00<00:00, 10287.80it/s]


In [25]:
eval_data = pd.merge(
    all_results, source[["ident", "activities.standard_value"]], on="ident"
)
eval_data["target"] = eval_data["activities.standard_value"].astype(float)
eval_data.to_csv("~/projects/kinodata-docked-rescore/eval_data/egnnr12_eval_data.csv", index=False)