### Goals: 
1. Verify that the JGA using Trippy's script is correct. 
1. Calculate hallucination and coreference JGA 
    - Hallucination: need to get untokenized inputs and see if any slot values are not from the conversational context. (I think it should be zero, given TripPy's implementation)
    - get list of dialogue ids that need coreference resolution from MultiWOZ2.3 

### Conclusion: 
1. JGA from Trippy's script is correct. 

### Deliverables achieved: 
- Cleaned up code in here for getting coref JGA and no hallucation frequency 
    - get_coref_jga.py
    - get_no_hallucination_frequency.py
- running metric_bert_dst generates a CSV file of the predictions in the format we want. Use this file to calculate cJGA

In [65]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [66]:
import json
from metric_bert_dst import load_dataset_config, tokenize, check_slot_inform

In [None]:
prediction_fn= "results/multiwoz23_lr1e-4_gpu2/pred_res.test.4736.json"

In [67]:
with open(prediction_fn, "r") as f: 
    data = json.load(f)

dataset_config = "dataset_config/multiwoz21.json"
class_types, slots, label_maps = load_dataset_config(dataset_config)

label_maps_tmp = {}
for v in label_maps:
    label_maps_tmp[tokenize(v)] = [tokenize(nv) for nv in label_maps[v]]
label_maps = label_maps_tmp


In [68]:
turnwise_jga = [] 

In [69]:
len(data)

7372

In [162]:
def get_slot_vals(slot, pred, joint_pd_slot, swap_with_map=True): 
    guid = pred['guid']  # List: set_type, dialogue_idx, turn_idx

    key_class_label_id = 'class_label_id_%s'%slot
    key_class_prediction = 'class_prediction_%s'%slot
    key_start_pos = 'start_pos_%s'%slot
    key_start_prediction = 'start_prediction_%s'%slot
    key_end_pos = 'end_pos_%s'%slot
    key_end_prediction = 'end_prediction_%s'%slot
    key_refer_id = 'refer_id_%s'%slot
    key_refer_prediction = 'refer_prediction_%s'%slot
    key_slot_groundtruth = 'slot_groundtruth_%s'%slot
    key_slot_prediction = 'slot_prediction_%s'%slot

    turn_gt_class = pred[key_class_label_id]
    turn_pd_class = pred[key_class_prediction]
    gt_start_pos = pred[key_start_pos]
    pd_start_pos = pred[key_start_prediction]
    gt_end_pos = pred[key_end_pos]
    pd_end_pos = pred[key_end_prediction]
    gt_refer = pred[key_refer_id]
    pd_refer = pred[key_refer_prediction]
    gt_slot = pred[key_slot_groundtruth]
    pd_slot = pred[key_slot_prediction]

    if swap_with_map: 
        gt_slot = tokenize(gt_slot)
        pd_slot = tokenize(pd_slot)

    # Make sure the true turn labels are contained in the prediction json file!
    joint_gt_slot = gt_slot

    if guid[-1] == '0': # First turn, reset the slots
        joint_pd_slot = 'none'

    # If turn_pd_class or a value to be copied is "none", do not update the dialog state.
    if turn_pd_class == class_types.index('none'):
        pass
    elif turn_pd_class == class_types.index('dontcare'):
        joint_pd_slot = 'dontcare'
    elif turn_pd_class == class_types.index('copy_value'):
        joint_pd_slot = pd_slot
    elif 'true' in class_types and turn_pd_class == class_types.index('true'):
        joint_pd_slot = 'true'
    elif 'false' in class_types and turn_pd_class == class_types.index('false'):
        joint_pd_slot = 'false'
    elif 'refer' in class_types and turn_pd_class == class_types.index('refer'):
        if pd_slot[0:3] == "§§ ":
            if pd_slot[3:] != 'none':
                joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[3:], label_maps, swap_with_map)
        elif pd_slot[0:2] == "§§":
            if pd_slot[2:] != 'none':
                joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[2:], label_maps, swap_with_map)
        elif pd_slot != 'none':
            joint_pd_slot = pd_slot
    elif 'inform' in class_types and turn_pd_class == class_types.index('inform'):
        if pd_slot[0:3] == "§§ ":
            if pd_slot[3:] != 'none':
                joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[3:], label_maps, swap_with_map)
        elif pd_slot[0:2] == "§§":
            if pd_slot[2:] != 'none':
                joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[2:], label_maps, swap_with_map)
        else:
            print("ERROR: Unexpected slot value format. Aborting.")
            exit()
    else:
        print("ERROR: Unexpected class_type. Aborting.")
        exit()


    # Check the joint slot correctness.
    # If the value label is not none, then we need to have a value prediction.
    # Even if the class_type is 'none', there can still be a value label,
    # it might just not be pointable in the current turn. It might however
    # be referrable and thus predicted correctly.
    if joint_gt_slot != 'none' and joint_gt_slot != 'dontcare' and joint_gt_slot != 'true' and joint_gt_slot != 'false' and joint_gt_slot in label_maps:
        for variant in label_maps[joint_gt_slot]:
            if variant == joint_pd_slot and swap_with_map:
                joint_pd_slot = joint_gt_slot 

    return joint_gt_slot, joint_pd_slot 

