<a href="https://colab.research.google.com/gist/timellemeet/73087b0e0c49b2b0fbe1c570ce948708/copy-of-thesis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%reset -s -f

In [None]:
%%capture
!pip install transformers==4.3.3 wandb tensorflow_addons --quiet

Logging to Weights & Biases

In [None]:
import wandb
from wandb.keras import WandbCallback
wandb.login(key=[YOUR API KEY HERE])

Define the configuration of the experiment.

In [None]:
import secrets
import logging
import tensorflow as tf
from transformers import logging as hf_logging

PRETRAINED_MODEL = "bert-base-uncased"

SHOW_MASKS = True

SKIP_RUNS = 0

config = {
# General
"name": "Main",
"remark": "Full experiment with proposed metrics",
"dataset": ["SST-2","QQP","QNLI"][0],
"target_size":0.01,
"n_experiments": 15,
"label_type": ["natural","token"][0],

# SpanBERT
"spanbert": True,
"n_spanbert_repeat":10,
"span_clip_multiplier":2,
"span_clip_max":10,
"sb_epochs":10,

# Span Extraction
# "spex":  [False, "finetune","pretrain","both"][0], # False / finetune, pretrain, both
"spex_mode": ["span", "token"][0], 
"spex_pt_epochs": 10,
"spex_ft_epochs": 10,
"shuffle_labels": True,
"ws_mean": 0.57,
"ws_var": 0.05,
"ws_histogram": False,
"ws_metrics": False,

# # Target Analysis
# "ta_epochs": 1,

# Augmentation
"augmentation": [False, "base","hetero"][0], # natural / token / heterogenous (inc target analysis)
"n_train_aug": 10,
"n_target_aug": 2,
"aug_epochs": 15, #later set to 10 #################################
"probabilistic_labels": True,
"min_weak_prob": 0.55,
"mask_prob": 0.15, 
"attention_multiplier": 3,
"UB":1.,
"LB":0.6,

# Finetuning
# "finetune_am": False,
"extrinsic_epochs": 30,
"extrinsic_batch_size": 16,
"eval_epochs": 10,
"eval_batch_size": 32,


# Hyperparameters
"seed":1234,
"batch_size": 32,
"shuffle_batches": 5,
"max_length": 200, ##
"smart_batching":True,
"fast_tokenizer": True,
"optimizer_lr": 2e-5,
"optimizer_epsilon": 1e-8
}

# logging.basicConfig(level=logging.WARNING)
tf.get_logger().setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)
hf_logging.set_verbosity_error()

### GPU
Get information about the GPU given by Colab.

In [None]:
!nvidia-smi

# Dataset processing

### Dataset loading

In [None]:
import pandas as pd

def load_dataset(dataset, split, head=True, return_class_names=False):
  df = pd.read_csv(f"/content/drive/Shareddrives/Thesis/datasets/{dataset}/{split}.tsv", sep='\t',  error_bad_lines=False)

  if dataset == "SST-2":
    df.rename(columns={"sentence": "text"}, inplace=True)
    class_names = ["Negative", "Positive"]

  elif dataset == "QQP":
    df.rename(columns={"question1": "text", "question2": "text_pair", "is_duplicate":"label"}, inplace=True)
    df.drop(columns=["id", "qid1","qid2"], inplace=True)
    class_names = ["Different", "Similar"]

  elif dataset == "QNLI":
    df.rename(columns={"question": "text", "sentence": "text_pair"}, inplace=True)
    df.drop(columns=["index"], inplace=True)
    df.label = df.label.apply(lambda x: 1 if x == "entailment" else 0)
    class_names = ["Missing", "Entailed"]

  if head: 
    display(df.head())
    
  if return_class_names:
    return df, class_names
  else:
    return df
    

### Prepending strategies

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
              PRETRAINED_MODEL, 
              use_fast=config["fast_tokenizer"])

if config["label_type"] == "token":
  label_names = ["[NEG]", "[POS]"]
  tokenizer.add_special_tokens({'additional_special_tokens': label_names})
else:
  #natural tokens
  label_names = False

In [None]:
from random import sample
def prepend_labels(df, task):
  if task =="augmentation":
    mapping = lambda row: f"{label_names[row.label]} {row.text}"
  elif task=="spex":
    #SHUFFLE LABELS TO PREVENT POSITION FITTING

    if config["shuffle_labels"]:
      shuffled_labels = lambda: sample(label_names, len(label_names))
      mapping = lambda row: " ".join([*shuffled_labels(), row.text])

    else:
      mapping = lambda row: " ".join([*label_names, row.text])

  text = df.apply(mapping, axis=1)
  
  #make prepend maks
  label_ids = tokenizer(label_names, add_special_tokens=False).input_ids

  return text, label_ids

In [None]:
def decode_inputs(inputs):
  head = inputs["input_ids"][:5]
  for row in head:
    print(tokenizer.decode(row).replace(" [PAD]", ""))

## Task processing

### Classification

In [None]:
def classification_preprocessing(inputs, df):
  prob_1 = df.label.values 
  prob_0 = 1 - prob_1
  labels = np.vstack((prob_0, prob_1)).T
  
  return inputs, labels

### SpanBERT

In [None]:
from itertools import accumulate

def spanbert_preprocessing(inputs, df):
  normal_tokens_mask = inputs["attention_mask"]*(1-inputs["special_tokens_mask"])

  #set boundary normal tokens to zero such that spans are not allowed to go there
  boundary_indices = np.diff(normal_tokens_mask, n=1)
  boundary_indices = boundary_indices.nonzero()
  boundary_indices[1][::2] +=1 #due to difference we need to move start positions (uneven) one position.
  normal_tokens_mask[boundary_indices] = 0  

  #filter out samples where there are no candiate tokens left (too short)
  n_normal_tokens = normal_tokens_mask.sum(axis=1) 
  filter = n_normal_tokens > 0

  n_normal_tokens = n_normal_tokens[filter]
  normal_tokens_mask = normal_tokens_mask[filter]
  inputs = {key: value[filter] for key, value in inputs.items()}
  

  #reverse count normal token spans to see where the span mask fits.
  flipped = np.flip(normal_tokens_mask,axis=1)
  cumsum_f = lambda row: list(accumulate(row, lambda acc, elem: acc + elem if elem else 0))
  flipped_cumsum = np.array([cumsum_f(row) for row in flipped])
  cumsum_mask = np.flip(flipped_cumsum,axis=1)
  
  #copy dataset if repeat for multiple epochs
  if config["n_spanbert_repeat"] > 1:
    inputs = {key:  np.tile(value, (config["n_spanbert_repeat"], 1)) for key, value in inputs.items()}
    n_normal_tokens = np.tile(n_normal_tokens, config["n_spanbert_repeat"])
    cumsum_mask = np.tile(cumsum_mask, (config["n_spanbert_repeat"], 1))

  #sample the span lengths
  geometric_means = config["mask_prob"] * n_normal_tokens
  geometric_probs = 1/geometric_means
  geometric_probs = np.minimum(geometric_probs, 1) #clip
  
  span_lengths = np.random.geometric(geometric_probs)
  span_lengths = np.minimum(span_lengths, config["span_clip_multiplier"] * geometric_means)
  span_lengths = np.minimum(span_lengths, config["span_clip_max"])

  #Make sure span length doesnt exceed any candidate spans
  span_lengths = np.minimum(span_lengths, np.amax(cumsum_mask, axis=-1))

  #boolean if indices are valid starting positions, by comparing to span length
  candidate_indices = np.greater_equal(cumsum_mask, span_lengths[:,None]).astype(float)
  candidate_probs = candidate_indices / candidate_indices.sum(axis=-1)[:,None]
  
  #sample span locations
  masked_indices = np.array([np.random.multinomial(n=1, pvals=row) for row in candidate_probs])
  _, starting_indices = np.where(masked_indices==1)
  for i, span in enumerate(span_lengths):
    s = starting_indices[i]
    e = int(s+span)
    masked_indices[i, s:e] = 1

  # get pair indices for SBO loss
  pair_indices = np.diff(masked_indices, n=1).nonzero()[1] #take indices of diff cols
  pair_indices = pair_indices.reshape(-1, 2) #shape vector to matrix (:,2)
  pair_indices[:,1] += 1 #adjust for diff position
  inputs["pair_indices"] = pair_indices
  inputs["masked_indices"] = masked_indices

  # tokens with -100 are ignored for the loss, othwerise input id of token
  labels = np.copy(inputs["input_ids"])
  labels = labels * masked_indices
  labels[labels == 0] = -100

  #masking strategies: 80% replaced mask, 10% random, 10% original (not removed)
  mask_types = np.random.multinomial(n=1, 
                          pvals=[0.8, 0.1, 0.1],
                          size=len(masked_indices))
  
  indices_replaced = mask_types[:,0, None] * masked_indices
  indices_random = mask_types[:,1, None] * masked_indices
  indices_filter = 1 - (indices_replaced + indices_random) #
  
  #draw random tokens
  all_tokens = list(tokenizer.get_vocab().values())
  normal_tokens = [t for t in all_tokens if t not in tokenizer.all_special_ids]
  random_words = np.random.choice(normal_tokens, size=labels.shape)

  #remove tokens for mask and replace 
  inputs["input_ids"] *= indices_filter 
  inputs["input_ids"] += indices_replaced * tokenizer.mask_token_id
  inputs["input_ids"] += indices_random * random_words
  
  return inputs, labels

