<a href="https://colab.research.google.com/github/tabasy/similarity_prompting/blob/main/exploiting_prompts_with_similarity.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Prepare Env

In [1]:
# @title install libs
from IPython.display import clear_output

try:
  import transformers
except:
  !pip install transformers datasets tokenizers
  # !pip install --upgrade pandas-profiling
  !pip install gdown

  !git clone https://github.com/dlukes/rbo
  import sys
  sys.path.append('/content/rbo')

!mkdir -p wictsv
!mkdir -p wic_gold

clear_output()

In [2]:
# @title import libs
from ipywidgets import interact, interactive, fixed, interact_manual

import os
import numpy as np
import pandas as pd

from copy import deepcopy
from operator import itemgetter
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, AutoModelForMaskedLM, AutoModelForCausalLM
from tokenizers import AddedToken
from datasets import list_datasets, load_dataset, concatenate_datasets, concatenate_datasets
from rbo import rbo
while not callable(rbo):
  rbo = rbo.rbo
  
from IPython.display import display, Markdown

from nltk.stem import WordNetLemmatizer, PorterStemmer
import nltk
nltk.download('wordnet')

from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.cluster import KMeans, DBSCAN
from sklearn.metrics import accuracy_score, normalized_mutual_info_score, adjusted_mutual_info_score

from scipy.stats import pearsonr, spearmanr

from matplotlib import pyplot as plt

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [3]:
#@title helper funcs
def pattern_parser(*texts, start='{', end='}'):

  tokens = tokenizer.tokenize(*texts, add_special_tokens=True)
  all_ids = tokenizer.encode(*texts, add_special_tokens=True)
  start = tokenizer.tokenize(start.rjust(2, ' ').ljust(3, ' '))[0]
  end = tokenizer.tokenize(end.rjust(2, ' ').ljust(3, ' '))[0]

  toks, ids, types = [], [], []
  inside = 0
  for wid, token in zip(all_ids, tokens):
    if token == start:
      inside = 1
    elif token == end:
      inside = 0
    else:
      toks.append(token)
      ids.append(wid)
      types.append(inside)

  return toks, ids, types

def log_step(path, **kwargs):
  data = {key: [value] for key, value in kwargs.items()}
  df = pd.DataFrame(data)

  if not os.path.isfile(path):
      df.to_csv(path, header=list(kwargs.keys()), sep='\t', index=False)
  else:
      df.to_csv(path, mode='a', header=False, sep='\t', index=False)
# list(zip(*pattern_parser(' { I really enjoy } that { film so } much. I am.', 'haaaaaaa aaa { aaa } ...')))

In [4]:
#@title loss helpers

crossent = nn.CrossEntropyLoss()

def get_crossent_loss(model, target_class):
  output = model()
  return crossent(output, torch.tensor([target_class]).cuda())

def get_adaptive_crossent_loss(model, input_ids, target_class):
  output = model(input_ids)
  return crossent(output, torch.tensor([target_class] * len(input_ids)).cuda())

def get_embedding_distances(pattern):
  current_embs = pattern.get_target_embeddings()
  initial_embs = pattern.get_initial_embeddings(target_only=True)
  distances = torch.cdist(current_embs, initial_embs).diagonal()
  return distances

# Process Data

In [6]:
!mkdir -p sp_sst
!git clone https://github.com/toriving/Sentiment-analysis

fatal: destination path 'Sentiment-analysis' already exists and is not an empty directory.


In [7]:
#@title toriving/sst

%%writefile sp_sst/sp_sst.py

# @title Gold words in context
import json
import datasets


_CITATION = """\
@ARTICLE{breit2021wictsv,
       author = {{Breit}, Anna and {Revenko}, Artem and {Rezaee}, Kiamehr and {Taher Pilehvar}, Mohammad and {Camacho-Collados}, Jose},
        title = "{WiC-TSV: An Evaluation Benchmark for Target Sense Verification of Words in Context}",
}
"""

_DESCRIPTION = """\
We present WiC-TSV, a new multi-domain evaluation benchmark for Word Sense Disambiguation.
depending on the input signals provided to the model.
"""
_PATHS = {
    "train_examples": "/content/Sentiment-analysis/data/stsa_binary_train.txt",
    "dev_examples":  "/content/Sentiment-analysis/data/stsa_binary_dev.txt",
    "test_examples": "/content/Sentiment-analysis/data/stsa_binary_test.txt",
}

def normalize_text(text):
  return text.replace('\n', '')
  return text.replace(' .', '.').replace(' ,', ',').replace(" '", "'").replace(" ?", "?").replace(" !", "!")

class SST(datasets.GeneratorBasedBuilder):
    """TODO(WiCTSV): Short description of my dataset."""

    # TODO(WiCTSV): Set up version.
    VERSION = datasets.Version("3.5.8")

    def _info(self):
        # TODO(WiCTSV): Specifies the datasets.DatasetInfo object
        return datasets.DatasetInfo(
            # This is the description that will appear on the datasets page.
            description=_DESCRIPTION,
            # datasets.features.FeatureConnectors
            features=datasets.Features(
                {
                    "text": datasets.Value("string"),
                    "label": datasets.Value("int32")
                }
            ),
            # If there's a common (input, target) tuple from the features,
            # specify them here. They'll be used if as_supervised=True in
            # builder.as_dataset.
            supervised_keys=None,
            # Homepage of the dataset for documentation
            homepage="https://github.com/google-research-datasets/boolean-questions",
            citation=_CITATION,
        )

    def _split_generators(self, dl_manager):
        """Returns SplitGenerators."""
        # TODO(WiCTSV): Downloads the data and defines the splits
        # dl_manager is a datasets.download.DownloadManager that can be used to
        # download and extract URLs
        # urls_to_download = _URLS
        dl = _PATHS

        return [
            datasets.SplitGenerator(name=datasets.Split.TRAIN,
                                    gen_kwargs={"ex": dl["train_examples"]}),
            datasets.SplitGenerator(name=datasets.Split.TEST, 
                                    gen_kwargs={"ex": dl["test_examples"]}),
            datasets.SplitGenerator(name=datasets.Split.VALIDATION, 
                                    gen_kwargs={"ex": dl["dev_examples"]}),
        ]

    def _generate_examples(self, ex):
      """Yields examples."""
      with open(ex, encoding="utf-8") as f:
        for id_, line in enumerate(f):
          example = {}
          # label sentence
          label, _, sentence = line.partition(' ')
          example["text"] = normalize_text(sentence)
          example["label"] = int(label)
          yield id_, example