In [93]:
# Get JGA line by line 
correct_ct = 0 
pred_dst = {}
gt_dst = {} 

pred_dst_set = set() 
gt_dst_set= set() 
dict_results = [] 
for pred in data: 
   # print(sorted(pred.keys()))
   # break
   guid = pred['guid']
   # reset dsts for new conversation 
   if guid[-1] == '0': 
      pred_dst = {} 
      gt_dst ={} 
      pred_dst_set = set() 
      gt_dst_set= set() 

   # guid: ['test/valid/train', 'dialog id', turn_idx]
   dialog_id, turn_idx = guid[1], guid[2]
   dict_result = {
      "id": f"{dialog_id.lower()}-{turn_idx}"
   } 

   for slot in slots: 
      joint_pd_slot = pred_dst.get(slot, None)

      joint_gt_slot, joint_pd_slot = get_slot_vals(slot, pred, joint_pd_slot)
      pred_dst[slot] = joint_pd_slot 
      gt_dst[slot] = joint_gt_slot  

      dict_result[f"{slot}_gold"]= joint_gt_slot
      dict_result[f"{slot}_pred"] = joint_pd_slot 

   dict_results.append(dict_result)

   # just need to compare these dictionaries, no need for fancy string formation 
   if pred_dst == gt_dst: 
      correct_ct += 1 
   elif dict_result['id'] in correct_list: 
      print(sorted(list(pred_dst_set)))
      print(sorted(list(gt_dst_set)))
      print(pred_dst_set - gt_dst_set)
      print(gt_dst_set - pred_dst_set)
      break 

#    break

print(correct_ct / len(data))

0.6121812262615302


In [72]:
import pandas as pd

In [73]:
# (row: ids: each slot pred, gold, )
mine_df = pd.DataFrame(dict_results)

In [74]:
# get original TripPy's predictions as similar dataframe object 
# generated from runnign `python metric_bert_dst.py multiwoz21 dataset_config/multiwoz21.json results/multiwoz23_lr1e-4_gpu2/pred_res.test.4736.json`

trippy_df = pd.read_csv("/data/home/justincho/trippy-public-master/results/multiwoz23_lr1e-4_gpu2/pred_res.test.4736.csv")

In [75]:
compare = mine_df == trippy_df

In [76]:
import math 

In [77]:
str(float("nan"))

'nan'

In [86]:
count = 0 
for idx, row in compare.iterrows(): 
    # print(idx, row )
    x = row.to_dict() 
    found= False 
    for k,v in x.items():  
        if not v: 
            # found = True 
            if mine_df.iloc[idx][k].strip() == "" and isinstance(trippy_df.iloc[idx][k], float) and math.isnan(trippy_df.iloc[idx][k]):  
                count += 1 
                continue 
            print(f"{k}, mine: {mine_df.iloc[idx][k]},  trippy: {trippy_df.iloc[idx][k]}")
        # if found: 
            # break 
    # if found: 
        # break 
count 

47

In [87]:
trippy_df.head(1)

Unnamed: 0,id,taxi-leaveAt_gold,taxi-leaveAt_pred,taxi-destination_gold,taxi-destination_pred,taxi-departure_gold,taxi-departure_pred,taxi-arriveBy_gold,taxi-arriveBy_pred,restaurant-book_people_gold,...,train-leaveAt_gold,train-leaveAt_pred,train-destination_gold,train-destination_pred,train-day_gold,train-day_pred,train-arriveBy_gold,train-arriveBy_pred,train-departure_gold,train-departure_pred
0,mul0003.json-0,none,none,none,none,none,none,none,none,none,...,none,none,none,none,none,none,none,none,none,none