### Span Extraction

In [None]:
#scan for positions test
def find_positions(sequence, ids):
  label_lengths = len(ids) 
  #search for consecutive 
  if label_lengths > 1:
    consecutive = False
    start = 0
    while not consecutive:
      positions = [sequence.index(t, start) for t in ids]
      consecutive = np.array_equal(np.diff(positions), np.ones(label_lengths-1))

      if not consecutive:
        start = min(positions) + 1

    start_position =  positions[0]
    end_position = positions[-1]   
    
  else:
    start_position =  sequence.index(ids[0])
    end_position = start_position

  return start_position, end_position


def spex_preprocessing(inputs, df, label_ids):
  inputs, label_probs = classification_preprocessing(inputs, df)

  shape = inputs["input_ids"].shape
  start_labels = np.zeros(shape)
  end_labels = np.zeros(shape)

  for i, row in enumerate(inputs["input_ids"].tolist()):
    pos_0_s, pos_0_e = find_positions(row, label_ids[0])
    pos_1_s, pos_1_e = find_positions(row, label_ids[1])

    start_labels[i, pos_0_s] = label_probs[i, 0]
    start_labels[i, pos_1_s] = label_probs[i, 1]
    
    end_labels[i, pos_0_e] = label_probs[i, 0]
    end_labels[i, pos_1_e] = label_probs[i, 1]

  if config["spex_mode"] == "span":
    labels = np.stack([start_labels, end_labels], axis=1)
  elif config["spex_mode"] == "token":
    if max([len(l) for l in label_ids]) > 1: 
      raise Exception("Token model cant be used for natural labels with multiple tokens")
    else:
      labels = start_labels

  else:
    print(config["spex_mode"])
    raise Exception("No labels defined")
  return inputs, labels

### Augmentation

In [None]:
def augmentation_preprocessing(inputs, df, label_ids, train, mask_strategy, repeat=1, analysis=None):
  df_labels = df.label.values
  #copy dataset for "dynamic" token masking
  if repeat > 1:
    inputs = {key:  np.tile(value, (repeat, 1)) for key, value in inputs.items()}
    df_labels = np.tile(df_labels, repeat)

  labels = np.copy(inputs["input_ids"])

  #make prepend mask to prevent label masking
  label_lengths = [len(label) for label in label_ids]
  
  prepend_mask = np.zeros(labels.shape, dtype=int)
 
  for i, row in enumerate(prepend_mask):
    label = df_labels[i]
    label_ids_length = label_lengths[label]
    row[1:label_ids_length+1] = 1

  inputs["special_tokens_mask"] += prepend_mask

  #determine which tokens to mask
  if mask_strategy == "base":
    #random masking
    masked_indices = np.random.binomial(n=1, p=config["mask_prob"],size=labels.shape) * (inputs["special_tokens_mask"] == 0) 

  elif mask_strategy == "hetero":
    complexity, attentions = analysis

    if repeat > 1:
      complexity, attentions =  np.tile(complexity, repeat), np.tile(attentions, (repeat, 1))

    max_length = inputs["special_tokens_mask"].shape[1]
    attentions = attentions[:, :max_length]
    
    inputs["complexity"] = complexity
    
    #remove special tokens and reweight probs
    special_token_probs = (attentions * inputs["special_tokens_mask"]).sum(axis=1)
    attentions *= (inputs["special_tokens_mask"] == 0) 
    normal_token_probs = attentions.sum(axis=1)
    redistribution_weights = attentions/normal_token_probs[:,None]
    extra_weights = redistribution_weights * special_token_probs[:,None]
    attentions += extra_weights

    #remove mean
    attentions -= (1/ np.count_nonzero(attentions, axis=1))[:,None]
    attentions *= config["attention_multiplier"]
    attentions += 1
    attentions *= (inputs["special_tokens_mask"] == 0) 

    #multiply weights with probability
    attentions *= config["mask_prob"]
    
    #clip elementwise probs
    attentions = np.clip(attentions, 0, 1)
    
    masked_indices = np.random.binomial(n=1, p=attentions) * (inputs["special_tokens_mask"] == 0)  

  #filter out samples where no masks
  n_masked_tokens = masked_indices.sum(axis=1) 
  filter = n_masked_tokens > 0

  masked_indices = masked_indices[filter]
  inputs = {key: value[filter] for key, value in inputs.items()}
  df_labels = df_labels[filter]
  labels = labels[filter]

  if train:
    # tokens with -100 are ignored for the loss, othwerise input id of token
    labels = labels * masked_indices
    labels[labels == 0] = -100

    # 80% (replace_prob) of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
    # 10% of the time, we replace masked input tokens with random word
    # The rest of the time (10% of the time) we keep the masked input tokens unchanged

    indices_replaced = np.random.binomial(n=1,p=0.8,size=labels.shape) * masked_indices

    indices_random =  np.random.binomial(n=1,p=0.5,size=labels.shape) * masked_indices
    indices_random =  np.maximum(indices_random - indices_replaced,0)

    
    #get list of feasible random word replacements
    all_tokens = list(tokenizer.get_vocab().values())
    normal_tokens = [t for t in all_tokens if t not in tokenizer.all_special_ids]
      
    random_words = np.random.choice(normal_tokens, size=labels.shape) * indices_random
    #adjust inputs by removed selected existing tokens
    indices_filter = 1 - (indices_replaced + indices_random)

    inputs["input_ids"] *= indices_filter
    inputs["input_ids"] += indices_replaced * tokenizer.mask_token_id
    inputs["input_ids"] += random_words
  
  else:
    #in predict mode, replace all masked tokens with mask
    indices_filter = 1 - masked_indices
    inputs["input_ids"] *= indices_filter
    inputs["input_ids"] += masked_indices * tokenizer.mask_token_id

  return inputs, labels, masked_indices, df_labels

### TF Data transformation

