In [1]:
from wandb_utils import RunInfo
import wandb
import pandas as pd
from datetime import datetime

In [2]:
from typing import List, Optional


def make_df(
    runs: List[RunInfo],
    since: Optional[datetime] = None,
) -> pd.DataFrame:
    run_data = pd.DataFrame(
        {
            "id": [info.run.id for info in runs],
            "name": [info.run.name for info in runs],
            "split": [info.split_type for info in runs],
            "split_seed": [info.split_seed for info in runs],
            "created_at": [info.run.created_at for info in runs],
            "tags": [info.tags for info in runs],
            "model": [run.model_name for run in runs],
            "run_info": runs
        }
    )
    run_data.created_at = run_data.created_at.apply(datetime.fromisoformat)
    if since is not None:
        run_data = run_data[run_data.created_at >= since]
    return run_data

## Retrieve test preds

In [3]:
runs = RunInfo.fetch(since=datetime(2023, 6, 25))
run_data = make_df(runs)
run_data.head()

Unnamed: 0,id,name,split,split_seed,created_at,tags,model,run_info
0,1muefho3,sleek-wave-588,random,3,2023-06-27 05:43:45,[transformer],,sleek-wave-588(1muefho3)
1,38m5aduh,jumping-river-587,scaffold,4,2023-06-27 04:22:58,[transformer],,jumping-river-587(38m5aduh)
2,2vjlgatf,prime-river-586,scaffold,3,2023-06-26 23:20:54,[transformer],,prime-river-586(2vjlgatf)
3,2ehdpcvy,northern-frost-585,pocket,4,2023-06-26 22:48:44,[transformer],,northern-frost-585(2ehdpcvy)
4,kuec8wy3,decent-gorge-584,scaffold,2,2023-06-26 19:39:51,[transformer],,decent-gorge-584(kuec8wy3)


In [4]:
run_data["test/corr"] = run_data["run_info"].apply(lambda r: r.run.summary.get("test/corr", None))
run_data

Unnamed: 0,id,name,split,split_seed,created_at,tags,model,run_info,test/corr
0,1muefho3,sleek-wave-588,random,3,2023-06-27 05:43:45,[transformer],,sleek-wave-588(1muefho3),
1,38m5aduh,jumping-river-587,scaffold,4,2023-06-27 04:22:58,[transformer],,jumping-river-587(38m5aduh),0.606501
2,2vjlgatf,prime-river-586,scaffold,3,2023-06-26 23:20:54,[transformer],,prime-river-586(2vjlgatf),0.638406
3,2ehdpcvy,northern-frost-585,pocket,4,2023-06-26 22:48:44,[transformer],,northern-frost-585(2ehdpcvy),0.144944
4,kuec8wy3,decent-gorge-584,scaffold,2,2023-06-26 19:39:51,[transformer],,decent-gorge-584(kuec8wy3),0.632014
5,3fqf8ej2,summer-bird-583,pocket,3,2023-06-26 18:42:00,[transformer],,summer-bird-583(3fqf8ej2),0.3382
6,29d8ls8g,jumping-blaze-582,random,2,2023-06-26 18:32:47,[transformer],,jumping-blaze-582(29d8ls8g),0.777079
7,1yjpmg8z,pretty-sky-581,random,1,2023-06-26 18:31:37,[transformer],,pretty-sky-581(1yjpmg8z),
8,2gm5y7zq,celestial-breeze-580,pocket,2,2023-06-26 15:20:09,[transformer],,celestial-breeze-580(2gm5y7zq),0.356832
9,3v68sndg,tough-galaxy-579,scaffold,1,2023-06-26 14:56:56,[transformer],,tough-galaxy-579(3v68sndg),0.719366


In [5]:
run_data = run_data.dropna()

In [6]:
def get_pred_artifacts(run_info):
    artifacts = {artifact.name: artifact for artifact in run_info.run.logged_artifacts() if "predictions" in artifact.name}
    return artifacts

In [7]:
run_data["artifacts"] = run_data["run_info"].apply(get_pred_artifacts)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  run_data["artifacts"] = run_data["run_info"].apply(get_pred_artifacts)


In [8]:
run_data = run_data[run_data["artifacts"] != dict()]

In [9]:
run_data["artifact_paths"] = run_data["artifacts"].apply(lambda ad: list(ad.values())[0].download())

[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [10]:
import json
from pathlib import Path

dfs = []
for _, eval_item in run_data.iterrows():
    preds_raw = json.loads(Path(eval_item["artifact_paths"]).joinpath("predictions.table.json").read_text())
    preds = pd.DataFrame(data=preds_raw["data"], columns=preds_raw["columns"])
    preds["ident"] = preds["ident"].astype(int)
    preds["model"] = eval_item.model
    preds["split"] = eval_item.split
    preds["split_seed"] = eval_item.split_seed
    dfs.append(preds)

In [11]:
all_results = pd.concat(dfs)
all_results

Unnamed: 0,pred,ident,model,split,split_seed
0,7.139952,44685,,scaffold,4
1,8.714574,105738,,scaffold,4
2,6.789509,108868,,scaffold,4
3,8.640821,117213,,scaffold,4
4,8.773761,120871,,scaffold,4
...,...,...,...,...,...
4401,8.404881,23312098,,scaffold,0
4402,7.110132,23314854,,scaffold,0
4403,7.757973,23314980,,scaffold,0
4404,8.256553,23314980,,scaffold,0


In [12]:
from kinodata.data.dataset import KinodataDocked
dataset = KinodataDocked()
source = dataset.df
source["ident"] = source["ident"].astype(int)

  from .autonotebook import tqdm as notebook_tqdm


Reading data frame..
Checking for missing pocket mol2 files...


100%|██████████| 2439/2439 [00:00<00:00, 13729.32it/s]


Adding pocket sequences from cached file /Users/joschka/projects/kinodata-docked-rescore/data/raw/pocket_sequences.csv.


In [13]:
eval_data = pd.merge(
    all_results, source[["ident", "activities.standard_value"]], on="ident"
)
eval_data["model"] = "Transformer"
eval_data["target"] = eval_data["activities.standard_value"].astype(float)
eval_data.to_csv("~/projects/kinodata-docked-rescore/eval_data/transformer_eval_data.csv", index=False)