Overwriting sp_sst/sp_sst.py


In [8]:
#@title dataset processor funcs

def balance_dataset(dataset, label_name='label', ex_per_class=None):
  
  if ex_per_class < -1:
    return dataset

  classes = sorted(dataset.unique(label_name))

  def make_balanced_index(example, index):
    class_num = classes.index(example[label_name])
    example['label_int'] = class_num
    example['balanced_index'] = len(classes) * (index) + class_num
    return example

  subsets = []

  for i, label in enumerate(classes):
    subset = dataset.select(np.where(np.array(dataset[label_name]) == label)[0])
    subsets.append(subset)
    if ex_per_class is None or ex_per_class <= 0 or len(subset) < ex_per_class:
      ex_per_class = len(subset)

  subsets = [subset.select(torch.arange(ex_per_class)) for subset in subsets]
  subsets = [subset.map(make_balanced_index, with_indices=True) for subset in subsets]

  return concatenate_datasets(subsets).sort('balanced_index')

  
def reformat_dataset(dataset, rules, templates):

  def apply_template(example, template):

    text_input = template

    for tag, repl in rules.items():
      if isinstance(repl, str) and repl in example:
        repl = example[repl]
      if callable(repl):
        repl = repl(example)
      text_input = text_input.replace(tag, str(repl))

    return text_input
     
  def apply_templates(example):
    new_example = deepcopy(example)

    for name, template in templates.items():
      new_example[name] = apply_template(example, template) 
    
    return new_example

  return dataset.map(apply_templates)


def tokenize_dataset(dataset, text_fields, base_model_name='roberta-base',
                     batch_size=32):
  
  tokenizer = AutoTokenizer.from_pretrained(base_model_name)
  if 'gpt' in base_model_name:
    tokenizer.pad_token = tokenizer.eos_token

  def tokenize_batch(batch_example):

    for inp, out in text_fields.items():
      tokenized = tokenizer(batch_example[inp], padding='longest')
      batch_example[out + '_input_ids'] = tokenized['input_ids']
      batch_example[out + '_attention_mask'] = tokenized['attention_mask']

    return batch_example

  return dataset.map(tokenize_batch, batched=True, batch_size=batch_size)


def prepare_dataset(dataset, label_name, ex_per_class, base_model_name,
                    rules, templates, batch_size=32, convert_label=False):
  balanced = balance_dataset(dataset, label_name, ex_per_class)
  reformatted = reformat_dataset(balanced, rules, templates)
  
  text_fields = {k: k for k in templates}
  tokenized = tokenize_dataset(reformatted, text_fields, base_model_name, batch_size)
  
  input_fields = [k+'_input_ids' for k in templates]
  mask_fields = [k+'_attention_mask' for k in templates]
  tokenized.set_format(type='pt', columns=input_fields+mask_fields+['label_int' if convert_label else label_name])
  
  clear_output()
  
  return tokenized


def prepare_datasets(datasets, label_name, ex_per_class, base_model_name,
                     rules, templates, batch_size=32, convert_label=False):
  if isinstance(datasets, dict):
    return {k: prepare_dataset(d, label_name, ex_per_class, base_model_name,
                               rules, templates, batch_size, convert_label) for k, d in datasets.items()}

  if isinstance(datasets, list):
    return [prepare_dataset(d, label_name, ex_per_class, base_model_name,
                            rules, templates, batch_size, convert_label) for d in datasets]

  if isinstance(datasets, tuple):
    return tuple(prepare_dataset(d, label_name, ex_per_class, base_model_name,
                                 rules, templates, convert_label) for d in datasets)

In [9]:
#@title glue dataset rules

puncs = ',.?؟،؛!:;'

def ensure_punc(text, end_char='.'):
  text = text.strip()
  if text[-1] not in puncs:
    text = text + end_char
  return text

def remove_punc(text):
  text = text.strip()
  while text[-1] in puncs:
    text = text[:-1] 
  return text

glue_rules = {}

glue_rules['sst2'] = {
    '<1>': lambda x: x['sentence'],
    '<1.>': lambda x: ensure_punc(x['sentence'], end_char='.'),
    '<1_>': lambda x: remove_punc(x['sentence']),
}

glue_rules['cola'] = glue_rules['sst2']

glue_rules['mnli'] = {
    '<1>': lambda x: x['premise'],
    '<1.>': lambda x: ensure_punc(x['premise'], end_char='.'),
    '<1_>': lambda x: remove_punc(x['premise']),
    '<2>': lambda x: x['hypothesis'],
    '<2.>': lambda x: ensure_punc(x['hypothesis'], end_char='.'),
    '<2_>': lambda x: remove_punc(x['hypothesis'])}

glue_rules['qqp'] = {
    '<1>': lambda x: x['question1'],
    '<1.>': lambda x: ensure_punc(x['question1'], end_char='.'),
    '<1_>': lambda x: remove_punc(x['question1']),
    '<2>': lambda x: x['question1'],
    '<2.>': lambda x: ensure_punc(x['question1'], end_char='.'),
    '<2_>': lambda x: remove_punc(x['question1']),
     }