In [None]:
def preprocess_data(df,
                    train, 
                    task,  
                    mask_strategy=None, 
                    analysis=None, 
                    repeat=1,
                    show_masks=False, 
                    max_length=config["max_length"], 
                    batch_size=config["batch_size"], 
                    shuffle_batches=config["shuffle_batches"], 
                    smart_batching=config["smart_batching"],
                    unpad=True): 
      
        #Converts a dataframe into into a tokenized Tensorflow Dataset
        #Batches are "smart" for speed based on http://mccormickml.com/2020/07/29/smart-batching-tutorial/

        #used for rolling window shuffle, thus orignal size to not mix epochs too much
        n_obs = len(df.index)

        label_ids = None
        #token prepending
        if task in ["augmentation", "spex"]:
          text, label_ids = prepend_labels(df, task)
        else:
          text = df.text

        #Use the hugginface tokenizer to convert text into tokens and additional masks.
        inputs = tokenizer(
            text = text.values.tolist(),
            text_pair = df.text_pair.values.tolist() if 'text_pair' in df.columns else None,
            return_tensors="np",
            return_attention_mask=True,
            return_token_type_ids=True,
            return_special_tokens_mask=True,
            max_length=max_length,
            padding=True,
            truncation=True,
        ).data

        #task specific processing

        if task == "augmentation":
          inputs, labels, masked_indices, df_labels = augmentation_preprocessing(inputs, df, label_ids, train, mask_strategy, repeat, analysis)
        elif task == "spex":
           inputs, labels = spex_preprocessing(inputs, df, label_ids)
        elif task == "spanbert":
           inputs, labels = spanbert_preprocessing(inputs, df)
        elif task == "classification":
           inputs, labels = classification_preprocessing(inputs, df)
        else:
          raise Exception("Unknown task")   

        if show_masks:
          print("decode inputs")
          decode_inputs(inputs)
        
        if train and smart_batching:
          #order based on token length
          order = inputs["attention_mask"].sum(axis=1).argsort()
          inputs = {key: inputs[key][order] for key in inputs.keys()}
          labels = labels[order] 

        dataset = tf.data.Dataset.from_tensor_slices(
            (inputs, labels) if train else inputs
        )

        #Shuffle samples for randomness, if smart batching a small rolling window is used to keep roughly equal length in a batch.
        if train:
          buffer_size = min(shuffle_batches * batch_size, n_obs) if smart_batching else n_obs
          
          dataset = dataset.shuffle(
              buffer_size=buffer_size,
              reshuffle_each_iteration=True,
              seed=1234,
          )


        dataset = dataset.batch(batch_size)

        #Shuffle batches to randomize lengths.
        if train:
          dataset = dataset.shuffle(
              np.ceil(n_obs / batch_size), reshuffle_each_iteration=True
          )

        #Unpad batches to the longest sequence in a batch for speed.
        def unpad_batch(batch, labels=None):
            batch_sequence_lengths = tf.math.reduce_sum(batch["attention_mask"], axis=1)

            max_batch_length = tf.reduce_max(batch_sequence_lengths)
            for field in ["input_ids", "token_type_ids", "attention_mask"]:
                batch[field] = batch[field][:, :max_batch_length]

            if task in ["spex","augmentation", "spanbert"] and train:
              labels = labels[..., :max_batch_length]

              if task == "spanbert":
                batch["masked_indices"] = batch["masked_indices"][:, :max_batch_length]

            return batch, labels

        if unpad:
          dataset = dataset.map(unpad_batch)

        if task == "augmentation" and not train:
          return dataset, inputs["input_ids"], masked_indices, df_labels
        elif task=="spex" and not train:
          return dataset, labels
        else:
          return dataset

# Tasks

### Model helpers

In [None]:
def switch_head(old_model, task, summary=False, **kwargs):
  model_config = BertConfig.from_pretrained(PRETRAINED_MODEL)
  path = f"models/cache/{secrets.token_hex(4)}"
  old_model.save_pretrained(path)
  model_config.update(kwargs)
  new_model = task.from_pretrained(path, config=model_config)

  if summary:
    new_model.summary()

  return new_model 

In [None]:
from sklearn.metrics import  (
    log_loss,
    mean_squared_error,
    mean_absolute_error,
    median_absolute_error,
    accuracy_score,
    confusion_matrix,
    f1_score,
    matthews_corrcoef,
    precision_score,
    recall_score,
    roc_auc_score,
)

def calculate_metrics(y_true, y_prob, prefix=False): 
  proba_metrics = {
      "Log Loss": log_loss(y_true.round(), y_prob),
      "MSE": mean_squared_error(y_true, y_prob),
      "MAE": mean_absolute_error(y_true, y_prob),
      "MedianAE": median_absolute_error(y_true, y_prob),
  }

  y_pred = y_prob.round()
  y_true = y_true.round()

  tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[False, True]).ravel()
  binary_metrics = {
        "Accuracy": accuracy_score(y_true, y_pred),
        "Macro F1": f1_score(y_true, y_pred, average="macro"),
        "Matthews": matthews_corrcoef(y_true, y_pred),
        "Precision": precision_score(y_true, y_pred),
        "Recall": recall_score(y_true, y_pred),
        "ROC AUC": roc_auc_score(y_true, y_pred, average="macro"),
        "Micro F1": f1_score(y_true, y_pred, average="micro"),
        "True Negatives": tn,
        "True Positives": tp,
        "False Positives": fp,
        "False Negatives": fn
    }

  combined = {**binary_metrics, **proba_metrics}

  if prefix:
    combined = {f"{prefix} {key}": val for key, val in combined.items()} 

  return combined

def aggregate_metrics(results):
  metrics_agg = {
    "n_runs": len(results)
  }

  for metric in results[0].keys():
    metric_rows = np.array([row[metric] for row in results])
    metrics_agg[metric+" (mean)"] = metric_rows.mean()
    metrics_agg[metric+" (std)"] = metric_rows.std()

  return metrics_agg

In [None]:
def flatten_metric(metric):
    #check if metric has multiple values per class
    #only case for matthews but is symmetric
    if isinstance(metric, np.ndarray):
      return metric[0].item()
    else: 
      return metric

class TrainLogger(tf.keras.callbacks.Callback):
  def __init__(self, task=""):
    self.task = task
  
  def on_train_end(self, logs=None):
    wandb.run.summary["graph"] = wandb.Graph.from_keras(self.model)

  def on_train_batch_end(self, batch, logs=None):    
    metrics = {f"Train {key} ({self.task})": flatten_metric(value) for key, value in logs.items()}
    wandb.log(metrics, commit=True)

  def on_epoch_end(self, epoch, logs=None):
    metrics = {f"Epoch {key} ({self.task})": flatten_metric(value) for key, value in logs.items()}
    wandb.log(metrics, commit=True)


In [None]:
#optimizers
from tensorflow_addons.optimizers import LAMB
from transformers import AdamWeightDecay

from transformers.models.bert.modeling_tf_bert import TFBertPreTrainedModel, TFBertMainLayer, TFBertForMaskedLM

### SpanBERT

In [None]:
from transformers.models.bert.modeling_tf_bert import TFBertPredictionHeadTransform, TFBertMLMHead, TFBertPositionEmbeddings

