In [None]:
# Sample feather file
!gdown 1JbWFWAMonwpQjVleR5UIxFpQqmtdPTjk

In [None]:
!pip install -q wandb

## Usage

1. Select any of the model, task or metric as (str or list). If any of these is None, all will selected. (Default: None)
2. Can also avoid writing complete task names. task="math" will plot all tasks containing keyword "math".
```
lmp.lineplot(x="step", model=["1.3B_deduped", "1.3B"], task="math", metric="acc", compare=True,
                project="my-project", name="my-run") # wandb args
```
3. Set, compare=True (default: False) if you want to compare models across different tasks and metrics.    
Otherwise, seperate model plots will be logged.
4. Get filtered dataframe for any model, task and metric.
    ```
  lmp.filter_df(model=["19M", "19M_dedup"], , task="math", metric="acc", save_csv="19M.csv")
    ```
5. Get raw dataframe
  ```
  lmp.get_df()
  ```

In [None]:
#@title Run Code

import glob
import os
import re
import json
import sys
import pandas as pd
import wandb


def read_feather(file_path):
    """
    Read a feather file into a data frame.

    Args:
        file_path: path to the feather file.

    Returns:
        lmplot object.
    """

    df = pd.read_feather(file_path)
    return lmplot(df)


def collect(pathname, save_feather=None):
    """
    Combine multiple eval_results.json files into a single data frame.

    Args:
        pathname: string
                Pathname to search for eval_results.json files.
        save_feather: string
                path to save the data frame as a feather file.

    Returns:
        lmplot object.
    """

    rgx_file_name = re.compile("^(?:(.*)_)?eval_results_([0-9-]+).json$")

    dict_list = []
    for file_path in glob.glob(pathname):
        file_name = os.path.basename(file_path)
        m = rgx_file_name.match(file_name)
        if m is None:
            print("WARNING: cannot parse results file name '{}'".format(file_name))
            continue

        header = dict()

        run_id = m[1]
        header["path"] = file_name
        header["timestamp"] = m[2]

        # Parse the file name and add
        rgx_meta = re.compile("^(.*)-global_step(\d+)$")
        rgx_m = rgx_meta.match(run_id)
        if rgx_m is not None:
            metadata_run_id = {"model": rgx_m[1], "step": int(rgx_m[2])}
        else:
            metadata_run_id = None

        if metadata_run_id is None:
            continue

        # Read the json file into a data frame
        with open(file_path) as f:
            try:
                eval_json = json.load(f)
                result_json = eval_json["results"]
            except:
                print(
                    "WARNING: cannot load file '{}'".format(file_path), file=sys.stderr
                )
                continue

        model = os.path.basename(eval_json["config"]["model_args"]["load"])
        metadata_config = {"model": model}

        for task in result_json.keys():
            for metric in result_json[task]:
                record = header.copy()
                record["task"] = task
                record["metric"] = metric
                record["value"] = result_json[task][metric]
                record.update(metadata_run_id)
                record.update(metadata_config)

                dict_list.append(record)

    df = pd.json_normalize(dict_list)
    columns = [("model", str), ("step", int)]
    for name, dtype in columns:
        if name in df:
            df[name] = df[name].astype(dtype)

    if save_feather:
        df.to_feather(save_feather)

    return lmplot(df)


