### **Toward Consistent, Verifiable, and Coherent Commonsense Reasoning in Large LMs**

This notebook provides source code for our two papers in Findings of EMNLP 2021:


1.  Shane Storks, Qiaozi Gao, Yichi Zhang, and Joyce Y. Chai (2021). *Tiered Reasoning for Intuitive Physics: Toward Verifiable Commonsense Language Understanding.* Findings of EMNLP 2021.
2.   Shane Storks and Joyce Y. Chai (2021). *Beyond the Tip of the Iceberg: Assessing Coherence of Text Classifiers.* Findings of EMNLP 2021.

*If you have any questions or problems, please open an issue on our [GitHub repo](https://github.com/sled-group/Verifiable-Coherent-NLU) or email Shane Storks.*

***First, configure the execution mode by selecting a few settings (expand cell if needed):***




   0. (Colab only) Insert the path in your Google Drive to the folder where this notebook is located.

In [1]:
DRIVE_PATH = '.'

1.   Model type (choose from BERT large, RoBERTa large, RoBERTa large + MNLI, DeBERTa base, and DeBERTa large).






In [2]:
# mode = 'bert' # BERT large
# mode = 'roberta' # RoBERTa large
# mode = 'roberta_mnli' # RoBERTa large pre-trained on MNLI
# mode = 'deberta' # DeBERTa base for training on TRIP
# mode = 'deberta_large' # DeBERTa large for training on CE and ART

mode = 'roberta'

2.   Name of the task we want to train or evaluate on. Set `debug` to `True` to run quick training/evaluation jobs on only a small amount of data.

In [3]:
task_name = 'trip'
# task_name = 'ce'
# task_name = 'art'

debug = False

3.   (If training models) Training batch size, learning rate, and maximum number of epochs. Settings for results in the paper are provided as examples.

In [4]:
config_batch_size = 1
config_lr = 1e-5 # Selected learning rate for best RoBERTa-based model in TRIP paper
config_epochs = 10

4.   (For training TRIP models only) Configure the loss weighting scheme for training models here. We provide the 4 modes from the paper as examples.


In [5]:
# Loss weights for (attributes, preconditions, effects, conflicts, story choices)
if task_name != 'trip':
  print("We do not need a loss weighting scheme for %s dataset. Ignoring this cell." % task_name)
# loss_weights = [0.0, 0.4, 0.4, 0.1, 0.1] # "All losses"
loss_weights = [0.0, 0.4, 0.4, 0.2, 0.0] # "Omit story choice loss"
# loss_weights = [0.0, 0.4, 0.4, 0.0, 0.2] # "Omit conflict detection loss"
# loss_weights = [0.0, 0.0, 0.0, 0.5, 0.5] # "Omit state classification losses"

   5. (If evaluating models) Provide the name of the pre-trained model directory here. This should be the name of a directory within the *saved_models* directory, which should be located where this notebook is. Names of provided pre-trained model directories are listed.

In [6]:
# TRIP, all losses
# eval_model_dir = 'bert-large-uncased_cloze_1_5e-06_4_0.0-0.4-0.4-0.1-0.1_tiered_pipeline_ablate_attributes_states-logits'
# eval_model_dir = 'roberta-large_cloze_1_1e-05_7_0.0-0.4-0.4-0.1-0.1_tiered_pipeline_ablate_attributes_states-logits'
# eval_model_dir = 'microsoft-deberta-base_cloze_1_5e-06_5_0.0-0.4-0.4-0.1-0.1_tiered_pipeline_ablate_attributes_states-logits'

# TRIP, no story classification loss
# eval_model_dir = 'bert-large-uncased_cloze_1_5e-05_8_0.0-0.4-0.4-0.2-0.0_tiered_pipeline_ablate_attributes_states-logits'
# eval_model_dir = 'roberta-large_cloze_1_1e-05_5_0.0-0.4-0.4-0.2-0.0_tiered_pipeline_lc_ablate_attributes_states-logits' # Best model trained in the TRIP paper
# eval_model_dir = 'microsoft-deberta-base_cloze_1_5e-05_5_0.0-0.4-0.4-0.2-0.0_tiered_pipeline_ablate_attributes_states-logits'

# eval_model_dir = 'google-electra-large-discriminator_cloze_1_1e-05_6_0.0-0.4-0.4-0.2-0.0_tiered_pipeline_lc_ablate_attributes_states-logits'


# TRIP, no conflict detection loss
# eval_model_dir = 'bert-large-uncased_cloze_1_1e-06_1_0.0-0.4-0.4-0.0-0.2_tiered_pipeline_ablate_attributes_states-logits'
# eval_model_dir = 'roberta-large_cloze_1_5e-06_8_0.0-0.4-0.4-0.0-0.2_tiered_pipeline_ablate_attributes_states-logits'
# eval_model_dir = 'microsoft-deberta-base_cloze_1_1e-06_3_0.0-0.4-0.4-0.0-0.2_tiered_pipeline_ablate_attributes_states-logits'

# TRIP, no physical state classification loss
# eval_model_dir = 'bert-large-uncased_cloze_1_1e-05_3_0.0-0.0-0.0-0.5-0.5_tiered_pipeline_ablate_attributes_states-logits'
# eval_model_dir = 'roberta-large_cloze_1_1e-06_7_0.0-0.0-0.0-0.5-0.5_tiered_pipeline_ablate_attributes_states-logits'
# eval_model_dir = 'microsoft-deberta-base_cloze_1_5e-06_9_0.0-0.0-0.0-0.5-0.5_tiered_pipeline_ablate_attributes_states-logits'

# CE
# eval_model_dir = 'bert-large-uncased_ConvEnt_32_7.5e-06_7_xval'
# eval_model_dir = 'roberta-large_ConvEnt_32_7.5e-06_9_xval'
# eval_model_dir = 'roberta-large-mnli_ConvEnt_32_7.5e-06_7_xval'
# eval_model_dir = 'microsoft-deberta-large_ConvEnt_16_1e-05_9_xval'

# ART
# eval_model_dir = 'bert-large-uncased_art_64_5e-06_8'
# eval_model_dir = 'roberta-large_art_64_2.5e-06_4'
# eval_model_dir = 'DeBERTa-deberta-large_art_32_1e-06_8'

**For more configuration options, scroll down to the Train Models > Configure Hyperparameters cell for the task you're working on.**

# Setup
Run this block every time when starting up the notebook. It will get Colab ready, preprocess the data, and load model packages and classes we'll need later. May take several minutes to run for the first time.

**If you get a `ModuleNotFoundError` for the `www` code base, try the following:**


1.   Ensure the DRIVE_PATH is set properly above.
2.   (Colab only) Verify that this notebook has access to your Google Drive (click the folder icon on the left and then the Google Drive icon).
2.   Try to restart the runtime and refresh your browser window.
2.   (Colab only) If the problem persists, revoke access to Google Drive and re-enable it.





## Colab Setup

Enable auto reloading of code libraries from Google Drive, set up connection to Google Drive, and import some packages. 🔌

In [7]:
%load_ext autoreload
%autoreload 2

In [8]:
import os
import json
import sys
import torch
import random
import numpy as np
import spacy
!pip install jsonlines

sys.path.append(DRIVE_PATH)

You should consider upgrading via the '/home/panqp/595/project/env/bin/python3 -m pip install --upgrade pip' command.[0m[33m
[0m

## Model Setup

Next, we'll load up the transformer model, tokenizer, etc. ⏳

### Install HuggingFace transformers and other dependencies

In [9]:
!pip install 'transformers==4.2.2'
!pip install sentencepiece
!pip3 install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2
!pip install deberta

You should consider upgrading via the '/home/panqp/595/project/env/bin/python3 -m pip install --upgrade pip' command.[0m[33m
You should consider upgrading via the '/home/panqp/595/project/env/bin/python3 -m pip install --upgrade pip' command.[0m[33m
You should consider upgrading via the '/home/panqp/595/project/env/bin/python3 -m pip install --upgrade pip' command.[0m[33m
You should consider upgrading via the '/home/panqp/595/project/env/bin/python3 -m pip install --upgrade pip' command.[0m[33m
[0m

### Get Model Components

Specify which model parameters from transformers we want to use:

In [10]:
if task_name in ['trip', 'ce']:
  multiple_choice = False
elif task_name == 'art':
  multiple_choice = True
else:
  raise ValueError("Task name should be set to 'trip', 'ce', or 'art' in the first cell of the notebook!")

if mode == 'bert':
  model_name = 'bert-large-uncased'
elif mode == 'roberta':
  model_name = 'roberta-large'
elif mode == 'roberta_mnli':
  model_name = 'roberta-large-mnli'
elif mode == 'deberta':
  model_name = 'microsoft/deberta-base'
elif mode == 'deberta_large':
  model_name = 'microsoft/deberta-large'
elif mode == 'electra':
  model_name = 'google/electra-large-discriminator'

Load the tokenizer:

In [11]:
from transformers import BertTokenizer, RobertaTokenizer, DebertaTokenizer, AlbertTokenizer, T5Tokenizer, GPT2Tokenizer
from transformers import AutoTokenizer
from DeBERTa import deberta

if mode in ['bert']:
  tokenizer_class = BertTokenizer
elif mode in ['roberta', 'roberta_mnli']:
  tokenizer_class = RobertaTokenizer
elif mode in ['deberta', 'deberta_large']:
  tokenizer_class = DebertaTokenizer


if mode not in ['electra']:
  tokenizer = tokenizer_class.from_pretrained(model_name, 
                                                do_lower_case = False, 
                                                cache_dir=os.path.join(DRIVE_PATH, 'cache'))
else:
  tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case = False, 
                                                cache_dir=os.path.join(DRIVE_PATH, 'cache'))

Load the model and optimizer:



In [12]:
from transformers import BertForSequenceClassification, RobertaForSequenceClassification, DebertaForSequenceClassification, AlbertForSequenceClassification, AdamW
from transformers import BertForMultipleChoice, RobertaForMultipleChoice, AlbertForMultipleChoice, DebertaModel
from transformers import BertModel, RobertaModel, AlbertModel, DebertaModel, T5Model, T5EncoderModel, GPT2Model
from transformers import RobertaForMaskedLM
from transformers import BertConfig, RobertaConfig, DebertaConfig, AlbertConfig, T5Config, GPT2Config
from transformers import ElectraForSequenceClassification, ElectraConfig, ElectraModel
from www.model.transformers_ext import DebertaForMultipleChoice
from torch.optim import Adam
if not multiple_choice:
  if mode == 'bert':
    model_class = BertForSequenceClassification
    config_class = BertConfig
    emb_class = BertModel
  elif mode in ['roberta', 'roberta_mnli']:
    model_class = RobertaForSequenceClassification
    config_class = RobertaConfig
    emb_class = RobertaModel
    lm_class = RobertaForMaskedLM
  elif mode in ['deberta', 'deberta_large']:
    model_class = DebertaForSequenceClassification
    config_class = DebertaConfig
    emb_class = DebertaModel
  elif mode in ['electra']:
    model_class = ElectraForSequenceClassification
    config_class = ElectraConfig
    emb_class = ElectraModel
else:
  if mode == 'bert':
    model_class = BertForMultipleChoice
    config_class = BertConfig
    emb_class = BertModel    
  elif mode in ['roberta', 'roberta_mnli']:
    model_class = RobertaForMultipleChoice
    config_class = RobertaConfig
    emb_class = RobertaModel
    lm_class = RobertaForMaskedLM
  elif mode in ['deberta', 'deberta_large']:
    model_class = DebertaForMultipleChoice
    config_class = DebertaConfig
    emb_class = DebertaModel

## Data Setup

Preprocess the dataset.

### Preprocessing

Construct the dataset from the .txt files collected from AMT. Save a backup copy in Drive.

In [13]:
from www.utils import print_dict

partitions = ['train', 'dev', 'test']
subtasks = ['cloze', 'order']

# We can split the data into multiple json files later
data_file = os.path.join(DRIVE_PATH, 'all_data/www.json')
with open(data_file, 'r') as f:
  dataset = json.load(f)

print('Preprocessed examples:')
for ex_idx in [0,1,5,10]:
  ex = dataset['dev'][list(dataset['dev'].keys())[ex_idx]]
  print_dict(ex)

Preprocessed examples:
{
  story_id: 
    13,
  worker_id: 
    A32W24TWSWXW,
  type: 
    None,
  idx: 
    None,
  aug: 
    False,
  actor: 
    John,
  location: 
    kitchen,
  objects: 
    cabinet, counter, knife, pan, potato, pizza,
  sentences: 
    [
      John was getting the snacks ready for the party.
      John opened the cabinet, took out a pan and put it on the counter.
      John opened the fridge and got out the pizza.
      John put the pizza on the pan and put them into the oven.
      John took a knife and cut the hot pizza in eight slices.
    ],
  length: 
    5,
  example_id: 
    13,
  plausible: 
    True,
  breakpoint: 
    -1,
  confl_sents: 
    [],
  confl_pairs: 
    [],
  states: 
    [
      {'h_location': [['John', 0]], 'conscious': [['John', 2]], 'wearing': [['John', 0]], 'h_wet': [['John', 0]], 'hygiene': [['John', 0]], 'location': [['snacks', 0], ['party', 0]], 'exist': [['snacks', 4], ['party', 2]], 'clean': [['snacks', 0], ['party', 0]], 'power': 

### Data Filtering and Sampling
Since there is a big imbalance between plausible/implausible class labels, we will upsample the plausible stories.

For now, we will also break the dataset into two sub-datasets: cloze and ordering.



In [14]:
cloze_dataset = {p: [] for p in dataset}
order_dataset = {p: [] for p in dataset}

for p in dataset:
  for exid in dataset[p]:
    ex = dataset[p][exid]

    if ex['type'] == None:
      continue
    
    ex_plaus = dataset[p][str(ex['story_id'])]

    if ex['type'] == 'cloze':
      cloze_dataset[p].append(ex)
      cloze_dataset[p].append(ex_plaus) # For every implausible story, add a copy of its corresponding plausible story

    # Exclude augmented ordering examples from dev and test, since the breakpoints aren't always accurate in those
    elif ex['type'] == 'order' and not (p != 'train' and ex['aug']): 
      order_dataset[p].append(ex)
      order_dataset[p].append(ex_plaus)



### Convert TRIP to Two-Story Classification Task

Ready the TRIP dataset for two-story classification.

In [15]:
# old load

# from www.utils import print_dict
# import json
# from collections import Counter

# data_file = os.path.join(DRIVE_PATH, 'all_data/www_2s_new.json')
# with open(data_file, 'r') as f:
#   cloze_dataset_2s, order_dataset_2s = json.load(f)  

# for p in cloze_dataset_2s:
#   label_dist = Counter([ex['label'] for ex in cloze_dataset_2s[p]])
#   print('Cloze label distribution (%s):' % p)
#   print(label_dist.most_common())
# print_dict(cloze_dataset_2s['train'][0])

In [16]:
# new load


from www.utils import print_dict
import json
from collections import Counter

data_file = os.path.join(DRIVE_PATH, 'all_data/www_2s_new.json')
data_cloze = os.path.join(DRIVE_PATH, 'aug_cloze.json')
data_order = os.path.join(DRIVE_PATH, 'aug_order.json')
with open(data_file, 'r') as f:
  cloze_dataset_2s_origin, order_dataset_2s_origin = json.load(f)  

with open(data_cloze, 'r') as f:
  cloze_dataset_2s = json.load(f)

with open(data_order, 'r') as f:
  order_dataset_2s = json.load(f)


cloze_dataset_2s['dev'] = cloze_dataset_2s_origin['dev']
cloze_dataset_2s['test'] = cloze_dataset_2s_origin['test']
order_dataset_2s['dev'] = order_dataset_2s_origin['dev']
order_dataset_2s['test'] = order_dataset_2s_origin['test']

for p in cloze_dataset_2s:
  label_dist = Counter([ex['label'] for ex in cloze_dataset_2s[p]])
  print('Cloze label distribution (%s):' % p)
  print(label_dist.most_common())
print_dict(cloze_dataset_2s['train'][0])

Cloze label distribution (train):
[(0, 2043), (1, 2032)]
Cloze label distribution (dev):
[(0, 161), (1, 161)]
Cloze label distribution (test):
[(1, 176), (0, 175)]
{
  example_id: 
    0-C0,
  stories: 
    [
      {'story_id': 0, 'worker_id': 'A1F01FVEPYCPHO', 'type': 'cloze', 'idx': 0, 'aug': False, 'actor': 'Tom', 'location': 'kitchen', 'objects': 'dustbin, microwave, pan, plate, cereal, soup', 'sentences': ['Tom bought a new dustbin for the kitchen.', 'Tom threw a broken plate in the dustbin.', 'Tom got some soup from the fridge.', 'Tom put the soup in the microwave.', 'Tom ate the cold soup.'], 'length': 5, 'example_id': '0-C0', 'plausible': False, 'breakpoint': 4, 'confl_sents': [3], 'confl_pairs': [[3, 4]], 'states': [{'h_location': [['Tom', 0]], 'conscious': [['Tom', 2]], 'wearing': [['Tom', 0]], 'h_wet': [['Tom', 0]], 'hygiene': [['Tom', 0]], 'location': [['dustbin', 6]], 'exist': [['dustbin', 4]], 'clean': [['dustbin', 0]], 'power': [['dustbin', 0]], 'functional': [['dustbin'

---

# TRIP Results

Contains code for the tiered and random TRIP baselines.

In [17]:
if task_name != 'trip':
  raise ValueError('Please configure task_name in first cell to "trip" to run TRIP results!')

## Random Tiered Classifier for TRIP

For the random baseline, we average the results of 10 runs. Running the below will report (mean, variance) for each evaluation partition.

In [18]:
from www.dataset.prepro import get_tiered_data
from www.dataset.featurize import add_bert_features_tiered, get_tensor_dataset_tiered
from collections import Counter
import numpy as np
from www.dataset.ann import att_to_num_classes, idx_to_att
from sklearn.metrics import accuracy_score, f1_score
from www.utils import print_dict

tiered_dataset = cloze_dataset_2s

seq_length = 16 # Max sequence length to pad to

tiered_dataset = get_tiered_data(tiered_dataset)
tiered_dataset = add_bert_features_tiered(tiered_dataset, tokenizer, seq_length, add_segment_ids=True)



In [19]:
from www.dataset.prepro import get_tiered_data, balance_labels
from www.dataset.featurize import add_bert_features_tiered, get_tensor_dataset_tiered
from collections import Counter
import numpy as np
from www.dataset.ann import att_to_num_classes, idx_to_att, att_default_values
from sklearn.metrics import accuracy_score, f1_score
from www.utils import print_dict
import numpy as np

# Have to add BERT input IDs and tensorize again
num_runs = 10
stories = []
pred_stories = []
conflicts = []
pred_conflicts = []
preconditions = []
pred_preconditions = []
effects = []
pred_effects = []
verifiability = []
consistency = []
for p in tiered_dataset:
  if p == 'train':
    continue
  metr_avg = {}
  print('starting %s...' % p)
  for r in range(num_runs):
    print('starting run %s...' % str(r))
    for ex in tiered_dataset[p]:
      verifiable = True
      consistent = True

      stories.append(ex['label'])
      pred_stories.append(np.random.randint(2))

      if stories[-1] != pred_stories[-1]:
        verifiable = False

      labels_ex_p = []
      preds_ex_p = []

      labels_ex_e = []
      preds_ex_e = []

      labels_ex_c = []
      preds_ex_c = []

      for si, story in enumerate(ex['stories']):
        labels_story_p = []
        preds_story_p = []

        labels_story_e = []
        preds_story_e = []      

        for ent_ann in story['entities']:
          entity = ent_ann['entity']

          if si == 1 - ex['label']:
            labels_ex_c.append(ent_ann['conflict_span_onehot'])
            pred = np.zeros(ent_ann['conflict_span_onehot'].shape)
            for cs in np.random.choice(len(pred), size=2, replace=False):
              pred[cs] = 1
            preds_ex_c.append(pred)

          labels_ent = []
          preds_ent = []
          for s, sent_ann in enumerate(ent_ann['preconditions']):
            if s < len(story['sentences']):
              if entity in story['sentences'][s]:

                labels_ent.append(sent_ann)
                sent_ann_pred = []
                for i, l in enumerate(sent_ann):
                  pl = np.random.randint(att_to_num_classes[idx_to_att[i]])
                  if pl > 0 and pl != att_default_values[idx_to_att[i]]:
                    if pl != l:
                      verifiable = False
                  sent_ann_pred.append(pl)
                preds_ent.append(sent_ann_pred)

          labels_story_p.append(labels_ent)
          preds_story_p.append(preds_ent)

          labels_ent = []
          preds_ent = []
          for s, sent_ann in enumerate(ent_ann['effects']):
            if s < len(story['sentences']):
              if entity in story['sentences'][s]:
    
                labels_ent.append(sent_ann)
                sent_ann_pred = []
                for i, l in enumerate(sent_ann):
                  pl = np.random.randint(att_to_num_classes[idx_to_att[i]])
                  if pl > 0 and pl != att_default_values[idx_to_att[i]]:
                    if pl != l:
                      verifiable = False
                  sent_ann_pred.append(pl)
                preds_ent.append(sent_ann_pred)

          labels_story_e.append(labels_ent)
          preds_story_e.append(preds_ent)

        labels_ex_p.append(labels_story_p)
        preds_ex_p.append(preds_story_p)

        labels_ex_e.append(labels_story_e)
        preds_ex_e.append(preds_story_e)

      conflicts.append(labels_ex_c)
      pred_conflicts.append(preds_ex_c)

      preconditions.append(labels_ex_p)
      pred_preconditions.append(preds_ex_p)

      effects.append(labels_ex_e)
      pred_effects.append(preds_ex_e)

      p_confl = np.nonzero(np.sum(np.array(preds_ex_c), axis=0))[0]
      l_confl = np.nonzero(np.sum(np.array(labels_ex_c), axis=0))[0]
      assert len(l_confl) == 2, str(labels_ex_c)
      if not (p_confl[0] == l_confl[0] and p_confl[1] == l_confl[1]):
        verifiable = False    
        consistent = False

      verifiability.append(1 if verifiable else 0)
      consistency.append(1 if consistent else 0)

    # Compute metrics
    metr = {}
    metr['story_accuracy'] = accuracy_score(stories, pred_stories)

    conflicts_flat = [c for c_ex in conflicts for c_ent in c_ex for c in c_ent]
    pred_conflicts_flat = [c for c_ex in pred_conflicts for c_ent in c_ex for c in c_ent]
    metr['confl_f1'] = f1_score(conflicts_flat, pred_conflicts_flat, average='macro')

    preconditions_flat = [p for p_ex in preconditions for p_story in p_ex for p_sent in p_story for p_ent in p_sent for p in p_ent]
    pred_preconditions_flat = [p for p_ex in pred_preconditions for p_story in p_ex for p_sent in p_story for p_ent in p_sent for p in p_ent]
    metr['precondition_f1'] = f1_score(preconditions_flat, pred_preconditions_flat, average='macro')

    effects_flat = [p for p_ex in effects for p_story in p_ex for p_sent in p_story for p_ent in p_sent for p in p_ent]
    pred_effects_flat = [p for p_ex in pred_effects for p_story in p_ex for p_sent in p_story for p_ent in p_sent for p in p_ent]
    metr['effect_f1'] = f1_score(effects_flat, pred_effects_flat, average='macro')

    metr['verifiability'] = np.mean(verifiability)
    metr['consistency'] = np.mean(consistency)

    for k in metr:
      if k not in metr_avg:
        metr_avg[k] = []
      metr_avg[k].append(metr[k])

  for k in metr_avg:
    metr_avg[k] = (np.mean(metr_avg[k]), np.var(metr_avg[k]) ** 0.5)
  print('RANDOM BASELINE (%s, %s runs)' % (str(p), str(num_runs)))
  print_dict(metr_avg)

starting dev...
starting run 0...
starting run 1...
starting run 2...
starting run 3...
starting run 4...
starting run 5...
starting run 6...
starting run 7...
starting run 8...
starting run 9...
RANDOM BASELINE (dev, 10 runs)
{
  story_accuracy: 
    (0.4963973922902495, 0.010229298100185335),
  confl_f1: 
    (0.48455594481671227, 0.0008871871462012627),
  precondition_f1: 
    (0.040229947186837825, 7.253229846940695e-05),
  effect_f1: 
    (0.040330445095181036, 0.00014750079110504094),
  verifiability: 
    (0.0, 0.0),
  consistency: 
    (0.11687148772552498, 0.0028343896293991067),
}


starting test...
starting run 0...
starting run 1...
starting run 2...
starting run 3...
starting run 4...
starting run 5...
starting run 6...
starting run 7...
starting run 8...
starting run 9...
RANDOM BASELINE (test, 10 runs)
{
  story_accuracy: 
    (0.5051997625768491, 0.001184504029598542),
  confl_f1: 
    (0.48478595705412203, 0.00034789292746123193),
  precondition_f1: 
    (0.04007996092

## Transformer-Based Tiered Classifier for TRIP

This is the baseline model presented in the paper. Based on the settings above, the below cells can be used for training and evaluating models.


### Featurization for Tiered Classification

Get the data ready for input to the model.

In [20]:
from www.dataset.prepro import get_tiered_data, balance_labels
from www.dataset.featurize import add_bert_features_tiered, get_tensor_dataset_tiered
from collections import Counter

tiered_dataset = cloze_dataset_2s

# Debug the code on a small amount of data
if debug:
  for k in tiered_dataset:
    tiered_dataset[k] = tiered_dataset[k][:20]

# train_spans = True
train_spans = False
if train_spans:
  tiered_dataset = get_story_spans_2s(tiered_dataset, train_only=True)
  tiered_dataset['train'] = [ex for ex in tiered_dataset['train'] if ex['label'] != -1] # For now, ignore examples where both stories are plausible :(

seq_length = 80 # Max sequence length to pad to

tiered_dataset = get_tiered_data(tiered_dataset)
tiered_dataset = add_bert_features_tiered(tiered_dataset, tokenizer, seq_length, add_segment_ids=True)

tiered_tensor_dataset = {}
max_story_length = max([len(ex['stories'][0]['sentences']) for p in tiered_dataset for ex in tiered_dataset[p]])
for p in tiered_dataset:
  tiered_tensor_dataset[p] = get_tensor_dataset_tiered(tiered_dataset[p], max_story_length, add_segment_ids=True)



### Train Models

#### Configure Hyperparameters
We will perform grid search over (batch size, learning rate). Configure the training sub-task, search space and set the maximum number of training epochs here. Currently configured for re-training the best RoBERTa-based model instance. Read code comments for more information.

**Additional configuration options:**
* Change the `generate_learning_curve` variable to `True` to generate data for training curves in the style presented in the paper.
* You may ablate the input to the Conflict Detector based on a few pre-defined ablation modes. To do so, change the `ablation` variable based on the comments in the code.

In [21]:
from www.dataset.ann import att_to_idx, att_to_num_classes, att_types

subtask = 'cloze'
batch_sizes = [config_batch_size]
learning_rates = [config_lr]
# learning_rates = [1e-3, 1e-4, 1e-5, 1e-6]
epochs = config_epochs
eval_batch_size = 16
generate_learning_curve = False # Generate data for training curve figure in TRIP paper

num_state_labels = {}
for att in att_to_idx:
  if att_types[att] == 'default':
    num_state_labels[att_to_idx[att]] = 3
  else:
    num_state_labels[att_to_idx[att]] = att_to_num_classes[att] # Location attributes fall into this since they don't have well-define pre- and post-condition yet

# Ablation options:
# - attributes: skip attribute prediction phase
# - embeddings: DON'T input contextual embeddings to conflict detector
# - states: DON'T input states to conflict detector
# - states-labels: in states input to conflict detector, include predicted labels
# - states-logits: in states input to conflict detector, include state logits (preferred)
# - states-teacher-forcing: train conflict detector on ground truth state labels (not predictions)
# - states-attention: re-weight input to conflict detector with weights conditioned on states representation
ablation = ['attributes', 'states-logits'] # This is the default mode presented in the paper

#### Perform Grid Search

Perform hyperparameter tuning to find the best story classification model.


In [22]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import get_linear_schedule_with_warmup
from www.model.train import train_epoch_tiered
from www.model.eval import evaluate_tiered, save_results, save_preds, add_entity_attribute_labels
from sklearn.metrics import accuracy_score, f1_score
from www.utils import print_dict, get_model_dir
from www.model.transformers_ext import TieredModelPipeline
from www.dataset.ann import att_to_num_classes
import shutil
import pandas as pd

seed_val = 22 # Save random seed for reproducibility
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

# We'll keep the validation data here with a constant eval batch size
dev_sampler = SequentialSampler(tiered_tensor_dataset['dev'])
dev_dataloader = DataLoader(tiered_tensor_dataset['dev'], sampler=dev_sampler, batch_size=eval_batch_size)
dev_dataset_name = subtask + '_%s_dev'
dev_ids = [ex['example_id'] for ex in tiered_dataset['dev']]

all_losses = []
param_combos = []
combo_names = []
all_val_objs = []
output_dirs = []
best_obj = 0.0
best_model = '<none>'
best_dir = ''
best_obj2 = 0.0
best_model2 = '<none>'
best_dir2 = ''

print('Beginning grid search for the %s sub-task over %s parameter combination(s)!' % (subtask, str(len(batch_sizes) * len(learning_rates))))
for bs in batch_sizes:
  for lr in learning_rates:
    print('\nTRAINING MODEL: bs=%s, lr=%s' % (str(bs), str(lr)))

    loss_values = []
    obj_values = []

    # Set up training dataset with new batch size
    train_sampler = RandomSampler(tiered_tensor_dataset['train'])
    train_dataloader = DataLoader(tiered_tensor_dataset['train'], sampler=train_sampler, batch_size=bs)

    # Set up model
    config = config_class.from_pretrained(model_name,
                                          cache_dir=os.path.join(DRIVE_PATH, 'cache'))    
    emb = emb_class.from_pretrained(model_name,
                                          config=config,
                                          cache_dir=os.path.join(DRIVE_PATH, 'cache'))    
    if torch.cuda.is_available():
      emb.cuda()
    device = emb.device
    max_story_length = max([len(ex['stories'][0]['sentences']) for p in tiered_dataset for ex in tiered_dataset[p]])
    model = TieredModelPipeline(emb, max_story_length, len(att_to_num_classes), num_state_labels,
                                config_class, model_name, device, 
                                ablation=ablation, loss_weights=loss_weights).to(device)

    # Set up optimizer
    optimizer = AdamW(model.parameters(), lr=lr)
    total_steps = len(train_dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps = total_steps)

    train_lc_data = []
    val_lc_data = []
    for epoch in range(epochs):
      # Train the model for one epoch
      print('[%s] Beginning epoch...' % str(epoch))

      epoch_loss, _ = train_epoch_tiered(model, optimizer, train_dataloader, device, seg_mode=False, 
                                         build_learning_curves=generate_learning_curve, val_dataloader=dev_dataloader, 
                                         train_lc_data=train_lc_data, val_lc_data=val_lc_data)
      
      # Save loss
      loss_values.append(epoch_loss)

      # Validate on dev set
      validation_results = evaluate_tiered(model, dev_dataloader, device, [(accuracy_score, 'accuracy'), (f1_score, 'f1')], seg_mode=False, return_explanations=True)
      metr_attr, all_pred_atts, all_atts, \
      metr_prec, all_pred_prec, all_prec, \
      metr_eff, all_pred_eff, all_eff, \
      metr_conflicts, all_pred_conflicts, all_conflicts, \
      metr_stories, all_pred_stories, all_stories, explanations = validation_results[:16]
      explanations = add_entity_attribute_labels(explanations, tiered_dataset['dev'], list(att_to_num_classes.keys()))

      print('[%s] Validation results:' % str(epoch))
      print('[%s] Preconditions:' % str(epoch))
      print_dict(metr_prec)
      print('[%s] Effects:' % str(epoch))
      print_dict(metr_eff)
      print('[%s] Conflicts:' % str(epoch))
      print_dict(metr_conflicts)
      print('[%s] Stories:' % str(epoch))
      print_dict(metr_stories)

      # Save accuracy - want to maximize verifiability of tiered predictions
      ver = metr_stories['verifiability']
      acc = metr_stories['accuracy']
      obj_values.append(ver)
      
      # Save model checkpoint
      print('[%s] Saving model checkpoint...' % str(epoch))
      model_param_str = get_model_dir(model_name.replace('/', '-'), subtask, bs, lr, epoch) + '_' +  '-'.join([str(lw) for lw in loss_weights]) +  '_tiered_pipeline_lc'
      if train_spans:
        model_param_str += 'spans'
      if len(model.ablation) > 0:
        model_param_str += '_ablate_'
        model_param_str += '_'.join(model.ablation)
      output_dir = os.path.join(DRIVE_PATH, 'saved_models', model_param_str)
      output_dirs.append(output_dir)
      if not os.path.exists(output_dir):
        os.makedirs(output_dir)

      save_results(metr_attr, output_dir, dev_dataset_name % 'attributes')
      save_results(metr_prec, output_dir, dev_dataset_name % 'preconditions')
      save_results(metr_eff, output_dir, dev_dataset_name % 'effects')
      save_results(metr_conflicts, output_dir, dev_dataset_name % 'conflicts')
      save_results(metr_stories, output_dir, dev_dataset_name % 'stories')
      save_results(explanations, output_dir, dev_dataset_name % 'explanations')

      # Just save story preds
      save_preds(dev_ids, all_stories, all_pred_stories, output_dir, dev_dataset_name % 'stories')

      emb = emb.module if hasattr(emb, 'module') else emb
      emb.save_pretrained(output_dir)
      torch.save(model, os.path.join(output_dir, 'classifiers.pth'))
      tokenizer.save_vocabulary(output_dir)

      if ver > best_obj:
        best_obj = ver
        best_model = model_param_str
        best_dir = output_dir
      if acc > best_obj2:
        best_obj2 = acc
        best_model2 = model_param_str
        best_dir2 = output_dir        

      # for od in output_dirs:
      #   if od != best_dir and od != best_dir2 and os.path.exists(od):
      #     shutil.rmtree(od)

      print('[%s] Finished epoch.' % str(epoch))

    all_losses.append(loss_values)
    all_val_objs.append(obj_values)
    param_combos.append((bs, lr))
    combo_names.append('bs=%s, lr=%s' % (str(bs), str(lr)))

print('Finished grid search! :)')
print('Best validation *verifiability* %s from model %s.' % (str(best_obj), best_model))
print('Best validation *accuracy* %s from model %s.' % (str(best_obj2), best_model2))

if generate_learning_curve:
  print('Saving learning curve data...')
  train_lc_data = [subrecord for record in train_lc_data for subrecord in record] # flatten
  val_lc_data = [subrecord for record in val_lc_data for subrecord in record] # flatten

  train_lc_data = pd.DataFrame(train_lc_data)
  print(os.path.join(best_dir if best_dir != '<none>' else best_dir2, 'learning_curve_data_train.csv'))
  train_lc_data.to_csv(os.path.join(best_dir if best_dir != '' else best_dir2, 'learning_curve_data_train.csv'), index=False)
  val_lc_data = pd.DataFrame(val_lc_data)
  val_lc_data.to_csv(os.path.join(best_dir if best_dir != '' else best_dir2, 'learning_curve_data_val.csv'), index=False)
  print('Learning curve data saved. %s rows saved for training, %s rows saved for validation.' % (str(len(train_lc_data.index)), str(len(val_lc_data.index))))

Beginning grid search for the cloze sub-task over 1 parameter combination(s)!

TRAINING MODEL: bs=1, lr=1e-05


[                                                                        ] N/A%

[0] Beginning epoch...


[########################################################################] 100%
[                                                                        ] N/A%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:03:01s.
[0] Validation results:
[0] Preconditions:
{
  accuracy: 
    0.9962154999299491,
  f1: 
    0.6000385613090384,
  accuracy_0: 
    0.9976766450287209,
  f1_0: 
    0.7281326698947256,
  accuracy_1: 
    0.9992761406622146,
  f1_1: 
    0.663381544942331,
  accuracy_2: 
    0.9994862933731845,
  f1_2: 
    0.8488198762496322,
  accuracy_3: 
    0.9991360388549012,
  f1_3: 
    0.6141415591904161,
  accuracy_4: 
    0.9997314715359829,
  f1_4: 
    0.4132885710115312,
  accuracy_5: 
    0.9868771307149862,
  f1_5: 
    0.5263723044567366,
  accuracy_6: 
    0.993251762947742,
  f1_6: 
    0.8073351709419829,
  accuracy_7: 
    0.9981203007518797,
  f1_7: 
    0.48816562188011386,
  accuracy_8: 
    0.9973847662634848,
  f1_8: 
    0.7702837031700575,
  accuracy_9: 
    0.9939289216830897,
  f1_9: 
    0.7373444375469692,
  accuracy_10: 
    0.997232989305562,
  f1_10: 
    0.6954612117945156,
  accuracy_11: 
    0.9972796899079999

[                                                                        ] N/A%

[0] Finished epoch.
[1] Beginning epoch...


[########################################################################] 100%
[                                                                        ] N/A%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:03:01s.
[1] Validation results:
[1] Preconditions:
{
  accuracy: 
    0.9961939009013216,
  f1: 
    0.6051704728849067,
  accuracy_0: 
    0.9972446644561714,
  f1_0: 
    0.7073562330774997,
  accuracy_1: 
    0.9989609115957596,
  f1_1: 
    0.6619661171770587,
  accuracy_2: 
    0.999684770933545,
  f1_2: 
    0.905682169588338,
  accuracy_3: 
    0.9992060897585578,
  f1_3: 
    0.6963537567817225,
  accuracy_4: 
    0.9998482230420772,
  f1_4: 
    0.5809270790323545,
  accuracy_5: 
    0.9869471816186429,
  f1_5: 
    0.5249330625753703,
  accuracy_6: 
    0.9925746042123943,
  f1_6: 
    0.8260439797477694,
  accuracy_7: 
    0.9984588801195535,
  f1_7: 
    0.5072542620698215,
  accuracy_8: 
    0.9973730911128753,
  f1_8: 
    0.748672182858173,
  accuracy_9: 
    0.9931233362910381,
  f1_9: 
    0.7601628433254701,
  accuracy_10: 
    0.9967893335824032,
  f1_10: 
    0.6661312510747394,
  accuracy_11: 
    0.9980152243963947,

[                                                                        ] N/A%

[1] Finished epoch.
[2] Beginning epoch...


[########################################################################] 100%
[                                                                        ] N/A%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:03:01s.
[2] Validation results:
[2] Preconditions:
{
  accuracy: 
    0.9961915658711997,
  f1: 
    0.61180009090497,
  accuracy_0: 
    0.9973030402092187,
  f1_0: 
    0.7070823112741803,
  accuracy_1: 
    0.9993228412646523,
  f1_1: 
    0.9969219676442833,
  accuracy_2: 
    0.999603044879279,
  f1_2: 
    0.8797271246945764,
  accuracy_3: 
    0.9989842618969784,
  f1_3: 
    0.5831638685498685,
  accuracy_4: 
    0.9998482230420772,
  f1_4: 
    0.5809270790323545,
  accuracy_5: 
    0.9863750992387802,
  f1_5: 
    0.5339462813947862,
  accuracy_6: 
    0.9923411012002055,
  f1_6: 
    0.8012631212889092,
  accuracy_7: 
    0.9985289310232102,
  f1_7: 
    0.5301932197750857,
  accuracy_8: 
    0.9974197917153131,
  f1_8: 
    0.7438531838177479,
  accuracy_9: 
    0.9931116611404287,
  f1_9: 
    0.7603359856156559,
  accuracy_10: 
    0.9970228365945921,
  f1_10: 
    0.7148426355055046,
  accuracy_11: 
    0.9979918740951759,

[                                                                        ] N/A%

[2] Finished epoch.
[3] Beginning epoch...


[########################################################################] 100%
[                                                                        ] N/A%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:03:02s.
[3] Validation results:
[3] Preconditions:
{
  accuracy: 
    0.9958448138980993,
  f1: 
    0.6101517601156232,
  accuracy_0: 
    0.9974314668659225,
  f1_0: 
    0.6730552385540634,
  accuracy_1: 
    0.9988091346378368,
  f1_1: 
    0.9946169852597522,
  accuracy_2: 
    0.9994862933731845,
  f1_2: 
    0.872930150157873,
  accuracy_3: 
    0.9991477140055107,
  f1_3: 
    0.6924906257333406,
  accuracy_4: 
    0.9999532993975623,
  f1_4: 
    0.6666588810513695,
  accuracy_5: 
    0.9861999719796385,
  f1_5: 
    0.532839525169434,
  accuracy_6: 
    0.9909167328258535,
  f1_6: 
    0.804925622871965,
  accuracy_7: 
    0.9984939055713818,
  f1_7: 
    0.5355608780637615,
  accuracy_8: 
    0.9972796899079999,
  f1_8: 
    0.7437561700605677,
  accuracy_9: 
    0.9916172418624201,
  f1_9: 
    0.7390068524360273,
  accuracy_10: 
    0.9967893335824032,
  f1_10: 
    0.7008640714237493,
  accuracy_11: 
    0.9977116704805492,

[                                                                        ] N/A%

[3] Finished epoch.
[4] Beginning epoch...


[########################################################################] 100%
[                                                                        ] N/A%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:03:02s.
[4] Validation results:
[4] Preconditions:
{
  accuracy: 
    0.9960783169102881,
  f1: 
    0.6120299298102946,
  accuracy_0: 
    0.9974781674683604,
  f1_0: 
    0.6973224494318738,
  accuracy_1: 
    0.9992994909634334,
  f1_1: 
    0.9968092583365986,
  accuracy_2: 
    0.9996263951804978,
  f1_2: 
    0.8990004741767823,
  accuracy_3: 
    0.9991944146079484,
  f1_3: 
    0.7002743464463634,
  accuracy_4: 
    0.9999416242469528,
  f1_4: 
    0.6589049967198761,
  accuracy_5: 
    0.986293373184514,
  f1_5: 
    0.5310478482737274,
  accuracy_6: 
    0.9919324709288749,
  f1_6: 
    0.813004846569844,
  accuracy_7: 
    0.9986223322280857,
  f1_7: 
    0.562440675324473,
  accuracy_8: 
    0.9972096390043431,
  f1_8: 
    0.7383656641862452,
  accuracy_9: 
    0.9925979545136132,
  f1_9: 
    0.7591074696157184,
  accuracy_10: 
    0.996625881473871,
  f1_10: 
    0.6950643446534551,
  accuracy_11: 
    0.9978050716854248,


[                                                                        ] N/A%

[4] Finished epoch.
[5] Beginning epoch...


[########################################################################] 100%
[                                                                        ] N/A%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:03:01s.
[5] Validation results:
[5] Preconditions:
{
  accuracy: 
    0.9961098398169337,
  f1: 
    0.6236570961760664,
  accuracy_0: 
    0.9973964414140942,
  f1_0: 
    0.7003629631886734,
  accuracy_1: 
    0.998972586746369,
  f1_1: 
    0.9953435766289808,
  accuracy_2: 
    0.9995446691262317,
  f1_2: 
    0.8574351542670205,
  accuracy_3: 
    0.9992060897585578,
  f1_3: 
    0.6985939182437543,
  accuracy_4: 
    0.9998365478914678,
  f1_4: 
    0.568600202917924,
  accuracy_5: 
    0.9866436277027973,
  f1_5: 
    0.5433720984212754,
  accuracy_6: 
    0.9918740951758278,
  f1_6: 
    0.8126119296559615,
  accuracy_7: 
    0.998657357679914,
  f1_7: 
    0.5768942695415371,
  accuracy_8: 
    0.997338065661047,
  f1_8: 
    0.7437181342167771,
  accuracy_9: 
    0.9923878018026433,
  f1_9: 
    0.7362845105057422,
  accuracy_10: 
    0.9969994862933732,
  f1_10: 
    0.7182421899294237,
  accuracy_11: 
    0.9980152243963947,


[                                                                        ] N/A%

[5] Finished epoch.
[6] Beginning epoch...


[########################################################################] 100%
[                                                                        ] N/A%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:03:00s.
[6] Validation results:
[6] Preconditions:
{
  accuracy: 
    0.9961203474524821,
  f1: 
    0.6161977492610499,
  accuracy_0: 
    0.9975248680707981,
  f1_0: 
    0.7007200946278247,
  accuracy_1: 
    0.9993345164152617,
  f1_1: 
    0.99697013519775,
  accuracy_2: 
    0.9996380703311073,
  f1_2: 
    0.8983305092022157,
  accuracy_3: 
    0.9991944146079484,
  f1_3: 
    0.699207984507276,
  accuracy_4: 
    0.9998131975902489,
  f1_4: 
    0.541635526387494,
  accuracy_5: 
    0.9866553028534069,
  f1_5: 
    0.5385953657370541,
  accuracy_6: 
    0.9920492224349694,
  f1_6: 
    0.7928905775018228,
  accuracy_7: 
    0.9985289310232102,
  f1_7: 
    0.5325219671063768,
  accuracy_8: 
    0.9976883201793303,
  f1_8: 
    0.7455823614383331,
  accuracy_9: 
    0.9928197823751926,
  f1_9: 
    0.7511433193926664,
  accuracy_10: 
    0.9968827347872787,
  f1_10: 
    0.6890344971043918,
  accuracy_11: 
    0.9976416195768926,


[                                                                        ] N/A%

[6] Finished epoch.
[7] Beginning epoch...


[########################################################################] 100%
[                                                                        ] N/A%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:02:59s.
[7] Validation results:
[7] Preconditions:
{
  accuracy: 
    0.9959790781301079,
  f1: 
    0.6043691982593784,
  accuracy_0: 
    0.9974081165647037,
  f1_0: 
    0.6628309732603169,
  accuracy_1: 
    0.9994045673189185,
  f1_1: 
    0.9972866697312611,
  accuracy_2: 
    0.9996497454817167,
  f1_2: 
    0.9036452603995531,
  accuracy_3: 
    0.9992411152103862,
  f1_3: 
    0.7177871909780237,
  accuracy_4: 
    0.9999532993975623,
  f1_4: 
    0.6666588810513695,
  accuracy_5: 
    0.9861182459253724,
  f1_5: 
    0.5220504971187017,
  accuracy_6: 
    0.9914421146032784,
  f1_6: 
    0.811588928585941,
  accuracy_7: 
    0.9984939055713818,
  f1_7: 
    0.5373803583407768,
  accuracy_8: 
    0.9972913650586093,
  f1_8: 
    0.742265061260455,
  accuracy_9: 
    0.9922360248447205,
  f1_9: 
    0.7753982439230879,
  accuracy_10: 
    0.9965324802689954,
  f1_10: 
    0.7076400537716893,
  accuracy_11: 
    0.9972446644561714

[                                                                        ] N/A%

[7] Finished epoch.
[8] Beginning epoch...


[########################################################################] 100%
[                                                                        ] N/A%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:03:00s.
[8] Validation results:
[8] Preconditions:
{
  accuracy: 
    0.9961717181151637,
  f1: 
    0.6099204154143765,
  accuracy_0: 
    0.9973030402092187,
  f1_0: 
    0.6867335967513174,
  accuracy_1: 
    0.9992294400597768,
  f1_1: 
    0.996497411457288,
  accuracy_2: 
    0.9995796945780601,
  f1_2: 
    0.8667937018469235,
  accuracy_3: 
    0.9992060897585578,
  f1_3: 
    0.7026288755537302,
  accuracy_4: 
    0.9999299490963434,
  f1_4: 
    0.6507819725070889,
  accuracy_5: 
    0.9865619016485313,
  f1_5: 
    0.525830856670891,
  accuracy_6: 
    0.9922710502965488,
  f1_6: 
    0.8086199414240353,
  accuracy_7: 
    0.9985289310232102,
  f1_7: 
    0.5369485113401821,
  accuracy_8: 
    0.9975365432214075,
  f1_8: 
    0.7431721686067005,
  accuracy_9: 
    0.9928781581282399,
  f1_9: 
    0.7712335154220535,
  accuracy_10: 
    0.9966609069256993,
  f1_10: 
    0.6986646765657665,
  accuracy_11: 
    0.9980969504506608

[                                                                        ] N/A%

[8] Finished epoch.
[9] Beginning epoch...


[########################################################################] 100%
[                                                                        ] N/A%

	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:03:00s.
[9] Validation results:
[9] Preconditions:
{
  accuracy: 
    0.9960327838229113,
  f1: 
    0.5989567465992012,
  accuracy_0: 
    0.9971746135525148,
  f1_0: 
    0.6634756124887446,
  accuracy_1: 
    0.9993111661140429,
  f1_1: 
    0.9968647479671763,
  accuracy_2: 
    0.9995213188250128,
  f1_2: 
    0.8652704730194637,
  accuracy_3: 
    0.9991710643067295,
  f1_3: 
    0.7070453593461951,
  accuracy_4: 
    0.9999065987951244,
  f1_4: 
    0.6333177624664253,
  accuracy_5: 
    0.9859547938168403,
  f1_5: 
    0.5183150634452525,
  accuracy_6: 
    0.9919091206276561,
  f1_6: 
    0.7952708325548645,
  accuracy_7: 
    0.9984588801195535,
  f1_7: 
    0.5227819671533945,
  accuracy_8: 
    0.9972913650586093,
  f1_8: 
    0.7426802937789062,
  accuracy_9: 
    0.9928781581282399,
  f1_9: 
    0.7405917993600148,
  accuracy_10: 
    0.9967426329799655,
  f1_10: 
    0.7063522114250383,
  accuracy_11: 
    0.99775837108298

Delete all non-best model checkpoints:


In [23]:
# import shutil

# Delete non-best model checkpoints
for od in output_dirs:
  if od != best_dir and od != best_dir2 and os.path.exists(od):
    shutil.rmtree(od)

### Test Models

Evaluate accuracy, consistency, and verifiability on the test set.

#### Load the Trained Model

Load the trained model we want to probe and select the appropriate dataset. Paths to the pre-trained models presented in the paper are already provided (download links are found in GitHub repo).

In [24]:
from www.model.transformers_ext import TieredModelPipeline
from www.dataset.ann import att_to_num_classes, att_to_idx, att_types
eval_model_dir = best_dir[15:]
probe_model = eval_model_dir
probe_model = os.path.join(DRIVE_PATH, 'saved_models', probe_model)

ablation = ['attributes', 'states-logits']

if 'cloze' in probe_model:
  subtask = 'cloze'
elif 'order' in probe_model:
  subtask = 'order'
  
if subtask == 'cloze':
  subtask_dataset = cloze_dataset_2s
elif subtask == 'order':
  subtask_dataset = order_dataset_2s

# Load the model
model = None
# model = torch.load(os.path.join(probe_model, 'classifiers.pth'), map_location=torch.device('cpu'))
model = torch.load(os.path.join(probe_model, 'classifiers.pth'))
if torch.cuda.is_available():
  model.cuda()
device = model.embedding.device

for layer in model.precondition_classifiers:
  layer.eval()
for layer in model.effect_classifiers:
  layer.eval()

#### Test the Model

Run inference on the testing set of TRIP. Can simply edit the top-level `for` loop if you want to run inference on other partitions.

In [25]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from www.model.eval import evaluate_tiered, save_results, save_preds, list_comparison, add_entity_attribute_labels
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
metrics = [(accuracy_score, 'accuracy'), (precision_score, 'precision'), (recall_score, 'recall'), (f1_score, 'f1')]
import numpy as np
from www.utils import print_dict

print('Testing model: %s.' % probe_model)

# May alter this depending on which partition(s) you want to run inference on
for p in tiered_dataset:
  if p != 'test':
    continue

  p_dataset = tiered_dataset[p]
  p_tensor_dataset = tiered_tensor_dataset[p]
  p_sampler = SequentialSampler(p_tensor_dataset)
  p_dataloader = DataLoader(p_tensor_dataset, sampler=p_sampler, batch_size=16)
  dev_dataset_name = subtask + '_%s_' + p
  p_ids = [ex['example_id'] for ex in tiered_dataset[p]]

  # Get preds and metrics on this partition
  metr_attr, all_pred_atts, all_atts, \
  metr_prec, all_pred_prec, all_prec, \
  metr_eff, all_pred_eff, all_eff, \
  metr_conflicts, all_pred_conflicts, all_conflicts, \
  metr_stories, all_pred_stories, all_stories, explanations = evaluate_tiered(model, p_dataloader, device, [(accuracy_score, 'accuracy'), (f1_score, 'f1')], seg_mode=False, return_explanations=True)
  explanations = add_entity_attribute_labels(explanations, tiered_dataset[p], list(att_to_num_classes.keys()))

  save_results(metr_attr, probe_model, dev_dataset_name % 'attributes')
  save_results(metr_prec, probe_model, dev_dataset_name % 'preconditions')
  save_results(metr_eff, probe_model, dev_dataset_name % 'effects')
  save_results(metr_conflicts, probe_model, dev_dataset_name % 'conflicts')
  save_results(metr_stories, probe_model, dev_dataset_name % 'stories')
  save_results(explanations, probe_model, dev_dataset_name % 'explanations')

  print('\nPARTITION: %s' % p)
  print('Stories:')
  print_dict(metr_stories)
  print('Conflicts:')
  print_dict(metr_conflicts)
  print('Preconditions:')
  print_dict(metr_prec)
  print('Effects:')
  print_dict(metr_eff)

[                                                                        ] N/A%

Testing model: ./saved_models/roberta-large_cloze_1_1e-05_2_0.0-0.4-0.4-0.2-0.0_tiered_pipeline_lc_ablate_attributes_states-logits.
	Beginning evaluation...
		Running prediction...


[########################################################################] 100%


		Computing metrics...
	Finished evaluation in 0:04:43s.

PARTITION: test
Stories:
{
  accuracy: 
    0.8290598290598291,
  f1: 
    0.8290473407364114,
  verifiability: 
    0.1794871794871795,
}


Conflicts:
{
  accuracy: 
    0.9862298195631529,
  f1: 
    0.7901516340352739,
}


Preconditions:
{
  accuracy: 
    0.996863458900496,
  f1: 
    0.5930098252192376,
  accuracy_0: 
    0.9980931277227574,
  f1_0: 
    0.591249496779566,
  accuracy_1: 
    0.9988543692247396,
  f1_1: 
    0.6593973480470631,
  accuracy_2: 
    0.999351814166629,
  f1_2: 
    0.7752617249234879,
  accuracy_3: 
    0.9994573327906662,
  f1_3: 
    0.595588524844265,
  accuracy_4: 
    0.9999095554651111,
  f1_4: 
    0.530849121850951,
  accuracy_5: 
    0.989930508449027,
  f1_5: 
    0.5306447502642644,
  accuracy_6: 
    0.9925081776933629,
  f1_6: 
    0.8213565487048328,
  accuracy_7: 
    0.9984172206394428,
  f1_7: 
    0.46350531310577026,
  accuracy_8: 
    0.9977313495832014,
  f1_8: 
    0.829302

#### Add Consistency Metric to Model Results
The intermediate conistency metric isn't included in the originally calculated metrics. This block adds the consistency metric to pre-existing model directory based on the tiered predictions. Generates a new `results_cloze_stories_final_[partition].json` file that includes the consistency metric.



In [26]:
import json
import os

model_directories = [eval_model_dir]

partitions = ['dev', 'test']
expl_fname = 'results_cloze_explanations_%s.json'
endtask_fname = 'results_cloze_stories_%s.json'
endtask_fname_new = 'results_cloze_stories_final_%s.json'
for md in model_directories:
  for p in partitions:
    explanations = json.load(open(os.path.join(DRIVE_PATH, 'saved_models', md, expl_fname % p), 'r'))
    endtask_results = json.load(open(os.path.join(DRIVE_PATH, 'saved_models', md, endtask_fname % p), 'r'))

    consistent_preds = 0
    verifiable_preds = 0
    total = 0
    for expl in explanations:
      if expl['valid_explanation']:
        verifiable_preds += 1
      if expl['story_pred'] == expl['story_label']:
        if len(expl['conflict_pred']) == len(expl['conflict_label']) and expl['conflict_pred'][0] == expl['conflict_label'][0] and expl['conflict_pred'][1] == expl['conflict_label'][1]:
          expl['consistent'] = True
          consistent_preds += 1
        else:
          expl['consistent'] = False
      total += 1

    endtask_results['consistency'] = float(consistent_preds) / total
    print('Found %s consistent preds in %s (versus %s verifiable)' % (str(consistent_preds), p, str(verifiable_preds)))
    json.dump(explanations, open(os.path.join(DRIVE_PATH, 'saved_models', md, (expl_fname % p).replace('explanations', 'explanations_consistency')), 'w'))
    json.dump(endtask_results, open(os.path.join(DRIVE_PATH, 'saved_models', md, endtask_fname_new % p), 'w'))

Found 162 consistent preds in dev (versus 68 verifiable)
Found 173 consistent preds in test (versus 63 verifiable)



# Conversational Entailment (CE) Results

Code for the coherence experiments on CE.

In [27]:
task_name = 'ce'
if task_name != 'ce':
  raise ValueError('Please configure task_name in first cell to "ce" to run CE results!')

## Load Conversational Entailment Dataset

In [28]:
import xml.etree.ElementTree as ET
import pickle
cache_train = os.path.join(DRIVE_PATH, 'all_data/ConvEnt/ConvEnt_train_resplit.json')
cache_dev = os.path.join(DRIVE_PATH,'all_data/ConvEnt/ConvEnt_dev_resplit.json')
cache_test = os.path.join(DRIVE_PATH,'all_data/ConvEnt/ConvEnt_test_resplit.json')
ConvEnt_train = json.load(open(cache_train))
ConvEnt_dev = json.load(open(cache_dev))
ConvEnt_test = json.load(open(cache_test))

# Combine train and dev and do cross-validation
cache_folds = os.path.join(DRIVE_PATH,'all_data/ConvEnt/ConvEnt_folds.pkl') # Folds used for results presented in paper
ConvEnt_train = ConvEnt_train + ConvEnt_dev
train_sources = list(set([ex['dialog_source'] for ex in ConvEnt_train]))
print("Reserved %s dialog sources for training and validation." % len(train_sources))

no_folds = 8
if not os.path.exists(cache_folds):
  folds = []
  for k in range(no_folds):
    folds.append(np.random.choice(train_sources, size=5, replace=False))
    train_sources = [s for s in train_sources if s not in folds[-1]]
  assert len(train_sources) == 0
  print(folds)
  pickle.dump(folds, open(cache_folds, 'wb'))
else:
  folds = pickle.load(open(cache_folds, 'rb'))

Reserved 40 dialog sources for training and validation.


In [29]:
print('train examples:', len(ConvEnt_train))
print('dev examples:', len(ConvEnt_dev))
print('test examples:', len(ConvEnt_test))

train examples: 703
dev examples: 110
test examples: 172


## Featurize Conversational Entailment

In [30]:
from www.dataset.featurize import add_bert_features_ConvEnt, get_tensor_dataset
import pickle
seq_length = 128

ConvEnt_train = add_bert_features_ConvEnt(ConvEnt_train, tokenizer, seq_length, add_segment_ids=True)
ConvEnt_dev = add_bert_features_ConvEnt(ConvEnt_dev, tokenizer, seq_length, add_segment_ids=True)
ConvEnt_test = add_bert_features_ConvEnt(ConvEnt_test, tokenizer, seq_length, add_segment_ids=True)

ConvEnt_train_folds = [[] for _ in range(no_folds)]
ConvEnt_dev_folds = [[] for _ in range(no_folds)]
for k in range(no_folds):
  ConvEnt_train_folds[k] = [ex for ex in ConvEnt_train if ex['dialog_source'] not in folds[k]]
  ConvEnt_dev_folds[k] = [ex for ex in ConvEnt_train if ex['dialog_source'] in folds[k]]

  if debug:
    ConvEnt_train_folds[k] = ConvEnt_train_folds[k][:10]
    ConvEnt_dev_folds[k] = ConvEnt_dev_folds[k][:10]

if debug:
  ConvEnt_train = ConvEnt_train[:10]
  ConvEnt_dev = ConvEnt_dev[:10]
  ConvEnt_test = ConvEnt_test[:10]

ConvEnt_train_tensor = get_tensor_dataset(ConvEnt_train, label_key='label', add_segment_ids=True)
ConvEnt_test_tensor = get_tensor_dataset(ConvEnt_test, label_key='label', add_segment_ids=True)

# Training sets for each validation fold
ConvEnt_train_folds_tensor = [get_tensor_dataset(ConvEnt_train_folds[k], label_key='label', add_segment_ids=True) for k in range(no_folds)]
ConvEnt_dev_folds_tensor = [get_tensor_dataset(ConvEnt_dev_folds[k], label_key='label', add_segment_ids=True) for k in range(no_folds)]

In [31]:
print('train examples:', len(ConvEnt_train))
print('dev examples:', len(ConvEnt_dev))
print('test examples:', len(ConvEnt_test))

train examples: 703
dev examples: 110
test examples: 172


## Train Models on Conversational Entailment

### Train Models

#### Configure Hyperparameters

#### Grid Search and Cross-Validation

#### Re-Train Best Model from Cross-Validation

Re-train a model with the best parameters from the search above. If this isn't run directly after the above cell, replace `save_fname.split('/'[-1])` in `xval_fnames` with the name of the `pkl` file previously generated in the `saved_models` directory.

## Test Models on Conversational Entailment

## Coherence Checks on Conversational Entailment

### Load and Featurize Span Data

### Load the Trained Model

Load the trained model we want to probe and select the appropriate dataset.

#### Load Trained Model's Base Predictions

For comparison, we also want the preds and labels for the previous level.

### Check a Model

Will print out strict and lenient coherence metrics.

# ART Results

Code for the coherence experiments on ART.

## Load ART dataset

ART is originally gathered from [HuggingFace datasets](https://huggingface.co/docs/datasets/), but we added some of our own annotations for the coherence evaluation.

## Train Models on ART

### Featurize ART

### Train Models

Train models on ART. Note that ART's test set is not public, so we cannot test the model (unless we submit to their [leaderboard](https://leaderboard.allenai.org/anli/submissions/public)).

#### Configure Hyperparameters

#### Grid Search

Delete non-best model checkpoints:

: 

## Coherence Checks on ART

### Load and Featurize Span Data

### Load the Trained Model

Load the trained model we want to probe and select the appropriate dataset.

#### Load Trained Model's Two-Story Classification Predictions

For comparison, we also want the preds and labels for the previous level.

### Calculate Coherence Metrics

As ART is a multiple-choice task, we will need to tune the confidence threshold $\rho$. This code will print out the strict and lenient coherence metrics, as well as the chosen $\rho$ (`best_threshold`).