class TFBertForSpanBert(TFBertForMaskedLM):
  def __init__(self, config, *inputs, position_size=200,**kwargs):
      super().__init__(config, *inputs, **kwargs)

      self.sbo_hidden = TFBertPredictionHeadTransform(config, name="sbo_hidden")
      self.sbo = TFBertMLMHead(config, input_embeddings=self.bert.embeddings.word_embeddings, name="sbo___cls")
      
      self.position_size = position_size

      self.position_embeddings = TFBertPositionEmbeddings(
            max_position_embeddings=config.max_position_embeddings,
            hidden_size=position_size, # as per spanbert
            initializer_range=config.initializer_range,
            name="position_embeddings",
        )

  def call(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        labels=None,
        training=False,
        **kwargs,
    ):
      inputs = input_processing(
            func=self.call,
            config=self.config,
            input_ids=input_ids["input_ids"],
            attention_mask=input_ids.get("attention_mask"),
            token_type_ids=input_ids.get("token_type_ids"),
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            labels=labels,
            training=training,
            kwargs_call=kwargs,
        )
    
      outputs = self.bert(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            token_type_ids=inputs["token_type_ids"],
            position_ids=inputs["position_ids"],
            head_mask=inputs["head_mask"],
            inputs_embeds=inputs["inputs_embeds"],
            output_attentions=inputs["output_attentions"],
            output_hidden_states=inputs["output_hidden_states"],
            return_dict=inputs["return_dict"],
            training=inputs["training"],
        )
      sequence_output = outputs[0]

      if input_ids.get("masked_indices") is not None:
        #mlm
        masked_sequence_output = tf.boolean_mask(sequence_output, input_ids["masked_indices"])
        masked_sequence_output = tf.expand_dims(masked_sequence_output, axis=0)
        
        mlm_preds = self.mlm(sequence_output=masked_sequence_output, training=inputs["training"])[0]
        
        #sbo
        span_lengths = tf.reduce_sum(input_ids["masked_indices"], axis=-1)
        
        #get indices for gather: 
        # if span_lengths [3,1,2]
        #position relative from start [0,1,2,0,0,1]
        #hidden state repeats indices to get hidden state * span_lengths [0,0,0,1,2,2]
        
        position_indices = tf.ragged.range(span_lengths).flat_values
        hidden_repeats = tf.repeat(tf.range(len(span_lengths)), span_lengths)

        #get left and right hidden states 
        #[hidden_state x obs]
        left_indices, right_indices = tf.unstack(input_ids["pair_indices"], axis=1)
        left_hidden = tf.gather(sequence_output, left_indices, batch_dims=1)
        right_hidden = tf.gather(sequence_output, right_indices, batch_dims=1)

        # [n_masked_tokens x hidden_size]
        left_hidden = tf.gather(left_hidden, hidden_repeats)
        right_hidden = tf.gather(right_hidden, hidden_repeats)

        # [max_span_length x hidden_size]
        max_span_length = tf.math.reduce_max(span_lengths)
        position_embeds = self.position_embeddings(sequence_output[:1, :max_span_length,:self.position_size])[0]
        # [n_masked_tokens x hidden_size] 
        position_hidden = tf.gather(position_embeds, position_indices)

        #combine the three inputs into a single vector of size [n_masked_tokens x 2*hidden_size + position_size]
        sbo_sequence = tf.concat([left_hidden, right_hidden, position_hidden], axis=1)
        #apply compression dense layer to [n_masked_tokens x hidden_size] 
        sbo_output = self.sbo_hidden(hidden_states=sbo_sequence,  training=inputs["training"])

        #Apply second layer and map to tokens, extra dim is needed for input [1 x n_masked_tokens x hidden_size] 
        sbo_output = tf.expand_dims(sbo_output, axis=0)
        sbo_preds = self.sbo(sequence_output=sbo_output, training=inputs["training"])[0] #remove dim again
        
        #stack both preds [n_masked_tokens x 2 (tasks) x hidden_size] 
        combined_preds = tf.stack([mlm_preds, sbo_preds], axis=1)
        
        return combined_preds
      else:
        return sequence_output

In [None]:
scce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
def spanbert_loss(y_true, y_pred):
    mlm_preds, sbo_preds = tf.unstack(y_pred, axis=1)
    y_true_masked = tf.boolean_mask(y_true, tf.not_equal(y_true, -100))
    mlm_loss = scce(y_true_masked, mlm_preds)
    sbo_loss = scce(y_true_masked, sbo_preds)
    loss = mlm_loss + sbo_loss
    return loss 

def spanbert_training(train_df, dev_df, model):

  train_dataset = preprocess_data(train_df, 
                            train=True,
                            task="spanbert")
  
  dev_dataset = preprocess_data(dev_df, 
                            train=True,
                            task="spanbert")

  # model = switch_head(model, TFBertForSpanBert)

  optimizer = LAMB(learning_rate=config["optimizer_lr"], epsilon=config["optimizer_epsilon"])
  loss  = spanbert_loss

  tf.keras.backend.clear_session()
  model.compile(optimizer, loss)

  early_stopping = tf.keras.callbacks.EarlyStopping(restore_best_weights=True)

  model.fit(train_dataset,
            validation_data=dev_dataset, 
            epochs=config["sb_epochs"], 
            callbacks=[TrainLogger(task="SpanBERT"), early_stopping]
            )

  return model

In [None]:
import os.path
from os import path

def from_spanbert(train_df=None, dev_df=None):
  base_path = "/content/drive/Shareddrives/Thesis/models/spanbert/"
  model_name_cols = ["dataset","max_length", "n_spanbert_repeat","span_clip_multiplier","sb_epochs"]
  model_name = " - ".join([f"{col} {config[col]}" for col in model_name_cols])
  print(f"checking for spanbert model: {model_name}")

  full_path = base_path + model_name
  
  if path.exists(full_path):
    print("Retrieving SpanBERT model from cache")
    model_config = BertConfig.from_pretrained(PRETRAINED_MODEL)
    model = TFBertForSpanBert.from_pretrained(full_path, config=model_config)
  else:
    if train_df is None or dev_df is None:
      raise Exception("Spanbert Model needs to be trained but no datasets provided")
    print("Training new SpanBERT model")
    model = TFBertForSpanBert.from_pretrained(PRETRAINED_MODEL)
    model = spanbert_training(train_df, dev_df, model)
    model.save_pretrained(full_path)

  return model

### Span Extraction

Simulating weak supervision

In [None]:
import seaborn as sns
sns.set_theme()
def simulate_ws(df):
  mu, var = config["ws_mean"], config["ws_var"]

  if mu < 0.5 or var >= mu * (1-mu):
    raise Exception("Invalid mu or variance provided")
  
  alpha = ((1-mu)/var - 1/mu) * (mu ** 2) 
  beta = alpha*(1/mu - 1)

  draws = np.random.beta(alpha, beta, size=len(df.index))

  #log draws
  table_title = f'Weak Supervison draws'
  if config["ws_histogram"]:
    plt = sns.distplot(draws, axlabel=table_title)
    
  table = wandb.Table(data=[[d] for d in draws], columns=["draws"])
  wandb.log({table_title: wandb.plot.histogram(table, "draws", title=table_title)})

  y_true = df.label.values
  y_weak = np.absolute(y_true - (1-draws))


  metrics = calculate_metrics(y_true, y_weak, prefix="WS")
  wandb.log(metrics)

  if config["ws_metrics"]:
    print("Weak Supervision metrics")
    display(metrics)

  df.label = y_weak

  return df, metrics

Token selection model