glue_rules['rte'] = {
    '<1>': lambda x: x['sentence1'],
    '<1.>': lambda x: ensure_punc(x['sentence1'], end_char='.'),
    '<1_>': lambda x: remove_punc(x['sentence1']),
    '<2>': lambda x: x['sentence2'],
    '<2.>': lambda x: ensure_punc(x['sentence2'], end_char='.'),
    '<2_>': lambda x: remove_punc(x['sentence2']),}

glue_rules['qnli'] = {
    '<q>': lambda x: x['question'],
    '<q.>': lambda x: ensure_punc(x['question'], end_char='.'),
    '<q_>': lambda x: remove_punc(x['question']),
    '<p>': lambda x: x['sentence'],
    '<p.>': lambda x: ensure_punc(x['sentence'], end_char='.'),
    '<p_>': lambda x: remove_punc(x['sentence']),}

glue_rules['stsb'] = glue_rules['rte']
glue_rules['wnli'] = glue_rules['rte']
glue_rules['mrpc'] = glue_rules['rte']

In [10]:
#@title super glue dataset rules

superglue_rules = {}

superglue_rules['wic'] = {
    '<1->': lambda x: x['sentence1'][:x['start1']],
    '<1a>': lambda x: ' a' if x['sentence1'][:x['start1']].strip().split(' ')[-1] == 'a' else '',
    '<2a>': lambda x: ' a' if x['sentence2'][:x['start2']].strip().split(' ')[-1] == 'a' else '',
    '<1b>': lambda x: (' ' + x['sentence1'][:x['start1']].strip().split(' ')[-1].strip()).rstrip(),
    '<2b>': lambda x: (' ' + x['sentence2'][:x['start2']].strip().split(' ')[-1].strip()).rstrip(),
    '<-1>': lambda x: x['sentence1'][x['end1']:],
     '<2->': lambda x: x['sentence2'][:x['start2']],
     '<-2>': lambda x: x['sentence2'][x['end2']:],
     '<t1>': lambda x: x['sentence1'][x['start1']:x['end1']],
     '<t2>': lambda x: x['sentence2'][x['start2']:x['end2']]}

superglue_rules['boolq'] = {
    '<q>': lambda x: x['question'],
    '<q.>': lambda x: ensure_punc(x['question'], end_char='.'),
    '<q_>': lambda x: remove_punc(x['question']),
    '<p>': lambda x: x['passage'],
    '<p.>': lambda x: ensure_punc(x['passage'], end_char='.'),
    '<p_>': lambda x: remove_punc(x['passage'])
     }

superglue_rules['cb'] = {
    '<1>': lambda x: x['premise'],
    '<1.>': lambda x: ensure_punc(x['premise'], end_char='.'),
    '<1_>': lambda x: remove_punc(x['premise']),
    '<2>': lambda x: x['hypothesis'],
    '<2.>': lambda x: ensure_punc(x['hypothesis'], end_char='.'),
    '<2_>': lambda x: remove_punc(x['hypothesis'])}

superglue_rules['rte'] = superglue_rules['cb']

In [11]:
#@title miscellaneous dataset rules

misc_name_map = {
    'sst2': 'sp_sst',#'toriving/sst2',
    'sst5': 'toriving/sst5',
    'yelp2': 'yelp_polarity',
    'yelp5': 'yelp_review_full',
    'sick': 'sick',
    'wic_tsv': '/content/wictsv',
    'wic_gold': '/content/wic_gold',

}

def pass_(dataset):
  return dataset

def drop_long(dataset, field='text', max_len=32):
  return dataset.filter(lambda x: x[field].count(' ') < max_len)

misc_preprocess = {'yelp2': drop_long,
                   'yelp5': drop_long,
                   'sst2': pass_,
                   'sst5': pass_,
                   'sick': pass_,
                   'wic_tsv': pass_,
                   'wic_gold': pass_,
                   }

In [12]:
#@title miscellaneous dataset rules

misc_rules = {}

misc_rules['wic_gold'] = superglue_rules['wic']

misc_rules['wic_tsv'] = {
    '<1>': lambda x: x['example'],
    '<1->': lambda x: ' '.join(x['example'].split()[:x['index']]),
    '<-1>': lambda x: ' '.join(x['example'].split()[x['index']+1:]),
    '<2>': lambda x: x['definition'],
    '<t1>': lambda x: x['token']}

misc_rules['sick'] = {
    '<1>': lambda x: x['sentence_A'],
    '<1.>': lambda x: ensure_punc(x['sentence_A'], end_char='.'),
    '<1_>': lambda x: remove_punc(x['sentence_A']),
    '<2>': lambda x: x['sentence_B'],
    '<2.>': lambda x: ensure_punc(x['sentence_B'], end_char='.'),
    '<2_>': lambda x: remove_punc(x['sentence_B'])}

misc_rules['yelp5'] = {
    '<1>': lambda x: x['text'],
    '<1.>': lambda x: ensure_punc(x['text'], end_char='.'),
    '<1_>': lambda x: remove_punc(x['text'])}

misc_rules['sst2'] = misc_rules['yelp5']
misc_rules['sst5'] = misc_rules['yelp5']


In [13]:
#@title balanced dataset class

from copy import deepcopy