class lmplot:
    def __init__(self, df):
        self.df = df

    def get_df(self):
        return self.df

    def all_models(self):
        """
        Return a list of all models in the data frame.
        """

        return self.df["model"].unique().tolist()

    def filter_df(self, model=None, task=None, metric=None, to_csv=None):
        """
        Filter the data frame for the specified model, task, metric.

        Args:
            model: string or list of strings or None
                If specified, only return info for the specified model(s).
            task: string or list of strings or None
                If specified, only return info for the specified task(s).
            metric: string or list of strings or None
                If specified, only return info for the specified metric(s).

        Returns:
            Dataframe with columns model, task, metric.
        """

        df = self.df

        filters = {}
        filters["model"] = model if isinstance(model, list) else [model]
        filters["task"] = task if isinstance(task, list) else [task]
        filters["metric"] = metric if isinstance(metric, list) else [metric]

        for colname, colvalues in filters.items():

            if task is None and colname == "task":
                # If task is not specified, we want to plot all tasks
                colvalues = df["task"].unique().tolist()

            if metric is None and colname == "metric":
                # If metric is not specified, we want to plot all metrics
                colvalues = df["metric"].unique().tolist()

            invalid_colvalues = [
                c for c in colvalues if c not in df[colname].unique().tolist()
            ]

            if invalid_colvalues:

                if colname == "task":
                    # Filter out tasks that are substrings of other tasks. Example: task="math" will get all tasks containing "math".

                    for invalid_task in invalid_colvalues:
                        invalid_flag = True

                        for task in df["task"].unique().tolist():
                            if invalid_task in task:
                                colvalues.append(task)
                                invalid_flag = False

                        if invalid_flag:
                            raise ValueError(
                                f"{colname} names {invalid_task} not found. Available are {df[colname].unique().tolist()}"
                            )

                else:
                    raise ValueError(
                        f"{colname} names {invalid_colvalues} not found. Available are {df[colname].unique().tolist()}"
                    )

            df = df[df[colname].isin(colvalues)]

        df["step"] = df["step"].astype(int)
        df = df.reset_index(drop=True)

        if to_csv:
            df.to_csv(to_csv, index=False)

        return df

    def _lineplot_tasks(
        self,
        df,
        x="step",
        model=None,
        task=None,
        metric=None,
        hue="model",
        compare=False,
    ):

        """
        Filter the dataframe and plot the lineplot for each.
        """

        for task in df["task"].unique().tolist():
            task = str(task)
            task_df = df[df["task"] == task]
            task_metrics = task_df["metric"].unique().tolist()

            for metric in task_metrics:
                metric = str(metric)

                metric_df = task_df[task_df["metric"] == metric]

                metric_df = metric_df.sort_values(by=x)

                table = wandb.Table(dataframe=metric_df)
                fields = {
                    "x-axis": f"{x}",
                    "y-axis": "value",
                    "color": f"{hue}",
                    "metric": f"{metric}",
                    "title": f"{task} ({metric})",
                }

                custom_chart = wandb.plot_table(
                    vega_spec_name="satpalsr/multiplot", data_table=table, fields=fields
                )

                if compare:

                    if len(task_metrics) == 1:
                        wandb.log({f"{task} {metric}": custom_chart})
                    else:
                        wandb.log({f"{task}/{metric}": custom_chart})  # Add task panel

                else:  # Add model panel

                    wandb.log({f"{model}/{task} ({metric})": custom_chart})

    def lineplot(
        self,
        x="step",
        model=None,
        task=None,
        metric=None,
        hue="model",
        compare=False,
        **kwargs,
    ):

        """
        Draw lineplot for each model, task and metric combination.

        Args:
            x (str): x-axis column name.
            model (str or list or None): List of model names.
            task (str or list or None): List of task names.
            metric (str or list or None): List of metric names.
            hue (str): Column name for hue
            compare (bool): If True, Models are compared in each plot. Plots are saved in task folders.
                            Else, Models are not compared. Plots are saved in model folders.

        Returns:
            None
        """

        if model is None:
            model = self.all_models()

        df = self.filter_df(model, task, metric)

        if df.empty:
            print("Model, task, metric combination not found.")
            return None

        run = wandb.init(**kwargs)

        if compare:
            self._lineplot_tasks(df, x, model, task, metric, hue, compare)

            run.finish()

        else:

            for model in df["model"].unique().tolist():
                model_df = df[df["model"] == model]

                self._lineplot_tasks(model_df, x, model, task, metric, hue, compare)

            run.finish()

In [None]:
lmp = read_feather('data.feather')
print(lmp.all_models())

# or
# lmp = collect("pythia/results/json/*/*.json", save_feather='data.feather')

Run visuals change based on compare value and whether single or multiple metrics are provided. Run below 3 cells to see.

In [None]:
lmp.lineplot(x='step',model=["1.3B_deduped", "1.3B"], task='math', metric='acc', compare=True, project="pythia", name="my-run-sm") # Single Metric

In [None]:
lmp.lineplot(x='step',model=["1.3B_deduped", "1.3B"], task='math', compare=True, project="pythia", name="my-run-am") # All Metrics

In [None]:
lmp.lineplot(x='step',model=["1.3B_deduped", "1.3B"], task='math', project="pythia", name="my-run-cf") # compare = False

In [None]:
# Keep any of the model, task or metric as None. (Default: None)
lmp.lineplot(x='step',model=["1.3B_deduped", "1.3B"], compare=True, project="pythia", name="my-run-models")