This notebook analyzes the results of the wandb sweeps. It is used to generate the plots in the paper.

## Imports


In [1]:
import wandb
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch

# wandb.init(project="6DimCachespliteinSweep", entity="st7ma784")
entity, project = "st7ma784", "6DIMCLIPTOKSweepvfinal6.0DIM"  # set to your entity and project 

api = wandb.Api()
runs = api.runs(entity + "/" + project) 

# runs = runs.filter("config.learning_rate = 0.0005")
# runs = runs.filter("config.batch_size = 10")
# runs = runs.filter("config.precision = 32")
# runs = runs.filter("config.maskLosses = 0")
# runs = runs.filter("config.embed_dim = 512")
# runs = runs.filter("config.transformer_width = 512")
# runs = runs.filter("config.transformer_heads = 16")
# runs = runs.filter("config.transformer_layers = 24")
# runs = runs.filter("config.prune = False")
# runs = runs.filter("config.meanloss = True")

# runs = runs.filter("config.logitsversion = 0")
# runs = runs.filter("config.projection = ''")

# runs = runs.filter("state = finished")
# print("found",len(runs),"runs")



summary_list, config_list, name_list = [], [], []
for run in runs: 
    # .summary contains the output keys/values for metrics like accuracy.
    #  We call ._json_dict to omit large files 
    summary=run.summary._json_dict
    #filter to just the keys of 
    key_set=set(["improbe","textprobe","train_loss","val_loss","first_logit"])
    summary={k:summary[k] for k in key_set if k in summary}
    summary_list.append(summary)

    # .config contains the hyperparameters.
    #  We remove special values that start with _.
    config_list.append(
        {k: v for k,v in run.config.items()
         if not k.startswith('_')})

    # .name is the human-readable name of the run.
    name_list.append(run.name)

runs_df = pd.DataFrame({
    "summary": summary_list,
    "config": config_list,
    "name": name_list
    })
#show the dataframe

print(runs_df.to_latex())





\begin{tabular}{llll}
\toprule
{} &                                summary &                                             config &                    name \\
\midrule
0  &        \{'train\_loss': 13932.708984375\} &  \{'JSE': 0, 'dir': '/data', 'dims': 6, 'debug':... &            solar-sea-32 \\
1  &      \{'train\_loss': 3.458843231201172\} &  \{'JSE': 0, 'dir': '/data', 'dims': 6, 'debug':... &          desert-wood-31 \\
2  &          \{'train\_loss': 8126.68359375\} &  \{'JSE': 0, 'dir': '/data', 'dims': 6, 'debug':... &           icy-vortex-30 \\
3  &      \{'train\_loss': 2.810676336288452\} &  \{'JSE': 0, 'dir': '/data', 'dims': 6, 'debug':... &          wild-sunset-29 \\
4  &     \{'train\_loss': 2.8677778244018555\} &  \{'JSE': 0, 'dir': '/data', 'dims': 6, 'debug':... &     celestial-violet-28 \\
5  &                  \{'train\_loss': 'NaN'\} &  \{'JSE': 0, 'dir': '/data', 'dims': 6, 'debug':... &       olive-capybara-27 \\
6  &             \{'train\_loss': 3141957.25\} &  \{'J

  print(runs_df.to_latex())


In [3]:


from sys import version


image_dims_list,text_dims_list= {}, {}
image_version_list,text_version_list= {}, {}

for run in runs: 
    #group runs by dims and record range of linear probe accuracy
    dims = run.config["dims"]
    version=run.config["logitsversion"]
    summary=run.summary._json_dict
    text_dims_list[dims] = text_dims_list.get("dims", []) + (summary["TProbe"])
    image_dims_list[dims] = image_dims_list.get("dims", []) + (summary["ImProbe"])
    text_version_list[version] = text_version_list.get("version", []) + (summary["TProbe"])
    image_version_list[version] = image_version_list.get("version", []) + (summary["ImProbe"])
    

#for each dims, compute mean and range of linear probe accuracy
text_dims_mean = {k: np.mean(v) for k,v in text_dims_list.items()}
text_dims_range = {k: np.max(v) - np.min(v) for k,v in text_dims_list.items()}
text_dims_std = {k: np.std(v) for k,v in text_dims_list.items()}
text_dims_count = {k: len(v) for k,v in text_dims_list.items()}

image_dims_mean = {k: np.mean(v) for k,v in image_dims_list.items()}
image_dims_range = {k: np.max(v) - np.min(v) for k,v in image_dims_list.items()}
image_dims_std = {k: np.std(v) for k,v in image_dims_list.items()}
image_dims_count = {k: len(v) for k,v in image_dims_list.items()}

text_version_mean = {k: np.mean(v) for k,v in text_version_list.items()}
text_version_range = {k: np.max(v) - np.min(v) for k,v in text_version_list.items()}
text_version_std = {k: np.std(v) for k,v in text_version_list.items()}
text_version_count = {k: len(v) for k,v in text_version_list.items()}

image_version_mean = {k: np.mean(v) for k,v in image_version_list.items()}
image_version_range = {k: np.max(v) - np.min(v) for k,v in image_version_list.items()}
image_version_std = {k: np.std(v) for k,v in image_version_list.items()}
image_version_count = {k: len(v) for k,v in image_version_list.items()}

#make the results into a dataframes: column headers are dims with each key split into image and text and the rows are mean, range, std, and count

dims_df = pd.DataFrame({
    "text_mean": text_dims_mean,
    "text_range": text_dims_range,
    "text_std": text_dims_std,
    "text_count": text_dims_count,
    "image_mean": image_dims_mean,
    "image_range": image_dims_range,
    "image_std": image_dims_std,
    "image_count": image_dims_count
    }) 
#print(dims_df.to_latex())

version_df = pd.DataFrame({
    "text_mean": text_version_mean,
    "text_range": text_version_range,
    "text_std": text_version_std,
    "text_count": text_version_count,
    "image_mean": image_version_mean,
    "image_range": image_version_range,
    "image_std": image_version_std,
    "image_count": image_version_count
    })
# print(version_df.to_latex())
version_df

KeyError: 'textprobe'