class BalancedDataset(torch.utils.data.Dataset):
  def __init__(self, dataset, template_rules, input_templates='<1>', label_name='label',
           base_model_name='roberta-base', ex_per_class=None, return_str='<str>', return_dict=True):
    
    self.class2ex = {}
    self.classes = sorted(dataset.unique(label_name))

    for label in self.classes:
      self.class2ex[label] = dataset.filter(lambda x: x[label_name] == label)
      if len(self.class2ex[label]) < ex_per_class:
        print(f'warning! not enough examples in class {label}: {len(self.class2ex[label])} < {ex_per_class}')

    self.label_name = label_name
    self.template_rules = template_rules
    self.input_templates = input_templates
    self.ex_per_class = ex_per_class
    self.return_str = return_str
    self.return_dict = return_dict
    self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)

  def __len__(self):
    return len(self.classes) * self.ex_per_class

  def __getitem__(self, index):
    return self.get(index)

  def prepare_input_(self, example, template):

    text_input = template

    for tag, repl in self.template_rules.items():
      if isinstance(repl, str) and repl in example:
        repl = example[repl]
      if callable(repl):
        repl = repl(example)
      text_input = text_input.replace(tag, str(repl))

      return text_input
     
  def prepare_inputs(self, example):
    inputs = deepcopy(self.input_templates)

    if isinstance(self.input_templates, str):
      return self.prepare_input_(example, inputs)

    if isinstance(self.input_templates, list):
      return [self.prepare_input_(example, input) for input in inputs]

    if isinstance(self.input_templates, tuple):
      return tuple(self.prepare_input_(example, input) for input in inputs)

    if isinstance(self.input_templates, dict):
      return {key:self.prepare_input_(example, input) for key, input in inputs.items()}

  def get(self, index):
    
    label = index % len(self.classes)
    index_in_class = index // len(self.classes)

    example = self.class2ex[label][index_in_class]
    label = example[self.label_name]
    inputs = self.prepare_inputs(example)

    if self.return_dict and isinstance(inputs, dict):
      inputs['label'] = label
      return inputs
    elif self.return_dict:
      return {'inputs': inputs, 'label': label}
    return inputs, label

# wic_rules = {'<*>': '<mask>',
#          '<1->': lambda x: x['sentence1'][:x['start1']],
#          '<-1>': lambda x: x['sentence1'][x['end1']:],
#          '<2->': lambda x: x['sentence2'][:x['start2']],
#          '<-2>': lambda x: x['sentence2'][x['end2']:],
#          '<t1>': lambda x: x['sentence1'][x['start1']:x['end1']],
#          '<t2>': lambda x: x['sentence2'][x['start2']:x['end2']]}
         
# bw = BalancedDataset(wt, wic_rules, input_templates={'1': '<str><1-><t1> or <*><-1>', '2':'<str><1-><t1><-1>. <t1> means <*>.'},
#                      label_name='label', base_model_name='roberta-base', ex_per_class=128)

# Design Model

In [14]:
# bert = AutoModelForMaskedLM.from_pretrained('bert-base-uncased')
# gpt = AutoModelForCausalLM.from_pretrained('gpt2')
# tk = AutoTokenizer.from_pretrained('gpt2')

In [15]:
# @title mask-for-mare

class MaskFormer(nn.Module):
  def __init__(self, 
               base_model_name='roberta-base',
               lm_head_name='lm_head',
               freeze_bert=True):
    super().__init__()

    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    transformer = AutoModelForMaskedLM.from_pretrained(base_model_name)
    self.transformer = transformer.base_model
    self.lm_head = getattr(transformer, lm_head_name)

    self.mask_id = tokenizer.mask_token_id

    if freeze_bert:
      self.freeze_bert(freeze=True)

    # self.classifier = torch.nn.Linear()

    self.hidden_size = self.transformer.config.vocab_size

  def freeze_bert(self, freeze):
    for p in self.transformer.parameters():
      p.requires_grad = not freeze

  def single_forward(self, input_ids, att_masks, output_mode='lm_probs'):
    encoder_embs = self.transformer(input_ids, att_masks, return_dict=False)[0]

    if output_mode == 'avg_embeddings':
      sum_embes = (att_masks.unsqueeze(-1) * encoder_embs).sum(dim=1) 
      return sum_embes / att_masks.sum(dim=1).unsqueeze(-1)

    # obtaining the representation of [MAKS] head
    mask_indices = input_ids == self.mask_id
    mask_embs = encoder_embs[mask_indices]

    if output_mode == 'embeddings':
      return mask_embs

    token_logits = self.lm_head(mask_embs)

    if output_mode == 'lm_logits':
      return token_logits

    token_probs = torch.softmax(token_logits, -1)
    # print(mask_token_probs.shape)
    if output_mode == 'lm_probs':
      return token_probs
    # Feeding cls_rep to the classifier layer
    # logits = self.classifier(token_probs)

  def forward(self, input_ids1, att_masks1, output_mode='lm_probs'):
    return self.single_forward(input_ids1, att_masks1, output_mode)

class DoubleMaskFormer(MaskFormer):

  def forward(self, input_ids1, att_masks1, input_ids2, att_masks2, output_mode='lm_probs'):
    return (self.single_forward(input_ids1, att_masks1, output_mode),
            self.single_forward(input_ids2, att_masks2, output_mode))

In [16]:
# @title generic mask-for-mare

