In [1]:
%load_ext lab_black
%load_ext autoreload
%autoreload 2

In [2]:
# Set for local or colab

import os
from os.path import join
import sys

# Check if running in colab
IN_COLAB = "google.colab" in sys.modules

# Project defaults
if IN_COLAB:
    print("ENVIRONMENT: Colab")

    # Mount drive
    from google.colab import drive

    drive.mount("/content/drive")

    # Set the project directory
    PROJECT_FOLDER = "/content/drive/MyDrive/MIDS/w266/w266-project-carlos"

    # Install dependencies
    !pip install -q transformers datasets SentencePiece

    # Set timezone
    !rm /etc/localtime
    !ln -s /usr/share/zoneinfo/US/Pacific /etc/localtime

else:
    print("ENVIRONMENT: Local")
    # Set the project directory
    PROJECT_FOLDER = "/user/w266/w266-project-carlos"

os.chdir(PROJECT_FOLDER)

# FOLDERS
DATASET_FOLDER = join(PROJECT_FOLDER, "dataset/dataset_final")
EXPERIMENT_BASE_FOLDER = join(PROJECT_FOLDER, "experiments")
EXPERIMENT_RESULTS_FOLDER = join(PROJECT_FOLDER, "experiment_results")

print(f"Working directory is: {os.getcwd()}")

ENVIRONMENT: Local
Working directory is: /user/w266/w266-project-carlos


In [3]:
import numpy as np
import pandas as pd
from pprint import pprint

from t5_model_support_functions import (
    sperarate_substring_by_spaces,
    get_pd_row_accuracy,
)

In [4]:
def print_accuracy_report(df):
    print("*" * 100)
    print("Overall results")
    print(df[["percent_equal", "equal"]].mean())

    print("")
    print("*" * 100)
    print("Succes by hardness")
    df_fail = df.loc[df["equal"] < 1].groupby("hardness").count()[["prediction"]]
    df_succ = df.loc[df["equal"] == 1].groupby("hardness").count()[["prediction"]]

    df_fail.columns = ["fail"]
    df_succ.columns = ["succes"]

    df_rep = pd.concat([df_succ, df_fail], axis=1)
    df_rep["succes_rate"] = df_rep["succes"] / (df_rep["succes"] + df_rep["fail"])
    print(df_rep)


def print_failed_examples(df):
    df_fail = df.loc[df["equal"] < 1]

    idx_range = np.arange(20)

    for i in idx_range:
        row = df_fail.iloc[i]
        print(f"Showing row: {row.name}")
        print("source // labels // prediction")
        pprint(f"S->  {row['source']}", width=200)
        print("")
        pprint(f"L->  {row['labels']}", width=200)
        print("")
        pprint(f"P->  {row['prediction']}", width=200)

        print("*" * 100)

### Select experiment folder

In [5]:
path = join(
    EXPERIMENT_RESULTS_FOLDER, "exp_03_codet5-large_results_2023_04_02-00_55.csv"
)

# path = join(
#     EXPERIMENT_RESULTS_FOLDER, "exp_03B_codet5-large_results_2023_04_08-20_28.csv"
# )


df = pd.read_csv(path)

df.head(2)

Unnamed: 0,tvBench_id,hardness,source,labels,prediction,percent_equal,equal
0,3092@y_name@DESC,Medium,<N> Give me the comparison about Team_ID over ...,mark bar data basketball_match encoding x all_...,mark bar data basketball_match encoding x all_...,1.0,1.0
1,3092@y_name@DESC,Medium,<N> Give me the comparison about Team_ID over ...,mark bar data basketball_match encoding x all_...,mark bar data basketball_match encoding x all_...,1.0,1.0


### Print Reports

In [6]:
print_accuracy_report(df)

****************************************************************************************************
Overall results
percent_equal    0.957882
equal            0.831624
dtype: float64

****************************************************************************************************
Succes by hardness
            succes  fail  succes_rate
hardness                             
Easy          1451   147     0.908010
Extra Hard     264   172     0.605505
Hard           503   237     0.679730
Medium        1837   265     0.873930


In [7]:
print_failed_examples(df)