In [None]:
from transformers import TFBertForQuestionAnswering
from transformers.modeling_tf_utils import input_processing
class TFBertForSpanSelection(TFBertForQuestionAnswering):       
  def call(
        self,
        input_ids = None,
        attention_mask = None,
        token_type_ids = None,
        position_ids = None,
        head_mask = None,
        inputs_embeds = None,
        output_attentions = None,
        output_hidden_states = None,
        return_dict = None,
        start_positions = None,
        end_positions = None,
        training = False,
        **kwargs,
    ):
        calc_attentions = not training and self.run_eagerly

        inputs = input_processing(
            func=self.call,
            config=self.config,
            input_ids=input_ids["input_ids"],
            attention_mask= input_ids.get("attention_mask"),
            token_type_ids=input_ids.get("token_type_ids"),
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=calc_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            start_positions=start_positions,
            end_positions=end_positions,
            training=training,
            kwargs_call=kwargs,
        )

        outputs = self.bert(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            token_type_ids=inputs["token_type_ids"],
            position_ids=inputs["position_ids"],
            head_mask=inputs["head_mask"],
            inputs_embeds=inputs["inputs_embeds"],
            output_attentions=inputs["output_attentions"],
            output_hidden_states=inputs["output_hidden_states"],
            return_dict=inputs["return_dict"],
            training=inputs["training"],
        )
        sequence_output = outputs[0]
        logits = self.qa_outputs(inputs=sequence_output)
        splits = tf.split(value=logits, num_or_size_splits=self.num_labels , axis=-1)
        squeezed = [tf.squeeze(input=s, axis=-1) for s in splits]

        if len(squeezed) == 1:
          predictions = squeezed[0]
        else:
          predictions = tf.stack(squeezed, axis=1)


        if calc_attentions:
          #if span modeee

          # #stack layer attentions
          # attentions = tf.stack(outputs["attentions"], axis=0)

          # #mean per layer
          # attentions = tf.math.reduce_mean(attentions, axis=0)

          #get attentions from last layer
          # print(outputs["attentions"][:10])
          attentions = outputs["attentions"][-1]

          #mean per head
          attentions = tf.math.reduce_mean(attentions, axis=1)

          #mean per token
          attentions = tf.math.reduce_mean(attentions, axis=1)
          
          return predictions, attentions
        else:
          return predictions

In [None]:
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

def span_selection_loss(labels, logits):
  loss = cce(labels, logits)
  if config["spex_mode"] == "span":
    #If multiple tasks, reduce loss per task
    loss = tf.math.reduce_mean(loss, axis=1)
  return loss

In [None]:
def spex_training(train_df,  dev_df, model, log_name, epochs=1):
  print(f"\nSpan Extraction {log_name} training\n")
  if config["spex_mode"] == "token":
     model = switch_head(model, TFBertForSpanSelection, num_labels=1) 
  else:
     model = switch_head(model, TFBertForSpanSelection, num_labels=2) 
  
  train_dataset = preprocess_data(df=train_df, train=True, task="spex")
  dev_dataset = preprocess_data(df=dev_df, train=True, task="spex")                          
  
  loss = span_selection_loss
  optimizer = LAMB(learning_rate=config["optimizer_lr"], epsilon=config["optimizer_epsilon"])
  train_metrics = [
      tf.keras.metrics.CategoricalAccuracy(name="Accuracy")
    ]
  
  tf.keras.backend.clear_session()
  model.compile(optimizer, span_selection_loss, train_metrics)
  
  early_stopping = tf.keras.callbacks.EarlyStopping(restore_best_weights=True)

  log_name = "ExEval "+prefix if prefix else "ExEval" 
  model.fit(train_dataset,
            validation_data=dev_dataset, 
            epochs=epochs, 
            callbacks=[TrainLogger(task="spex"+log_name), early_stopping]
            )
  
  return model

In [None]:
def target_analysis(df, model):
  print("\nTarget Analysis\n")
  if config["spex_mode"] == "token":
     model = switch_head(model, TFBertForSpanSelection, num_labels=1) 
  else:
     model = switch_head(model, TFBertForSpanSelection, num_labels=2) 
  
  dataset, labels = preprocess_data(df=df, 
                            train=False,
                            task="spex",
                            unpad=False
                            )
  
  model.compile(run_eagerly=True)

  predictions, attentions = model.predict(dataset)

  probs = softmax(predictions, axis=-1)
  complexity = 1 - (probs * labels).sum(axis=-1)

  if config["spex_mode"] == "span":
    complexity = complexity.mean(axis=1)

  return complexity, attentions

### Augmentation

In [None]:
# ONLY USED FOR TARGET AUGMENTATION NOT TRAINING
class TFBertForPredictMaskedLM(TFBertForMaskedLM):
  def call(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        labels=None,
        training=False,
        **kwargs,
    ):

    outputs = super().call( 
        input_ids["input_ids"],
        input_ids.get("attention_mask"),
        input_ids.get("token_type_ids"),
        position_ids,
        head_mask,
        inputs_embeds,
        output_attentions,
        output_hidden_states,
        return_dict,
        labels,
        training
    )

    # print(outputs)

    if training or not self.run_eagerly:
      return outputs
    else:  
      probs = tf.nn.softmax(outputs.logits, axis=2)

      complexity = input_ids.get("complexity")
      if config["augmentation"] == "hetero" and complexity is not None: 
        #observation specific lower bounds
        obs_lower_bounds = config["LB"] + (config["UB"]  - config["LB"])*complexity

        sorted_probs = tf.sort(probs, direction='DESCENDING')
        sorted_indices = tf.argsort(probs, direction='DESCENDING')
        cum_prob = tf.math.cumsum(sorted_probs, axis=-1)

        ub_mask = tf.math.greater_equal(cum_prob, 1-config["UB"])
        
        lb_threshold = tf.stack([tf.fill(tf.shape(ub_mask[0]), 1 - lb) for lb in obs_lower_bounds], axis=0)
        lb_mask = tf.math.less_equal(cum_prob, lb_threshold)

        prob_mask = tf.math.logical_and(ub_mask, lb_mask)
        prob_mask = tf.cast(prob_mask, tf.float32)
        
        n_viable_tokens = tf.math.reduce_sum(prob_mask, axis=-1)
        mean_viable_tokens = tf.math.reduce_mean(n_viable_tokens).numpy()
        
        if mean_viable_tokens < 10:
          print(f"Mean suitable candidate tokens: {mean_viable_tokens.round(2)}")
          print(f"Mean complexity: {tf.math.reduce_mean(complexity).numpy()}")

        if tf.math.reduce_min(n_viable_tokens).numpy() == 0:
          # if no viable tokens, set most likely token to prob 1
          backup_mask = 1-tf.math.minimum(n_viable_tokens, 1)
          backup_mask = tf.expand_dims(backup_mask, axis=2)
          paddings = [[0,0], [0,0],[0, tf.shape(prob_mask)[-1] -1]]
          backup_mask = tf.pad(backup_mask, paddings)
          prob_mask+= backup_mask

          print("No candidate tokens for an observation")
          print(f"Mean suitable candidate tokens: {mean_viable_tokens.round(2)}")
        
       
        #filter out of bounds candidates
        candidate_probs = sorted_probs * prob_mask

        #reweight probabilities
        candidate_probs /= tf.expand_dims(tf.reduce_sum(candidate_probs, axis=-1), axis=-1) 
        candiate_logits = tf.math.log(candidate_probs)

        #sample tokens

        sampled_indices = tf.stack([tf.random.categorical(l, 1) for l in candiate_logits], axis=0)
        sampled_indices = tf.squeeze(sampled_indices)
        
        preds = tf.gather(sorted_indices, sampled_indices, axis=2, batch_dims=2)
        selected_probs = tf.gather(sorted_probs, sampled_indices, axis=2, batch_dims=2)

      else:
        #most likely selections
        preds = tf.math.argmax(probs, axis=2)
        selected_probs = tf.math.reduce_max(probs, axis=2)

      #pad to equal batch length

      shape = tf.shape(input_ids["input_ids"])
      paddings = [[0, 0], [0, config["max_length"]-shape[1]]]
      preds = tf.pad(preds, 
                    paddings, 
                    'CONSTANT', 
                    constant_values=tokenizer.pad_token_id)
      
      selected_probs = tf.pad(selected_probs, 
                    paddings, 
                    'CONSTANT', 
                    constant_values=0)
      outputs.logits = preds
      outputs.attentions = selected_probs

      return outputs