In [90]:
# calculate jga on csv loaded from metric_bert_dst.py 
jgas =[] 
correct_list =[] 
for idx, row in mine_df.iterrows(): 
    jga = 1 
    for slot in slots: 
        if row[f"{slot}_gold"] != row[f"{slot}_pred"]: 
            jga = 0 
    if jga == 1: 
        correct_list.append(row['id'])
    jgas.append(jga)

import numpy as np 
np.mean(jgas)


0.6121812262615302

In [79]:
data[0]['guid']

['test', 'MUL0003.json', '0']

In [80]:
import pandas as pd

In [95]:
# Get conversations that require coreference 
with open("data/MULTIWOZ2.3/data.json", "r") as f: 
    multiwoz23 = json.load(f)

In [96]:
from tqdm import tqdm 

In [99]:
need_corefs =[] 
for dial_id, dial in tqdm(multiwoz23.items()):
    context = []

    # if dial_id in ["pmul4707.json", "pmul2245.json", "pmul4776.json",
    #                 "pmul3872.json", "pmul4859.json"]:

    #     """
    #     note: these five dialogs do not contain any annotation
    #     for user side, including span_info or dialog acts
    #     """
    #     pdb.set_trace()
    #     continue

    need_coref = False
    for turn_num in range(math.ceil(len(dial["log"]) / 2)):
        # # # turn number
        turn = {"turn_num": turn_num, "dial_id": dial_id}

        # # # user utterance
        user_utt = dial["log"][turn_num * 2]["text"]
        sys_resp = dial["log"][turn_num * 2 + 1]["text"]
        # any turn that comes after requiring coreference resolution will also need coref resolution
        need_coref = "coreference" in dial['log'][turn_num * 2] or need_coref
        turn['need_coref'] = need_coref

        if need_coref: 
            need_corefs.append(f"{dial_id.lower()}-{turn_num}")


100%|██████████| 10438/10438 [00:00<00:00, 78060.21it/s]


In [101]:
len(need_corefs)

7566

In [102]:
# calculate coreference jga on csv loaded from metric_bert_dst.py 
jgas =[] 
correct_list =[] 
for idx, row in mine_df.iterrows(): 
    if row['id'] not in need_corefs: 
        continue 
    jga = 1 
    for slot in slots: 
        if row[f"{slot}_gold"] != row[f"{slot}_pred"]: 
            jga = 0 
    if jga == 1: 
        correct_list.append(row['id'])
    jgas.append(jga)

np.mean(jgas)


0.3669985775248933

In [115]:
# calculate hallucination frequency 
# 1. get decoded conversational context 
# 2. named entity slots 

named_entity_slots = {
    "attraction--name",
    "restaurant--name",
    "hotel--name",
    "taxi--departure",
    "taxi--destination",
    "train--departure",
    "train--destination",
}

In [121]:
mine_df.keys()
named_entity_slots_for_trippy ={
    "attraction-name",
    "restaurant-name",
    "hotel-name",
    "taxi-departure",
    "taxi-destination",
    "train-departure",
    "train-destination",
}

In [118]:
# get decoded conversational context
data_version = "2.3"
dials_form = {} 
for dial_id, dial in tqdm(multiwoz23.items()):
    context = []
    for turn_num in range(math.ceil(len(dial["log"]) / 2)):
        # # # turn number
        turn = {"turn_num": turn_num, "dial_id": dial_id}

        # # # user utterance
        user_utt = dial["log"][turn_num * 2]["text"]
        sys_resp = dial["log"][turn_num * 2 + 1]["text"]
        # any turn that comes after requiring coreference resolution will also need coref resolution
        need_coref = "coreference" in dial['log'][turn_num * 2] or need_coref
        turn['need_coref'] = need_coref

        # # # dialog states, extracted based on "metadata", only in system side (turn_num * 2 + 1)
        slots_inf = []
        for domain, slot in dial["log"][turn_num * 2 + 1]["metadata"].items():
            for slot_type, slot_val in slot["book"].items():
                if data_version == "2.3":
                    slot_val = [] if slot_val == "" else [slot_val]
                if (
                    slot_val != []
                    and slot_type != "booked"
                    and slot_val[0] != "not mentioned"
                ):
                    slots_inf += [domain, slot_type, slot_val[0] + ","]

            for slot_type, slot_val in slot["semi"].items():
                # 2.3 doesn't have a list of possible values. just a single value. wrap as a list
                if data_version == "2.3":
                    slot_val = [] if slot_val == "" else [slot_val]
                if slot_val != [] and slot_val[0] != "not mentioned":
                    slots_inf += [domain, slot_type, slot_val[0] + ","]

        turn["slots_inf"] = " ".join(slots_inf)
        # turn["slots_err"] = self.create_err(slots_inf[:])
        # turn["slots_err"] = ""
        # # adding current turn to dialog history
        context.append("<user> " + user_utt)
        # # # dialog history
        turn["context"] = " ".join(context)
        # adding system response to next turn
        context.append("<system> " + sys_resp)
        dials_form[dial_id.lower() + "-" + str(turn_num)] = turn