class MaskFormer(nn.Module):
  def __init__(self, 
               base_model_name='roberta-base',
               freeze_bert=True,
               is_causal=False):
    super().__init__()

    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    if 'gpt' in base_model_name or is_causal:
      self.lm_mode = 'causal'
      transformer = AutoModelForCausalLM.from_pretrained(base_model_name)
    else:
      self.lm_mode = 'masked'
      transformer = AutoModelForMaskedLM.from_pretrained(base_model_name)

    self.transformer = transformer#.base_model
    self.mask_id = tokenizer.mask_token_id

    if freeze_bert:
      self.freeze_bert(freeze=True)

    self.hidden_size = self.transformer.config.vocab_size

  def freeze_bert(self, freeze):
    for p in self.transformer.parameters():
      p.requires_grad = not freeze

  def masked_forward(self, input_ids, att_masks):
    outputs = self.transformer(input_ids, att_masks, return_dict=True, output_hidden_states=True)
    embs = outputs.hidden_states[-1]
    logits = outputs.logits
    mask_indices = input_ids == self.mask_id

    sum_embs = (att_masks.unsqueeze(-1) * embs).sum(dim=1) 
    avg_embs = sum_embs / att_masks.sum(dim=1).unsqueeze(-1)

    return {
        'avg_embeddings': avg_embs,
        'embeddings': embs[mask_indices],
        'lm_logits': logits[mask_indices],
        'lm_probs': logits[mask_indices].softmax(dim=-1),
    }

  def causal_forward(self, input_ids, att_masks):
    outputs = self.transformer(input_ids, attention_mask=att_masks, return_dict=True, output_hidden_states=True)
    embs = outputs.hidden_states[-1]
    logits = outputs.logits
    last_indices = att_masks.sum(dim=1) - 1
    batch_size = input_ids.shape[0]

    sum_embs = (att_masks.unsqueeze(-1) * embs).sum(dim=1) 
    avg_embs = sum_embs / att_masks.sum(dim=1).unsqueeze(-1)

    return {
        'avg_embeddings': avg_embs,
        'embeddings': embs[torch.arange(batch_size), last_indices],
        'lm_logits': logits[torch.arange(batch_size), last_indices],
        'lm_probs': logits[torch.arange(batch_size), last_indices].softmax(dim=-1),
    }

  def single_forward(self, input_ids, att_masks, output_mode='lm_probs'):
    if self.lm_mode == 'causal':
      outputs = self.causal_forward(input_ids, att_masks)
    else:
      outputs = self.masked_forward(input_ids, att_masks)

    return outputs[output_mode]

  def forward(self, input_ids1, att_masks1, output_mode='lm_probs'):
    return self.single_forward(input_ids1, att_masks1, output_mode)

class DoubleMaskFormer(MaskFormer):

  def forward(self, input_ids1, att_masks1, input_ids2, att_masks2, output_mode='lm_probs'):
    return (self.single_forward(input_ids1, att_masks1, output_mode),
            self.single_forward(input_ids2, att_masks2, output_mode))

In [17]:
#@title define similarity/correlation funcs

from scipy.stats import pearsonr, spearmanr, kendalltau

stemer = PorterStemmer() 
lemmatizer = WordNetLemmatizer()

def to_numpy(tensor):
  if torch.is_tensor(tensor):
    return tensor.cpu().numpy()
  return tensor

def to_tensor(array):
  if torch.is_tensor(array):
    return array
  return torch.from_numpy(array)

def simplify_token(token):
  # return stemer.stem(lemmatizer.lemmatize(token.replace('Ġ', '').lower()))
  return stemer.stem(token.replace('Ġ', '').lower())

def unify_token(token):
  # return stemer.stem(lemmatizer.lemmatize(token.replace('Ġ', '').lower()))
  return token.lower()

def simplify_tokens(token_ids):
  return [simplify_token(t) for t in tokenizer.convert_ids_to_tokens(token_ids)]

def unify_tokens(token_ids):
  return [unify_token(t) for t in tokenizer.convert_ids_to_tokens(token_ids)]

def rbo_rank_sim(feature1, feature2, k=100, p=0.999, m=0.0):
  topk1 = feature1[:k].tolist()
  topk2 = feature2[:k].tolist()
  lower_bound_sim, res, _ = rbo(topk1, topk2, p=p)
  return lower_bound_sim + m * res

def rbo_rank_sim_stem(feature1, feature2, k=100, p=0.999, m=0.0):
  topk1 = feature1[:k].tolist()
  topk2 = feature2[:k].tolist()
  topk1 = simplify_tokens(topk1)
  topk2 = simplify_tokens(topk2)
  lower_bound_sim, res, _ = rbo(topk1, topk2, p=p)
  return lower_bound_sim + m * res

def rbo_rank_sim_lower(feature1, feature2, k=100, p=0.999, m=0.0):
  topk1 = feature1[:k].tolist()
  topk2 = feature2[:k].tolist()
  topk1 = unify_tokens(topk1)
  topk2 = unify_tokens(topk2)
  lower_bound_sim, res, _ = rbo(topk1, topk2, p=p)
  return lower_bound_sim + m * res

def weighted_overlap(ranking1, ranking2):
  overlap, total = [], []
  for i in range(1, len(ranking1)+1):
    total += [1 / (2 * i)]
    idx1 = ranking1.index(i) + 1 if i in ranking1 else -1
    idx2 = ranking2.index(i) + 1 if i in ranking2 else -1

    if idx1 > 0 and idx2 > 0:
      overlap += [1 / (idx1 + idx2)]
      
  return sum(overlap) / sum(total)

def wo_rank_sim(feature1, feature2, k=100, **kwargs):
  topk1 = feature1[:k].tolist()
  topk2 = feature2[:k].tolist()
  return weighted_overlap(topk1, topk2)

def wo_rank_sim_stem(feature1, feature2, k=100, **kwargs):
  topk1 = feature1[:k].tolist()
  topk2 = feature2[:k].tolist()
  topk1 = simplify_tokens(topk1)
  topk2 = simplify_tokens(topk2)
  return weighted_overlap(topk1, topk2)

def pearson_corr(feature1, feature2, **kwargs):
  return pearsonr(to_numpy(feature1), to_numpy(feature2))[0]

def spearman_corr(feature1, feature2, **kwargs):
  return spearmanr(to_numpy(feature1), to_numpy(feature2)).correlation

def kendall_tau(feature1, feature2, k=100, **kwargs):
  return kendalltau(to_numpy(feature1)[:k], to_numpy(feature2)[:k])[0]

def cosine_sim(feature1, feature2, **kwargs):
  return torch.nn.functional.cosine_similarity(to_tensor(feature1).float(), to_tensor(feature2).float(), dim=-1)

def euclidean_dist(feature1, feature2, **kwargs):
  return np.linalg.norm(to_numpy(feature1) - to_numpy(feature2)) ** 0.5