In [None]:
scce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
def masked_sparse_categorical_crossentropy(y_true, y_pred):

    y_true_masked = tf.boolean_mask(y_true, tf.not_equal(y_true, -100))
    y_pred_masked = tf.boolean_mask(y_pred, tf.not_equal(y_true, -100))
    loss = scce(y_true_masked, y_pred_masked)
    return loss

def augmentation_training(target_df, dev_df, model, logname=""):
  print("Augmentation training")
  train_dataset = preprocess_data(target_df, 
                            train=True,
                            task="augmentation",
                            mask_strategy="base",
                            repeat=config["n_train_aug"]
                            )
  
  dev_dataset = preprocess_data(dev_df, 
                            train=True,
                            task="augmentation",
                            mask_strategy="base",
                            repeat=config["n_train_aug"]
                            )

  # model = switch_head(model, TFBertForMaskedLM)
  model = switch_head(model, TFBertForPredictMaskedLM)

  optimizer = LAMB(learning_rate=config["optimizer_lr"], epsilon=config["optimizer_epsilon"])
  loss = masked_sparse_categorical_crossentropy

  tf.keras.backend.clear_session()
  model.compile(optimizer, loss)
  
  early_stopping = tf.keras.callbacks.EarlyStopping(restore_best_weights=True)

  model.fit(train_dataset,
            validation_data=dev_dataset, 
            epochs=config["aug_epochs"], 
            callbacks=[TrainLogger(task=logname+ " - aug"), early_stopping]
            )

  return model

In [None]:
from transformers import BertConfig

def target_augmentation(target_df, model, analysis=None, mask_strategy=config["augmentation"], prob_labels=config["probabilistic_labels"], batch_size=config["batch_size"], prefix="", head=False):
  print("Target Augmentation")
  dataset, input_ids, masked_indices, df_labels = preprocess_data(
                            df=target_df, 
                            train=False,
                            task="augmentation",
                            analysis=analysis,
                            mask_strategy=mask_strategy,
                            batch_size=batch_size,
                            repeat=config["n_target_aug"]
                            )
  
  model = switch_head(model, TFBertForPredictMaskedLM, output_attentions=True)
  
  #Perform augmentation
  tf.keras.backend.clear_session()
  model.compile(run_eagerly=True)
  predictions = model.predict(dataset, verbose=1)

  preds = predictions.logits
  probs = predictions.attentions
  
  #cut predictions to max sequence length
  max_length = masked_indices.shape[1]
  preds = preds[:,:max_length]
  probs = probs[:,:max_length]

  # place token predictions
  inputs_masked = input_ids * (1 - masked_indices)
  prediction_mask = preds * masked_indices
  inputs_masked += prediction_mask
  texts = tokenizer.batch_decode(inputs_masked)

  #fill augmented df
  def text_mapping(text, label):
    start = len(f"[CLS] {label_names[label]}")
    end = text.find("[SEP]")
    return text[start:end].strip()

  def text_pair_mapping(text):
    start =  text.find("[SEP]")+len("[SEP]")
    end = text.find("[SEP]", start)
    return text[start:end].strip()

  text = [text_mapping(text, df_labels[i]) for i,text in enumerate(texts)]
  
  if 'text_pair' in target_df.columns:
    text_pair = [text_pair_mapping(text) for text in texts]
  else:
    text_pair = False

  #calculate weak labels
  if prob_labels:
    probs_sum = np.sum(probs * masked_indices, axis=1)
    K = masked_indices.sum(axis=1)

    normal_mask = ~np.any([input_ids == t for t in tokenizer.all_special_ids], axis=0)
    N = normal_mask.sum(axis=1) + K

    if config["label_type"] == "natural":
      label_lengths = [len(l_ids) for l_ids in tokenizer(label_names, add_special_tokens=False).input_ids]
      N -= np.array([label_lengths[label] for label in df_labels])

    #formula
    ##confidence = (np.square(N-K) + K * probs_sum) / np.square(N)
    confidence = ((N-K) + probs_sum) / N

    #insure no negative labels
    confidence = np.fmax(confidence, config["min_weak_prob"])

    #set probabilistic labels
    df_labels = np.absolute(1-df_labels-confidence)

  #Save to df
  augmented_df = pd.DataFrame({"text":text, "label":df_labels})

  if text_pair is not False:
    augmented_df["text_pair"] = text_pair

  #calculate diversity
  n_unique_tokens = len(np.unique(prediction_mask)) - 1 #remove 0 
  n_masked_tokens = masked_indices.sum()

  type_token_ratio = {
      prefix+" Type Token Ratio": n_unique_tokens / n_masked_tokens
  }
  wandb.log(type_token_ratio)
  

  #print some augmentations
  if head:
    print("\nraw predictions\n")
    texts = tokenizer.batch_decode(preds[:head])
    for row in texts:
      print(row.replace(" [PAD]", ""))

    print("\ninputs\n")
    texts = tokenizer.batch_decode(input_ids[:head])
    for row in texts:
      print(row.replace(" [PAD]", ""))

    print("\nAugmented sequences\n")
    texts = tokenizer.batch_decode(inputs_masked[:head])
    for row in texts:
      print(row.replace(" [PAD]", ""))

    print("\n")

    display(target_df.head(head))
    display(augmented_df.head(head))

  return augmented_df, type_token_ratio

# Evaluation methods

In [None]:
from transformers import TFBertForSequenceClassification
#use to ensure attention mask and token type ids is used
class TFBertForClassification(TFBertForSequenceClassification):
  def call(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        labels=None,
        training=False,
        **kwargs,
    ):

    return super().call( 
        input_ids["input_ids"],
        input_ids.get("attention_mask"),
        input_ids.get("token_type_ids"),
        position_ids,
        head_mask,
        inputs_embeds,
        output_attentions,
        output_hidden_states,
        return_dict,
        labels,
        training
    )

In [None]:
from tensorflow_addons.metrics import MatthewsCorrelationCoefficient, F1Score
from scipy.special import softmax

def extrinsic_evaluation(train_df, test_df, dev_df, epochs=config["extrinsic_epochs"], model=None, finetune=False, prefix=False, patience=0, batch_size=config["extrinsic_batch_size"]):
  if finetune:
    print("Finetuning Augmentation model")
    #Check if this doesnt mess with the global model
    model = switch_head(model, TFBertForClassification)
  else:
    print("Training new end model")
    model = TFBertForClassification.from_pretrained(PRETRAINED_MODEL)

  optimizer = LAMB(learning_rate=config["optimizer_lr"], epsilon=config["optimizer_epsilon"])
  loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
  train_metrics = [
      tf.keras.metrics.CategoricalAccuracy(name= "Accuracy"),
      F1Score(num_classes=2, name="Micro F1", average="micro"),
    ]

  tf.keras.backend.clear_session()

  model.compile(optimizer, loss, train_metrics)
  
  train_dataset = preprocess_data(train_df, train=True, task="classification", batch_size=batch_size)
  dev_dataset = preprocess_data(dev_df, train=True, task="classification")
  test_dataset = preprocess_data(test_df, train=False, task="classification")
  
  early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_Accuracy", 
                                                    mode="max",
                                                    patience=patience,
                                                    restore_best_weights=True)

  log_name = prefix+" ExEval" if prefix else "ExEval" 
  model.fit(train_dataset, 
            epochs=epochs,
            validation_data=dev_dataset, 
            callbacks=[TrainLogger(task=log_name), early_stopping])

  predictions = model.predict(test_dataset)[0]
  y_probs = softmax(predictions, axis=1)
  y_pred = np.argmax(predictions, axis=1)
  y_true = test_df.label.values

  metrics = calculate_metrics(y_true, y_pred, prefix=prefix)

  wandb.log({prefix+" pr" : wandb.plot.pr_curve(y_true, y_probs, labels=class_names),
             prefix+" roc" : wandb.plot.roc_curve(y_true, y_probs, labels=class_names),
             prefix+" conf_mat" : wandb.plot.confusion_matrix(y_probs, y_true, class_names=class_names)})

  wandb.log(metrics)
  return metrics