100%|██████████| 10438/10438 [00:01<00:00, 8585.57it/s]


In [119]:
# make sure that all conversatiosn in mine_df are findable in reformatted multiwoz2.3 
for k in mine_df['id'].tolist(): 
    if k not in dials_form: 
        print(f"{k} not found.")

In [138]:
no_hall_freqs = [] 
for idx, row in mine_df.iterrows(): 
    no_hall_freq = 1
    dial_id = row['id']
    context = dials_form[dial_id]['context'].lower().replace(" '", "'")
    for slot in named_entity_slots_for_trippy: 
        # pred_slot = row[f'{slot}_pred'].replace(" ' ", "'")
        pred_slot = row[f'{slot}_pred']
        
        if pred_slot != "none" and pred_slot not in context: 
            no_hall_freq = 0 
            if pred_slot in label_maps: 
                for alternative in label_maps[pred_slot]: 
                    # print(alternative)
                    if alternative.replace(" ' ", "'") in context: 
                        no_hall_freq = 1 

            if no_hall_freq == 0 : 
                print(pred_slot)
                print(context)
    if no_hall_freq == 0 : 
        break 
    no_hall_freqs.append(no_hall_freq)

np.mean(no_hall_freqs)

frankie and benny ' s
<user> hi , what options are available in the south of cambridge for upscale dining ? <system> is there a particular cuisine you are looking for ? <user> i'm not picky , just let me know a few types of cuisine that are in the area please . <system> peking restaurant as well as the good luck food takeaway serve chinese food . taj tandoori serves indian , and frankie and benny's serves italian . there is also a mexican restaurant , chiquito . <user> frankie and benny's sounds good . what is the phone number for that restaurant ?


1.0

In [133]:
label_maps["christ college"]

["christ ' s college", 'christs college']

In [136]:
label_maps["queens' college"]

KeyError: "queens' college"

In [170]:
from transformers import (WEIGHTS_NAME, BertConfig, BertTokenizer,
                          RobertaConfig, RobertaTokenizer)

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
data[0].keys()