sim_funcs = {
    'rbo': rbo_rank_sim,
    'rbo_stem': rbo_rank_sim_stem,
    'rbo_lower': rbo_rank_sim_lower,
    'wo': wo_rank_sim,
    'wo_stem': wo_rank_sim_stem,
    'pearson': pearson_corr,
    'spearman': spearman_corr,
    'kendall_tau': kendall_tau,
    'cosine': cosine_sim,
    'euclidean': euclidean_dist,
}

In [18]:
def weighted_overlap(ranking1, ranking2):
  overlap, total = [], []
  for i in range(1, len(ranking1)+1):
    total += [1 / (2 * i)]
    idx1 = ranking1.index(i) + 1 if i in ranking1 else -1
    idx2 = ranking2.index(i) + 1 if i in ranking2 else -1

    if idx1 > 0 and idx2 > 0:
      overlap += [1 / (idx1 + idx2)]
      
  return sum(overlap) / sum(total)

weighted_overlap([1, 2, 3, 5, 4], [1, 2, 3, 4, 5])

0.9975669099756691

In [19]:
#@title define inference funcs
from itertools import combinations

@torch.no_grad()
def get_outputs(model, batch, mode='lm_probs'):
  model.eval()
  outputs = {}
  for key in batch:
    if 'input_ids' in key:
      prefix = key.partition('input_ids')[0]
      mask_key = prefix + 'attention_mask'
      input_ids = batch[key].cuda()
      att_mask = batch[mask_key].cuda()
      outputs[prefix+mode] = model(input_ids, att_mask, mode)
  return outputs

def get_class_features(model, batch, labels, mode=('embeddings', 'raw')):
  outputs = get_outputs(model, batch, mode=mode[0])
  all_features = []
  for k, output in outputs.items():
    template_features = []
    for label in labels.cuda().unique():
      class_features = output[labels.cuda()==label].mean(dim=0)
      if mode[1] == 'ranking':
        class_features = class_features.argsort(descending=True)
      template_features.append(class_features)
    all_features.append(torch.stack(template_features, dim=0))
  all_features = torch.stack(all_features, dim=0)
  t, c, d = all_features.shape
  return all_features.view(t*c, d)

def get_batch_features(model, batch, mode=('embeddings', 'raw')):
  outputs = get_outputs(model, batch, mode=mode[0])
  all_features = []
  for k, output in outputs.items():
    template_features = output
    if mode[1] == 'ranking':
      template_features = template_features.argsort(descending=True, dim=-1)
    all_features.append(template_features)
  all_features = torch.stack(all_features, dim=0)
  t, b, d = all_features.shape
  return all_features.permute(1, 0, 2).view(b, t, d)

def get_cosine_sims(batch_features, class_features):
  cosine_sims = []
  for features in batch_features:
    cosine_sims.append(cosine_sim(features, class_features))
  return torch.stack(cosine_sims, dim=0)

def get_class_sims(batch_features, class_features, sim_fn, **kwargs):
  all_sims = []
  for features in batch_features:
    sims = []
    for feature1, feature2 in zip(features, class_features):
      sim = sim_fn(feature1, feature2, **kwargs)
      sims.append(sim)
    all_sims.append(sims)
  return torch.tensor(all_sims)

def get_pair_sims(batch_features, sim_fn, **kwargs):
  all_sims = []
  for features in batch_features:
    sims = []
    for i, j in combinations(range(len(features)), 2):
      sim = sim_fn(features[i], features[j], **kwargs)
      sims.append(sim)
    all_sims.append(sims)
  return torch.tensor(all_sims)