In [None]:
def train_eval_model(train_df, dev_df, test_df, epochs=config["eval_epochs"], batch_size=config["eval_batch_size"]):
  print("Training Eval model")
  eval_model = TFBertForClassification.from_pretrained(PRETRAINED_MODEL)
  
  optimizer = LAMB(learning_rate=config["optimizer_lr"], epsilon=config["optimizer_epsilon"])
  loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

  train_metrics = [
      tf.keras.metrics.CategoricalAccuracy(name= "Accuracy")
    ]

  tf.keras.backend.clear_session()

  eval_model.compile(optimizer, loss, train_metrics)

  train_dataset = preprocess_data(train_df, train=True, task="classification", batch_size=batch_size)
  dev_dataset = preprocess_data(dev_df, train=True, task="classification", batch_size=batch_size)

  early_stopping = tf.keras.callbacks.EarlyStopping(restore_best_weights=True)

  eval_model.fit(train_dataset,
                 validation_data=dev_dataset, 
                 epochs=epochs, 
                 callbacks=[TrainLogger(task="Semantic"), early_stopping])
  
  test_dataset = preprocess_data(test_df, train=False, task="classification", batch_size=batch_size)

  predictions = eval_model.predict(test_dataset)[0]
  y_probs = softmax(predictions, axis=1)
  y_pred = np.argmax(predictions, axis=1)
  y_true = test_df.label.values

  semantic_metrics = calculate_metrics(y_true, y_pred, prefix="Semantic")

  wandb.log({"Semantic pr" : wandb.plot.pr_curve(y_true, y_probs, labels=class_names),
             "Semantic roc" : wandb.plot.roc_curve(y_true, y_probs, labels=class_names),
             "Semantic conf_mat" : wandb.plot.confusion_matrix(y_probs, y_true, class_names=class_names)})

  wandb.log(semantic_metrics)

  return eval_model, semantic_metrics

In [None]:
def semantic_eval(eval_model, augmented_df, prefix=""):
  print("Semantic evaluation "+prefix)
  dataset = preprocess_data(augmented_df, train=False, task="classification")

  # in this case predictions serve as ground truth to evaluate weak labels
  predictions = eval_model.predict(dataset)[0]

  y_probs = softmax(predictions, axis=1)
  y_true = y_probs[:,1]

  y_weak = augmented_df.label.values

  logname = prefix+ " sem_eval"
  metrics = calculate_metrics(y_true, y_weak, prefix=logname)
  wandb.log(metrics)

  return metrics

In [None]:
def augmentation_eval(target_df, augmented_df, test_data, dev_df, model=None, prefix="eval"):
  print("Extrinsic Evaluation")

  metrics = {}

  # print("only augmented data")
  # only_aug_metrics = extrinsic_evaluation(augmented_df, test_data, dev_df, finetune=False, prefix=f"{prefix} - Only AUG")
  # metrics.update(only_aug_metrics)

  # print("only target but equal amount of data as combined")
  # only_target_more_data = extrinsic_evaluation(target_df, test_data, dev_df, finetune=False, prefix=f"{prefix} - Only Target", epochs=config["extrinsic_epochs"] * (config["n_target_aug"] + 1))
  # metrics.update(only_target_more_data)
  
  print("Combined dataset")
  combined_df = pd.concat([target_df, augmented_df], ignore_index=True)
  combined_metrics = extrinsic_evaluation(combined_df, test_data, dev_df, finetune=False, prefix=f"{prefix} - Combined")
  metrics.update(combined_metrics)

  if model is not None:
    print("Combined dataset with finetuned model")
    combined_wam = extrinsic_evaluation(combined_df, test_data, dev_df, finetune=True, model=model, prefix=f"{prefix} - Combined wAM", patience=2)
    metrics.update(combined_wam)
  
  return metrics

# Experiment main

### Development loop
useful for toggling settings

In [None]:
# steps = ["max_length", "target_size", "ws_mean", "ws_var", "shuffle_labels","mask_prob", "spex_mode"]
# tags = [f"{step}: {config[step]}" for step in steps]
# group = " - ".join(tags)

# wandb.init(project="thesis-"+config["dataset"],
#            reinit=True,
#            id=config["id"],
#            name=f'{config["name"]} ({config["id"]})',
#            notes=config["remark"],
#            save_code=True,
#            tags=tags,
#            group=group,
#            config=config)

# import numpy as np
# from sklearn.model_selection import train_test_split
# from transformers import TFBertModel

# train_data, class_names = load_dataset(config["dataset"], split="train", return_class_names=True)
# test_data = load_dataset(config["dataset"], split="dev", head=False)

# if not label_names:
#   label_names = class_names

# np.random.seed(config["seed"])
# target_data_seeds = np.random.randint(low=1, high=10000, size=config["n_experiments"], )

# results = []
# for run, seed in enumerate(target_data_seeds):
#   print(f"\nExecuting run {run+1}/{config['n_experiments']}\n")
  
#   metrics = {}

#   model = TFBertModel.from_pretrained(PRETRAINED_MODEL)
#   weak_df, target_df = train_test_split(train_data, test_size=config["target_size"], random_state=seed, stratify=train_data.label)
  
#   if config["finetune_am"]:
#     print("Initial Evaluation")
#     extrinsic_metrics = extrinsic_evaluation(target_df, test_data, finetune=False, prefix="Initial")
#     metrics.update(extrinsic_metrics)

#   if config["spanbert"] is not False: 
#     print("SpanBERT training")
#     model = spanbert_training(weak_df, target_df, model)

#     if config["finetune_am"]:
#       print("After SpanBERT Evaluation")
#       extrinsic_metrics = extrinsic_evaluation(target_df, test_data, finetune=True, model=model, prefix="After SpanBERT")
#       metrics.update(extrinsic_metrics)

  
#   if config["spex"] is not False:

#     if config["spex"] in ["pretrain","both"]:
#       print("Simulating Weak Supervision")
#       weak_df, ws_metrics = simulate_ws(weak_df)
#       metrics.update(ws_metrics)

#       print("Span Extractive pre-training")
#       model = spex_training(weak_df, model, log_name="pretrain", epochs=config["spex_pt_epochs"])

#       if config["augmentation"] == "hetero":
#         print("Target Analysis")
#         analysis = target_analysis(target_df, model)
#         target_complexity = {"target_complexity": analysis[0].mean()}
#         wandb.log(target_complexity)
#         metrics.update(target_complexity)

#       else:
#         analysis = None

#     if config["spex"] in ["finetune","both"]:
#       print("Span Extractive finetuning")
#       model = spex_training(target_df, model, log_name="finetune", epochs=config["spex_ft_epochs"])

#     if config["finetune_am"]:
#       print("After SpanEx Evaluation")
#       extrinsic_metrics = extrinsic_evaluation(target_df, test_data, finetune=True, model=model, prefix="After SPEX")
#       metrics.update(extrinsic_metrics)

#   if config["augmentation"] is not False:
#     print("Augmentation model training")
#     model = augmentation_training(target_df, model)
    
#     print("Target Augmentation")
#     augmented_df, type_token_ratio = target_augmentation(target_df, model, analysis, head=10)
#     metrics.update(type_token_ratio)

#     combined_df = pd.concat([target_df, augmented_df], ignore_index=True)

