In [None]:
import pandas as pd
import wandb
import json

In [None]:
# Function which downloads the results from the W&B server
def get_results_table(run_id, table_name):
    api = wandb.Api()
    json_file = api.artifact(f'r-pad/taxpose/run-{run_id}-{table_name}:v0').get_path(f'{table_name}.table.json').download()

    # Get the config from the run
    run = api.run(f'r-pad/taxpose/{run_id}')
    config = run.config
    
    table_name = config["object_class"]["name"]
    model_name = config["model"]["name"]

    with open(json_file) as file:
        json_dict = json.load(file)
    df = pd.DataFrame(json_dict["data"], columns=json_dict["columns"])
    
    df.columns = pd.MultiIndex.from_product([[table_name], df.columns])

    # Set the row index as model_name
    df.index = [model_name]

    return df

In [None]:
# Function which takes a list of results tables, and concatenates them into a single table with a multi-index.
def concat_results(tables):
    df = pd.concat(tables, axis=1)
    return df

In [None]:
MLAT_IDS = [
    "jqyrs601",  # Stack Wine
    "h7a9oxtp",  # Put Toilet Roll on Stand
    "v075mup0",  # Place Hanger on Rack
    "fpu8sirp",  # Phone on Base
    "2rtpvbn3",  # Insert Onto Square Base
]

TAXPOSE_IDS = [
    "xxecq5xe",  # Stack Wine
    "dhkc7eva",  # Put Toilet Roll on Stand
    "cs4gc0lg",  # Place Hanger on Rack
    "tp9wuqcw",  # Phone on Base
    "mae2i315",  # Insert Onto Square Base

]

mlat_dfs_train = []
mlat_dfs_val = []
for run_id in MLAT_IDS:
    mlat_dfs_train.append(get_results_table(run_id, "train_metrics"))
    mlat_dfs_val.append(get_results_table(run_id, "val_metrics"))

mlat_train_table = concat_results(mlat_dfs_train)
mlat_val_table = concat_results(mlat_dfs_val)

taxpose_dfs_train = []
taxpose_dfs_val = []
for run_id in TAXPOSE_IDS:
    taxpose_dfs_train.append(get_results_table(run_id, "train_metrics"))
    taxpose_dfs_val.append(get_results_table(run_id, "val_metrics"))

taxpose_train_table = concat_results(taxpose_dfs_train)
taxpose_val_table = concat_results(taxpose_dfs_val)

full_train_table = pd.concat([taxpose_train_table, mlat_train_table], axis=0)
full_train_table




In [None]:
print(full_train_table.style.format('{:.3f}').to_latex())

In [None]:
for df1, df2 in zip(taxpose_dfs_train, mlat_dfs_train):
    # print(df1)
    df = pd.concat([df1, df2], axis=0).droplevel(0, axis=1)
    # print(df.style.format('{:.2f}').to_latex())
    print(df.to_markdown())

In [None]:
print(full_train_table.to_markdown())

In [None]:
|               |   stack_wine |    |   put_toilet_roll_on_stand |    |   place_hanger_on_rack |   |   phone_on_base |    |   insert_onto_square_peg |    |
|               |   angle_err |   t_err |   angle_err |   t_err |   angle_err |   t_err |   angle_err |   t_err |   angle_err |   t_err |
|:--------------|------------------------------:|--------------------------:|--------------------------------------------:|----------------------------------------:|----------------------------------------:|------------------------------------:|---------------------------------:|-----------------------------:|------------------------------------------:|--------------------------------------:|
| taxpose       |                      1.48548  |                0.00308973 |                                     1.17297 |                              0.001249   |                                5.47136  |                          0.0119683  |                         4.14353  |                   0.00543616 |                                   7.0977  |                            0.00351971 |
| mlat_s256_vnn |                      0.764146 |                0.00122502 |                                     1.14988 |                              0.00134385 |                                0.623557 |                          0.00195536 |                         0.803998 |                   0.00106143 |                                   1.20883 |                            0.00328621 |

In [None]:
|               |  stack_wine\\ angle_err (°) |   t_err (mm) |   put_toilet_roll_on_stand\\ angle_err (°) |   t_err (mm) |   place_hanger_on_rack\\ angle_err (°) |   t_err (mm) |   phone_on_base\\ angle_err (°) |   t_err (mm) |   insert_onto_square_peg\\ angle_err (°) |   t_err (mm) |
|:--------------|------------------------------:|--------------------------:|--------------------------------------------:|----------------------------------------:|----------------------------------------:|------------------------------------:|---------------------------------:|-----------------------------:|------------------------------------------:|--------------------------------------:|
| TAX-Pose       |                      1.47  |                3.09 |                                     1.17 |                              **1.25**   |                                5.47  |                          12.0  |                         4.14  |                   5.43 |                                   7.10  |                            3.52 |
| Ours (RelDist) |                      **0.76** |                **1.02** |                                     **1.15** |                              1.34 |                                **0.62** |                          **2.00** |                         **0.80** |                   **1.06** |                                   **1.21** |                            **3.29** |


In [None]:
full_val_table = pd.concat([taxpose_val_table, mlat_val_table], axis=0)
full_val_table