Showing row: 12
source // labels // prediction
("S->  <N> Visualize a bar chart for what are the names and account balances of customers with the letter 'a' in their names ? , could you sort from high to low by the y axis ? </N> <C> mark [T] "
 'data customer encoding x [X] y aggregate [AggFunction] [Y] color [Z] transform filter [F] group [G] bin [B] sort [S] topk [K] </C> <D> customer <COL> cust_name </COL> <VAL> Mary </VAL> </D>')

"L->  mark bar data customer encoding x cust_name y aggregate none acc_bal transform filter cust_name like '%a%' sort y desc"

"P->  mark bar data customer encoding x cust_name y aggregate none acc_bal transform filter cust_name like '%a' sort y desc"
****************************************************************************************************
Showing row: 13
source // labels // prediction
("S->  <N> Visualize a bar chart for what are the names and account balances of customers with the letter 'a' in their names ? , could you sort from high to low 

#### Process `!=` errors

In [8]:
df["prediction"] = df["prediction"].apply(
    sperarate_substring_by_spaces, args=("!=", True)
)

In [9]:
df["percent_equal"] = df.apply(get_pd_row_accuracy, axis=1)
df["equal"] = df["percent_equal"].apply(lambda var: 1.0 if var == 1.0 else 0.0)

print_accuracy_report(df)

****************************************************************************************************
Overall results
percent_equal    0.978454
equal            0.884537
dtype: float64

****************************************************************************************************
Succes by hardness
            succes  fail  succes_rate
hardness                             
Easy          1451   147     0.908010
Extra Hard     384    52     0.880734
Hard           609   131     0.822973
Medium        1869   233     0.889153


In [10]:
print_failed_examples(df)

Showing row: 12
source // labels // prediction
("S->  <N> Visualize a bar chart for what are the names and account balances of customers with the letter 'a' in their names ? , could you sort from high to low by the y axis ? </N> <C> mark [T] "
 'data customer encoding x [X] y aggregate [AggFunction] [Y] color [Z] transform filter [F] group [G] bin [B] sort [S] topk [K] </C> <D> customer <COL> cust_name </COL> <VAL> Mary </VAL> </D>')

"L->  mark bar data customer encoding x cust_name y aggregate none acc_bal transform filter cust_name like '%a%' sort y desc"

"P->  mark bar data customer encoding x cust_name y aggregate none acc_bal transform filter cust_name like '%a' sort y desc"
****************************************************************************************************
Showing row: 13
source // labels // prediction
("S->  <N> Visualize a bar chart for what are the names and account balances of customers with the letter 'a' in their names ? , could you sort from high to low 

#### Process unmatched `%` errors

In [11]:
import re


def fix_unmatched_percent(text):
    pattern = r"'(%\w+)(')"
    repl_pattern = r"\1%\2"

    matches = re.findall(pattern, text)
    replacement = re.sub(pattern=pattern, repl=repl_pattern, string=text)

    return replacement


df["prediction"] = df["prediction"].apply(fix_unmatched_percent)

In [12]:
df["percent_equal"] = df.apply(get_pd_row_accuracy, axis=1)
df["equal"] = df["percent_equal"].apply(lambda var: 1.0 if var == 1.0 else 0.0)

print_accuracy_report(df)

****************************************************************************************************
Overall results
percent_equal    0.978408
equal            0.883716
dtype: float64

****************************************************************************************************
Succes by hardness
            succes  fail  succes_rate
hardness                             
Easy          1451   147     0.908010
Extra Hard     384    52     0.880734
Hard           605   135     0.817568
Medium        1869   233     0.889153


In [13]:
print_failed_examples(df)

Showing row: 12
source // labels // prediction
("S->  <N> Visualize a bar chart for what are the names and account balances of customers with the letter 'a' in their names ? , could you sort from high to low by the y axis ? </N> <C> mark [T] "
 'data customer encoding x [X] y aggregate [AggFunction] [Y] color [Z] transform filter [F] group [G] bin [B] sort [S] topk [K] </C> <D> customer <COL> cust_name </COL> <VAL> Mary </VAL> </D>')

"L->  mark bar data customer encoding x cust_name y aggregate none acc_bal transform filter cust_name like '%a%' sort y desc"

"P->  mark bar data customer encoding x cust_name y aggregate none acc_bal transform filter cust_name like %a%' sort y desc"
****************************************************************************************************
Showing row: 13
source // labels // prediction
("S->  <N> Visualize a bar chart for what are the names and account balances of customers with the letter 'a' in their names ? , could you sort from high to low 