def get_class_sim_dataset(model, data_loader, class_features, sim_fn, mode=('embeddings', 'raw', 'ranking'), **ranking_kwargs):
  all_sims = []

  for batch in tqdm(data_loader):
    batch_features = get_batch_features(model, batch, mode=mode[:2])
    batch_features = batch_features.repeat(1, class_features.shape[0]//batch_features.shape[1], 1)
    batch_sims = get_class_sims(batch_features, class_features, sim_fn, **ranking_kwargs)
    all_sims += batch_sims.tolist()

  return torch.tensor(all_sims)

def get_features_dataset(model, data_loader, mode=('embeddings', 'raw')):
  all_features = []

  for batch in tqdm(data_loader):
    batch_features = get_batch_features(model, batch, mode=mode[:2])
    all_features += batch_features.tolist()

  return torch.tensor(all_features)

def get_pair_sim_dataset(model, data_loader, sim_fn, mode=('embeddings', 'raw', 'ranking'), **ranking_kwargs):
  all_sims = []

  for batch in tqdm(data_loader):
    batch_features = get_batch_features(model, batch, mode=mode[:2])
    # batch_features = batch_features.repeat(1, class_features.shape[0]//batch_features.shape[1], 1)
    batch_sims = get_pair_sims(batch_features, sim_fn, **ranking_kwargs)
    all_sims += batch_sims.tolist()

  return torch.tensor(all_sims)

def get_labels_dataset(data_loader, label_name='label'):
  all_labels = []

  for batch in tqdm(data_loader):
    all_labels += batch[label_name].tolist()

  return torch.tensor(all_labels)

In [20]:
#@title define inference funcs 2
def get_class_mean_sim_data(model, train_loader, test_loader, mode, sim_fn, sim_kwargs, label_name='label'):
  batch = next(iter(train_loader))
  class_features = get_class_features(model, batch, batch[label_name], mode=mode)
  train_sims = get_class_sim_dataset(model, train_loader, class_features, sim_fn, mode=mode, **sim_kwargs)
  test_sims = get_class_sim_dataset(model, test_loader, class_features, sim_fn, mode=mode, **sim_kwargs)
  return train_sims, test_sims

def get_pair_sim_data(model, train_loader, test_loader, mode, sim_fn, sim_kwargs):
  train_sims = get_pair_sim_dataset(model, train_loader, sim_fn, mode=mode, **sim_kwargs)
  test_sims = get_pair_sim_dataset(model, test_loader, sim_fn, mode=mode, **sim_kwargs)
  return train_sims, test_sims

def get_normal_data(model, train_loader, test_loader, mode):
  train_features = get_features_dataset(model, train_loader, mode=mode)
  test_features = get_features_dataset(model, test_loader, mode=mode)
  return train_features.squeeze(1), test_features.squeeze(1)

# Paper Experiments

**wic**: 
* `template1`: `<1-><t1> or <*><-1>`
* `template2`: `<2-><t2> or <*><-2>`

**wic_gpt**: 
* `template1`: `<1-><t1><-1> <t1> means`
* `template2`: `<2-><t2><-2> <t2> means`

**sst**: `<1> this movie was <*>.`

**sst_autoprompt**: `<1> atmosphere alot dialogue Clone totally <*>.`

**sick**: `<1>? Answer: <*>, <2>`

**sick_autoprompt**: `<1> <*> concretepathic workplace <2>`

In [22]:
#@title data process
dataset_name = 'wic' #@param ['sst2', 'sick', 'wic']
base_model_name = "roberta-large" #@param ["bert-base-uncased", "bert-large-uncased", "bert-base-cased", "bert-large-cased", "roberta-base", "roberta-large"]
template1 = "\u003C1->\u003Ct1> or \u003C*>\u003C-1>"  #@param {type: "string"}
template2 = "\u003C2->\u003Ct2> or \u003C*>\u003C-2>"  #@param {type: "string"}
#@markdown `template2`: is used for WiC

# max_len =         64#@param {type: "integer"}
train_ex_per_class =    16#@param {type: "integer"}
test_ex_per_class =      -2#@param {type: "integer"}
#@markdown `test_ex_per_class=-1`: balanced, `-2`: standard testset
batch_size = 128 #@param {type: "integer"}

tokenizer = AutoTokenizer.from_pretrained(base_model_name)

if dataset_name in superglue_rules:
  rules = superglue_rules[dataset_name]
  dataset = load_dataset('super_glue', dataset_name)
else:
  rules = misc_rules[dataset_name]
  real_dataset_name = misc_name_map[dataset_name]
  dataset = load_dataset(real_dataset_name)

rules['<*>'] = tokenizer.mask_token

templates = {'1': template1}
if len(template2.strip()) > 0:
  templates['2'] = template2

use_validation_as_test = True #@param {type: "boolean"}
testset_name = 'validation' if use_validation_as_test else 'test'

trainset = prepare_dataset(dataset['train'], 'label',
                           train_ex_per_class, base_model_name,
                           rules, templates, batch_size)
# valset = prepare_dataset(dataset['validation'], 'label', 
                        #  test_ex_per_class, base_model_name,
#                            rules, templates, batch_size)
testset = prepare_dataset(dataset[testset_name], 'label', 
                          test_ex_per_class, base_model_name,
                           rules, templates, batch_size)

print(f'train examples: {len(trainset)}',
      # f'\nval examples: {len(valset)}',
      f'\ntest examples: {len(testset)}')

train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=False)
# val_loader = DataLoader(valset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)

batch = next(iter(train_loader))
# val_batch = next(iter(val_loader))
test_batch = next(iter(test_loader))
print('\nbatch keys:', batch.keys(),
      '\nbatch input shape:', batch['1_input_ids'].shape,
      '\nbatch labels:', batch['label'])

train examples: 32 
test examples: 638

batch keys: dict_keys(['label', '1_input_ids', '1_attention_mask', '2_input_ids', '2_attention_mask']) 
batch input shape: torch.Size([32, 28]) 
batch labels: tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
        0, 1, 0, 1, 0, 1, 0, 1])


In [23]:
model = MaskFormer(base_model_name).cuda()

## test model

In [24]:
#@title cosine based classification

#@markdown For `wic`, set `sim_mode=pair`, for other datasets (which use one template) use `sim_mode=class`.

if 'model' not in dir():
  model = MaskFormer(base_model_name).cuda()

output_mode = "embeddings" #@param ["embeddings", "lm_logits", "lm_probs", "avg_embeddings"]
feature_mode = "raw" #@param ["raw", "ranking"]
sim_fn_name = "cosine" #@param ["cosine", "euclidean", "rbo", "rbo_stem", "wo", "wo_stem", "pearson", "spearman"]
sim_mode = "pair" #@param ["pair", "class"]


mode = output_mode, feature_mode
sim_fn = sim_funcs[sim_fn_name]
kwargs = {'k': 100, 'p':0.995}

if sim_mode == 'class':
  x, x_test = get_class_mean_sim_data(model, train_loader, test_loader, mode, sim_fn, kwargs)
else:
  x, x_test = get_pair_sim_data(model, train_loader, test_loader, mode, sim_fn, kwargs)

y, y_test = get_labels_dataset(train_loader), get_labels_dataset(test_loader)

clf = LogisticRegression(solver='newton-cg', penalty='none')
clf.fit(x, y)

y_pred = clf.predict(x)
print('train accuracy:', accuracy_score(y, y_pred))

y_test_pred = clf.predict(x_test)
print('test accuracy: ', accuracy_score(y_test, y_test_pred))

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

train accuracy: 0.71875
test accuracy:  0.5924764890282131




In [25]:
#@title spearman based classification

#@markdown For `wic`, set `sim_mode=pair`, for other datasets (which use one template) use `sim_mode=class`.

if 'model' not in dir():
  model = MaskFormer(base_model_name).cuda()

output_mode = "embeddings" #@param ["embeddings", "lm_logits", "lm_probs", "avg_embeddings"]
feature_mode = "raw" #@param ["raw", "ranking"]
sim_fn_name = "spearman" #@param ["cosine", "euclidean", "rbo", "rbo_stem", "wo", "wo_stem", "pearson", "spearman"]
sim_mode = "pair" #@param ["pair", "class"]