In [None]:
t

In [None]:
concat_results(dfs_train)

In [None]:
# Create two sample DataFrames with MultiIndex columns
data1 = {'A': [1, 2, 3],
         'B': [4, 5, 6]}
index1 = pd.MultiIndex.from_tuples([('Group1', 'X'), ('Group1', 'Y'), ('Group1', 'Z')], names=['Group', 'Subgroup'])
df1 = pd.DataFrame(data1, index=index1)

data2 = {'C': [7, 8, 9],
         'D': [10, 11, 12]}
index2 = pd.MultiIndex.from_tuples([('Group2', 'X'), ('Group2', 'Y'), ('Group2', 'Z')], names=['Group', 'Subgroup'])
df2 = pd.DataFrame(data2, index=index2)

# Concatenate the DataFrames column-wise
result = pd.concat([df1, df2], axis=1)

# Display the result
print(result)

In [None]:
MLAT_ABLATION_IDS = [
    "zswyokhc", # 1 demo
    "1hhy8jy8", # 5 demos
    "ry1ggn0r", # 10 demos
]

TAXPOSE_ABLATION_IDS = [
    "5do9r1ft", # 1 demo
    "awbr16hl", # 5 demos
    "n9likyeo", # 10 demos
]

In [None]:
taxpose_train_dfs = []
taxpose_val_dfs = []

for run_id in TAXPOSE_ABLATION_IDS:
    taxpose_train_dfs.append(get_results_table(run_id, "train_metrics"))
    taxpose_val_dfs.append(get_results_table(run_id, "val_metrics"))

mlat_train_dfs = []
mlat_val_dfs = []

for run_id in MLAT_ABLATION_IDS:
    mlat_train_dfs.append(get_results_table(run_id, "train_metrics"))
    mlat_val_dfs.append(get_results_table(run_id, "val_metrics"))


In [None]:
# Create a table with only the angle error, where the index is the number of demonstrations.
taxpose_train_table = concat_results(taxpose_train_dfs)
taxpose_val_table = concat_results(taxpose_val_dfs)

taxpose_train_table_angle = taxpose_train_table.xs("angle_err", axis=1, level=1)
taxpose_val_table_angle = taxpose_val_table.xs("angle_err", axis=1, level=1)
taxpose_train_table_t = taxpose_train_table.xs("t_err", axis=1, level=1)
taxpose_val_table_t = taxpose_val_table.xs("t_err", axis=1, level=1)

# Rename the columns to be the number of demonstrations
taxpose_train_table_angle.columns = [1, 5, 10]
taxpose_val_table_angle.columns = [1, 5, 10]
taxpose_train_table_t.columns = [1, 5, 10]
taxpose_val_table_t.columns = [1, 5, 10]

# MLAT
mlat_train_table = concat_results(mlat_train_dfs)
mlat_val_table = concat_results(mlat_val_dfs)

mlat_train_table_angle = mlat_train_table.xs("angle_err", axis=1, level=1)
mlat_val_table_angle = mlat_val_table.xs("angle_err", axis=1, level=1)
mlat_train_table_t = mlat_train_table.xs("t_err", axis=1, level=1)
mlat_val_table_t = mlat_val_table.xs("t_err", axis=1, level=1)

# Rename the columns to be the number of demonstrations
mlat_train_table_angle.columns = [1, 5, 10]
mlat_val_table_angle.columns = [1, 5, 10]
mlat_train_table_t.columns = [1, 5, 10]
mlat_val_table_t.columns = [1, 5, 10]

# Concatenate the tables
full_train_table_angle = pd.concat([taxpose_train_table_angle, mlat_train_table_angle], axis=0)
full_val_table_angle = pd.concat([taxpose_val_table_angle, mlat_val_table_angle], axis=0)
full_train_table_t = pd.concat([taxpose_train_table_t, mlat_train_table_t], axis=0)
full_val_table_t = pd.concat([taxpose_val_table_t, mlat_val_table_t], axis=0)




In [None]:
full_train_table_angle

In [None]:
print(full_train_table_angle.to_markdown())

In [None]:
full_train_table_t

In [None]:
print(full_train_table_t.to_markdown())

In [None]:
full_val_table_angle

In [None]:
full_val_table_t

In [None]:
# Make two line plots of the angle error and translation error, with the number of demonstrations on the x-axis.
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10, 5))

ax = fig.add_subplot(1, 2, 1)
full_train_table_angle.T.plot(ax=ax)
plt.title("Angle Error")
plt.ylabel("Angle Error (°)")
plt.xlabel("Number of Demonstrations")
plt.ylim(0, 5.5)

plt.xticks([1, 5, 10])

# Rename the lines on the legend
lines = ax.get_lines()
lines[0].set_label("TAX-Pose")
lines[1].set_label("Ours (RelDist)")


ax = fig.add_subplot(1, 2, 2)
full_train_table_t.T.plot(ax=ax)
plt.title("Translation Error")
plt.ylabel("Translation Error (mm)")
plt.xlabel("Number of Demonstrations")
plt.ylim(0, 0.0125)

# Only plot ticks 1, 5, 10
plt.xticks([1, 5, 10])

