<a href="https://colab.research.google.com/github/sadra-barikbin/AIChallengeSSA/blob/master/Attentionist_SSA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
!pip install transformers pytorch-ignite wandb stanza==1.1.1 seqeval

Collecting transformers
  Downloading transformers-4.18.0-py3-none-any.whl (4.0 MB)
[K     |████████████████████████████████| 4.0 MB 5.1 MB/s 
[?25hCollecting pytorch-ignite
  Downloading pytorch_ignite-0.4.9-py3-none-any.whl (259 kB)
[K     |████████████████████████████████| 259 kB 49.5 MB/s 
[?25hCollecting wandb
  Downloading wandb-0.12.16-py2.py3-none-any.whl (1.8 MB)
[K     |████████████████████████████████| 1.8 MB 37.8 MB/s 
[?25hCollecting stanza==1.1.1
  Downloading stanza-1.1.1-py3-none-any.whl (227 kB)
[K     |████████████████████████████████| 227 kB 45.1 MB/s 
[?25hCollecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[K     |████████████████████████████████| 43 kB 856 kB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 44.5 MB/s 
[?25hCollecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenize

In [None]:
import os
import gc
import json
import torch
import ignite
import wandb
import json
import warnings
import functools
import itertools
import collections
import copy
import yaml
import torch.nn.functional     as     F
import numpy                   as     np
import matplotlib.pyplot       as     plt
import scipy.stats
from  tabulate                                import tabulate
from  operator                                import attrgetter, itemgetter
from  enum                                    import IntEnum
from  dataclasses                             import dataclass
from  torch                                   import Tensor
from  tqdm                                    import tqdm
from  torch.utils.data                        import DataLoader, Dataset
from  torch.utils.tensorboard                 import SummaryWriter
from  torch.nn.utils.rnn                      import pad_sequence
from  torch                                   import nn, Tensor
from  torch.optim.lr_scheduler                import StepLR
from  transformers                            import RobertaModel, RobertaTokenizerFast
from  transformers                            import PreTrainedTokenizerFast
from  transformers                            import BertModel, BertTokenizerFast
from  transformers                            import AutoTokenizer, AutoModel
from  transformers.modeling_outputs           import BaseModelOutput
from  typing                                  import Tuple, Dict, List, Any, Sequence, Union
from  typing                                  import Optional, Callable, Type
from  timeit                                  import timeit
from  ignite.metrics                          import Accuracy, Fbeta, Average
from  ignite.metrics                          import Metric, RunningAverage
from  ignite.handlers.terminate_on_nan        import TerminateOnNan
from  ignite.handlers.checkpoint              import ModelCheckpoint
from  ignite.handlers                         import EarlyStopping, global_step_from_engine
from  ignite.handlers                         import EpochOutputStore, LRScheduler
from  ignite.handlers.param_scheduler         import create_lr_scheduler_with_warmup
from  ignite.handlers.param_scheduler         import PiecewiseLinear
from  ignite.engine                           import create_supervised_trainer
from  ignite.engine                           import Engine, create_supervised_evaluator
from  ignite.engine.events                    import Events
from  ignite.contrib.handlers.tqdm_logger     import ProgressBar
from  ignite.contrib.handlers.wandb_logger    import WandBLogger

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
DATASET_NAMES_MAP = {'opener_en': 'Opener_En', 'opener_es': 'Opener_Es', 'norec': 'Norec',
                     'darmstadt_unis': 'Darmstadt_unis', 'multibooked_eu': 'Multibooked_eu',
                     'multibooked_ca': 'Multibooked_ca', 'mpqa': 'MPQA'}

In [None]:
!wandb login --cloud 6f2c5e33031ea25c0b2e103b93f7adca208f2772

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


# Helper Classes

In [None]:
class TokenBIOLabelEnum(IntEnum):

  def __new__(cls, value, tag=None):
    obj = int.__new__(cls, value)
    obj._value_ = value
    obj.tag = tag
    return obj

  @classmethod
  def from_integers(cls, integers: Sequence[int]) -> Sequence[IntEnum]:
    return [cls(integer) for integer in integers]

  @classmethod
  def from_labels(cls, labels: Sequence[IntEnum]) -> Sequence[int]:
    return [label.value for label in labels]

  def is_begin(self):
    return self % 2 == 0 and self != self.__class__.OUT

  def is_in(self):
    return self % 2 == 1

  def is_out(self):
    return self == self.__class__.OUT

In [None]:
class NonPolarSentExprTokenEnum(TokenBIOLabelEnum):
  EXPR_BEGIN = 0
  EXPR_IN = 1
  OUT = 2

class SentExprTokenEnum(TokenBIOLabelEnum):
  POSITIVE_EXPR_BEGIN = (0, 'Positive')
  POSITIVE_EXPR_IN = (1, 'Positive')
  NEGATIVE_EXPR_BEGIN = (2, 'Negative')
  NEGATIVE_EXPR_IN = (3, 'Negative')
  NEUTRAL_EXPR_BEGIN = (4, 'Neutral')
  NEUTRAL_EXPR_IN = (5, 'Neutral')
  OUT = 6

class SentHolderTokenEnum(TokenBIOLabelEnum):
  EXPR_BEGIN = 0
  EXPR_IN = 1
  OUT = 2


class SentTargetTokenEnum(TokenBIOLabelEnum):
  EXPR_BEGIN = 0
  EXPR_IN = 1
  OUT = 2

In [None]:
@dataclass
class SentimentGraphNode:

  span_in_sentence: slice
  tag: Any=None

  @property
  def span_length(self) -> int:
    return self.span_in_sentence.stop - self.span_in_sentence.start


  @property
  def indices(self) -> List[int]:
    return list(range(self.span_in_sentence.start, self.span_in_sentence.stop))

  def __hash__(self) -> int:
    return (self.span_in_sentence.start, self.span_in_sentence.stop).__hash__()


  def __eq__(self, other) -> bool:
    return self.span_in_sentence == other.span_in_sentence


@dataclass
class SentimentGraphNodeSet:
  sentiment_nodes: List[SentimentGraphNode]
  target_nodes: List[SentimentGraphNode]
  holder_nodes: List[SentimentGraphNode]

In [None]:
@dataclass
class SentimentGraphGoldNode:
  nodes: List[SentimentGraphNode]

  def __hash__(self) -> int:
    return tuple(self.nodes).__hash__()


  def __eq__(self, other) -> bool:
    return self.nodes == other.nodes


@dataclass
class SentimentGraphGoldNodeSet:
  sentiment_nodes: List[SentimentGraphGoldNode]
  target_nodes: List[SentimentGraphGoldNode]
  holder_nodes: List[SentimentGraphGoldNode]

In [None]:
@dataclass
class SentimentGraphEdgeSet:
  sentiment_target_edges: Tensor
  sentiment_holder_edges: Tensor

@dataclass
class SentimentGraphGoldEdgeSet:
  sentiment_target_edges: Dict[SentimentGraphGoldNode, SentimentGraphGoldNode]
  sentiment_holder_edges: Dict[SentimentGraphGoldNode, SentimentGraphGoldNode]

In [None]:
@dataclass
class SentimentGraph:
  nodes: SentimentGraphNodeSet
  edges: SentimentGraphEdgeSet

@dataclass
class GoldSentimentGraph:
  nodes: SentimentGraphGoldNodeSet
  edges: SentimentGraphGoldEdgeSet

# Tokenizer

In [None]:
@functools.lru_cache()
def get_tokenizer(dataset_name: str) -> PreTrainedTokenizerFast:
  if dataset_name in ["opener_en", "mpqa", "darmstadt_unis"]:
    return RobertaTokenizerFast.from_pretrained('roberta-base', add_prefix_space=True)
  else:
    return BertTokenizerFast.from_pretrained('setu4993/LaBSE', add_prefix_space=True)

# Config

In [None]:
CONFIG = {
    "sequence_labeler":{
        "WEIGHTED_LOSS" : False
    },
    "base_pooler":{
        "AVERAGE_HEADS" : False,
        "TARGET_AND_HOLDER_SEPARATE_HEAD" : False
    },
    "edge_predictor":{
        "USE_BIAS" : False
    },
    "polarity_predictor" : {
        "exist" : False,
        "WEIGHTED_LOSS" : False
    },
    "dataset" : "opener_en",
    "batch_size" : 32,
    "random_seed" : 41,
    "optimizer" : "AdamW",
    "LR" : {
        "base" : {
            "init" : 1e-4,
            "policy" : "StepLR(9,.1)",
            "warm-up" : "Linear(1e-6 -> 1e-4, 1 epoch)"
        },
        "novelty" : {
            "init" : 1e-3,
            "policy" : "StepLR(9,.1)",
            "warm-up" : "Linear(1e-5 -> 1e-3, 1 epoch)"
        }
    }
}

# Data

In [None]:
!git clone https://github.com/sadra-barikbin/semeval22_structured_sentiment.git

Cloning into 'semeval22_structured_sentiment'...
remote: Enumerating objects: 1007, done.[K
remote: Counting objects: 100% (113/113), done.[K
remote: Compressing objects: 100% (65/65), done.[K
remote: Total 1007 (delta 64), reused 78 (delta 48), pack-reused 894[K
Receiving objects: 100% (1007/1007), 16.09 MiB | 12.01 MiB/s, done.
Resolving deltas: 100% (493/493), done.


In [None]:
import stanza
stanza.download('en')
stanza.download('es')
stanza.download('eu')
stanza.download('ca')
stanza.download('nb')

Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/master/resources_1.1.0.json: 122kB [00:00, 20.6MB/s]                    
2022-05-11 06:07:14 INFO: Downloading default packages for language: en (English)...
Downloading http://nlp.stanford.edu/software/stanza/1.1.0/en/default.zip: 100%|██████████| 428M/428M [01:16<00:00, 5.61MB/s]
2022-05-11 06:08:38 INFO: Finished downloading models and saved to /root/stanza_resources.
Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/master/resources_1.1.0.json: 122kB [00:00, 43.5MB/s]                    
2022-05-11 06:08:38 INFO: Downloading default packages for language: es (Spanish)...
Downloading http://nlp.stanford.edu/software/stanza/1.1.0/es/default.zip: 100%|██████████| 583M/583M [01:45<00:00, 5.51MB/s]
2022-05-11 06:10:33 INFO: Finished downloading models and saved to /root/stanza_resources.
Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/master/resources_1.1.0.json

In [None]:
! cd semeval22_structured_sentiment/data/mpqa && bash process_mpqa.sh
! cd semeval22_structured_sentiment/data/darmstadt_unis && bash process_darmstadt.sh

# Prepare Datasets

In [None]:
class SemEval2022Task10Dataset(Dataset):

  def __init__(self, split: str, config: Dict[str, Any],
               path: str='semeval22_structured_sentiment/data'):
    super(SemEval2022Task10Dataset, self).__init__()

    dataset_names = ["opener_en", "opener_es", "norec", "darmstadt_unis", "mpqa",
                    "multibooked_ca", "multibooked_eu", "ALL"]

    self.have_polarity_predictor = config["polarity_predictor"]["exist"]

    name = config["dataset"]
    assert split in ["train", "dev", "test"]
    self.split = split

    tokenizer = get_tokenizer(name)

    if name == 'ALL':
      assert split != 'test'

      data = []
      for _name in dataset_names[:-1]:
        data.extend(json.load(open(f"{path}/{_name}/{split}.json")))
    else:
      data = json.load(open(f"{path}/{name}/{split}.json"))

    if split == 'test':
      self.data = data
      return

    self.data = []
    sentim_labels_stat = collections.Counter(SentExprTokenEnum)
    target_labels_stat = collections.Counter(SentTargetTokenEnum)
    holder_labels_stat = collections.Counter(SentHolderTokenEnum)
    polarity_stat = collections.Counter()
    for sent_idx, item in enumerate(data):
      if not item['opinions']:
          continue
      if item['sent_id'] in ['multibooked/corpora/eu/kaype-quintamar-llanes_1-1',
                              'temp_fbis/21.31.56-18015-2', '20020414/21.16.03-15717-6',
                              'xbank/wsj_0583-27']:
        # These samples are erroneous
        continue

      encoded = tokenizer([item['text']], add_special_tokens=False,
                         return_offsets_mapping=True, return_length=True)
      tokens_char_offsets = encoded['offset_mapping'][0]
      length = encoded['length'][0]
      if self.have_polarity_predictor:
        tokens_labels_sentiment = [NonPolarSentExprTokenEnum.OUT] * length
      else:
        tokens_labels_sentiment = [SentExprTokenEnum.OUT] * length
      tokens_labels_target = [SentTargetTokenEnum.OUT] * length
      tokens_labels_holder = [SentHolderTokenEnum.OUT] * length
      sentiment_nodes = []
      target_nodes = []
      holder_nodes = []
      sentim_tgt_edges = {}
      sentim_hld_edges = {}

      gold_sentiment_nodes = []
      gold_target_nodes = []
      gold_holder_nodes = []
      gold_sentim_tgt_edges = {}
      gold_sentim_hld_edges = {}

      for opinion in item['opinions']:
        opinion_sentiment_nodes = []
        opinion_target_nodes = []
        opinion_holder_nodes = []

        polarity_stat.update(opinion['Polarity'])

        for char_span in opinion["Polar_expression"][1]:
          char_span_begin, char_span_end = (int(x) for x in char_span.split(':'))
          token_span_begin, token_span_end = self._char_span_to_token_span_idx(
              sent_idx, tokens_char_offsets, (char_span_begin, char_span_end)
          )
          node = SentimentGraphNode(
              slice(token_span_begin, token_span_end), opinion['Polarity']
          )
          opinion_sentiment_nodes.append(node)
          if node not in sentiment_nodes:
            sentiment_nodes.append(node)

            self._update_sequence_labels(
                sent_idx, tokens_labels_sentiment,
                token_span_begin, token_span_end, "sentiment", opinion["Polarity"]
            )

        for char_span in opinion["Target"][1]:
          char_span_begin, char_span_end = (int(x) for x in char_span.split(':'))
          token_span_begin, token_span_end = self._char_span_to_token_span_idx(
              sent_idx, tokens_char_offsets, (char_span_begin, char_span_end)
          )
          node = SentimentGraphNode(
              slice(token_span_begin, token_span_end), "target"
          )
          opinion_target_nodes.append(node)
          if node not in target_nodes:
            target_nodes.append(node)

            self._update_sequence_labels(
                sent_idx, tokens_labels_target,
                token_span_begin, token_span_end, "target"
            )

        for char_span in opinion["Source"][1]:
          char_span_begin, char_span_end = (int(x) for x in char_span.split(':'))
          token_span_begin, token_span_end = self._char_span_to_token_span_idx(
              sent_idx, tokens_char_offsets, (char_span_begin, char_span_end)
          )
          node = SentimentGraphNode(
              slice(token_span_begin, token_span_end), "holder"
          )
          opinion_holder_nodes.append(node)
          if node not in holder_nodes:
            holder_nodes.append(node)

            self._update_sequence_labels(
                sent_idx, tokens_labels_holder,
                token_span_begin, token_span_end, "holder"
            )

        for sentiment_node in opinion_sentiment_nodes:
          sentim_tgt_edges[sentiment_node] = opinion_target_nodes
          sentim_hld_edges[sentiment_node] = opinion_holder_nodes
        
        gold_sentiment_node = SentimentGraphGoldNode(opinion_sentiment_nodes)
        gold_target_node = SentimentGraphGoldNode(opinion_target_nodes)
        gold_holder_node = SentimentGraphGoldNode(opinion_holder_nodes)
        
        gold_sentiment_nodes.append(gold_sentiment_node)
        gold_target_nodes.append(gold_target_node)
        gold_holder_nodes.append(gold_holder_node)

        gold_sentim_tgt_edges[gold_sentiment_node] = gold_target_node
        gold_sentim_hld_edges[gold_sentiment_node] = gold_holder_node

      sentiment_nodes.sort(key=attrgetter("span_in_sentence"))
      target_nodes.sort(key=attrgetter("span_in_sentence"))
      holder_nodes.sort(key=attrgetter("span_in_sentence"))

      sentim_tgt_adj_matrix = torch.zeros((len(sentiment_nodes), len(target_nodes)))
      sentim_hld_adj_matrix = torch.zeros((len(sentiment_nodes), len(holder_nodes)))
      for i, sentim_node in enumerate(sentiment_nodes):

        for neighbor_tgt_node in sentim_tgt_edges[sentim_node]:
          sentim_tgt_adj_matrix[i, target_nodes.index(neighbor_tgt_node)] = 1

        for neighbor_hld_node in sentim_hld_edges[sentim_node]:
          sentim_hld_adj_matrix[i, holder_nodes.index(neighbor_hld_node)] = 1
      
      gold_node_set = SentimentGraphGoldNodeSet(gold_sentiment_nodes, gold_target_nodes,
                                                gold_holder_nodes)
      gold_edge_set = SentimentGraphGoldEdgeSet(gold_sentim_tgt_edges, gold_sentim_hld_edges)
      gold_graph = GoldSentimentGraph(gold_node_set, gold_edge_set)
      

      self.data.append({
          "sent_id": item['sent_id'],
          "text": item['text'],
          "seq_labels": {
              "sentiment": tokens_labels_sentiment,
              "target": tokens_labels_target,
              "holder": tokens_labels_holder},
          "graph":SentimentGraph(
              SentimentGraphNodeSet(
                sentiment_nodes, target_nodes, holder_nodes
              ),
              SentimentGraphEdgeSet(
                sentim_tgt_adj_matrix, sentim_hld_adj_matrix
              )
          ),
          "gold_graph": gold_graph}
        )

      sentim_labels_stat.update(tokens_labels_sentiment)
      target_labels_stat.update(tokens_labels_target)
      holder_labels_stat.update(tokens_labels_holder)

    self.sentim_seq_label_weights = torch.Tensor(
        [sentim_labels_stat[label] for label in sorted(sentim_labels_stat)])
    self.sentim_seq_label_weights = self.sentim_seq_label_weights.max() / \
                                          self.sentim_seq_label_weights

    self.target_seq_label_weights = torch.Tensor(
        [target_labels_stat[label] for label in sorted(target_labels_stat)])
    self.target_seq_label_weights = self.target_seq_label_weights.max() / \
                                          self.target_seq_label_weights

    self.holder_seq_label_weights = torch.Tensor(
        [holder_labels_stat[label] for label in sorted(holder_labels_stat)])
    self.holder_seq_label_weights = self.holder_seq_label_weights.max() / \
                                          self.holder_seq_label_weights

    self.polarity_weights = torch.Tensor([polarity_stat['Positive'], polarity_stat['Neutral'],
                                          polarity_stat['Negative']])

  def _update_sequence_labels(self, sent_idx: int, sequence: List[TokenBIOLabelEnum],
                              start_token: int, end_token: int,
                              expr_type: str, polarity: str=None):
    if expr_type == 'sentiment':
      if self.have_polarity_predictor:
        begin_label = NonPolarSentExprTokenEnum.EXPR_BEGIN
        in_label = NonPolarSentExprTokenEnum.EXPR_IN
        out = NonPolarSentExprTokenEnum.OUT
      else:
        if polarity == 'Positive':
          begin_label = SentExprTokenEnum.POSITIVE_EXPR_BEGIN
          in_label = SentExprTokenEnum.POSITIVE_EXPR_IN
        elif polarity == 'Negative':
          begin_label = SentExprTokenEnum.NEGATIVE_EXPR_BEGIN
          in_label = SentExprTokenEnum.NEGATIVE_EXPR_IN
        elif polarity == 'Neutral':
          begin_label = SentExprTokenEnum.NEUTRAL_EXPR_BEGIN
          in_label = SentExprTokenEnum.NEUTRAL_EXPR_IN
        else:
          raise ValueError("Given expr_type is 'sentiment' but 'polarity' is None.")
      out = SentExprTokenEnum.OUT
    elif expr_type == 'target':
      begin_label = SentTargetTokenEnum.EXPR_BEGIN
      in_label = SentTargetTokenEnum.EXPR_IN
      out = SentTargetTokenEnum.OUT
    elif expr_type == 'holder':
      begin_label = SentHolderTokenEnum.EXPR_BEGIN
      in_label = SentHolderTokenEnum.EXPR_IN
      out = SentHolderTokenEnum.OUT
    else:
      raise ValueError(f"Given expr_type is not recognized. expr_type={expr_type}")

    # First time I saw these warnings in MPQA dataset. They were
    # mistakes so I opened an issue. But generally such scenarios in which
    # a part of an expression is itself another expression, are acceptable.
    # So I took union of them as the final expression.
    if sequence[start_token] != out:
      warnings.warn(f"Sequence is already updated at index {start_token}, "
                    f"expr_type={expr_type}, sent_idx={sent_idx}")

      if expr_type == 'sentiment' and sequence[start_token].tag != begin_label.tag:
        warnings.warn("Two expressions with conflicting polarities"
                      f" in the same place. sent_idx={sent_idx} "
                      f"polarities={sequence[start_token].tag}-{begin_label.tag}")
    else:
      sequence[start_token] = begin_label
    for i in range(start_token + 1, end_token):
      if sequence[i] != out:
        warnings.warn(f"Sequence is already updated at index {i}, "
                      f"expr_type={expr_type}, sent_idx={sent_idx}")

        if expr_type == 'sentiment' and sequence[i].tag != in_label.tag:
          warnings.warn("Two expressions with conflicting polarities in the same place."
                        f" sent_idx={sent_idx} polarities={sequence[i].tag}-{in_label.tag}")

      sequence[i] = in_label


  def _char_span_to_token_span_idx(
      self, sent_idx: int, tokens_char_offsets: List[Tuple[int, int]],
      char_span: Tuple[int, int]) -> Tuple[int, int]:

    begin_token_idx, end_token_idx = None, None
    for i, (char_offset_begin, char_offset_end) in enumerate(tokens_char_offsets):
      if char_offset_begin == char_span[0] and begin_token_idx is None:
        begin_token_idx = i

      if char_offset_end == char_span[1] and end_token_idx is None:
        end_token_idx = i + 1

      if begin_token_idx is not None and end_token_idx is not None:
        return begin_token_idx, end_token_idx

    raise ValueError("Given char_span does not exist in given token char offsets. "
                     f"sent_idx={sent_idx} char_span={char_span} "
                     f"token_char_offsets={tokens_char_offsets}")


  def __len__(self):
    return len(self.data)


  def __getitem__(self, idx) -> Dict[str, Any] :
    item = self.data[idx]

    if self.split == 'train':
      item["label_weights"] = {"sentiment": self.sentim_seq_label_weights,
                              "target": self.target_seq_label_weights,
                              "holder": self.holder_seq_label_weights}

      item["polarity_weights"] = self.polarity_weights

    return item

In [None]:
train_ds = SemEval2022Task10Dataset('train', CONFIG)
dev_ds = SemEval2022Task10Dataset('dev', CONFIG)

# Model Definition

In SemEval2022 Task10 description, it is stated that we can see the sentiment, target and holder expressions in a sentence as a graph, which the expressions are its nodes and there is edge from a sentiment expression node to its pertaining target and holder ones.

In [None]:
class BIOSchemeTokenClassification:

  @classmethod
  def _nstate(cls, state: str, label: TokenBIOLabelEnum) -> str:
    if label.is_out():
      return 'skipping'
    if state == 'reading':
      if label.is_begin():
        return 'start_reading'
      else: # label.is_in()
        return 'reading'
    elif state == 'skipping':
      if label.is_begin():
        return 'start_reading'
      else:
        return 'skipping'
    else: # state == 'start_reading'
      if label.is_begin():
        return 'start_reading'
      else: # label.is_in()
        return 'reading'


  @classmethod
  def extract_spans(cls, token_labels: Sequence[TokenBIOLabelEnum]) -> List[Tuple[slice, Any]]:
    state = 'skipping'
    span_start = 0
    span_length = 0
    spans = []
    for i in range(len(token_labels)):
      label = token_labels[i]
      nstate = BIOSchemeTokenClassification._nstate(state, label)
      if nstate == 'start_reading':
        if state == 'reading' or state == 'start_reading':
          spans.append((slice(span_start, span_start + span_length), token_labels[i-1].tag))
        span_start = i
        span_length = 1
      elif nstate == 'reading':
        span_length += 1
      else: # nstate == 'skipping'
        if state == 'reading' or state == 'start_reading':
          spans.append((slice(span_start, span_start + span_length), token_labels[i-1].tag))
      state = nstate
    if state != 'skipping':
      spans.append((slice(span_start, span_start + span_length), token_labels[i].tag))
    return spans

## Sequence Labeler

In [None]:
class LinearSequenceLabeler(nn.Module):

  def __init__(self, input_dim: int, hidden_dim: int, label_enum: TokenBIOLabelEnum):
    super(LinearSequenceLabeler, self).__init__()

    self.seq_label = nn.ModuleList([nn.Linear(input_dim, hidden_dim), nn.Dropout(), 
                                    nn.ReLU(), nn.Linear(hidden_dim, len(label_enum))])
    self.label_enum = label_enum


  def forward(self, ptm_output: BaseModelOutput) -> Tensor:
    output = ptm_output.last_hidden_state
    for layer in self.seq_label:
      output = layer(output)
    return output


  @torch.no_grad()
  def predict(self, output: Tensor,
              padding_mask: Tensor) -> List[Sequence[TokenBIOLabelEnum]]:
    output = torch.argmax(output, dim=2)
    batch_size = output.shape[0]
    preds_batch = []
    for i in range(batch_size):
      mask_i = padding_mask[i] == 1
      output_i = output[i]
      preds_batch.append(self.label_enum.from_integers(output_i[mask_i].cpu().numpy()))

    return preds_batch

## Node Extractor

In [None]:
class SentimentGraphNodeExtractor(nn.Module):

  def __init__(self):
    super(SentimentGraphNodeExtractor, self).__init__()

    self.sentiment_labeler = LinearSequenceLabeler(768, 64, SentExprTokenEnum)
    self.target_labeler = LinearSequenceLabeler(768, 64, SentTargetTokenEnum)
    self.holder_labeler = LinearSequenceLabeler(768, 64, SentHolderTokenEnum)


  def forward(self, ptm_output: BaseModelOutput) -> Tuple[Tensor, Tensor, Tensor]:
    return (self.sentiment_labeler(ptm_output),
            self.target_labeler(ptm_output),
            self.holder_labeler(ptm_output))


  @torch.no_grad()
  def predict(self, output: Tuple[Tensor, Tensor, Tensor],
              padding_mask: Tensor) -> List[SentimentGraphNodeSet]:
                                                          
    sentiment_labels = self.sentiment_labeler.predict(output[0], padding_mask)
    target_labels = self.target_labeler.predict(output[1], padding_mask)
    holder_labels = self.holder_labeler.predict(output[2], padding_mask)

    batch_size = padding_mask.shape[0]
    nodes_batch = []
    for i in range(batch_size):

      sentiment_nodes = BIOSchemeTokenClassification.extract_spans(sentiment_labels[i])
      sentiment_nodes = [SentimentGraphNode(*tup) for tup in sentiment_nodes]

      target_nodes = BIOSchemeTokenClassification.extract_spans(target_labels[i])
      target_nodes = [SentimentGraphNode(*tup) for tup in target_nodes]

      holder_nodes = BIOSchemeTokenClassification.extract_spans(holder_labels[i])
      holder_nodes = [SentimentGraphNode(*tup) for tup in holder_nodes]

      nodes_batch.append(SentimentGraphNodeSet(sentiment_nodes, target_nodes,
                                               holder_nodes))
    return nodes_batch

## Polarity Predictor

In [None]:
class PolarityPredictor(nn.Module):

  def __init__(self):
    super(PolarityPredictor, self).__init__()

    self.net = nn.ModuleList([nn.Linear(768, 64), nn.Dropout(), nn.ReLU(),
                              nn.Linear(64, 16)])
    self.cls = nn.Linear(16, 3)


  def forward(self, ptm_output: BaseModelOutput,
              graphs_nodes: List[SentimentGraphNodeSet]) -> List[Tensor]:

    last_hidden_state_batch = ptm_output.last_hidden_state
    batch_size = last_hidden_state_batch.shape[0]
    fun_sum_over_span_tokens = functools.partial(torch.sum, dim=0)

    preds_batch: List[Tensor] = []
    for i in range(batch_size):
      last_hidden_state = last_hidden_state_batch[i]
      nodes = graphs_nodes[i].sentiment_nodes

      if len(nodes) == 0: # This case arises when NodeExtractor has not predicted any Sentim. node.
        preds_batch.append(torch.empty(0, device=DEVICE))
        continue

      nodes_h = torch.cat([last_hidden_state[node.span_in_sentence] for node in nodes])
      for layer in self.net:
        nodes_h = layer(nodes_h)

      nodes_h = torch.split(nodes_h, [node.span_length for node in nodes])
      nodes_h = torch.stack([fun_sum_over_span_tokens(node_h) for node_h in nodes_h])

      preds_batch.append(self.cls(F.relu(nodes_h)))

    return preds_batch


  @torch.no_grad()
  def predict(self, ptm_output: BaseModelOutput,
              graphs_nodes: List[SentimentGraphNodeSet]) -> List[SentimentGraphNodeSet]:

    output = self(ptm_output, graphs_nodes)
    for i, graph_nodes in enumerate(graphs_nodes):
      sentim_nodes = graph_nodes.sentiment_nodes
      for j, node in enumerate(sentim_nodes):
        pred = torch.argmax(output[i][j])
        assert pred <= 2
        if pred == 0:
          node.tag = "Positive"
        elif pred == 1:
          node.tag = "Neutral"
        else:
          node.tag = "Negative"
    return graphs_nodes

## Edge Predictor

In [None]:
# Originally from <https://discuss.pytorch.org/t/does-nn-sigmoid-have-bias-parameter/10561/5>
class BiasedSigmoid(nn.Module):

  def __init__(self, device=None):
    super(BiasedSigmoid, self).__init__()
    self.bias = nn.Parameter(2 * torch.rand(1, device=device) - 1)


  def forward(self, input: Tensor) -> Tensor:
    return torch.sigmoid(input + self.bias)

In [None]:
class AttentionistSentimentGraphEdgePredictor(nn.Module):

  def __init__(self, config: Dict[str, Any]):
    super(AttentionistSentimentGraphEdgePredictor, self).__init__()

    self.tgt_hld_separate_head = config["base_pooler"]["TARGET_AND_HOLDER_SEPARATE_HEAD"]
    if self.tgt_hld_separate_head:
      if config["edge_predictor"]["USE_BIAS"]:
        self.target_edge_predictor = BiasedSigmoid()
        self.holder_edge_predictor = BiasedSigmoid()
      else:
        self.target_edge_predictor = nn.Sigmoid()
        self.holder_edge_predictor = nn.Sigmoid()
    else:
      if config["edge_predictor"]["USE_BIAS"]:
        self.edge_predictor = BiasedSigmoid()
      else:
        self.edge_predictor = nn.Sigmoid()


  def forward(self, ptm_output: BaseModelOutput,
              graphs_nodes: List[SentimentGraphNodeSet]) -> List[SentimentGraphEdgeSet]:

    if self.tgt_hld_separate_head:
      attentions_target_batch, attentions_holder_batch = ptm_output.attentions
    else:
      attentions_target_batch = ptm_output.attentions
      attentions_holder_batch = ptm_output.attentions

    batch_size = len(graphs_nodes)
    graphs_edge_probs = []
    for i in range(batch_size):
      attentions_target = attentions_target_batch[i]
      attentions_holder = attentions_holder_batch[i]

      graph_nodes = graphs_nodes[i]

      sentim_tgt_edges = []
      for sentim_node, tgt_node in itertools.product(graph_nodes.sentiment_nodes,
                                                     graph_nodes.target_nodes):
        sentim_tgt_edges.append(
            attentions_target[sentim_node.span_in_sentence, tgt_node.span_in_sentence].sum() + \
            attentions_target[tgt_node.span_in_sentence, sentim_node.span_in_sentence].sum()
        )
      if len(sentim_tgt_edges) == 0:
        sentim_tgt_edges = None
      else:
        sentim_tgt_edges = torch.stack(sentim_tgt_edges).view(len(graph_nodes.sentiment_nodes),
                                                              len(graph_nodes.target_nodes))
        if self.tgt_hld_separate_head:
          sentim_tgt_edges = self.target_edge_predictor(sentim_tgt_edges)
        else:
          sentim_tgt_edges = self.edge_predictor(sentim_tgt_edges)

      sentim_hld_edges = []
      for sentim_node, hld_node in itertools.product(graph_nodes.sentiment_nodes,
                                                     graph_nodes.holder_nodes):
        sentim_hld_edges.append(
            attentions_holder[sentim_node.span_in_sentence, hld_node.span_in_sentence].sum() + \
            attentions_holder[hld_node.span_in_sentence, sentim_node.span_in_sentence].sum()
        )
      if len(sentim_hld_edges) == 0:
        sentim_hld_edges = None
      else:  
        sentim_hld_edges = torch.stack(sentim_hld_edges).view(len(graph_nodes.sentiment_nodes),
                                                            len(graph_nodes.holder_nodes))
        if self.tgt_hld_separate_head:
          sentim_hld_edges = self.holder_edge_predictor(sentim_hld_edges)
        else:
          sentim_hld_edges = self.edge_predictor(sentim_hld_edges)

      graphs_edge_probs.append(SentimentGraphEdgeSet(sentim_tgt_edges, sentim_hld_edges))

    return graphs_edge_probs


  @torch.no_grad()
  def predict(self, output: List[SentimentGraphEdgeSet]) -> List[SentimentGraphEdgeSet]:
    edges = []
    for graph_E in output:
      tgt_edges = graph_E.sentiment_target_edges
      tgt_edges = tgt_edges.round().int() if tgt_edges is not None else None

      hld_edges = graph_E.sentiment_holder_edges
      hld_edges = hld_edges.round().int() if hld_edges is not None else None

      edges.append(SentimentGraphEdgeSet(tgt_edges, hld_edges))

    return edges


## Model Base - RoBERTa | BERT

In [None]:
class ModelBase(nn.Module):

  def __init__(self, config: Dict[str, Any]):
    super(ModelBase, self).__init__()

    if config["dataset"] in ['opener_en', 'mpqa', 'darmstadt_unis']:
      self.ptm = RobertaModel.from_pretrained('roberta-base', output_attentions=True,
                                              add_pooling_layer=False,)
    else:
      self.ptm = BertModel.from_pretrained('setu4993/LaBSE',
                                           output_attentions=True, add_pooling_layer=False)


  def forward(self, ptm_input: Dict[str, Any]) -> BaseModelOutput:
    ptm_output = self.ptm(**ptm_input)

    if not ptm_output.attentions:
      raise ValueError(
          "Attentions should be included in model input. You might have forgotten "
          "to set `output_attentions=True` on constructing PTM or calling it."
      )

    return ptm_output

## Base Pooler

In [None]:
class BasePooler(nn.Module):

  def __init__(self, config: Dict[str, Any]):
    super(BasePooler, self).__init__()

    self.avg_heads = config["base_pooler"]["AVERAGE_HEADS"]
    self.tgt_hld_separate_head = config["base_pooler"]["TARGET_AND_HOLDER_SEPARATE_HEAD"]
    self.dataset = config["dataset"]

    if self.avg_heads:
      if self.tgt_hld_separate_head:
        initial_attention_weights_target = torch.zeros(12, 12).float()
        initial_attention_weights_holder = torch.zeros(12, 12).float()

        if self.dataset in ['opener_en', 'mpqa', 'darmstadt_unis']:
          initial_attention_weights_target[7, 6] = 1.
          initial_attention_weights_holder[10, 9] = 1.
        else:
          initial_attention_weights_target[10, 8] = 1.
          initial_attention_weights_holder[10, 3] = 1.

        self.attention_weights_target = nn.Parameter(initial_attention_weights_target)
        self.attention_weights_holder = nn.Parameter(initial_attention_weights_holder)
      else:
        initial_attention_weights = torch.zeros(12, 12).float()

        if self.dataset in ['opener_en', 'mpqa', 'darmstadt_unis']:
          initial_attention_weights[7, 6] = 1.
        else:
          initial_attention_weights[10, 8] = 1.

        self.attention_weights = nn.Parameter(initial_attention_weights)


  def forward(self, ptm_output: BaseModelOutput) -> BaseModelOutput:

    ptm_output.last_hidden_state = ptm_output.last_hidden_state[:, 1:-1, :] # Excluding CLS,SEP

    if not self.avg_heads:
      if not self.tgt_hld_separate_head:
        if self.dataset in ['opener_en', 'mpqa', 'darmstadt_unis']:
          ptm_output.attentions = ptm_output.attentions[7][:, 6] # Head 7 (pink) of layer 8
        else:
          ptm_output.attentions = ptm_output.attentions[10][:, 8] # Head 9 (green) of layer 11
      else:
        if self.dataset in ['opener_en', 'mpqa', 'darmstadt_unis']:
          ptm_output.attentions = [ptm_output.attentions[7][:, 6],# Head 7 (pink) of layer 8
                                   ptm_output.attentions[10][:, 9]]# Head 10 of layer 11
        else:
          ptm_output.attentions = [ptm_output.attentions[10][:, 8],# Head 9 (green) of layer 11
                                   ptm_output.attentions[10][:, 3]]# Head 4 (green) of layer 11
    else:
      if not self.tgt_hld_separate_head:
        ptm_output.attentions = torch.tensordot(torch.stack(ptm_output.attentions),
                                                self.attention_weights, dims=([0,2],[0,1]))
      else:
        ptm_output.attentions = [torch.tensordot(torch.stack(ptm_output.attentions),
                                                self.attention_weights_target,
                                                 dims=([0,2],[0,1])),
                                 torch.tensordot(torch.stack(ptm_output.attentions),
                                                self.attention_weights_holder,
                                                 dims=([0,2],[0,1]))]
      # Alternatively:
      # ptm_output.attentions = torch.einsum(
      #     'lbhmn, lh -> bmn', torch.stack(ptm_output.attentions), self.attention_weights
      # )
    
    # Excluding CLS,SEP
    if isinstance(ptm_output.attentions, list):
      ptm_output.attentions = [item[:, 1:-1, 1:-1] for item in ptm_output.attentions]
    else:
      ptm_output.attentions = ptm_output.attentions[:, 1:-1, 1:-1] 

    return ptm_output

## Whole Model

In [None]:
class StructuredSentimentPredictor(nn.Module):

  def __init__(self, config: Dict[str, Any]):
    super(StructuredSentimentPredictor, self).__init__()

    self.base = ModelBase(config)

    self.novelty = nn.ModuleDict({'base_pooler': BasePooler(config),
                                  'edge_predictor': AttentionistSentimentGraphEdgePredictor(config),
                                  'node_extractor': SentimentGraphNodeExtractor()})
    self.have_polarity_predictor = config["polarity_predictor"]["exist"]
    if self.have_polarity_predictor:
      self.novelty.update({"polarity_predictor":PolarityPredictor()})

  @dataclass
  class Output:
    node_extractor_output: Tuple[Tensor, Tensor, Tensor]
    edge_predictor_output: List[SentimentGraphEdgeSet]
    polarity_predictor_output: List[Tensor] = None


  def forward(self, inputs: Dict[str, Any]) -> Output:

    base_output = self.base(inputs["ptm_input"])

    base_output = self.novelty['base_pooler'](base_output)

    node_extractor_out = self.novelty['node_extractor'](base_output)

    graphs_nodes = inputs["graphs_nodes"]
    edge_predictor_out = self.novelty['edge_predictor'](base_output, graphs_nodes)

    if self.have_polarity_predictor:
      polarity_predictor_out = self.novelty['polarity_predictor'](base_output, graphs_nodes)

    return StructuredSentimentPredictor.Output(
        node_extractor_out,
        edge_predictor_out,
        polarity_predictor_out if self.have_polarity_predictor else None
    )


  @torch.no_grad()
  def predict(self, inputs: Dict[str, Any], output: Output) -> List[SentimentGraph]:

    graphs_nodes = self.novelty['node_extractor'].predict(
      output.node_extractor_output, inputs['padding_mask'])

    base_output = self.base(inputs["ptm_input"])
    base_output = self.novelty['base_pooler'](base_output)

    graphs_edges_probs = self.novelty['edge_predictor'](base_output, graphs_nodes)
    graphs_edges = self.novelty['edge_predictor'].predict(graphs_edges_probs)

    if self.have_polarity_predictor:
      graphs_nodes = self.novelty['polarity_predictor'].predict(base_output, graphs_nodes)

    return [SentimentGraph(graph_nodes, graph_edges) for graph_nodes,
            graph_edges in zip(graphs_nodes, graphs_edges)]


# Loss functions

## Node Extractor Loss - Cross Entropy

In [None]:
def node_extractor_loss(y_pred: StructuredSentimentPredictor.Output,
                        y: Dict[str,Any], weighted=False) -> torch.double:
  sentim_labels_pred, tgt_labels_pred, hld_labels_pred = y_pred.node_extractor_output
  sentim_labels, tgt_labels, hld_labels = y['seq_labels']

  sentim_labels_loss = F.cross_entropy(
      sentim_labels_pred.permute(0, 2, 1), sentim_labels,
      ignore_index=-1, weight=y["seq_label_weights"]["sentiment"] if weighted else None)
  tgt_labels_loss = F.cross_entropy(
      tgt_labels_pred.permute(0, 2, 1), tgt_labels,
      ignore_index=-1, weight=y["seq_label_weights"]["target"] if weighted else None)
  hld_labels_loss = F.cross_entropy(
      hld_labels_pred.permute(0, 2, 1), hld_labels,
      ignore_index=-1, weight=y["seq_label_weights"]["holder"] if weighted else None)

  return sentim_labels_loss + tgt_labels_loss + hld_labels_loss

## Edge Extractor Loss - Binary Cross Entropy

In [None]:
def edge_extractor_loss(y_pred: StructuredSentimentPredictor.Output,
                        y: Dict[str,Any]) -> torch.double:
  adj_matrices_pred_batch: List[SentimentGraphEdgeSet] = y_pred.edge_predictor_output
  adj_matrices_batch: List[SentimentGraphEdgeSet] = y["graphs_edges"]

  loss = torch.zeros((), device=DEVICE)

  for edges_pred, edges in zip(adj_matrices_pred_batch, adj_matrices_batch):

    target_edges_pred = edges_pred.sentiment_target_edges
    target_edges = edges.sentiment_target_edges
    if target_edges_pred is not None:
      loss += F.binary_cross_entropy(target_edges_pred, target_edges)

    holder_edges_pred = edges_pred.sentiment_holder_edges
    holder_edges = edges.sentiment_holder_edges
    if holder_edges_pred is not None:
      loss += F.binary_cross_entropy(holder_edges_pred, holder_edges)

  return loss

## Polarity Predictor Loss - Cross Entropy

In [None]:
def polarity_predictor_loss(y_pred: StructuredSentimentPredictor.Output,
                            y: Dict[str,Any], weighted: bool = False) -> torch.double:
  polarities_pred_batch = pad_sequence(y_pred.polarity_predictor_output, batch_first=True)

  return F.cross_entropy(polarities_pred_batch.permute(0, 2, 1),
                         y["sentim_nodes_polarities"],
                         ignore_index=-1,
                         weight=y["polarity_weights"] if weighted else None)

## Total Loss

In [None]:
def loss_fn(y_pred: StructuredSentimentPredictor.Output,
                        y: Dict[str,Any], config: Dict[str, Any]) -> torch.double:
  polarity_loss = polarity_predictor_loss(
      y_pred, y,
      weighted=config["polarity_predictor"]["WEIGHTED_LOSS"]
  ) if config["polarity_predictor"]["exist"] else 0.

  return node_extractor_loss(y_pred, y, weighted=config["sequence_labeler"]["WEIGHTED_LOSS"]) + \
         edge_extractor_loss(y_pred, y) + polarity_loss

# Train & Evaluation

## Data Loaders

In [None]:
def collate_fn(batch: List[Dict[str, Any]], tokenizer: RobertaTokenizerFast) -> Dict[str, Any]:

  texts = [item['text'] for item in batch]
  ptm_input = tokenizer(texts, padding=True, return_tensors='pt')
  padding_mask = tokenizer(texts, padding=True, return_tensors='pt',
                           add_special_tokens=False)["attention_mask"]

  sentim_labels = pad_sequence([torch.LongTensor(
      SentExprTokenEnum.from_labels(item["seq_labels"]["sentiment"])
      ) for item in batch], batch_first=True, padding_value=-1)
  target_labels = pad_sequence([torch.LongTensor(
      SentTargetTokenEnum.from_labels(item["seq_labels"]["target"])
      ) for item in batch], batch_first=True, padding_value=-1)
  holder_labels = pad_sequence([torch.LongTensor(
      SentHolderTokenEnum.from_labels(item["seq_labels"]["holder"])
      ) for item in batch], batch_first=True, padding_value=-1)

  seq_labels = (sentim_labels, target_labels, holder_labels)

  graphs_edges = [item["graph"].edges for item in batch]
  graphs_nodes = [item["graph"].nodes for item in batch]

  # Only is used if config.polarity_predictor.exist == True
  sentim_nodes_polarities = []
  for node_set in graphs_nodes:
    polarities = []
    for node in node_set.sentiment_nodes:
      if node.tag == 'Positive':
        polarity = 0
      elif node.tag == 'Neutral':
        polarity = 1
      elif node.tag == 'Negative':
        polarity = 2
      else:
        raise ValueError(f"Given invalid sentiment node tag: {node.tag}")
      polarities.append(polarity)
    sentim_nodes_polarities.append(torch.LongTensor(polarities))
  sentim_nodes_polarities = pad_sequence(sentim_nodes_polarities, batch_first=True,
                                         padding_value=-1)

  graphs = [item["graph"] for item in batch]

  gold_graphs = [item["gold_graph"] for item in batch]

  return {"seq_labels": seq_labels, "graphs": graphs,
          "seq_label_weights": batch[0]["label_weights"],
          "graphs_nodes": graphs_nodes, "graphs_edges": graphs_edges,
          "ptm_input": ptm_input, "padding_mask": padding_mask,
          'sentim_nodes_polarities': sentim_nodes_polarities,
          "polarity_weights": batch[0]["polarity_weights"],
          "gold_graphs": gold_graphs}


In [None]:
train_dataloader = DataLoader(
    train_ds, collate_fn=functools.partial(collate_fn, tokenizer=get_tokenizer(CONFIG["dataset"])),
    batch_size=CONFIG["batch_size"])
dev_dataloader = DataLoader(
    dev_ds, collate_fn=functools.partial(collate_fn, tokenizer=get_tokenizer(CONFIG["dataset"])),
    batch_size= CONFIG["batch_size"] * 2)

## Model & Optimizer

In [None]:
torch.cuda.empty_cache()

In [None]:
model = StructuredSentimentPredictor(CONFIG).to(DEVICE)

optimizer_parameter_groups = [
  {'params': list(model.base.parameters())},
  {'params': list(model.novelty.parameters())}
]
optimizer = torch.optim.AdamW(optimizer_parameter_groups)

Downloading:   0%|          | 0.00/478M [00:00<?, ?B/s]

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'roberta.pooler.dense.weight', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.layer_norm.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Engines

In [None]:
def prepare_batch(
    batch: Dict[str, Any],
    device=DEVICE,
    non_blocking=True
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  
  padding_mask = batch["padding_mask"].to(device, non_blocking=non_blocking)

  x = {
      "ptm_input": {
           "attention_mask": batch["ptm_input"]["attention_mask"].to(
               device,
               non_blocking=non_blocking),
            "input_ids": batch["ptm_input"]["input_ids"].to(
                device,
                non_blocking=non_blocking)
            },
       "graphs_nodes": batch["graphs_nodes"],
       "padding_mask": padding_mask
       }

  y = {
      "seq_labels": tuple(tensor.to(
          device,
          non_blocking=non_blocking) for tensor in batch["seq_labels"]),
       "graphs": batch["graphs"],
       "gold_graphs": batch["gold_graphs"],
       "graphs_edges": [SentimentGraphEdgeSet(
           item.sentiment_target_edges.to(device, non_blocking=non_blocking),
           item.sentiment_holder_edges.to(device, non_blocking=non_blocking)
           ) for item in batch["graphs_edges"]
           ],
       "padding_mask": padding_mask,
       "seq_label_weights": {k: v.to(
           device,
           non_blocking=non_blocking) for k,v in batch["seq_label_weights"].items()
           },
       "polarity_weights": batch["polarity_weights"].to(device, non_blocking=non_blocking),
       "sentim_nodes_polarities": batch["sentim_nodes_polarities"].to(
           device,
           non_blocking=non_blocking)
       }

  return x, y

In [None]:
def train_output_transform(
    x: Dict[str, Any],
    y: Dict[str, Any],
    y_pred: StructuredSentimentPredictor.Output,
    loss: Tensor
    ) -> Dict[str, Any]:

  return {"y": y, "y_pred_raw": y_pred, "loss": loss.item()}

In [None]:
# Submit issue on Ignite design about this
def evaluate_output_transform(
    x: Dict[str, Any],
    y: Dict[str, Any],
    y_pred: StructuredSentimentPredictor.Output,
    model: nn.Module
    ) -> Dict[str, Any]:

  return {"y": y, "y_pred_raw": y_pred, "y_pred": model.predict(x, y_pred)}

In [None]:
trainer = create_supervised_trainer(
    model, optimizer, functools.partial(loss_fn, config=CONFIG),
    deterministic=True, device=DEVICE,
    prepare_batch=prepare_batch, output_transform=train_output_transform
    )

evaluator = create_supervised_evaluator(
    model, prepare_batch=prepare_batch,
    device=DEVICE, output_transform=functools.partial(evaluate_output_transform, model=model)
    )

train_evaluator = create_supervised_evaluator(
    model, prepare_batch=prepare_batch,
    device=DEVICE, output_transform=functools.partial(evaluate_output_transform, model=model)
    )

## Metrics

In [None]:
class SeqLabelEntropy(Metric):
  def __init__(self, output_transform=lambda x: x):
    self.token_cnt = None
    self.entropies = None

    super(SeqLabelEntropy, self).__init__(output_transform=output_transform)


  def reset(self):
    self.token_cnt = {'sentiment': 0, 'target': 0, 'holder': 0}
    self.entropies = {'sentiment': 0., 'target': 0., 'holder': 0.}
    super(SeqLabelEntropy, self).reset()


  def update(self, output: Tuple[StructuredSentimentPredictor.Output, Dict[str, Any]]):
    sentim_labels, tgt_labels, hld_labels = output[0].node_extractor_output

    padding_mask = output[1]["padding_mask"].detach().cpu().numpy()

    sentim_labels = sentim_labels.detach().cpu().softmax(dim=2)
    self.entropies['sentiment'] += (scipy.stats.entropy(sentim_labels,
                                                        axis=2) * padding_mask).sum()
    self.token_cnt['sentiment'] += sentim_labels.size()[:2].numel()
  
    tgt_labels = tgt_labels.detach().cpu().softmax(dim=2)
    self.entropies['target'] += (scipy.stats.entropy(tgt_labels, axis=2) * padding_mask).sum()
    self.token_cnt['target'] += tgt_labels.size()[:2].numel()
  
    hld_labels = hld_labels.detach().cpu().softmax(dim=2)
    self.entropies['holder'] += (scipy.stats.entropy(hld_labels, axis=2) * padding_mask).sum()
    self.token_cnt['holder'] += hld_labels.size()[:2].numel()


  def compute(self) -> Dict[str, float]:
    return {(k.capitalize() + " Label Entropy"): v / self.token_cnt[k] for k, v \
            in self.entropies.items()}


In [None]:
class MatchLogicBase:
  def is_match(self, sentim_overlap: int, target_overlap: int, holder_overlap: int,
               is_polarity_matched: bool) -> bool:
    ...


  def weighted_match(self,
                     sentim_nodes: List[SentimentGraphNode],
                     target_neighbors: List[SentimentGraphNode],
                     holder_neighbors: List[SentimentGraphNode],
                     sentim_overlap: int, target_overlap: int, holder_overlap: int,
                     ) -> float:
    ...


  def _union_length(self, nodes: List[SentimentGraphNode]) -> int:
    if len(nodes) == 0:
      return 1
    return len(set(itertools.chain(*map(attrgetter('indices'), nodes))))

In [None]:
class GraphMatchLogic(MatchLogicBase):

  def __init__(self, keep_polarity: bool=True):
    self.keep_polarity = keep_polarity


  def is_match(self, sentim_overlap: int, target_overlap: int, holder_overlap: int,
               is_polarity_matched: bool) -> bool:
    return sentim_overlap > 0 and target_overlap > 0 and holder_overlap > 0 and \
              (is_polarity_matched or not self.keep_polarity)


  def weighted_match(self,
                     sentim_nodes: List[SentimentGraphNode],
                     target_neighbors: List[SentimentGraphNode],
                     holder_neighbors: List[SentimentGraphNode],
                     sentim_overlap: int, target_overlap: int, holder_overlap: int,
                     ) -> float:
    return (sentim_overlap / float(self._union_length(sentim_nodes)) + \
            target_overlap / float(self._union_length(target_neighbors)) + \
            holder_overlap / float(self._union_length(holder_neighbors))
            ) / 3.

In [None]:
class SentimentMatchLogic(MatchLogicBase):

  def __init__(self, keep_polarity: bool=True):
    self.keep_polarity = keep_polarity


  def is_match(self, sentim_overlap: int, target_overlap: int, holder_overlap: int,
               is_polarity_matched: bool) -> bool:
    return sentim_overlap > 0 and (is_polarity_matched or not self.keep_polarity)


  def weighted_match(self,
                     sentim_nodes: List[SentimentGraphNode],
                     target_neighbors: List[SentimentGraphNode],
                     holder_neighbors: List[SentimentGraphNode],
                     sentim_overlap: int, target_overlap: int, holder_overlap: int,
                     ) -> float:
    return sentim_overlap / float(self._union_length(sentim_nodes)) 

In [None]:
class TargetMatchLogic(MatchLogicBase):
  def is_match(self, sentim_overlap: int, target_overlap: int, holder_overlap: int,
               is_polarity_matched: bool) -> bool:
    return target_overlap > 0


  def weighted_match(self,
                     sentim_nodes: List[SentimentGraphNode],
                     target_neighbors: List[SentimentGraphNode],
                     holder_neighbors: List[SentimentGraphNode],
                     sentim_overlap: int, target_overlap: int, holder_overlap: int,
                     ) -> float:
    return target_overlap / float(self._union_length(target_neighbors)) 

In [None]:
class HolderMatchLogic(MatchLogicBase):
  def is_match(self, sentim_overlap: int, target_overlap: int, holder_overlap: int,
               is_polarity_matched: bool) -> bool:
    return holder_overlap > 0


  def weighted_match(self,
                     sentim_nodes: List[SentimentGraphNode],
                     target_neighbors: List[SentimentGraphNode],
                     holder_neighbors: List[SentimentGraphNode],
                     sentim_overlap: int, target_overlap: int, holder_overlap: int,
                     ) -> float:
    return holder_overlap / float(self._union_length(holder_neighbors)) 

In [None]:
class MetricBase(Metric):

  def __init__(self, match_logic: MatchLogicBase, weighted=True, output_transform=lambda x:x):
    self.weighted = weighted
    self.match_logic = match_logic

    super(MetricBase, self).__init__(output_transform=output_transform)


  def extract_sentim_nodes_neighbors(
    self, graph: SentimentGraph) -> List[Tuple[List[SentimentGraphNode], ...]]:
    edges_sentim_tgt = graph.edges.sentiment_target_edges
    if edges_sentim_tgt is not None:
      edges_sentim_tgt = edges_sentim_tgt.detach().cpu()

    edges_sentim_hld = graph.edges.sentiment_holder_edges
    if edges_sentim_hld is not None:
      edges_sentim_hld = edges_sentim_hld.detach().cpu()

    sentim_nodes_neighbors = []
    for i, sentim_node in enumerate(graph.nodes.sentiment_nodes):
      assert sentim_node.tag is not None

      if edges_sentim_tgt is not None:
        target_neighbors = [
          graph.nodes.target_nodes[j] for j in edges_sentim_tgt[i].nonzero()]
      else:
        target_neighbors = []

      if edges_sentim_hld is not None:
        holder_neighbors = [
          graph.nodes.holder_nodes[j] for j in edges_sentim_hld[i].nonzero()]
      else:
        holder_neighbors = []

      sentim_nodes_neighbors.append((target_neighbors, holder_neighbors))

    return sentim_nodes_neighbors


  def node_neighbors_overlap_length(self, nodes1: List[SentimentGraphNode],
                      nodes2: List[SentimentGraphNode]) -> int:
    if len(nodes1) == 0 and len(nodes2) == 0:
      return 1

    return len(set(itertools.chain(*map(attrgetter('indices'), nodes1))).intersection(
        set(itertools.chain(*map(attrgetter('indices'), nodes2)))
    ))


  def update(self, output: Tuple[List[SentimentGraph], List[GoldSentimentGraph]]):

    y_pred, y = output

    for p_graph, g_graph in zip(y_pred, y):

      p_nodes_neighbors = self.extract_sentim_nodes_neighbors(p_graph)

      intersect_matrix = torch.zeros((len(p_graph.nodes.sentiment_nodes),
                                         len(g_graph.nodes.sentiment_nodes)), dtype=int)

      weighted_intersect_matrix_for_p = torch.zeros(len(p_graph.nodes.sentiment_nodes),
                                        len(g_graph.nodes.sentiment_nodes))
      weighted_intersect_matrix_for_g = torch.zeros(len(g_graph.nodes.sentiment_nodes),
                                        len(p_graph.nodes.sentiment_nodes))

      for i, p_sentim_node in enumerate(p_graph.nodes.sentiment_nodes):
        for j, g_sentim_node in enumerate(g_graph.nodes.sentiment_nodes):

          p_target_neighbors, p_holder_neighbors = p_nodes_neighbors[i]

          g_target_neighbors = g_graph.edges.sentiment_target_edges[g_sentim_node].nodes
          g_holder_neighbors = g_graph.edges.sentiment_holder_edges[g_sentim_node].nodes

          sentim_overlap = self.node_neighbors_overlap_length([p_sentim_node], g_sentim_node.nodes)
          target_overlap = self.node_neighbors_overlap_length(p_target_neighbors, g_target_neighbors)
          holder_overlap = self.node_neighbors_overlap_length(p_holder_neighbors, g_holder_neighbors)

          if self.match_logic.is_match(sentim_overlap, target_overlap, holder_overlap,
                                       g_sentim_node.nodes[0].tag == p_sentim_node.tag):
          
            intersect_matrix[i, j] = 1

            if self.weighted:
              weighted_intersect_matrix_for_p[i, j] = \
                self.match_logic.weighted_match([p_sentim_node], p_target_neighbors,
                                                p_holder_neighbors,
                                                sentim_overlap, target_overlap, holder_overlap)

              weighted_intersect_matrix_for_g[j, i] = \
                self.match_logic.weighted_match(g_sentim_node.nodes, g_target_neighbors,
                                                g_holder_neighbors,
                                                sentim_overlap, target_overlap, holder_overlap)

      self.update_(intersect_matrix, weighted_intersect_matrix_for_g,
                   weighted_intersect_matrix_for_p)


In [None]:
class Precision(MetricBase):
  def __init__(self, match_logic: MatchLogicBase, weighted=True, output_transform=lambda x:x):
    self.weighted_tp = None
    self.tp = None
    self.fp = None

    super(Precision, self).__init__(match_logic, weighted=weighted,
                                    output_transform=output_transform)


  def reset(self):
      self.weighted_tp = 0.
      self.tp = 0
      self.fp = 0

      super(Precision, self).reset()


  def update_(self, intersect_matrix: Tensor, weighted_intersect_matrix_for_g: Tensor,
               weighted_intersect_matrix_for_p: Tensor):
    self.tp += intersect_matrix.any(dim=1).sum().item()
    self.fp += (~ intersect_matrix.any(dim=1)).sum().item()

    if self.weighted:
      self.weighted_tp += weighted_intersect_matrix_for_p.max(dim=1).values.sum().item()


  def compute(self) -> float:
    if self.weighted:
      return self.weighted_tp / (self.tp + self.fp + 1e-10)
    else:
      return self.tp / (self.tp + self.fp + 1e-10)

In [None]:
class Recall(MetricBase):

  def __init__(self, match_logic: MatchLogicBase, weighted=True, output_transform=lambda x:x):
    self.weighted_tp = None
    self.tp = None
    self.fn = None

    super(Recall, self).__init__(match_logic, weighted=weighted,
                                 output_transform=output_transform)


  def reset(self):
    self.weighted_tp = 0.
    self.tp = 0
    self.fn = 0

    super(Recall, self).reset()


  def update_(self, intersect_matrix: Tensor, weighted_intersect_matrix_for_g: Tensor,
               weighted_intersect_matrix_for_p: Tensor):

    self.tp += intersect_matrix.any(dim=0).sum().item()
    self.fn += (~ intersect_matrix.any(dim=0)).sum().item()

    if self.weighted:
      try:
        self.weighted_tp += weighted_intersect_matrix_for_g.max(dim=1).values.sum().item()
      except IndexError:
        pass


  def compute(self) -> float:
    if self.weighted:
      return self.weighted_tp / (self.tp + self.fn + 1e-10)
    else:
      return self.tp / (self.tp + self.fn + 1e-10)

In [None]:
def F1(match_logic: MatchLogicBase, weighted=True, output_transform=lambda x: x) -> MetricBase:
  p = Precision(match_logic, weighted=weighted, output_transform=output_transform)
  r = Recall(match_logic, weighted=weighted, output_transform=output_transform)
  return 2 * p * r / (p + r + 1e-10)

In [None]:
@torch.no_grad()
def loss_metric(input, config: Dict[str, Any]):
  y_pred, y = input["y_pred_raw"], input["y"]
  return loss_fn(y_pred, y, config).item()

In [None]:
@torch.no_grad()
def acc_output_transform(
    inputs: Dict[str, Any]
) -> Tuple[List[SentimentGraph], List[GoldSentimentGraph]]:
  return inputs['y_pred'], inputs['y']['gold_graphs']

In [None]:
def loss_output_transform(inputs):
  return inputs["loss"]

## Attach Metrics

In [None]:
train_loss = Average(output_transform=loss_output_transform)
eval_loss = Average(output_transform=functools.partial(loss_metric, config=CONFIG))
running_loss = RunningAverage(output_transform=loss_output_transform)

In [None]:
graph_f1 = F1(GraphMatchLogic(), output_transform=acc_output_transform)

sentiment_f1 = F1(SentimentMatchLogic(), output_transform=acc_output_transform)

target_f1 = F1(TargetMatchLogic(), output_transform=acc_output_transform)

holder_f1 = F1(HolderMatchLogic(), output_transform=acc_output_transform)

In [None]:
label_entropy = SeqLabelEntropy(output_transform=lambda out: (out['y_pred_raw'], out['y']))

In [None]:
train_loss.attach(trainer, 'Loss')

graph_f1.attach(train_evaluator, 'Graph F1')
sentiment_f1.attach(train_evaluator, 'Sentiment F1')
target_f1.attach(train_evaluator, 'Target F1')
holder_f1.attach(train_evaluator, 'Holder F1')


eval_loss.attach(evaluator, 'Loss')

graph_f1.attach(evaluator, 'Graph F1')
sentiment_f1.attach(evaluator, 'Sentiment F1')
target_f1.attach(evaluator, 'Target F1')
holder_f1.attach(evaluator, 'Holder F1')


running_loss.attach(trainer, 'Running Loss')

In [None]:
label_entropy.attach(train_evaluator, 'Label Entropy')
label_entropy.attach(evaluator, 'Label Entropy')

## Safety Measure (Terminate on NaN)

In [None]:
trainer.add_event_handler(Events.ITERATION_COMPLETED,
                          TerminateOnNan(output_transform=itemgetter('loss')))

<ignite.engine.events.RemovableEventHandle at 0x7faabae82710>

## Logging (WandB & Tqdm)

In [None]:
note = 'SemEval2022 Task 10 Subtask 1'
wandb_logger = WandBLogger(entity='sadra-barikbin',
                           project='SSA-Attentionist',
                           name='test',
                           tags=[],
                           notes=note, resume=True)

wandb_logger.attach_opt_params_handler(trainer, event_name=Events.EPOCH_COMPLETED, 
                                       optimizer=optimizer)

wandb_logger.attach_output_handler(trainer, event_name=Events.EPOCH_COMPLETED,
                                   tag="training", metric_names=["Loss"])

metric_names = ["Graph F1", "Sentiment F1", "Target F1", "Holder F1"]

wandb_logger.attach_output_handler(train_evaluator, event_name=Events.COMPLETED,
                                   global_step_transform=global_step_from_engine(trainer),
                                   tag="training", metric_names=metric_names)

wandb_logger.attach_output_handler(evaluator, event_name=Events.COMPLETED,
                                   global_step_transform=global_step_from_engine(trainer),
                                   tag="evaluation", metric_names=["Loss"] + metric_names)

[34m[1mwandb[0m: Currently logged in as: [33msadra-barikbin[0m (use `wandb login --relogin` to force relogin)


<ignite.engine.events.RemovableEventHandle at 0x7faab8d61c10>

In [None]:
@trainer.on(Events.COMPLETED)
def close_logger():
  wandb_logger.close()

In [None]:
pbar = ProgressBar()
pbar.attach(trainer,metric_names=['Running Loss'])

## LR Scheduling

### StepLR

In [None]:
lr_scheduler = LRScheduler(StepLR(optimizer, 9, gamma=0.1), save_history=True)
scheduler_handler = trainer.add_event_handler(
    Events.EPOCH_COMPLETED, lr_scheduler)

### Warm-Up

In [None]:
base_scheduler = PiecewiseLinear(
    optimizer, "lr",
    milestones_values=[(1, 1e-6), (len(train_ds) // CONFIG["batch_size"], 1e-4)],
    param_group_index=0
)

novelty_scheduler = PiecewiseLinear(
    optimizer, "lr",
    milestones_values=[(1, 1e-5), (len(train_ds) // CONFIG["batch_size"], 1e-3)],
    param_group_index=1
)

event_filter = lambda _, event: event < (len(train_ds) // CONFIG["batch_size"]) + 1
scheduler1_handler = trainer.add_event_handler(Events.ITERATION_STARTED(event_filter),
                                               base_scheduler)
scheduler2_handler = trainer.add_event_handler(Events.ITERATION_STARTED(event_filter),
                                               novelty_scheduler)

In [None]:
wandb_logger.config['LR_scheduler'] = 'Linear Warm-Up + StepLR(9,0.1)'

## Early Stopping

In [None]:
stopper = EarlyStopping(patience=6, score_function=lambda engine: engine.state.metrics['Graph F1'],
                        trainer=trainer)
stopper_handle = evaluator.add_event_handler(Events.COMPLETED, stopper)

## Checkpointing

In [None]:
# /tmp/run
checkpointer = ModelCheckpoint(wandb_logger.run.dir, 'Attentionist',
                               score_name='Graph F1', n_saved=2, require_empty=False)
# How can I do it every 2 evaluator epochs ?
evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model})

<ignite.engine.events.RemovableEventHandle at 0x7f7a5b266190>

## Run

In [None]:
@trainer.on(Events.EPOCH_COMPLETED)
def evaluate():
  evaluator.run(dev_dataloader)
  train_evaluator.run(train_dataloader)

In [None]:
wandb_logger.config['random_seed'] = CONFIG["random_seed"]
trainer.run(train_dataloader,max_epochs=20)

<a name="inference"></a>
# Inference

## Load Best Model

In [None]:
variant_name = "%%VARIANT_NAME%%"
dataset_name = "%%DATASET_NAME%%"
project_name = 'SSA-Attentionist'

In [None]:
api = wandb.Api()
runs = api.runs(
    f'sadra-barikbin/{project_name}',
    order='summary_metrics.evaluation/Graph F1',
    filters={'tags': DATASET_NAMES_MAP[dataset_name]}
    # As per MongoDB query lang, this way, order of array's elements matters.
    # If you want to ignore order:
    # filters={'tags': {'$all': [DATASET_NAMES_MAP[dataset_name], variant_name]} }
)
model_path = list(list(runs)[0].files())[1].download()
state_dict = torch.load(model_path.name)

In [None]:
model = StructuredSentimentPredictor()
model.load_state_dict(state_dict)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.decoder.weight', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.weight', 'roberta.pooler.dense.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

In [None]:
tokenizer = get_tokenizer(CONFIG["dataset"])

## Dataset, DataLoader and Engine

In [None]:
def test_collate_fn(batch: List[Dict[str, Any]], device: str='cuda') -> Dict[str, Any]:

  texts = [item['text'] for item in batch]
  ptm_input = tokenizer(texts, padding=True, return_tensors='pt')
  ptm_input['input_ids'] = ptm_input['input_ids'].to(device)
  ptm_input['attention_mask'] = ptm_input['attention_mask'].to(device)
  if 'token_type_ids' in ptm_input:
    ptm_input['token_type_ids'] = ptm_input['token_type_ids'].to(device)
  padding_mask = tokenizer(texts, padding=True, return_tensors='pt',
                           add_special_tokens=False)["attention_mask"].to(device)

  return {"pred_template": batch, "ptm_input": ptm_input, "padding_mask": padding_mask}


In [None]:
test_ds = SemEval2022Task10Dataset(CONFIG["dataset"], 'test')
test_dataloader = DataLoader(test_ds, batch_size=CONFIG["batch_size"] * 2,
                             collate_fn=test_collate_fn)

In [None]:
@torch.no_grad()
def test_engine_process_function(engine: Engine, batch: Dict[str, Any]) -> List[Dict[str, Any]]:

  model.eval()

  base_output = model.base(batch["ptm_input"])
  base_output = model.novelty['base_pooler'](base_output)
  node_extractor_out = model.novelty['node_extractor'](base_output)
  output: List[SentimentGraph] = model.predict(
      batch, StructuredSentimentPredictor.Output(node_extractor_out, None)
  )

  pred_template = batch['pred_template']

  tokens_char_offsets_batch = tokenizer([item['text'] for item in pred_template],
                                  add_special_tokens=False, return_offsets_mapping=True,
                                  return_length=True)['offset_mapping']
  
  for sent_idx in range(len(output)):
    sentim_nodes = output[sent_idx].nodes.sentiment_nodes
    tgt_nodes = output[sent_idx].nodes.target_nodes
    hld_nodes = output[sent_idx].nodes.holder_nodes

    edges = output[sent_idx].edges

    text = pred_template[sent_idx]['text']

    pred_template[sent_idx]['opinions'] = []

    for i in range(len(sentim_nodes)):
      opinion = {}
      
      node_tokens_char_offsets = tokens_char_offsets_batch[
                                                      sent_idx][sentim_nodes[i].span_in_sentence]
      node_char_idx_begin = node_tokens_char_offsets[0][0]
      node_char_idx_end = node_tokens_char_offsets[-1][1]

      opinion['Polar_expression'] = [[text[node_char_idx_begin:node_char_idx_end]],
                                     [f"{node_char_idx_begin}:{node_char_idx_end}"]]
      
      tgt_neighbors_exprs = []
      tgt_neighbors_char_spans = []
      if edges.sentiment_target_edges is not None:
        for tgt_node_idx in edges.sentiment_target_edges[i].nonzero():
          neighbor_tgt_node = tgt_nodes[tgt_node_idx]

          node_tokens_char_offsets = tokens_char_offsets_batch[
                                                      sent_idx][neighbor_tgt_node.span_in_sentence]

          node_char_idx_begin = node_tokens_char_offsets[0][0]
          node_char_idx_end = node_tokens_char_offsets[-1][1]

          tgt_neighbors_exprs.append(text[node_char_idx_begin:node_char_idx_end])
          tgt_neighbors_char_spans.append(f"{node_char_idx_begin}:{node_char_idx_end}")
      opinion['Target'] = [tgt_neighbors_exprs, tgt_neighbors_char_spans]

      hld_neighbors_exprs = []
      hld_neighbors_char_spans = []
      if edges.sentiment_holder_edges is not None:
        for hld_node_idx in edges.sentiment_holder_edges[i].nonzero():
          neighbor_hld_node = hld_nodes[hld_node_idx]

          node_tokens_char_offsets = tokens_char_offsets_batch[
                                                      sent_idx][neighbor_hld_node.span_in_sentence]

          node_char_idx_begin = node_tokens_char_offsets[0][0]
          node_char_idx_end = node_tokens_char_offsets[-1][1]

          hld_neighbors_exprs.append(text[node_char_idx_begin:node_char_idx_end])
          hld_neighbors_char_spans.append(f"{node_char_idx_begin}:{node_char_idx_end}")
      opinion['Source'] = [hld_neighbors_exprs, hld_neighbors_char_spans]

      opinion['Polarity'] = sentim_nodes[i].tag

      pred_template[sent_idx]['opinions'].append(opinion)

  return pred_template


In [None]:
tester = Engine(test_engine_process_function)

In [None]:
output_store = EpochOutputStore()
output_store.attach(tester, 'outputs')

In [None]:
pbar = ProgressBar()
pbar.attach(tester)

In [None]:
tester.run(test_dataloader)

[1/39]   3%|2          [00:00<?]

State:
	iteration: 39
	epoch: 1
	epoch_length: 39
	max_epochs: 1
	output: <class 'list'>
	batch: <class 'dict'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>
	outputs: <class 'list'>

## Save Predictions

In [None]:
path = f"monolingual/{CONFIG['dataset']}"
os.makedirs(path)

with open(f"{path}/predictions.json", 'w') as f:
  json.dump(list(itertools.chain(*tester.state.outputs)), f)

In [None]:
! zip -r monolingual.zip monolingual

  adding: monolingual/ (stored 0%)
  adding: monolingual/opener_es/ (stored 0%)
  adding: monolingual/opener_es/predictions.json (deflated 83%)
  adding: monolingual/norec/ (stored 0%)
  adding: monolingual/norec/predictions.json (deflated 80%)
  adding: monolingual/multibooked_eu/ (stored 0%)
  adding: monolingual/multibooked_eu/predictions.json (deflated 84%)
  adding: monolingual/multibooked_ca/ (stored 0%)
  adding: monolingual/multibooked_ca/predictions.json (deflated 83%)


## Out of Dataset Example

In [None]:
@torch.no_grad()
def predict(text: Union[List[str], str]) -> SentimentGraph:

  if type(text) == str:
    text = [text]
  ptm_input = tokenizer(text, padding=True, return_tensors='pt')
  ptm_input["input_ids"] = ptm_input["input_ids"].to(DEVICE)
  ptm_input["attention_mask"] = ptm_input["attention_mask"].to(DEVICE)
  padding_mask = tokenizer(text, padding=True, return_tensors='pt',
                          add_special_tokens=False)["attention_mask"].to(DEVICE)
  inputs = {'ptm_input': ptm_input, 'padding_mask': padding_mask}
  model.eval()
  base_output = model.base(inputs["ptm_input"])
  node_extractor_out = model.novelty['node_extractor'](base_output)
  return model.predict(inputs, StructuredSentimentPredictor.Output(node_extractor_out, None))[0]

In [None]:
predict("I love this book!")

# Analysis

## Ablation Study

### Helper classes & functions

In [None]:
class ModelFactory:

  def __init__(self, model_class: Type[nn.Module], config: Dict[str, Any]):
    self.model_class = model_class
    self.config = config

  def With(self, key: str, value: Any):
    father, _, item = key.rpartition('.')
    keychain = father.split('.')
    if keychain[0] == '':
      keychain = []
    father = functools.reduce(lambda x, k: x[k], keychain, self.config)
    father[item] = value
    return self


  def build(
      self,
      factory_function: Optional[Callable[[Type[nn.Module], Dict[str, Any]], nn.Module]] = None
  ) -> nn.Module:

    if factory_function is None:
      return self.model_class(**self.config)
    else:
      return factory_function(self.model_class, self.config)


In [None]:
def factory_function(cls, config):
  return cls(copy.deepcopy(config))

In [None]:
def attach_handlers(trainer: Engine, evaluator: Engine):
  trainer.add_event_handler(Events.ITERATION_COMPLETED,
                            TerminateOnNan(output_transform=itemgetter('loss')))
  pbar = ProgressBar()
  pbar.attach(trainer,metric_names=['Running Loss'])

  stopper = EarlyStopping(patience=6, score_function=lambda engine: engine.state.metrics['Graph F1'],
                          trainer=trainer)
  evaluator.add_event_handler(Events.COMPLETED, stopper)

In [None]:
def create_engines(model, optimizer, config) -> Tuple[Engine, Engine, Engine]:
  trainer = create_supervised_trainer(
      model, optimizer, functools.partial(loss_fn, config=config), deterministic=True,
      device=DEVICE, prepare_batch=prepare_batch, output_transform=train_output_transform)

  evaluator = create_supervised_evaluator(
      model, prepare_batch=prepare_batch,
      device=DEVICE, output_transform=functools.partial(evaluate_output_transform, model=model))

  train_evaluator = create_supervised_evaluator(
      model, prepare_batch=prepare_batch,
      device=DEVICE, output_transform=functools.partial(evaluate_output_transform, model=model))

  return trainer, evaluator, train_evaluator

In [None]:
def create_and_attach_common_metrics(trainer, evaluator, train_evaluator, config):
  train_loss = Average(output_transform=loss_output_transform)
  eval_loss = Average(output_transform=functools.partial(loss_metric, config=config))
  running_loss = RunningAverage(output_transform=loss_output_transform)

  graph_f1 = F1(GraphMatchLogic(), output_transform=acc_output_transform)
  sentiment_f1 = F1(SentimentMatchLogic(), output_transform=acc_output_transform)
  nsentiment_f1 = F1(SentimentMatchLogic(keep_polarity=False), output_transform=acc_output_transform)
  target_f1 = F1(TargetMatchLogic(), output_transform=acc_output_transform)
  holder_f1 = F1(HolderMatchLogic(), output_transform=acc_output_transform)

  graph_f1.attach(train_evaluator, 'Graph F1')
  sentiment_f1.attach(train_evaluator, 'Sentiment F1')
  nsentiment_f1.attach(train_evaluator, 'NSentiment F1')
  target_f1.attach(train_evaluator, 'Target F1')
  holder_f1.attach(train_evaluator, 'Holder F1')

  graph_f1.attach(evaluator, 'Graph F1')
  sentiment_f1.attach(evaluator, 'Sentiment F1')
  nsentiment_f1.attach(evaluator, 'NSentiment F1')
  target_f1.attach(evaluator, 'Target F1')
  holder_f1.attach(evaluator, 'Holder F1')

  train_loss.attach(trainer, 'Loss')
  eval_loss.attach(evaluator, 'Loss')
  running_loss.attach(trainer, 'Running Loss')

In [None]:
def attach_wandb_logger(trainer, evaluator, train_evaluator,
                        optimizer, name: str, config: Dict[str, Any]) -> WandBLogger:

  note = 'SemEval2022 Task 10 Subtask 1'
  wandb_logger = WandBLogger(entity='sadra-barikbin', project='SSA-Attentionist-Ablation-Study',
                             config=config, group=f"{DATASET_NAMES_MAP[config['dataset']]}:{name}",
                             notes=note, resume=True,
                            tags=[DATASET_NAMES_MAP[config["dataset"]], name])

  wandb_logger.attach_opt_params_handler(trainer, event_name=Events.EPOCH_COMPLETED, 
                                         optimizer=optimizer)

  wandb_logger.attach_output_handler(trainer, event_name=Events.EPOCH_COMPLETED,
                                    tag="training", metric_names=["Loss"])

  metric_names = ["Graph F1", "Sentiment F1", "NSentiment F1", "Target F1", "Holder F1"]

  wandb_logger.attach_output_handler(train_evaluator, event_name=Events.COMPLETED,
                                    global_step_transform=global_step_from_engine(trainer),
                                    tag="training", metric_names=metric_names)

  wandb_logger.attach_output_handler(evaluator, event_name=Events.COMPLETED,
                                    global_step_transform=global_step_from_engine(trainer),
                                    tag="evaluation", metric_names=["Loss"] + metric_names)
  @trainer.on(Events.COMPLETED)
  def close_logger():
    wandb_logger.close()

  return wandb_logger

### Variants

In [None]:
variants = [
            (ModelFactory(StructuredSentimentPredictor, copy.deepcopy(CONFIG)).With(
                "edge_predictor.USE_BIAS", True).With(
                    "base_pooler.TARGET_AND_HOLDER_SEPARATE_HEAD", True).With(
                        "base_pooler.AVERAGE_HEADS", True).With(
                            "sequence_labeler.WEIGHTED_LOSS", True
                        ),
             "+Bias+SepH+AvgH+wSeq"
             ),
            (ModelFactory(StructuredSentimentPredictor, copy.deepcopy(CONFIG)).With(
                "edge_predictor.USE_BIAS", True).With(
                    "base_pooler.TARGET_AND_HOLDER_SEPARATE_HEAD", True).With(
                        "base_pooler.AVERAGE_HEADS", True),
             "+Bias+SepH+AvgH"
             ),
            (ModelFactory(StructuredSentimentPredictor, copy.deepcopy(CONFIG)).With(
                "edge_predictor.USE_BIAS", True).With(
                    "base_pooler.AVERAGE_HEADS", True),
             "+Bias+AvgH"
             ),
            (ModelFactory(StructuredSentimentPredictor, copy.deepcopy(CONFIG)).With(
                "edge_predictor.USE_BIAS", True).With(
                    "base_pooler.TARGET_AND_HOLDER_SEPARATE_HEAD", True),
             "+Bias+SepH"
             ),
            (ModelFactory(StructuredSentimentPredictor, copy.deepcopy(CONFIG)).With(
                "base_pooler.AVERAGE_HEADS", True),
             "+AvgH"
             ),
            (ModelFactory(StructuredSentimentPredictor, copy.deepcopy(CONFIG)).With(
                "base_pooler.TARGET_AND_HOLDER_SEPARATE_HEAD", True),
             "+SepH"
             ),
            (ModelFactory(StructuredSentimentPredictor, copy.deepcopy(CONFIG)).With(
                "edge_predictor.USE_BIAS", True),
             "+Bias"
             ),
            (ModelFactory(StructuredSentimentPredictor, copy.deepcopy(CONFIG)),
             "base"
            ),     
]

### Run

In [None]:
for random_seed in [41, 666, 1567, 4447, 37773]:

  if random_seed == 41 or random_seed == 666 or random_seed == 1567 or random_seed == 4447:
    continue

  torch.manual_seed(random_seed)

  for dataset_name in ['opener_en', 'mpqa', 'darmstadt_unis',
                       'opener_es', 'norec', 'multibooked_ca', 'multibooked_eu']:
    
    if (dataset_name == 'opener_en' or \
        dataset_name == 'mpqa' or \
        # dataset_name == 'darmstadt_unis' or \
        dataset_name == 'opener_es' or \
        dataset_name == 'norec' or \
        dataset_name == 'multibooked_ca'):
      continue

    for variant_idx, (m_factory, variant_name) in enumerate(variants):
      
      m_factory.With("dataset", dataset_name)
      if dataset_name in ['opener_es', 'multibooked_ca', 'multibooked_eu']:
        m_factory.With("batch_size", 16)
      else:
        m_factory.With("batch_size", 32)

      m_factory.With("random_seed", random_seed)

      config = m_factory.config

      with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        train_ds = SemEval2022Task10Dataset('train', config)
        dev_ds = SemEval2022Task10Dataset('dev', config)

      train_dataloader = DataLoader(
          train_ds,
          collate_fn=functools.partial(collate_fn, tokenizer=get_tokenizer(config["dataset"])),
          batch_size=config["batch_size"])

      dev_dataloader = DataLoader(
          dev_ds,
          collate_fn=functools.partial(collate_fn, tokenizer=get_tokenizer(config["dataset"])),
          batch_size= config["batch_size"] * 2)

      # torch.cuda.empty_cache()
      print(
          f"Variant: {variant_name}, Dataset: {dataset_name}, Random seed: {random_seed}"
          f", Batch size: {config['batch_size']}"
      )
      model = m_factory.build(factory_function).to(DEVICE)

      optimizer_parameter_groups = [
        {'params': list(model.base.parameters())},
        {'params': list(model.novelty.parameters())}
      ]
      optimizer = torch.optim.AdamW(optimizer_parameter_groups)

      trainer, evaluator, train_evaluator = create_engines(model, optimizer, config)

      create_and_attach_common_metrics(trainer, evaluator, train_evaluator, config)

      attach_handlers(trainer, evaluator)

      wandb_handler = attach_wandb_logger(trainer, evaluator, train_evaluator, optimizer,
                                          variant_name, config)
    

      lr_scheduler = LRScheduler(StepLR(optimizer, 9, gamma=0.1), save_history=True)
      scheduler_handler = trainer.add_event_handler(
          Events.EPOCH_COMPLETED, lr_scheduler
      )

      base_warmup_scheduler = PiecewiseLinear(
          optimizer, "lr",
          milestones_values=[(1, 1e-6), (len(train_ds) // config["batch_size"], 1e-4)],
          param_group_index=0
      )

      novelty_warmup_scheduler = PiecewiseLinear(
          optimizer, "lr",
          milestones_values=[(1, 1e-5), (len(train_ds) // config["batch_size"], 1e-3)],
          param_group_index=1
      )

      event_filter = lambda _, event: event < (len(train_ds) // config["batch_size"]) + 1
      trainer.add_event_handler(
          Events.ITERATION_STARTED(event_filter),
          base_warmup_scheduler
      )
      trainer.add_event_handler(
          Events.ITERATION_STARTED(event_filter),
          novelty_warmup_scheduler
      )

      @trainer.on(Events.EPOCH_COMPLETED)
      def evaluate():
        evaluator.run(dev_dataloader)
        train_evaluator.run(train_dataloader)

      trainer.run(train_dataloader,max_epochs=20)

      # I encountered a lot of trouble without line bottom. Without that, CUDA_OUT_OF_MEMORY
      # was raised. In previous version, I did'nt face that because I had checkpointing that
      # implicitly did this job.
      model.cpu()

In [None]:
wandb_handler.close()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
evaluation/Graph F1,▂▁▂▁▃▇▇▆█████
evaluation/Holder F1,▅▁▄▆▅▇▇▇█████
evaluation/Loss,▄▁▁▅▄▅▅▇▇▇▇▇█
evaluation/NSentiment F1,▁▄▄▄▆▇▇▆█████
evaluation/Sentiment F1,▁▄▃▁▃▇▇▆█▇▇██
evaluation/Target F1,▂▁▃▅▅▆▆▅▇▇▇██
lr/group_0,█████████▁▁▁▁
lr/group_1,█████████▁▁▁▁
training/Graph F1,▁▁▂▂▄▆▆▇█████
training/Holder F1,▄▁▃▅▆▇▇██████

0,1
evaluation/Graph F1,0.63964
evaluation/Holder F1,0.93756
evaluation/Loss,11.38285
evaluation/NSentiment F1,0.8361
evaluation/Sentiment F1,0.81649
evaluation/Target F1,0.7679
lr/group_0,1e-05
lr/group_1,0.0001
training/Graph F1,0.87952
training/Holder F1,0.96531


### Compare

In [None]:
api = wandb.Api()
runs = api.runs(
    f'sadra-barikbin/SSA-Attentionist-Ablation-Study',
    order='summary_metrics.evaluation/Graph F1'
)

In [None]:
settings = ["base", "+AvgH", "+SepH", "+Bias", "+Bias+AvgH", "+Bias+SepH", "+Bias+SepH+AvgH"]

In [None]:
for ds in DATASET_NAMES_MAP:
  for setting in settings:
    print([r.history(keys=["evaluation/Graph F1"]).max()["evaluation/Graph F1"] for r in runs if DATASET_NAMES_MAP[ds] in r.tags and setting in r.tags])

In [None]:
results = []
base_result = None
for setting in settings:
  
  f1_s = [
    np.mean([
     run.history(
      keys=["evaluation/Graph F1"]
     ).max()["evaluation/Graph F1"] for run in runs if DATASET_NAMES_MAP[ds] in run.tags and \
                                                       setting in run.tags
    ]
  ) for ds in DATASET_NAMES_MAP ]
  f1_s.append(np.mean(f1_s))
  
  if setting == 'base':
    base_result = np.array(f1_s)
  else:
    f1_s = (np.array(f1_s) - base_result).tolist()

  results.append([setting] + f1_s)

In [None]:
print(tabulate(results, headers=list(DATASET_NAMES_MAP.values())+["Average"]))

                    Opener_En     Opener_Es        Norec    Darmstadt_unis    Multibooked_eu    Multibooked_ca        MPQA      Average
---------------  ------------  ------------  -----------  ----------------  ----------------  ----------------  ----------  -----------
base              0.670929      0.627744      0.45091            0.432957         0.667957          0.609643     0.387436    0.549654
+AvgH            -0.0583241     0.0186089    -0.0549741         -0.117793         0.00157354        0.0423888   -0.331165   -0.0713835
+SepH            -0.00351852   -0.00114257    0.00168822         0.0280018       -0.00140467       -0.0201359   -0.0920118  -0.0126462
+Bias            -0.0341817     0.0172585    -0.0135171          0.0142917        0.00257007        0.028886    -0.0122436   0.00043769
+Bias+AvgH       -0.0640342     0.000325614  -0.0615373         -0.128843         0.0213508         0.0386022   -0.320668   -0.0735435
+Bias+SepH        0.000599423   0.0376719    -0.06553

In [None]:
settings = ["+Bias", "+Bias+AvgH", "+Bias+SepH", "+Bias+SepH+AvgH"]

In [None]:
results = []
base_result = None
for setting in settings:
  
  f1_s = [
    np.mean([
     run.history(
      keys=["evaluation/Graph F1"]
     ).max()["evaluation/Graph F1"] for run in runs if DATASET_NAMES_MAP[ds] in run.tags and \
                                                       setting in run.tags
    ]
  ) for ds in DATASET_NAMES_MAP ]
  f1_s.append(np.mean(f1_s))
  
  if setting == '+Bias':
    base_result = np.array(f1_s)
  else:
    f1_s = (np.array(f1_s) - base_result).tolist()

  if setting == '+Bias':
    s = 'base'
  else:
    s = setting[5:]
  results.append([s] + f1_s)

In [None]:
print(tabulate(results, headers=list(DATASET_NAMES_MAP.values())+["Average"],tablefmt='latex',floatfmt=".2f"))

\begin{tabular}{lrrrrrrrr}
\hline
            &   Opener\_En &   Opener\_Es &   Norec &   Darmstadt\_unis &   Multibooked\_eu &   Multibooked\_ca &   MPQA &   Average \\
\hline
 base       &        0.64 &        0.65 &    0.44 &             0.45 &             0.67 &             0.64 &   0.38 &      0.55 \\
 +AvgH      &       -0.03 &       -0.02 &   -0.05 &            -0.14 &             0.02 &             0.01 &  -0.31 &     -0.07 \\
 +SepH      &        0.03 &        0.02 &   -0.05 &            -0.03 &            -0.00 &            -0.02 &  -0.08 &     -0.02 \\
 +SepH+AvgH &        0.01 &       -0.00 &   -0.03 &            -0.07 &            -0.02 &            -0.00 &  -0.24 &     -0.05 \\
\hline
\end{tabular}


## Other analyses (given by organizers)

In [None]:
!git clone https://github.com/sadra-barikbin/semeval22_structured_sentiment.git

Cloning into 'semeval22_structured_sentiment'...
remote: Enumerating objects: 1045, done.[K
remote: Counting objects: 100% (151/151), done.[K
remote: Compressing objects: 100% (86/86), done.[K
remote: Total 1045 (delta 79), reused 112 (delta 56), pack-reused 894[K
Receiving objects: 100% (1045/1045), 16.44 MiB | 20.50 MiB/s, done.
Resolving deltas: 100% (508/508), done.


In [None]:
! cd semeval22_structured_sentiment/data/mpqa && bash process_mpqa.sh
! cd semeval22_structured_sentiment/data/darmstadt_unis && bash process_darmstadt.sh

In [None]:
! cd semeval22_structured_sentiment && git pull

Already up to date.


In [None]:
! cd semeval22_structured_sentiment/analysis && bash analysis_script.sh ../submissions/sadra-barikbin/monolingual/norec/predictions.json ../data ../submissions

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  5877  100  5877    0     0  49805      0 --:--:-- --:--:-- --:--:-- 49805
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  4443  100  4443    0     0  33659      0 --:--:-- --:--:-- --:--:-- 33659
Domain Analysis on Norec:
stage:     	44
restaurants:     	50
games:     	76
literature:     	165
products:     	224
screen:     	346
music:     	367


stage: 0.396
restaurants: 0.407
games: 0.396
literature: 0.314
products: 0.312
screen: 0.390
music: 0.314
Negation Analysis on Norec:


Polar expression count:
########################################
In neg scope: 

# Run on Test Data

## Helper classes & functions

In [None]:
class ModelFactory:

  def __init__(self, model_class: Type[nn.Module], config: Dict[str, Any]):
    self.model_class = model_class
    self.config = config

  def With(self, key: str, value: Any):
    father, _, item = key.rpartition('.')
    keychain = father.split('.')
    if keychain[0] == '':
      keychain = []
    father = functools.reduce(lambda x, k: x[k], keychain, self.config)
    father[item] = value
    return self


  def build(
      self,
      factory_function: Optional[Callable[[Type[nn.Module], Dict[str, Any]], nn.Module]] = None
  ) -> nn.Module:

    if factory_function is None:
      return self.model_class(**self.config)
    else:
      return factory_function(self.model_class, self.config)


In [None]:
def factory_function(cls, config):
  return cls(copy.deepcopy(config))

In [None]:
def attach_handlers(trainer: Engine, evaluator: Engine, checkpoint_dir, checkpoint_name):
  trainer.add_event_handler(Events.ITERATION_COMPLETED,
                            TerminateOnNan(output_transform=itemgetter('loss')))
  pbar = ProgressBar()
  pbar.attach(trainer,metric_names=['Running Loss'])

  stopper = EarlyStopping(patience=6, score_function=lambda engine: engine.state.metrics['Graph F1'],
                          trainer=trainer)
  evaluator.add_event_handler(Events.COMPLETED, stopper)

  checkpointer = ModelCheckpoint(checkpoint_dir, checkpoint_name,
                               score_name='Graph F1', n_saved=2, require_empty=False)
  evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model})

In [None]:
def create_engines(model, optimizer, config) -> Tuple[Engine, Engine, Engine]:
  trainer = create_supervised_trainer(
      model, optimizer, functools.partial(loss_fn, config=config), deterministic=True,
      device=DEVICE, prepare_batch=prepare_batch, output_transform=train_output_transform)

  evaluator = create_supervised_evaluator(
      model, prepare_batch=prepare_batch,
      device=DEVICE, output_transform=functools.partial(evaluate_output_transform, model=model))

  train_evaluator = create_supervised_evaluator(
      model, prepare_batch=prepare_batch,
      device=DEVICE, output_transform=functools.partial(evaluate_output_transform, model=model))

  return trainer, evaluator, train_evaluator

In [None]:
def create_and_attach_common_metrics(trainer, evaluator, train_evaluator, config):
  train_loss = Average(output_transform=loss_output_transform)
  eval_loss = Average(output_transform=functools.partial(loss_metric, config=config))
  running_loss = RunningAverage(output_transform=loss_output_transform)

  graph_f1 = F1(GraphMatchLogic(), output_transform=acc_output_transform)
  sentiment_f1 = F1(SentimentMatchLogic(), output_transform=acc_output_transform)
  nsentiment_f1 = F1(SentimentMatchLogic(keep_polarity=False), output_transform=acc_output_transform)
  target_f1 = F1(TargetMatchLogic(), output_transform=acc_output_transform)
  holder_f1 = F1(HolderMatchLogic(), output_transform=acc_output_transform)

  graph_f1.attach(train_evaluator, 'Graph F1')
  sentiment_f1.attach(train_evaluator, 'Sentiment F1')
  nsentiment_f1.attach(train_evaluator, 'NSentiment F1')
  target_f1.attach(train_evaluator, 'Target F1')
  holder_f1.attach(train_evaluator, 'Holder F1')

  graph_f1.attach(evaluator, 'Graph F1')
  sentiment_f1.attach(evaluator, 'Sentiment F1')
  nsentiment_f1.attach(evaluator, 'NSentiment F1')
  target_f1.attach(evaluator, 'Target F1')
  holder_f1.attach(evaluator, 'Holder F1')

  train_loss.attach(trainer, 'Loss')
  eval_loss.attach(evaluator, 'Loss')
  running_loss.attach(trainer, 'Running Loss')

In [None]:
def attach_wandb_logger(trainer, evaluator, train_evaluator,
                        optimizer, name: str, config: Dict[str, Any]) -> WandBLogger:

  note = 'SemEval2022 Task 10 Subtask 1'
  wandb_logger = WandBLogger(entity='sadra-barikbin', project='SSA-Attentionist-best-variants',
                             config=config, group=f"{DATASET_NAMES_MAP[config['dataset']]}:{name}",
                             notes=note, resume=True,
                            tags=[DATASET_NAMES_MAP[config["dataset"]], name])

  wandb_logger.attach_opt_params_handler(trainer, event_name=Events.EPOCH_COMPLETED, 
                                         optimizer=optimizer)

  wandb_logger.attach_output_handler(trainer, event_name=Events.EPOCH_COMPLETED,
                                    tag="training", metric_names=["Loss"])

  metric_names = ["Graph F1", "Sentiment F1", "NSentiment F1", "Target F1", "Holder F1"]

  wandb_logger.attach_output_handler(train_evaluator, event_name=Events.COMPLETED,
                                    global_step_transform=global_step_from_engine(trainer),
                                    tag="training", metric_names=metric_names)

  wandb_logger.attach_output_handler(evaluator, event_name=Events.COMPLETED,
                                    global_step_transform=global_step_from_engine(trainer),
                                    tag="evaluation", metric_names=["Loss"] + metric_names)
  @trainer.on(Events.COMPLETED)
  def close_logger():
    wandb_logger.close()

  return wandb_logger

## Best Variants

In [None]:
variants=[
  (ModelFactory(StructuredSentimentPredictor, copy.deepcopy(CONFIG)).With(
    "edge_predictor.USE_BIAS", True).With(
        "base_pooler.TARGET_AND_HOLDER_SEPARATE_HEAD", True).With(
            "dataset", "multibooked_ca"),
  "+Bias+SepH" # I put this first because has the longest sentence therein. So as to prevent
               # CUDA out of memory errors.
  ),
  (ModelFactory(StructuredSentimentPredictor, copy.deepcopy(CONFIG)).With(
    "edge_predictor.USE_BIAS", True).With(
        "base_pooler.TARGET_AND_HOLDER_SEPARATE_HEAD", True).With(
            "dataset", "opener_en"),
  "+Bias+SepH"
  ),
  (ModelFactory(StructuredSentimentPredictor, copy.deepcopy(CONFIG)).With(
      "dataset", "mpqa"),
    "base"
  ),
  (ModelFactory(StructuredSentimentPredictor, copy.deepcopy(CONFIG)).With(
    "base_pooler.TARGET_AND_HOLDER_SEPARATE_HEAD", True).With(
        "dataset", "darmstadt_unis"),
  "+SepH"
  ),
  (ModelFactory(StructuredSentimentPredictor, copy.deepcopy(CONFIG)).With(
    "edge_predictor.USE_BIAS", True).With(
        "base_pooler.TARGET_AND_HOLDER_SEPARATE_HEAD", True).With(
            "dataset", "opener_es"),
  "+Bias+SepH"
  ),
  (ModelFactory(StructuredSentimentPredictor, copy.deepcopy(CONFIG)).With(
    "base_pooler.TARGET_AND_HOLDER_SEPARATE_HEAD", True).With(
        "dataset", "norec"),
  "+SepH"
  ),
  (ModelFactory(StructuredSentimentPredictor, copy.deepcopy(CONFIG)).With(
    "base_pooler.AVERAGE_HEADS", True).With(
        "dataset", "multibooked_eu"),
  "+AvgH"
  ),
]

## Run

In [None]:
for variant_idx, (m_factory, variant_name) in enumerate(variants):

  for random_seed in [41, 666, 1567, 4447, 37773]:

    torch.manual_seed(random_seed)
      
    if m_factory.config["dataset"] in ['opener_es', 'multibooked_ca', 'multibooked_eu']:
      m_factory.With("batch_size", 16)
    else:
      m_factory.With("batch_size", 32)

    m_factory.With("random_seed", random_seed)

    config = m_factory.config

    dataset_name = config["dataset"]

    with warnings.catch_warnings():
      warnings.simplefilter("ignore")
      train_ds = SemEval2022Task10Dataset('train', config)
      dev_ds = SemEval2022Task10Dataset('dev', config)

      train_dataloader = DataLoader(
          train_ds,
          collate_fn=functools.partial(collate_fn, tokenizer=get_tokenizer(config["dataset"])),
          batch_size=config["batch_size"])

      dev_dataloader = DataLoader(
          dev_ds,
          collate_fn=functools.partial(collate_fn, tokenizer=get_tokenizer(config["dataset"])),
          batch_size= config["batch_size"] * 2)

      print(
          f"Variant: {variant_name}, Dataset: {dataset_name}, Random seed: {random_seed}"
          f", Batch size: {config['batch_size']}"
      )
      model = m_factory.build(factory_function).to(DEVICE)

      optimizer_parameter_groups = [
        {'params': list(model.base.parameters())},
        {'params': list(model.novelty.parameters())}
      ]
      optimizer = torch.optim.AdamW(optimizer_parameter_groups)

      trainer, evaluator, train_evaluator = create_engines(model, optimizer, config)

      create_and_attach_common_metrics(trainer, evaluator, train_evaluator, config)

      wandb_handler = attach_wandb_logger(trainer, evaluator, train_evaluator, optimizer,
                                          variant_name, config)

      attach_handlers(trainer, evaluator, wandb_handler.run.dir,
                      f"Attentionist_{dataset_name}_{variant_name}")

    

      lr_scheduler = LRScheduler(StepLR(optimizer, 9, gamma=0.1), save_history=True)
      scheduler_handler = trainer.add_event_handler(
          Events.EPOCH_COMPLETED, lr_scheduler
      )

      base_warmup_scheduler = PiecewiseLinear(
          optimizer, "lr",
          milestones_values=[(1, 1e-6), (len(train_ds) // config["batch_size"], 1e-4)],
          param_group_index=0
      )

      novelty_warmup_scheduler = PiecewiseLinear(
          optimizer, "lr",
          milestones_values=[(1, 1e-5), (len(train_ds) // config["batch_size"], 1e-3)],
          param_group_index=1
      )

      event_filter = lambda _, event: event < (len(train_ds) // config["batch_size"]) + 1
      trainer.add_event_handler(
          Events.ITERATION_STARTED(event_filter),
          base_warmup_scheduler
      )
      trainer.add_event_handler(
          Events.ITERATION_STARTED(event_filter),
          novelty_warmup_scheduler
      )

      @trainer.on(Events.EPOCH_COMPLETED)
      def evaluate():
        evaluator.run(dev_dataloader)
        train_evaluator.run(train_dataloader)

      trainer.run(train_dataloader,max_epochs=20)

      # I encountered a lot of trouble without line bottom. Without that, CUDA_OUT_OF_MEMORY
      # was raised. In previous version, I did'nt face that because I had checkpointing that
      # implicitly did this job.
      model.cpu()

In [None]:
wandb_handler.close()

## Evaluate

> First run [Inference](#inference) section functions

In [None]:
api = wandb.Api()
for ds in DATASET_NAMES_MAP:

  if ds not in ['mpqa','multibooked_ca']:
    continue

  runs = api.runs(
      f'sadra-barikbin/SSA-Attentionist-best-variants',
      order='summary_metrics.evaluation/Graph F1',
      filters={'tags': DATASET_NAMES_MAP[ds]}
  )

  for idx, run in enumerate(runs):

    if ds == 'multibooked_ca' and idx==0:
      continue
  
    model_path = run.files()[1].download(replace=True)

    config = yaml.load(open(run.files()[2].download(replace=True).name, 'r'), yaml.Loader)
    config = {k:v['value'] for k,v in config.items() if k != 'wandb_version'}

    state_dict = torch.load(model_path.name)
    model = StructuredSentimentPredictor(config)
    model.load_state_dict(state_dict)
    model.to(DEVICE)

    tokenizer = get_tokenizer(ds)

    test_ds = SemEval2022Task10Dataset('test', config)
    test_dataloader = DataLoader(
        test_ds,
        batch_size=config["batch_size"] * 2,
        collate_fn=functools.partial(test_collate_fn, device=DEVICE)
    )
    
    tester = Engine(test_engine_process_function)
    output_store = EpochOutputStore()
    output_store.attach(tester, 'outputs')

    pbar = ProgressBar()
    pbar.attach(tester)

    tester.run(test_dataloader)

    path = f"test_preds/{ds}/{idx}"
    os.makedirs(path)

    with open(f"{path}/predictions.json", 'w') as f:
      json.dump(list(itertools.chain(*tester.state.outputs)), f)

Some weights of the model checkpoint at setu4993/LaBSE were not used when initializing BertModel: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


[1/11]   9%|9          [00:00<?]

Some weights of the model checkpoint at setu4993/LaBSE were not used when initializing BertModel: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


[1/11]   9%|9          [00:00<?]

Some weights of the model checkpoint at setu4993/LaBSE were not used when initializing BertModel: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


[1/11]   9%|9          [00:00<?]

Some weights of the model checkpoint at setu4993/LaBSE were not used when initializing BertModel: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


[1/11]   9%|9          [00:00<?]

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'roberta.pooler.dense.weight', 'lm_head.bias', 'roberta.pooler.dense.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


[1/33]   3%|3          [00:00<?]

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'roberta.pooler.dense.weight', 'lm_head.bias', 'roberta.pooler.dense.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


[1/33]   3%|3          [00:00<?]

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'roberta.pooler.dense.weight', 'lm_head.bias', 'roberta.pooler.dense.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


[1/33]   3%|3          [00:00<?]

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'roberta.pooler.dense.weight', 'lm_head.bias', 'roberta.pooler.dense.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


[1/33]   3%|3          [00:00<?]

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'roberta.pooler.dense.weight', 'lm_head.bias', 'roberta.pooler.dense.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


[1/33]   3%|3          [00:00<?]

In [None]:
! git clone https://github.com/sadra-barikbin/semeval22_structured_sentiment.git

In [None]:
from semeval22_structured_sentiment.evaluation.evaluate import convert_opinion_to_tuple, tuple_f1

In [None]:
for ds in DATASET_NAMES_MAP:
  
  with open(f"semeval22_structured_sentiment/data/{ds}/test.json") as o:
    gold = json.load(o)
  gold = dict([(s["sent_id"], convert_opinion_to_tuple(s)) for s in gold])
  
  f1_sum = 0
  f1_argmax = 0
  f1_max = -1
  for idx in range(5):

    with open(f"test_preds/{ds}/{idx}/predictions.json") as o:
      preds = json.load(o)
      preds = dict([(s["sent_id"], convert_opinion_to_tuple(s)) for s in preds])

    g = set(gold.keys())
    p = set(preds.keys())

    assert g.issubset(p), f"missing some sentences: {g.difference(p)}"
    assert p.issubset(g), f"predictions contain sentences that are not in golds: {p.difference(g)}"

    f1 = tuple_f1(gold, preds)
    if f1 > f1_max:
      f1_max = f1
      f1_argmax = idx

    f1_sum += f1
  
  print(f"{DATASET_NAMES_MAP[ds]} F1: {f1_sum / 5.}, argbest rand_seed: {f1_argmax}")

Opener_En F1: 0.5898529367600808, argbest rand_seed: 4
Opener_Es F1: 0.5538162046890416, argbest rand_seed: 3
Norec F1: 0.3241464435669958, argbest rand_seed: 4
Darmstadt_unis F1: 0.2586611903403816, argbest rand_seed: 2
Multibooked_eu F1: 0.5956168588807508, argbest rand_seed: 4
Multibooked_ca F1: 0.6173553271467782, argbest rand_seed: 2
MPQA F1: 0.26075571778518675, argbest rand_seed: 4