dict_keys(['guid', 'class_prediction_taxi-leaveAt', 'class_label_id_taxi-leaveAt', 'start_prediction_taxi-leaveAt', 'start_pos_taxi-leaveAt', 'end_prediction_taxi-leaveAt', 'end_pos_taxi-leaveAt', 'refer_prediction_taxi-leaveAt', 'refer_id_taxi-leaveAt', 'input_ids_taxi-leaveAt', 'class_prediction_taxi-destination', 'class_label_id_taxi-destination', 'start_prediction_taxi-destination', 'start_pos_taxi-destination', 'end_prediction_taxi-destination', 'end_pos_taxi-destination', 'refer_prediction_taxi-destination', 'refer_id_taxi-destination', 'input_ids_taxi-destination', 'class_prediction_taxi-departure', 'class_label_id_taxi-departure', 'start_prediction_taxi-departure', 'start_pos_taxi-departure', 'end_prediction_taxi-departure', 'end_pos_taxi-departure', 'refer_prediction_taxi-departure', 'refer_id_taxi-departure', 'input_ids_taxi-departure', 'class_prediction_taxi-arriveBy', 'class_label_id_taxi-arriveBy', 'start_prediction_taxi-arriveBy', 'start_pos_taxi-arriveBy', 'end_predictio

In [172]:
tokenizer.decode(data[0]['input_ids_restaurant-food'])
# print(data[0]['refer_id_restaurant-food'])


"[CLS] i'm looking for a place to stay. it needs to be a guest house and include free wifi. [SEP] [SEP] [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [P

In [183]:
# Get JGA line by line 
correct_ct = 0 
pred_dst = {}
gt_dst = {} 

pred_dst_set = set() 
gt_dst_set= set() 
dict_results = [] 
no_hall_freqs = [] 
count = 0 

for pred in data: 
   # print(sorted(pred.keys()))
   # break
   guid = pred['guid']
   # reset dsts for new conversation 
   if guid[-1] == '0': 
      pred_dst = {} 
      gt_dst ={} 
      pred_dst_set = set() 
      gt_dst_set= set() 

   # guid: ['test/valid/train', 'dialog id', turn_idx]
   dialog_id, turn_idx = guid[1], guid[2]
   dict_result = {
      "id": f"{dialog_id.lower()}-{turn_idx}"
   } 

   input_context = tokenizer.decode(pred['input_ids_restaurant-food'])

   no_hall_freq = 1
   context = dials_form[dict_result["id"]]['context'].lower().replace(" '", "'")
   for slot in named_entity_slots_for_trippy: 
      joint_pd_slot = pred_dst.get(slot, None)

      joint_gt_slot, joint_pd_slot = get_slot_vals(slot, pred, joint_pd_slot, swap_with_map=False)
      pred_dst[slot] = joint_pd_slot 
      gt_dst[slot] = joint_gt_slot  

      dict_result[f"{slot}_gold"]= joint_gt_slot
      dict_result[f"{slot}_pred"] = joint_pd_slot

      def preprocess(slot): 
         slot = slot.replace(" ' s", "'s")
         slot = slot.replace(" '", "'")
         slot = slot.replace("guest house", "guesthouse")
         slot = slot.replace(" ", "")
         slot = slot.replace("[SEP]", "")
         return slot 
      if joint_pd_slot != "none" \
         and joint_pd_slot not in context \
         and preprocess(joint_pd_slot) not in context \
         and preprocess(joint_pd_slot) not in preprocess(context) \
         and preprocess(joint_pd_slot) not in input_context \
         and preprocess(joint_pd_slot) not in preprocess(input_context): 
         no_hall_freq = 0 
         print(joint_pd_slot)
         print(preprocess(joint_pd_slot))
         print(context)
         print(preprocess(input_context))
         # print(context)

         count += 1 
      # elif joint_pd_slot != "none" and joint_pd_slot in context: 
      #    print(joint_pd_slot)
      #    print(context)
      #    no_hall_freq = 0 
      
      if joint_pd_slot in label_maps: 
         for alternative in label_maps[joint_pd_slot]: 
            # print(alternative)
            if alternative in context: 
               no_hall_freq = 1 

      # if no_hall_freq == 0 : 
      #    print(joint_pd_slot)
      #    print(context)
   no_hall_freqs.append(no_hall_freq)    

   # if no_hall_freq == 0 : 
   #    break 

# 100% is expected
np.mean(no_hall_freqs), count 

(1.0, 0)

In [164]:
"museum of archaelogy and anthropology" == "museum of archaeology and anthropology"

False

In [1]:
from get_conditional_jga import load_pred_as_df
from get_coref_jga import format_data_as_df

In [2]:
import os

In [3]:
dir = "results/multiwoz23_lr1e-4_gpu2"

orig_fn = os.path.join(dir, "pred_res.test.final.json")

tp_fn = os.path.join(dir, f"pred_res.test.finalTP.json")
sd_fn = os.path.join(dir, f"pred_res.test.finalSD.json")
nei_fn = os.path.join(dir, f"pred_res.test.finalNEI.json")

orig_df = load_pred_as_df(orig_fn)
tp_df = load_pred_as_df(tp_fn)
sd_df = load_pred_as_df(sd_fn)
nei_df = load_pred_as_df(nei_fn)

In [4]:
orig_ids = orig_df['id'].tolist() 
sd_ids = sd_df['id'].tolist() 
tp_ids = tp_df['id'].tolist() 
nei_ids = nei_df['id'].tolist()

In [11]:
set(orig_ids) == set(sd_ids), tp_ids == orig_ids, nei_ids == orig_ids

(False, False, True)

In [6]:
tp_ids  == sd_ids

True

In [10]:
sorted(tp_ids[:5]), sorted(orig_ids[:5])

(['pmul4648-0', 'sng0073-0', 'sng0073-1', 'sng0073-2', 'sng0073-3'],
 ['mul0003.json-0',
  'mul0003.json-1',
  'mul0003.json-2',
  'mul0003.json-3',
  'mul0003.json-4'])

In [13]:
len(orig_ids)

7372

In [12]:
len(tp_ids)

7372

In [8]:
set(orig_ids).intersection(set(sd_ids))

set()