mode = output_mode, feature_mode
sim_fn = sim_funcs[sim_fn_name]
kwargs = {'k': 100, 'p':0.995}

if sim_mode == 'class':
  x, x_test = get_class_mean_sim_data(model, train_loader, test_loader, mode, sim_fn, kwargs)
else:
  x, x_test = get_pair_sim_data(model, train_loader, test_loader, mode, sim_fn, kwargs)

y, y_test = get_labels_dataset(train_loader), get_labels_dataset(test_loader)

clf = LogisticRegression(solver='newton-cg', penalty='none')
clf.fit(x, y)

y_pred = clf.predict(x)
print('train accuracy:', accuracy_score(y, y_pred))

y_test_pred = clf.predict(x_test)
print('test accuracy: ', accuracy_score(y_test, y_test_pred))

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

train accuracy: 0.78125
test accuracy:  0.6896551724137931




In [26]:
#@title rbo based classification

#@markdown For `wic`, set `sim_mode=pair`, for other datasets (which use one template) use `sim_mode=class`.

#@markdown The `rbo` similarity (`sim_fn_name`) works with `feature_mode=ranking, output_mode=lm_probs`, which is not reported in the paper.

if 'model' not in dir():
  model = MaskFormer(base_model_name).cuda()

output_mode = "lm_probs" #@param ["embeddings", "lm_logits", "lm_probs", "avg_embeddings"]
feature_mode = "ranking" #@param ["raw", "ranking"]
sim_fn_name = "rbo_stem" #@param ["cosine", "euclidean", "rbo", "rbo_stem", "wo", "wo_stem", "pearson", "spearman"]
sim_mode = "pair" #@param ["pair", "class"]

mode = output_mode, feature_mode
sim_fn = sim_funcs[sim_fn_name]
kwargs = {'k': 100, 'p':0.995}

if sim_mode == 'class':
  x, x_test = get_class_mean_sim_data(model, train_loader, test_loader, mode, sim_fn, kwargs)
else:
  x, x_test = get_pair_sim_data(model, train_loader, test_loader, mode, sim_fn, kwargs)

y, y_test = get_labels_dataset(train_loader), get_labels_dataset(test_loader)

clf = LogisticRegression(solver='newton-cg', penalty='none')
clf.fit(x, y)

y_pred = clf.predict(x)
print('train accuracy:', accuracy_score(y, y_pred))

y_test_pred = clf.predict(x_test)
print('test accuracy: ', accuracy_score(y_test, y_test_pred))

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

train accuracy: 0.75
test accuracy:  0.7084639498432602




## WiC Qualitative Analysis (using RBO ranking similarity)

Run the above classifier for WiC before runing this.

In [46]:
#@title calculate confidence scores
conf_scores = torch.from_numpy(clf.decision_function(x_test))

sorted_indices = {
    "high_conf_positive": conf_scores.argsort(descending=True),
    "high_conf_negative": conf_scores.argsort(descending=False),
    "low_conf": conf_scores.abs().argsort(descending=False),
}

In [53]:
#@title check examples with top-k predicted `<mask>` words

mode = "high_conf_positive"#@param ["high_conf_positive", "high_conf_negative", "low_conf"]
num_examples = 5#@param {type:"integer"}
top_k_words = 10#@param {type:"integer"}
wrong = True #@param {type:"boolean"}

i, c = 0, 0

while c < num_examples and i < len(testset):
  ex_idx = sorted_indices[mode][i]
  ex = testset[ex_idx:ex_idx+1]

  with torch.no_grad():
    out1 = model(ex['1_input_ids'].cuda(), ex['1_attention_mask'].cuda(), 'lm_probs')
    out2 = model(ex['2_input_ids'].cuda(), ex['2_attention_mask'].cuda(), 'lm_probs')

  if (wrong and ex['label'].item() != y_test_pred[ex_idx]) or (not wrong and ex['label'].item() == y_test_pred[ex_idx]):
    print('\nlabel:', ex['label'].item(),
          '\tprediction:', y_test_pred[ex_idx],
          '\tsigned_conf:', conf_scores[ex_idx].item(), '\n')
    print(tokenizer.decode(ex['1_input_ids'][0], skip_special_tokens=False).replace('<pad>', '').replace('<s>', '').replace('</s>', ''))
    print(tokenizer.decode(out1[0].argsort(descending=True)[:top_k_words], skip_special_tokens=False))
    print()
    print(tokenizer.decode(ex['2_input_ids'][0], skip_special_tokens=False).replace('<pad>', '').replace('<s>', '').replace('</s>', ''))
    print(tokenizer.decode(out2[0].argsort(descending=True)[:top_k_words], skip_special_tokens=False))
    print('\nRBO similarity:', rbo_rank_sim_stem(out1[0].argsort(descending=True)[:top_k_words],
                                    out2[0].argsort(descending=True)[:top_k_words]))
    print('_'*100)
    c += 1
  i += 1


label: 0 	prediction: 1 	signed_conf: 8.951171897124643 

He could not conceal his hostility or<mask>.
 anger disgust irritation contempt frustration rage fear disappointment annoyance resentment

He could no longer contain his hostility or<mask>.
 anger rage frustration aggression disgust bitterness fury hatred contempt resentment

RBO similarity: 0.02950917250218558
____________________________________________________________________________________________________

label: 0 	prediction: 1 	signed_conf: 6.653665639844856 

In the middle or<mask> of the marathon, David collapsed from fatigue.
 end late fall later beginning evening early course afternoon middle

Rain during the middle or<mask> of April.
 end beginning fall last all most late rest ending later

RBO similarity: 0.026825909887329944
____________________________________________________________________________________________________

label: 0 	prediction: 1 	signed_conf: 6.141114813328208 

There was a blockage or<mask> i