#     # print("Semantic Evaluation")
#     # semantic cross entropy = intrinsic_evaluaton(augmented_df)

#   else:
#     combined_df = target_df

#   print("Extrinsic Evaluation")

#   print("only augmented data")
#   extrinsic_metrics = extrinsic_evaluation(augmented_df, test_data, finetune=False, prefix="Only AUG")
#   metrics.update(extrinsic_metrics)

#   print("only target but equal amount of data")
#   extrinsic_metrics = extrinsic_evaluation(target_df, test_data, finetune=False, prefix="Only target equal epochs", epochs=config["extrinsic_epochs"] * (config["n_target_aug"] + 1))
#   metrics.update(extrinsic_metrics)
  
#   print("Combined dataset")
#   extrinsic_metrics = extrinsic_evaluation(combined_df, test_data, finetune=False, prefix="Combined")
#   metrics.update(extrinsic_metrics)

#   if config["finetune_am"]: 
#     print("Combined dataset with finetuned model")
#     extrinsic_metrics = extrinsic_evaluation(combined_df, test_data, finetune=True, model=model, prefix="Combined wAM")
#     metrics.update(extrinsic_metrics)

#   print("\nRun Metrics\n")
#   display(metrics)
#   results.append(metrics)

#   #aggegate metrics
#   overall_metrics = aggregate_metrics(results)
#   wandb.log(overall_metrics)
#   wandb.run.summary.update(overall_metrics)

# print("\nOverall Metrics\n")
# display(overall_metrics)


### Test loop
useful for paper results

In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
from transformers import TFBertModel

main_data, class_names = load_dataset(config["dataset"], split="train", return_class_names=True)
dev_df = load_dataset(config["dataset"], split="dev", head=False)

if not label_names:
  label_names = class_names

np.random.seed(config["seed"])
target_data_seeds = np.random.randint(low=1, high=10000, size=config["n_experiments"], )
aug_preview = 5

#W&B group name
steps = ["max_length", "target_size", "spex_pt_epochs", "spex_ft_epochs", "aug_epochs","mask_prob", "spex_mode"]
tags = [f"{step}_{config[step]}" for step in steps]
group = " - ".join(tags)

for run, seed in enumerate(target_data_seeds):
  print(f"\nExecuting run {run+1}/{config['n_experiments']}\n")

  #skip runs if continuing from new point
  if run < SKIP_RUNS:
    print(f"Skipping run {run+1}")
    continue
  
  config["id"] = str(secrets.token_hex(4))

  wandb.init(project="thesis-"+config["dataset"],
            reinit=True,
            id=config["id"],
            name=f'{config["name"]} - {run+1} ({config["id"]})',
            notes=config["remark"],
            save_code=True,
            tags=tags,
            group=group,
            config=config)
  
  #already train spanbert before hand just in case
  # from_spanbert(main_data, dev_df)

  # prep dataset
  # sample a test and target df, both same percentage on total data
  train_df, test_df = train_test_split(main_data, test_size=config["target_size"], random_state=seed, stratify=main_data.label)
  weak_df, target_df = train_test_split(train_df, test_size=config["target_size"] / (1-config["target_size"]), random_state=seed, stratify=train_df.label)

  # train model for semantic evaluation and reference performance
  eval_model, top_metrics = train_eval_model(weak_df, dev_df, test_df)

  # generate weak supervision
  weak_df, ws_metrics = simulate_ws(weak_df)

  # initial evaluation
  print("Initial Evaluation")
  initial_metrics = extrinsic_evaluation(target_df, test_df, dev_df, finetune=False, prefix="Baseline")

  #vanilla augmentation train + eval
  prefix = "Benchmark"
  print(f"\n{prefix} Evaluation\n")
  model = TFBertModel.from_pretrained(PRETRAINED_MODEL)
  model = augmentation_training(target_df, dev_df, model, logname=prefix)
  augmented_df, type_token_ratio = target_augmentation(target_df, 
                                                       model, 
                                                       analysis=None,
                                                       mask_strategy="base",
                                                       prob_labels=False,
                                                       prefix=prefix,
                                                       head=aug_preview)

  semantic_quality =  semantic_eval(eval_model, augmented_df, prefix=prefix)

  eval_metrics = augmentation_eval(target_df, augmented_df, test_df, dev_df, model, prefix=prefix)

  # spex ft
  # vanilla augmentation train + eval
  prefix = "spex-ft"
  print(f"\n{prefix} Evaluation\n")
  model = TFBertModel.from_pretrained(PRETRAINED_MODEL)
  model = spex_training(target_df, dev_df, model, log_name=prefix+" finetune", epochs=config["spex_ft_epochs"])
  model = augmentation_training(target_df, dev_df, model, logname=prefix)
  augmented_df, type_token_ratio = target_augmentation(target_df, 
                                                       model, 
                                                       analysis=None,
                                                       mask_strategy="base",
                                                       prob_labels=False,
                                                       prefix=prefix,
                                                       head=aug_preview)

  semantic_quality =  semantic_eval(eval_model, augmented_df, prefix=prefix)

  eval_metrics = augmentation_eval(target_df, augmented_df, test_df, dev_df, model, prefix=prefix)

  ##if we have ws
  # spex pt
  # spex ft
  # vanilla augmentation train + eval
  prefix = "spex-full"
  print(f"\n{prefix} Evaluation\n")
  model = TFBertModel.from_pretrained(PRETRAINED_MODEL)
  model = spex_training(weak_df, dev_df, model, log_name=prefix+" pretrain", epochs=config["spex_pt_epochs"])
  model = spex_training(target_df, dev_df, model, log_name=prefix+" finetune", epochs=config["spex_ft_epochs"])

  model = augmentation_training(target_df, dev_df, model, logname=prefix)
  augmented_df, type_token_ratio = target_augmentation(target_df, 
                                                       model, 
                                                       analysis=None,
                                                       mask_strategy="base",
                                                       prob_labels=False,
                                                       prefix=prefix,
                                                       head=aug_preview)

  semantic_quality =  semantic_eval(eval_model, augmented_df, prefix=prefix)

  eval_metrics = augmentation_eval(target_df, augmented_df, test_df, dev_df, model, prefix=prefix)

  ##if we have spanbert
  # load spanbert
  # spex pt
  # target analsyis
  # spex ft
  # vanilla augmentation train
  # vanilla augmentation eval
  # heterogenous augmentation eval
  # hetero + prob_label augmentation eval
  prefix = "wSB"
  print(f"\n{prefix} Evaluation\n")

  model = from_spanbert(main_data, dev_df)
  model = spex_training(weak_df, dev_df, model, log_name=prefix+" pretrain", epochs=config["spex_pt_epochs"])

  analysis = target_analysis(target_df, model)
  target_complexity = {"target_complexity": analysis[0].mean()}
  wandb.log(target_complexity)

  model = spex_training(target_df, dev_df, model, log_name=prefix+" finetune", epochs=config["spex_ft_epochs"])
  model = augmentation_training(target_df, dev_df, model, logname=prefix)
  
  eval_configs = ["sb+spex", "sb+spex+hetero", "sb+spex+hetero+probs"]
  for i, name in enumerate(eval_configs):
    mask_strategy = "base" if i == 0 else "hetero"
    prob_labels = False if i <= 1 else True

    augmented_df, type_token_ratio = target_augmentation(target_df, 
                                                        model, 
                                                        analysis=analysis,
                                                        mask_strategy=mask_strategy,
                                                        prob_labels=prob_labels,
                                                        prefix=name,
                                                        head=aug_preview)

    semantic_quality =  semantic_eval(eval_model, augmented_df, prefix=name)

    eval_metrics = augmentation_eval(target_df, augmented_df, test_df, dev_df, model, prefix=name)