<a href="https://colab.research.google.com/github/sadra-barikbin/novel-solutions-for-sentiment-analysis/blob/main/Attentionist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [1]:
!pip install transformers pytorch-ignite wandb

Collecting transformers
  Downloading transformers-4.16.2-py3-none-any.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 5.0 MB/s 
[?25hCollecting pytorch-ignite
  Downloading pytorch_ignite-0.4.8-py3-none-any.whl (251 kB)
[K     |████████████████████████████████| 251 kB 70.9 MB/s 
[?25hCollecting wandb
  Downloading wandb-0.12.11-py2.py3-none-any.whl (1.7 MB)
[K     |████████████████████████████████| 1.7 MB 47.2 MB/s 
[?25hCollecting tokenizers!=0.11.3,>=0.10.1
  Downloading tokenizers-0.11.6-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.5 MB)
[K     |████████████████████████████████| 6.5 MB 52.6 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.47-py2.py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 65.5 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 6.8 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6

In [2]:
import os
import json
import torch
import ignite
import wandb
import json
import warnings
import functools
import itertools
import collections
import torch.nn.functional     as     F
import numpy                   as     np
import matplotlib.pyplot       as     plt
import scipy.stats
from  operator                                import attrgetter, itemgetter
from  enum                                    import IntEnum
from  dataclasses                             import dataclass
from  torch                                   import Tensor
from  tqdm                                    import tqdm
from  torch.utils.data                        import DataLoader, Dataset
from  torch.utils.tensorboard                 import SummaryWriter
from  torch.nn.utils.rnn                      import pad_sequence
from  torch                                   import nn, Tensor
from  torch.optim.lr_scheduler                import StepLR
from  transformers                            import RobertaModel, RobertaTokenizerFast
from  transformers                            import PreTrainedTokenizerFast
from  transformers                            import BertModel, BertTokenizerFast
from  transformers                            import AutoTokenizer, AutoModel
from  transformers.modeling_outputs           import BaseModelOutput
from  typing                                  import Tuple, Dict, List, Any, Sequence
from  timeit                                  import timeit
from  ignite.metrics                          import Accuracy, Fbeta, Average
from  ignite.metrics                          import Metric, RunningAverage
from  ignite.handlers.terminate_on_nan        import TerminateOnNan
from  ignite.handlers.checkpoint              import ModelCheckpoint
from  ignite.handlers                         import EarlyStopping, global_step_from_engine
from  ignite.handlers                         import EpochOutputStore
from  ignite.handlers.param_scheduler         import create_lr_scheduler_with_warmup
from  ignite.handlers.param_scheduler         import PiecewiseLinear
from  ignite.engine                           import create_supervised_trainer
from  ignite.engine                           import Engine, create_supervised_evaluator
from  ignite.engine.events                    import Events
from  ignite.contrib.handlers.tqdm_logger     import ProgressBar
from  ignite.contrib.handlers.wandb_logger    import WandBLogger

In [3]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DATASET_NAME = 'ALL'
SEQ_LABEL_WEIGHTED = False
torch.manual_seed(41)

<torch._C.Generator at 0x7fbf39805c30>

# Helper Classes

In [4]:
class TokenBIOLabelEnum(IntEnum):

  def __new__(cls, value, tag=None):
    obj = int.__new__(cls, value)
    obj._value_ = value
    obj.tag = tag
    return obj

  @classmethod
  def from_integers(cls, integers: Sequence[int]) -> Sequence[IntEnum]:
    return [cls(integer) for integer in integers]

  @classmethod
  def from_labels(cls, labels: Sequence[IntEnum]) -> Sequence[int]:
    return [label.value for label in labels]

  def is_begin(self):
    return self % 2 == 0 and self != self.__class__.OUT

  def is_in(self):
    return self % 2 == 1

  def is_out(self):
    return self == self.__class__.OUT

In [5]:
class SentExprTokenEnum(TokenBIOLabelEnum):
  POSITIVE_EXPR_BEGIN = (0, 'Positive')
  POSITIVE_EXPR_IN = (1, 'Positive')
  NEGATIVE_EXPR_BEGIN = (2, 'Negative')
  NEGATIVE_EXPR_IN = (3, 'Negative')
  NEUTRAL_EXPR_BEGIN = (4, 'Neutral')
  NEUTRAL_EXPR_IN = (5, 'Neutral')
  OUT = 6

class SentHolderTokenEnum(TokenBIOLabelEnum):
  EXPR_BEGIN = 0
  EXPR_IN = 1
  OUT = 2


class SentTargetTokenEnum(TokenBIOLabelEnum):
  EXPR_BEGIN = 0
  EXPR_IN = 1
  OUT = 2

In [6]:
@dataclass
class SentimentGraphNode:

  span_in_sentence: slice
  tag: Any=None

  @property
  def span_length(self) -> int:
    return self.span_in_sentence.stop - self.span_in_sentence.start


  @property
  def indices(self) -> List[int]:
    return list(range(self.span_in_sentence.start, self.span_in_sentence.stop))

  def __hash__(self) -> int:
    return (self.span_in_sentence.start, self.span_in_sentence.stop).__hash__()


  def __eq__(self, other) -> bool:
    return self.span_in_sentence == other.span_in_sentence


@dataclass
class SentimentGraphNodeSet:
  sentiment_nodes: List[SentimentGraphNode]
  target_nodes: List[SentimentGraphNode]
  holder_nodes: List[SentimentGraphNode]

In [7]:
@dataclass
class SentimentGraphEdgeSet:
  sentiment_target_edges: Tensor
  sentiment_holder_edges: Tensor

In [8]:
@dataclass
class SentimentGraph:
  nodes: SentimentGraphNodeSet
  edges: SentimentGraphEdgeSet

# Tokenizer

In [9]:
@functools.lru_cache()
def get_tokenizer(dataset_name: str) -> PreTrainedTokenizerFast:
  if dataset_name in ["opener_en", "mpqa", "darmstadt_unis"]:
    return RobertaTokenizerFast.from_pretrained('roberta-base', add_prefix_space=True)
  else:
    return BertTokenizerFast.from_pretrained('setu4993/LaBSE', add_prefix_space=True)

# Prepare Datasets

In [None]:
! unzip drive/MyDrive/semeval2022task10.zip

In [None]:
class SemEval2022Task10Dataset(Dataset):

  def __init__(self, name: str, split: str):
    super(SemEval2022Task10Dataset, self).__init__()

    dataset_names = ["opener_en", "opener_es", "norec", "darmstadt_unis", "mpqa",
                    "multibooked_ca", "multibooked_eu", "all"]

    assert name in dataset_names
    assert split in ["train", "dev", "test"]

    tokenizer = get_tokenizer(name)

    if name == 'all':
      assert split != 'test'

      data = []
      for _name in dataset_names:
        data.extend(json.load(open(f"{_name}/{split}.json")))
    else:
      data = json.load(open(f"{name}/{split}.json"))

    if split == 'test':
      self.data = data
      return

    self.data = []
    sentim_labels_stat = collections.Counter(SentExprTokenEnum)
    target_labels_stat = collections.Counter(SentTargetTokenEnum)
    holder_labels_stat = collections.Counter(SentHolderTokenEnum)
    for sent_idx, item in enumerate(data):
      if not item['opinions']:
          continue
      if name == 'mpqa' and split == 'train' and (sent_idx == 1897 or sent_idx == 2292):
        continue
      if name == 'mpqa' and split == 'dev' and sent_idx == 2006:
        continue
      if item['sent_id'] == 'multibooked/corpora/eu/kaype-quintamar-llanes_1-1':
        continue # Multibooked_eu train sample 833

      encoded = tokenizer([item['text']], add_special_tokens=False,
                         return_offsets_mapping=True, return_length=True)
      tokens_char_offsets = encoded['offset_mapping'][0]
      length = encoded['length'][0]
      tokens_labels_sentiment = [SentExprTokenEnum.OUT] * length
      tokens_labels_target = [SentTargetTokenEnum.OUT] * length
      tokens_labels_holder = [SentHolderTokenEnum.OUT] * length
      sentiment_nodes = []
      target_nodes = []
      holder_nodes = []
      sentim_tgt_edges = {}
      sentim_hld_edges = {}

      for opinion in item['opinions']:
        opinion_sentiment_nodes = []
        opinion_target_nodes = []
        opinion_holder_nodes = []

        for char_span in opinion["Polar_expression"][1]:
          char_span_begin, char_span_end = (int(x) for x in char_span.split(':'))
          token_span_begin, token_span_end = self._char_span_to_token_span_idx(
              sent_idx, tokens_char_offsets, (char_span_begin, char_span_end)
          )
          node = SentimentGraphNode(
              slice(token_span_begin, token_span_end), opinion['Polarity']
          )
          opinion_sentiment_nodes.append(node)
          if node not in sentiment_nodes:
            sentiment_nodes.append(node)

            self._update_sequence_labels(
                sent_idx, tokens_labels_sentiment,
                token_span_begin, token_span_end, "sentiment", opinion["Polarity"]
            )

        for char_span in opinion["Target"][1]:
          char_span_begin, char_span_end = (int(x) for x in char_span.split(':'))
          token_span_begin, token_span_end = self._char_span_to_token_span_idx(
              sent_idx, tokens_char_offsets, (char_span_begin, char_span_end)
          )
          node = SentimentGraphNode(
              slice(token_span_begin, token_span_end), "target"
          )
          opinion_target_nodes.append(node)
          if node not in target_nodes:
            target_nodes.append(node)

            self._update_sequence_labels(
                sent_idx, tokens_labels_target,
                token_span_begin, token_span_end, "target"
            )

        for char_span in opinion["Source"][1]:
          char_span_begin, char_span_end = (int(x) for x in char_span.split(':'))
          token_span_begin, token_span_end = self._char_span_to_token_span_idx(
              sent_idx, tokens_char_offsets, (char_span_begin, char_span_end)
          )
          node = SentimentGraphNode(
              slice(token_span_begin, token_span_end), "holder"
          )
          opinion_holder_nodes.append(node)
          if node not in holder_nodes:
            holder_nodes.append(node)

            self._update_sequence_labels(
                sent_idx, tokens_labels_holder,
                token_span_begin, token_span_end, "holder"
            )

        for sentiment_node in opinion_sentiment_nodes:
          sentim_tgt_edges[sentiment_node] = opinion_target_nodes
          sentim_hld_edges[sentiment_node] = opinion_holder_nodes

      sentiment_nodes.sort(key=attrgetter("span_in_sentence"))
      target_nodes.sort(key=attrgetter("span_in_sentence"))
      holder_nodes.sort(key=attrgetter("span_in_sentence"))

      sentim_tgt_adj_matrix = torch.zeros((len(sentiment_nodes), len(target_nodes)))
      sentim_hld_adj_matrix = torch.zeros((len(sentiment_nodes), len(holder_nodes)))
      for i, sentim_node in enumerate(sentiment_nodes):

        for neighbor_tgt_node in sentim_tgt_edges[sentim_node]:
          sentim_tgt_adj_matrix[i, target_nodes.index(neighbor_tgt_node)] = 1

        for neighbor_hld_node in sentim_hld_edges[sentim_node]:
          sentim_hld_adj_matrix[i, holder_nodes.index(neighbor_hld_node)] = 1
      self.data.append({
          "sent_id": item['sent_id'],
          "text": item['text'],
          "seq_labels": {
              "sentiment": tokens_labels_sentiment,
              "target": tokens_labels_target,
              "holder": tokens_labels_holder},
          "graph":SentimentGraph(
              SentimentGraphNodeSet(
                sentiment_nodes, target_nodes, holder_nodes
              ),
              SentimentGraphEdgeSet(
                sentim_tgt_adj_matrix, sentim_hld_adj_matrix
              )
          )}
        )

      sentim_labels_stat.update(tokens_labels_sentiment)
      target_labels_stat.update(tokens_labels_target)
      holder_labels_stat.update(tokens_labels_holder)

    self.sentim_seq_label_weights = torch.Tensor(
        [sentim_labels_stat[label] for label in sorted(sentim_labels_stat)])
    self.sentim_seq_label_weights = self.sentim_seq_label_weights.max() / \
                                          self.sentim_seq_label_weights

    self.target_seq_label_weights = torch.Tensor(
        [target_labels_stat[label] for label in sorted(target_labels_stat)])
    self.target_seq_label_weights = self.target_seq_label_weights.max() / \
                                          self.target_seq_label_weights

    self.holder_seq_label_weights = torch.Tensor(
        [holder_labels_stat[label] for label in sorted(holder_labels_stat)])
    self.holder_seq_label_weights = self.holder_seq_label_weights.max() / \
                                          self.holder_seq_label_weights

  def _update_sequence_labels(self, sent_idx: int, sequence: List[TokenBIOLabelEnum],
                              start_token: int, end_token: int,
                              expr_type: str, polarity: str=None):
    if expr_type == 'sentiment':
      if polarity == 'Positive':
        begin_label = SentExprTokenEnum.POSITIVE_EXPR_BEGIN
        in_label = SentExprTokenEnum.POSITIVE_EXPR_IN
      elif polarity == 'Negative':
        begin_label = SentExprTokenEnum.NEGATIVE_EXPR_BEGIN
        in_label = SentExprTokenEnum.NEGATIVE_EXPR_IN
      elif polarity == 'Neutral':
        begin_label = SentExprTokenEnum.NEUTRAL_EXPR_BEGIN
        in_label = SentExprTokenEnum.NEUTRAL_EXPR_IN
      else:
        raise ValueError("Given expr_type is 'sentiment' but 'polarity' is None.")
      out = SentExprTokenEnum.OUT
    elif expr_type == 'target':
      begin_label = SentTargetTokenEnum.EXPR_BEGIN
      in_label = SentTargetTokenEnum.EXPR_IN
      out = SentTargetTokenEnum.OUT
    elif expr_type == 'holder':
      begin_label = SentHolderTokenEnum.EXPR_BEGIN
      in_label = SentHolderTokenEnum.EXPR_IN
      out = SentHolderTokenEnum.OUT
    else:
      raise ValueError(f"Given expr_type is not recognized. expr_type={expr_type}")

    # First time I saw these warnings in MPQA dataset. They were
    # mistakes so I opened an issue. But generally such scenarios in which
    # a part of an expression is itself another expression, are acceptable.
    # So I took union of them as the final expression.
    if sequence[start_token] != out:
      warnings.warn(f"Sequence is already updated at index {start_token}, "
                    f"expr_type={expr_type}, sent_idx={sent_idx}")

      if expr_type == 'sentiment' and sequence[start_token].tag != begin_label.tag:
        warnings.warn("Two expressions with conflicting polarities"
                      f" in the same place. sent_idx={sent_idx} "
                      f"polarities={sequence[start_token].tag}-{begin_label.tag}")
    else:
      sequence[start_token] = begin_label
    for i in range(start_token + 1, end_token):
      if sequence[i] != out:
        warnings.warn(f"Sequence is already updated at index {i}, "
                      f"expr_type={expr_type}, sent_idx={sent_idx}")

        if expr_type == 'sentiment' and sequence[i].tag != in_label.tag:
          warnings.warn("Two expressions with conflicting polarities in the same place."
                        f" sent_idx={sent_idx} polarities={sequence[i].tag}-{in_label.tag}")

      sequence[i] = in_label


  def _char_span_to_token_span_idx(
      self, sent_idx: int, tokens_char_offsets: List[Tuple[int, int]],
      char_span: Tuple[int, int]) -> Tuple[int, int]:

    begin_token_idx, end_token_idx = None, None
    for i, (char_offset_begin, char_offset_end) in enumerate(tokens_char_offsets):
      if char_offset_begin == char_span[0] and begin_token_idx is None:
        begin_token_idx = i

      if char_offset_end == char_span[1] and end_token_idx is None:
        end_token_idx = i + 1

      if begin_token_idx is not None and end_token_idx is not None:
        return begin_token_idx, end_token_idx

    raise ValueError("Given char_span does not exist in given token char offsets. "
                     f"sent_idx={sent_idx} char_span={char_span} "
                     f"token_char_offsets={tokens_char_offsets}")


  def __len__(self):
    return len(self.data)


  def __getitem__(self, idx) -> Dict[str, Any] :
    item = self.data[idx]

    item["label_weights"] = {"sentiment": self.sentim_seq_label_weights,
                             "target": self.target_seq_label_weights,
                             "holder": self.holder_seq_label_weights}
    return item

In [None]:
train_ds = SemEval2022Task10Dataset(DATASET_NAME, 'train')
dev_ds = SemEval2022Task10Dataset(DATASET_NAME, 'dev')

# Model Definition

In SemEval2022 Task10 description, it is stated that we can see the sentiment, target and holder expressions in a sentence as a graph, which the expressions are its nodes and there is edge from a sentiment expression node to its pertaining target and holder ones.

In [None]:
class BIOSchemeTokenClassification:

  @classmethod
  def _nstate(cls, state: str, label: TokenBIOLabelEnum) -> str:
    if label.is_out():
      return 'skipping'
    if state == 'reading':
      if label.is_begin():
        return 'start_reading'
      else: # label.is_in()
        return 'reading'
    elif state == 'skipping':
      if label.is_begin():
        return 'start_reading'
      else:
        return 'skipping'
    else: # state == 'start_reading'
      if label.is_begin():
        return 'start_reading'
      else: # label.is_in()
        return 'reading'


  @classmethod
  def extract_spans(cls, token_labels: Sequence[TokenBIOLabelEnum]) -> List[Tuple[slice, Any]]:
    state = 'skipping'
    span_start = 0
    span_length = 0
    spans = []
    for i in range(len(token_labels)):
      label = token_labels[i]
      nstate = BIOSchemeTokenClassification._nstate(state, label)
      if nstate == 'start_reading':
        if state == 'reading' or state == 'start_reading':
          spans.append((slice(span_start, span_start + span_length), token_labels[i-1].tag))
        span_start = i
        span_length = 1
      elif nstate == 'reading':
        span_length += 1
      else: # nstate == 'skipping'
        if state == 'reading' or state == 'start_reading':
          spans.append((slice(span_start, span_start + span_length), token_labels[i-1].tag))
      state = nstate
    if state != 'skipping':
      spans.append((slice(span_start, span_start + span_length), token_labels[i].tag))
    return spans

## Sequence Labeler

In [None]:
class LinearSequenceLabeler(nn.Module):

  def __init__(self, input_dim: int, hidden_dim: int, label_enum: TokenBIOLabelEnum):
    super(LinearSequenceLabeler, self).__init__()

    self.seq_label = nn.ModuleList([nn.Linear(input_dim, hidden_dim), nn.Dropout(), 
                                    nn.ReLU(), nn.Linear(hidden_dim, len(label_enum))])
    self.label_enum = label_enum


  def forward(self, ptm_output: BaseModelOutput) -> Tensor:
    output = ptm_output.last_hidden_state
    for layer in self.seq_label:
      output = layer(output)
    return output


  @torch.no_grad()
  def predict(self, output: Tensor,
              padding_mask: Tensor) -> List[Sequence[TokenBIOLabelEnum]]:
    output = torch.argmax(output, dim=2)
    batch_size = output.shape[0]
    preds_batch = []
    for i in range(batch_size):
      mask_i = padding_mask[i] == 1
      output_i = output[i].detach()
      preds_batch.append(self.label_enum.from_integers(output_i[mask_i].cpu().numpy()))

    return preds_batch

## Node Extractor

In [None]:
class SentimentGraphNodeExtractor(nn.Module):

  def __init__(self):
    super(SentimentGraphNodeExtractor, self).__init__()

    self.sentiment_labeler = LinearSequenceLabeler(768, 64, SentExprTokenEnum)
    self.target_labeler = LinearSequenceLabeler(768, 64, SentTargetTokenEnum)
    self.holder_labeler = LinearSequenceLabeler(768, 64, SentHolderTokenEnum)


  def forward(self, ptm_output: BaseModelOutput) -> Tuple[Tensor, Tensor, Tensor]:
    return (self.sentiment_labeler(ptm_output),
            self.target_labeler(ptm_output),
            self.holder_labeler(ptm_output))


  @torch.no_grad()
  def predict(self, output: Tuple[Tensor, Tensor, Tensor],
              padding_mask: Tensor) -> List[SentimentGraphNodeSet]:
                                                          
    sentiment_labels = self.sentiment_labeler.predict(output[0], padding_mask)
    target_labels = self.target_labeler.predict(output[1], padding_mask)
    holder_labels = self.holder_labeler.predict(output[2], padding_mask)

    batch_size = padding_mask.shape[0]
    nodes_batch = []
    for i in range(batch_size):

      sentiment_nodes = BIOSchemeTokenClassification.extract_spans(sentiment_labels[i])
      sentiment_nodes = [SentimentGraphNode(*tup) for tup in sentiment_nodes]

      target_nodes = BIOSchemeTokenClassification.extract_spans(target_labels[i])
      target_nodes = [SentimentGraphNode(*tup) for tup in target_nodes]

      holder_nodes = BIOSchemeTokenClassification.extract_spans(holder_labels[i])
      holder_nodes = [SentimentGraphNode(*tup) for tup in holder_nodes]

      nodes_batch.append(SentimentGraphNodeSet(sentiment_nodes, target_nodes,
                                               holder_nodes))
    return nodes_batch

## Polarity Predictor

In [None]:
class PolarityPredictor(nn.Module):

  def __init__(self):
    self.net = nn.ModuleList([nn.Linear(768, 64), nn.Dropout(), nn.ReLU(),
                              nn.Linear(64, 16)])
    self.cls = nn.Linear(16, 3)


  def forward(self, ptm_output: BaseModelOutput,
              graphs_nodes: List[SentimentGraphNodeSet]) -> List[Tensor]:

    last_hidden_state_batch = ptm_output.last_hidden_state
    batch_size = last_hidden_state_batch.shape[0]
    fun_sum_over_span_tokens = functools.partial(torch.sum, dim=0)

    preds_batch = []
    for i in range(batch_size):
      last_hidden_state = last_hidden_state_batch[i]
      nodes = graphs_nodes[i].sentiment_nodes

      nodes_h = torch.cat([last_hidden_state[node.span_in_sentence] for node in nodes])
      for layer in self.net:
        nodes_h = layer(nodes_h)

      nodes_h = torch.split(nodes_h, [node.span_length for node in nodes])
      nodes_h = torch.stack([fun_sum_over_span_tokens(node_h) for node_h in nodes_h])

      preds_batch.append(self.cls(F.relu(nodes_h)))

    return preds_batch

## Edge Predictor

In [None]:
# Originally from <https://discuss.pytorch.org/t/does-nn-sigmoid-have-bias-parameter/10561/5>
class BiasedSigmoid(nn.Module):

  def __init__(self, device=None):
    super(BiasedSigmoid, self).__init__()
    self.bias = nn.Parameter(2 * torch.rand(1, device=device) - 1)


  def forward(self, input: Tensor) -> Tensor:
    return torch.sigmoid(input + self.bias)

In [None]:
class AttentionistSentimentGraphEdgePredictor(nn.Module):

  def __init__(self):
    super(AttentionistSentimentGraphEdgePredictor, self).__init__()
    self.edge_predictor = BiasedSigmoid()


  def forward(self, ptm_output: BaseModelOutput,
              graphs_nodes: List[SentimentGraphNodeSet]) -> List[SentimentGraphEdgeSet]:

    attentions_batch = ptm_output.attentions
    batch_size = len(graphs_nodes)
    graphs_edge_probs = []
    for i in range(batch_size):
      attentions = attentions_batch[i]

      graph_nodes = graphs_nodes[i]

      sentim_tgt_edges = []
      for sentim_node, tgt_node in itertools.product(graph_nodes.sentiment_nodes,
                                                     graph_nodes.target_nodes):
        sentim_tgt_edges.append(
            attentions[sentim_node.span_in_sentence, tgt_node.span_in_sentence].sum() + \
            attentions[tgt_node.span_in_sentence, sentim_node.span_in_sentence].sum()
        )
      if len(sentim_tgt_edges) == 0:
        sentim_tgt_edges = None
      else:
        sentim_tgt_edges = torch.stack(sentim_tgt_edges).view(len(graph_nodes.sentiment_nodes),
                                                              len(graph_nodes.target_nodes))
        sentim_tgt_edges = self.edge_predictor(sentim_tgt_edges)

      sentim_hld_edges = []
      for sentim_node, hld_node in itertools.product(graph_nodes.sentiment_nodes,
                                                     graph_nodes.holder_nodes):
        sentim_hld_edges.append(
            attentions[sentim_node.span_in_sentence, hld_node.span_in_sentence].sum() + \
            attentions[hld_node.span_in_sentence, sentim_node.span_in_sentence].sum()
        )
      if len(sentim_hld_edges) == 0:
        sentim_hld_edges = None
      else:  
        sentim_hld_edges = torch.stack(sentim_hld_edges).view(len(graph_nodes.sentiment_nodes),
                                                            len(graph_nodes.holder_nodes))
        sentim_hld_edges = self.edge_predictor(sentim_hld_edges)

      graphs_edge_probs.append(SentimentGraphEdgeSet(sentim_tgt_edges, sentim_hld_edges))

    return graphs_edge_probs


  @torch.no_grad()
  def predict(self, output: List[SentimentGraphEdgeSet]) -> List[SentimentGraphEdgeSet]:
    edges = []
    for graph_E in output:
      tgt_edges = graph_E.sentiment_target_edges
      tgt_edges = tgt_edges.round().int() if tgt_edges is not None else None

      hld_edges = graph_E.sentiment_holder_edges
      hld_edges = hld_edges.round().int() if hld_edges is not None else None

      edges.append(SentimentGraphEdgeSet(tgt_edges, hld_edges))

    return edges


## Model Base - RoBERTa | BERT

In [None]:
class ModelBase(nn.Module):

  def __init__(self):
    super(ModelBase, self).__init__()

    if DATASET_NAME in ['opener_en', 'mpqa', 'darmstadt_unis']:
      self.ptm = RobertaModel.from_pretrained('roberta-base', output_attentions=True,
                                              add_pooling_layer=False)
    else:
      self.ptm = BertModel.from_pretrained('setu4993/LaBSE',
                                           output_attentions=True, add_pooling_layer=False)


  def forward(self, ptm_input: Dict[str, Any]) -> BaseModelOutput:
    ptm_output = self.ptm(**ptm_input)

    if not ptm_output.attentions:
      raise ValueError(
          "Attentions should be included in model input. You might have forgotten "
          "to set `output_attentions=True` on constructing PTM or calling it."
      )

    return ptm_output

## Base Pooler

In [None]:
class BasePooler(nn.Module):

  def __init__(self, pool_all_heads_attention=False):
    super(BasePooler, self).__init__()

    self.pool_all_heads_attention = pool_all_heads_attention
    if pool_all_heads_attention:
      initial_attention_weights = torch.zeros(12, 12).float()

      if DATASET_NAME in ['opener_en', 'mpqa', 'darmstadt_unis']:
        initial_attention_weights[7, 6] = 1.
      else:
        initial_attention_weights[10, 8] = 1.

      self.attention_weights = nn.Parameter(initial_attention_weights)


  def forward(self, ptm_output: BaseModelOutput) -> BaseModelOutput:

    ptm_output.last_hidden_state = ptm_output.last_hidden_state[:, 1:-1, :] # Excluding CLS,SEP

    if not self.pool_all_heads_attention:
      if DATASET_NAME in ['opener_en', 'mpqa', 'darmstadt_unis']:
        ptm_output.attentions = ptm_output.attentions[7][:, 6] # Head 7 (pink) of layer 8
      else:
        ptm_output.attentions = ptm_output.attentions[10][:, 8] # Head 9 (green) of layer 11
    else:
      ptm_output.attentions = torch.tensordot(torch.stack(ptm_output.attentions),
                                              self.attention_weights, dims=([0,2],[0,1]))
      # Alternatively:
      # ptm_output.attentions = torch.einsum(
      #     'lbhmn, lh -> bmn', torch.stack(ptm_output.attentions), self.attention_weights
      # )
    
    ptm_output.attentions = ptm_output.attentions[:, 1:-1, 1:-1] # Excluding CLS,SEP

    return ptm_output

## Whole Model

In [None]:
class StructuredSentimentPredictor(nn.Module):

  def __init__(self, pool_all_heads_attention=False):
    super(StructuredSentimentPredictor, self).__init__()
    self.base = ModelBase()
    # self.polarity_predictor = PolarityPredictor()
    self.novelty = nn.ModuleDict({'base_pooler': BasePooler(pool_all_heads_attention),
                                  'edge_predictor': AttentionistSentimentGraphEdgePredictor(),
                                  'node_predictor': SentimentGraphNodeExtractor()})

  @dataclass
  class Output:
    node_extractor_output: Tuple[Tensor, Tensor, Tensor]
    edge_predictor_output: List[SentimentGraphEdgeSet]


  def forward(self, inputs: Dict[str, Any]) -> Output:

    base_output = self.base(inputs["ptm_input"])

    base_output = self.novelty['base_pooler'](base_output)

    node_predictor_out = self.novelty['node_predictor'](base_output)

    graphs_nodes = inputs["graphs_nodes"]

    edge_predictor_out = self.novelty['edge_predictor'](base_output, graphs_nodes)

    return StructuredSentimentPredictor.Output(
        node_predictor_out, edge_predictor_out)


  @torch.no_grad()
  def predict(self, inputs: Dict[str, Any], output: Output) -> List[SentimentGraph]:

    graphs_nodes = self.novelty['node_predictor'].predict(
      output.node_extractor_output, inputs['padding_mask'])

    base_output = self.base(inputs["ptm_input"])
    base_output = self.novelty['base_pooler'](base_output)

    graphs_edges_probs = self.novelty['edge_predictor'](base_output, graphs_nodes)
    graphs_edges = self.novelty['edge_predictor'].predict(graphs_edges_probs)

    return [SentimentGraph(graph_nodes, graph_edges) for graph_nodes,
            graph_edges in zip(graphs_nodes, graphs_edges)]


# Loss functions

## Node Extractor Loss - Cross Entropy

In [None]:
def node_extractor_loss(y_pred: StructuredSentimentPredictor.Output,
                        y: Dict[str,Any], weighted=False) -> torch.double:
  sentim_labels_pred, tgt_labels_pred, hld_labels_pred = y_pred.node_extractor_output
  sentim_labels, tgt_labels, hld_labels = y['seq_labels']

  sentim_labels_loss = F.cross_entropy(
      sentim_labels_pred.permute(0, 2, 1), sentim_labels,
      ignore_index=-1, weight=y["seq_label_weights"]["sentiment"] if weighted else None)
  tgt_labels_loss = F.cross_entropy(
      tgt_labels_pred.permute(0, 2, 1), tgt_labels,
      ignore_index=-1, weight=y["seq_label_weights"]["target"] if weighted else None)
  hld_labels_loss = F.cross_entropy(
      hld_labels_pred.permute(0, 2, 1), hld_labels,
      ignore_index=-1, weight=y["seq_label_weights"]["holder"] if weighted else None)

  return sentim_labels_loss + tgt_labels_loss + hld_labels_loss

## Edge Extractor Loss - Binary Cross Entropy

In [None]:
def edge_extractor_loss(y_pred: StructuredSentimentPredictor.Output,
                        y: Dict[str,Any]) -> torch.double:
  adj_matrices_pred_batch: List[SentimentGraphEdgeSet] = y_pred.edge_predictor_output
  adj_matrices_batch: List[SentimentGraphEdgeSet] = y["graphs_edges"]

  loss = torch.zeros((), device=DEVICE)

  for edges_pred, edges in zip(adj_matrices_pred_batch, adj_matrices_batch):

    target_edges_pred = edges_pred.sentiment_target_edges
    target_edges = edges.sentiment_target_edges
    if target_edges_pred is not None:
      loss += F.binary_cross_entropy(target_edges_pred, target_edges)

    holder_edges_pred = edges_pred.sentiment_holder_edges
    holder_edges = edges.sentiment_holder_edges
    if holder_edges_pred is not None:
      loss += F.binary_cross_entropy(holder_edges_pred, holder_edges)

  return loss

## Total Loss

In [None]:
def loss_fn(y_pred: StructuredSentimentPredictor.Output,
                        y: Dict[str,Any], seq_label_weighted=SEQ_LABEL_WEIGHTED) -> torch.double:
  return node_extractor_loss(y_pred, y, weighted=seq_label_weighted) + \
         edge_extractor_loss(y_pred, y)

# Train & Evaluation

## Data Loaders

In [None]:
tokenizer = get_tokenizer(DATASET_NAME)

In [None]:
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:

  texts = [item['text'] for item in batch]
  ptm_input = tokenizer(texts, padding=True, return_tensors='pt')
  padding_mask = tokenizer(texts, padding=True, return_tensors='pt',
                           add_special_tokens=False)["attention_mask"]

  sentim_labels = pad_sequence([torch.LongTensor(
      SentExprTokenEnum.from_labels(item["seq_labels"]["sentiment"])
      ) for item in batch], batch_first=True, padding_value=-1)
  target_labels = pad_sequence([torch.LongTensor(
      SentTargetTokenEnum.from_labels(item["seq_labels"]["target"])
      ) for item in batch], batch_first=True, padding_value=-1)
  holder_labels = pad_sequence([torch.LongTensor(
      SentHolderTokenEnum.from_labels(item["seq_labels"]["holder"])
      ) for item in batch], batch_first=True, padding_value=-1)

  seq_labels = (sentim_labels, target_labels, holder_labels)

  graphs_edges = [item["graph"].edges for item in batch]
  graphs_nodes = [item["graph"].nodes for item in batch]

  graphs = [item["graph"] for item in batch]

  return {"seq_labels": seq_labels, "graphs": graphs,
          "seq_label_weights": batch[0]["label_weights"],
          "graphs_nodes": graphs_nodes, "graphs_edges": graphs_edges,
          "ptm_input": ptm_input, "padding_mask": padding_mask}


In [None]:
BATCH_SIZE = 32

In [None]:
train_dataloader = DataLoader(train_ds, collate_fn=collate_fn, batch_size=BATCH_SIZE)
dev_dataloader = DataLoader(dev_ds, collate_fn=collate_fn, batch_size= BATCH_SIZE * 2)

## Model & Optimizer

In [None]:
torch.cuda.empty_cache()

In [None]:
model = StructuredSentimentPredictor(pool_all_heads_attention=False).to(DEVICE)

optimizer_parameter_groups = [
  {'params': list(model.base.parameters())},
  {'params': list(model.novelty.parameters())}
]
optimizer = torch.optim.AdamW(optimizer_parameter_groups)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.decoder.weight', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Engines

In [None]:
def train_prepare_batch(batch: Dict[str, Any], device=DEVICE,
                        non_blocking=True) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  x = {"ptm_input": {
      "attention_mask": batch["ptm_input"]["attention_mask"].to(
          device, non_blocking=non_blocking),
      "input_ids": batch["ptm_input"]["input_ids"].to(
          device, non_blocking=non_blocking)},
       "graphs_nodes": batch["graphs_nodes"]}

  y = {"seq_labels": tuple(tensor.to(device,
                                non_blocking=non_blocking) for tensor in batch["seq_labels"]),
       "graphs_edges": [SentimentGraphEdgeSet(
           item.sentiment_target_edges.to(device, non_blocking=non_blocking),
           item.sentiment_holder_edges.to(device, non_blocking=non_blocking)
           ) for item in batch["graphs_edges"]
           ],
       "padding_mask": batch["padding_mask"].to(device, non_blocking=non_blocking),
       "seq_label_weights": {k: v.to(
           device, non_blocking=non_blocking) for k,v in batch["seq_label_weights"].items()}}

  return x, y

In [None]:
def evaluate_prepare_batch(batch: Dict[str, Any], device='cuda',
                        non_blocking=True) -> Tuple[Dict[str, Any], Dict[str, Any]]:

  padding_mask = batch["padding_mask"].to(device, non_blocking=non_blocking)

  x = {"ptm_input": {
        "attention_mask": batch["ptm_input"]["attention_mask"].to(
          device, non_blocking=non_blocking),
        "input_ids": batch["ptm_input"]["input_ids"].to(
          device, non_blocking=non_blocking)},
       "padding_mask": padding_mask, "graphs_nodes": batch["graphs_nodes"]}

  y = {"seq_labels": tuple(tensor.to(device,
                                non_blocking=non_blocking) for tensor in batch["seq_labels"]),
       "graphs": batch["graphs"],
       "graphs_edges": [SentimentGraphEdgeSet(
           item.sentiment_target_edges.to(device, non_blocking=non_blocking),
           item.sentiment_holder_edges.to(device, non_blocking=non_blocking)
           ) for item in batch["graphs_edges"]
           ], "padding_mask": padding_mask,
       "seq_label_weights": {k: v.to(
           device, non_blocking=non_blocking) for k,v in batch["seq_label_weights"].items()}}
  return x, y

In [None]:
def train_output_transform(x: Dict[str, Any], y: Dict[str, Any],
                              y_pred: StructuredSentimentPredictor.Output, loss: Tensor
) -> Dict[str, Any]:

  return {"y": y, "y_pred_raw": y_pred, "loss": loss}

In [None]:
# Submit issue on Ignite design about this
def evaluate_output_transform(x: Dict[str, Any], y: Dict[str, Any],
                              y_pred: StructuredSentimentPredictor.Output
) -> Dict[str, Any]:

  return {"y_pred": model.predict(x, y_pred), "y": y, "y_pred_raw": y_pred}

In [None]:
trainer = create_supervised_trainer(
    model, optimizer, functools.partial(loss_fn, seq_label_weighted=SEQ_LABEL_WEIGHTED),
    deterministic=True, device=DEVICE,
    prepare_batch=train_prepare_batch, output_transform=train_output_transform
    )

evaluator = create_supervised_evaluator(
    model, prepare_batch=evaluate_prepare_batch,
    device=DEVICE, output_transform=evaluate_output_transform
    )

train_evaluator = create_supervised_evaluator(
    model, prepare_batch=evaluate_prepare_batch,
    device=DEVICE, output_transform=evaluate_output_transform
    )

NameError: ignored

## Metrics

In [None]:
class SeqLabelEntropy(Metric):
  def __init__(self, output_transform=lambda x: x):
    self.token_cnt = None
    self.entropies = None

    super(SeqLabelEntropy, self).__init__(output_transform=output_transform)


  def reset(self):
    self.token_cnt = {'sentiment': 0, 'target': 0, 'holder': 0}
    self.entropies = {'sentiment': 0., 'target': 0., 'holder': 0.}
    super(SeqLabelEntropy, self).reset()


  def update(self, output: Tuple[StructuredSentimentPredictor.Output, Dict[str, Any]]):
    sentim_labels, tgt_labels, hld_labels = output[0].node_extractor_output

    padding_mask = output[1]["padding_mask"].cpu().numpy()

    sentim_labels = sentim_labels.detach().cpu().softmax(dim=2)
    self.entropies['sentiment'] += (scipy.stats.entropy(sentim_labels,
                                                        axis=2) * padding_mask).sum()
    self.token_cnt['sentiment'] += sentim_labels.size()[:2].numel()
  
    tgt_labels = tgt_labels.detach().cpu().softmax(dim=2)
    self.entropies['target'] += (scipy.stats.entropy(tgt_labels, axis=2) * padding_mask).sum()
    self.token_cnt['target'] += tgt_labels.size()[:2].numel()
  
    hld_labels = hld_labels.detach().cpu().softmax(dim=2)
    self.entropies['holder'] += (scipy.stats.entropy(hld_labels, axis=2) * padding_mask).sum()
    self.token_cnt['holder'] += hld_labels.size()[:2].numel()


  def compute(self) -> Dict[str, float]:
    return {(k.capitalize() + " Label Entropy"): v / self.token_cnt[k] for k, v \
            in self.entropies.items()}


In [None]:
class MatchLogicBase:
  def is_match(self, sentim_overlap: int, target_overlap: int, holder_overlap: int,
               is_polarity_matched: bool) -> bool:
    ...


  def weighted_match(self,
                    sentim_node: SentimentGraphNode, target_neighbors: List[SentimentGraphNode],
                    holder_neighbors: List[SentimentGraphNode],
                    sentim_overlap: int, target_overlap: int, holder_overlap: int,
                    ) -> float:
    ...


  def _union_length(self, nodes: List[SentimentGraphNode]) -> int:
    if len(nodes) == 0:
      return 1
    return len(set(itertools.chain(*map(attrgetter('indices'), nodes))))

In [None]:
class GraphMatchLogic(MatchLogicBase):

  def __init__(self, keep_polarity: bool=True):
    self.keep_polarity = keep_polarity


  def is_match(self, sentim_overlap: int, target_overlap: int, holder_overlap: int,
               is_polarity_matched: bool) -> bool:
    return sentim_overlap > 0 and target_overlap > 0 and holder_overlap > 0 and \
              (is_polarity_matched or not self.keep_polarity)


  def weighted_match(self,
                    sentim_node: SentimentGraphNode, target_neighbors: List[SentimentGraphNode],
                    holder_neighbors: List[SentimentGraphNode],
                    sentim_overlap: int, target_overlap: int, holder_overlap: int,
                    ) -> float:
    return (sentim_overlap / float(sentim_node.span_length) + \
            target_overlap / float(self._union_length(target_neighbors)) + \
            holder_overlap / float(self._union_length(holder_neighbors))
            ) / 3.

In [None]:
class SentimentMatchLogic(MatchLogicBase):

  def __init__(self, keep_polarity: bool=True):
    self.keep_polarity = keep_polarity


  def is_match(self, sentim_overlap: int, target_overlap: int, holder_overlap: int,
               is_polarity_matched: bool) -> bool:
    return sentim_overlap > 0 and (is_polarity_matched or not self.keep_polarity)


  def weighted_match(self,
                    sentim_node: SentimentGraphNode, target_neighbors: List[SentimentGraphNode],
                    holder_neighbors: List[SentimentGraphNode],
                    sentim_overlap: int, target_overlap: int, holder_overlap: int,
                    ) -> float:
    return sentim_overlap / float(sentim_node.span_length) 

In [None]:
class TargetMatchLogic(MatchLogicBase):
  def is_match(self, sentim_overlap: int, target_overlap: int, holder_overlap: int,
               is_polarity_matched: bool) -> bool:
    return target_overlap > 0


  def weighted_match(self,
                    sentim_node: SentimentGraphNode, target_neighbors: List[SentimentGraphNode],
                    holder_neighbors: List[SentimentGraphNode],
                    sentim_overlap: int, target_overlap: int, holder_overlap: int,
                    ) -> float:
    return target_overlap / float(self._union_length(target_neighbors)) 

In [None]:
class HolderMatchLogic(MatchLogicBase):
  def is_match(self, sentim_overlap: int, target_overlap: int, holder_overlap: int,
               is_polarity_matched: bool) -> bool:
    return holder_overlap > 0


  def weighted_match(self,
                    sentim_node: SentimentGraphNode, target_neighbors: List[SentimentGraphNode],
                    holder_neighbors: List[SentimentGraphNode],
                    sentim_overlap: int, target_overlap: int, holder_overlap: int,
                    ) -> float:
    return holder_overlap / float(self._union_length(holder_neighbors)) 

In [None]:
class MetricBase(Metric):

  def __init__(self, match_logic: MatchLogicBase, weighted=True, output_transform=lambda x:x):
    self.weighted = weighted
    self.match_logic = match_logic

    super(MetricBase, self).__init__(output_transform=output_transform)


  def extract_sentim_nodes_neighbors(
    self, graph: SentimentGraph) -> List[Tuple[List[SentimentGraphNode], ...]]:
    edges_sentim_tgt = graph.edges.sentiment_target_edges
    if edges_sentim_tgt is not None:
      edges_sentim_tgt = edges_sentim_tgt.detach().cpu()

    edges_sentim_hld = graph.edges.sentiment_holder_edges
    if edges_sentim_hld is not None:
      edges_sentim_hld = edges_sentim_hld.detach().cpu()

    sentim_nodes_neighbors = []
    for i, sentim_node in enumerate(graph.nodes.sentiment_nodes):
      assert sentim_node.tag is not None

      if edges_sentim_tgt is not None:
        target_neighbors = [
          graph.nodes.target_nodes[j] for j in edges_sentim_tgt[i].nonzero()]
      else:
        target_neighbors = []

      if edges_sentim_hld is not None:
        holder_neighbors = [
          graph.nodes.holder_nodes[j] for j in edges_sentim_hld[i].nonzero()]
      else:
        holder_neighbors = []

      sentim_nodes_neighbors.append((target_neighbors, holder_neighbors))

    return sentim_nodes_neighbors


  def node_neighbors_overlap_length(self, nodes1: List[SentimentGraphNode],
                      nodes2: List[SentimentGraphNode]) -> int:
    if len(nodes1) == 0 and len(nodes2) == 0:
      return 1

    return len(set(itertools.chain(*map(attrgetter('indices'), nodes1))).intersection(
        set(itertools.chain(*map(attrgetter('indices'), nodes2)))
    ))


  def update(self, output: Tuple[List[SentimentGraph], List[SentimentGraph]]):

    y_pred, y = output

    for p_graph, g_graph in zip(y_pred, y):

      p_nodes_neighbors = self.extract_sentim_nodes_neighbors(p_graph)
      g_nodes_neighbors = self.extract_sentim_nodes_neighbors(g_graph)

      intersect_matrix = torch.zeros((len(p_graph.nodes.sentiment_nodes),
                                         len(g_graph.nodes.sentiment_nodes)), dtype=int)

      weighted_intersect_matrix_for_p = torch.zeros(len(p_graph.nodes.sentiment_nodes),
                                        len(g_graph.nodes.sentiment_nodes))
      weighted_intersect_matrix_for_g = torch.zeros(len(g_graph.nodes.sentiment_nodes),
                                        len(p_graph.nodes.sentiment_nodes))

      for i, p_sentim_node in enumerate(p_graph.nodes.sentiment_nodes):
        for j, g_sentim_node in enumerate(g_graph.nodes.sentiment_nodes):

          p_target_neighbors, p_holder_neighbors = p_nodes_neighbors[i]
          g_target_neighbors, g_holder_neighbors = g_nodes_neighbors[j]

          sentim_overlap = self.node_neighbors_overlap_length([p_sentim_node], [g_sentim_node])
          target_overlap = self.node_neighbors_overlap_length(p_target_neighbors, g_target_neighbors)
          holder_overlap = self.node_neighbors_overlap_length(p_holder_neighbors, g_holder_neighbors)

          if self.match_logic.is_match(sentim_overlap, target_overlap, holder_overlap,
                                       g_sentim_node.tag == p_sentim_node.tag):
          
            intersect_matrix[i, j] = 1

            if self.weighted:
              weighted_intersect_matrix_for_p[i, j] = \
                self.match_logic.weighted_match(p_sentim_node, p_target_neighbors,
                                                p_holder_neighbors,
                                                sentim_overlap, target_overlap, holder_overlap)

              weighted_intersect_matrix_for_g[j, i] = \
                self.match_logic.weighted_match(g_sentim_node, g_target_neighbors,
                                                g_holder_neighbors,
                                                sentim_overlap, target_overlap, holder_overlap)

      self.update_(intersect_matrix, weighted_intersect_matrix_for_g,
                   weighted_intersect_matrix_for_p)


In [None]:
class Precision(MetricBase):
  def __init__(self, match_logic: MatchLogicBase, weighted=True, output_transform=lambda x:x):
    self.weighted_tp = None
    self.tp = None
    self.fp = None

    super(Precision, self).__init__(match_logic, weighted=weighted,
                                    output_transform=output_transform)


  def reset(self):
      self.weighted_tp = 0.
      self.tp = 0
      self.fp = 0

      super(Precision, self).reset()


  def update_(self, intersect_matrix: Tensor, weighted_intersect_matrix_for_g: Tensor,
               weighted_intersect_matrix_for_p: Tensor):
    self.tp += intersect_matrix.any(dim=1).sum().item()
    self.fp += (~ intersect_matrix.any(dim=1)).sum().item()

    if self.weighted:
      self.weighted_tp += weighted_intersect_matrix_for_p.max(
          dim=1).values.sum().item()


  def compute(self) -> float:
    if self.weighted:
      return self.weighted_tp / (self.tp + self.fp + 1e-10)
    else:
      return self.tp / (self.tp + self.fp + 1e-10)

In [None]:
class Recall(MetricBase):

  def __init__(self, match_logic: MatchLogicBase, weighted=True, output_transform=lambda x:x):
    self.weighted_tp = None
    self.tp = None
    self.fn = None

    super(Recall, self).__init__(match_logic, weighted=weighted,
                                 output_transform=output_transform)


  def reset(self):
    self.weighted_tp = 0.
    self.tp = 0
    self.fn = 0

    super(Recall, self).reset()


  def update_(self, intersect_matrix: Tensor, weighted_intersect_matrix_for_g: Tensor,
               weighted_intersect_matrix_for_p: Tensor):

    self.tp += intersect_matrix.any(dim=0).sum().item()
    self.fn += (~ intersect_matrix.any(dim=0)).sum().item()

    if self.weighted:
      try:
        self.weighted_tp += weighted_intersect_matrix_for_g.max(
            dim=1).values.sum().item()
      except IndexError:
        pass


  def compute(self) -> float:
    if self.weighted:
      return self.weighted_tp / (self.tp + self.fn + 1e-10)
    else:
      return self.tp / (self.tp + self.fn + 1e-10)

In [None]:
def F1(match_logic: MatchLogicBase, weighted=True, output_transform=lambda x: x) -> MetricBase:
  p = Precision(match_logic, weighted=weighted, output_transform=output_transform)
  r = Recall(match_logic, weighted=weighted, output_transform=output_transform)
  return 2 * p * r / (p + r + 1e-10)

In [None]:
@torch.no_grad()
def loss_metric(input):
  y_pred, y = input["y_pred_raw"], input["y"]
  return loss_fn(y_pred, y, seq_label_weighted=SEQ_LABEL_WEIGHTED).item()

In [None]:
@torch.no_grad()
def acc_output_transform(inputs: Dict[str, Any]):
  return inputs['y_pred'], inputs['y']['graphs']

In [None]:
def loss_output_transform(inputs):
  return inputs["loss"]

In [None]:
train_loss = Average(output_transform=loss_output_transform)
eval_loss = Average(output_transform=loss_metric)
running_loss = RunningAverage(output_transform=loss_output_transform)

In [None]:
graph_f1 = F1(GraphMatchLogic(), output_transform=acc_output_transform)

# sentiment_f1 = F1(SentimentMatchLogic(), output_transform=acc_output_transform)

# target_f1 = F1(TargetMatchLogic(), output_transform=acc_output_transform)

# holder_f1 = F1(HolderMatchLogic(), output_transform=acc_output_transform)

In [None]:
label_entropy = SeqLabelEntropy(output_transform=lambda out: (out['y_pred_raw'], out['y']))

In [None]:
train_loss.attach(trainer, 'Loss')

graph_f1.attach(train_evaluator, 'Graph F1')
# sentiment_f1.attach(train_evaluator, 'Sentiment F1')
# target_f1.attach(train_evaluator, 'Target F1')
# holder_f1.attach(train_evaluator, 'Holder F1')


eval_loss.attach(evaluator, 'Loss')

graph_f1.attach(evaluator, 'Graph F1')
# sentiment_f1.attach(evaluator, 'Sentiment F1')
# target_f1.attach(evaluator, 'Target F1')
# holder_f1.attach(evaluator, 'Holder F1')


running_loss.attach(trainer, 'Running Loss')

In [None]:
label_entropy.attach(train_evaluator, 'Label Entropy')
label_entropy.attach(evaluator, 'Label Entropy')

## Safety Measure (Terminate on NaN)

In [None]:
trainer.add_event_handler(Events.ITERATION_COMPLETED,
                          TerminateOnNan(output_transform=itemgetter('loss')))

<ignite.engine.events.RemovableEventHandle at 0x7f7ab1557a90>

## Logging (WandB & Tqdm)

In [None]:
note = 'SemEval2022 Task 10 Subtask 1'
wandb_logger = WandBLogger(entity='sadra-barikbin',
                           project='ABSA',
                           name='Attentionist-Opener_En',
                           tags=['Opener_En', 'Warm-Up', 'StepLR'],
                           notes=note, resume=True)

wandb_logger.attach_opt_params_handler(trainer, event_name=Events.EPOCH_COMPLETED, 
                                       optimizer=optimizer, sync=False)

wandb_logger.attach_output_handler(trainer, event_name=Events.EPOCH_COMPLETED,
                                   tag="training", metric_names=["Loss"], sync=False)

metric_names = ["Graph F1"]#, "Sentiment F1", "Target F1", "Holder F1"]

wandb_logger.attach_output_handler(train_evaluator, event_name=Events.COMPLETED,
                                   global_step_transform=global_step_from_engine(trainer),
                                   tag="training", metric_names=metric_names, sync=False)

wandb_logger.attach_output_handler(evaluator, event_name=Events.COMPLETED,
                                   global_step_transform=global_step_from_engine(trainer),
                                   tag="evaluation", metric_names=["Loss"] + metric_names,
                                   sync=False)

<ignite.engine.events.RemovableEventHandle at 0x7f7ab17c3c90>

In [None]:
@trainer.on(Events.COMPLETED)
def close_logger():
  wandb_logger.close()

In [None]:
# Ignite bug? check if could be duplicately attached
pbar = ProgressBar()
pbar.attach(trainer,metric_names=['Running Loss'])

## LR Scheduling

### StepLR on model.novelty

In [None]:
from torch.optim.optimizer import Optimizer
from ignite.handlers.param_scheduler import ParamScheduler
from typing import Optional, Union

In [None]:
class StepParamScheduler(ParamScheduler):

  def __init__(self, optimizer: Optimizer, param_name: str, gamma: float = 0.1 ,
        step_size: int = 1, save_history: bool = False,
        param_group_index: Optional[int] = None):
    super(StepParamScheduler, self).__init__(optimizer, param_name, save_history = save_history,
                                          param_group_index = param_group_index)
    if step_size <= 0:
        raise ValueError(
                f"Argument step_size should be greater than zero, but given {step_size}"
        )
    self.gamma = gamma
    self.step_size = step_size
    self.current_step = step_size - 1

    self._state_attrs += ['gamma', 'step_size', 'current_step']

  def get_param(self) -> Union[List[float], float]:
    gamma = 1 if self.current_step != 0 else self.gamma

    self.current_step -= 1
    if self.current_step == -1:
        self.current_step += self.step_size

    if len(self.optimizer_param_groups) == 1:
      return self.optimizer_param_groups[0][self.param_name] * gamma
    else:
      return [pg[self.param_name] * gamma for pg in self.optimizer_param_groups]

In [None]:
lr_scheduler = StepParamScheduler(optimizer, 'lr', gamma=0.1, step_size=9,
                                  save_history=True, param_group_index=1)
scheduler_handler = trainer.add_event_handler(
    Events.EPOCH_COMPLETED, lr_scheduler)

wandb_logger.config['LR_scheduler'] = 'StepLR(gamma=0.5)'

### Just Warm-Up

In [None]:
base_scheduler = PiecewiseLinear(optimizer, "lr",
                                 milestones_values=[(1, 1e-6), (len(train_ds) // BATCH_SIZE, 1e-4)],
                                 param_group_index=0)

novelty_scheduler = PiecewiseLinear(optimizer, "lr",
                                 milestones_values=[(1, 1e-5), (len(train_ds) // BATCH_SIZE, 1e-3)],
                                 param_group_index=1)
event_filter = lambda _, event: event < (len(train_ds) // BATCH_SIZE) + 1
scheduler1_handler = trainer.add_event_handler(Events.ITERATION_STARTED(event_filter),
                                               base_scheduler)
scheduler2_handler = trainer.add_event_handler(Events.ITERATION_STARTED(event_filter),
                                               novelty_scheduler)

wandb_logger.config['LR_scheduler'] = 'Linear Warm-Up + StepLR(gamma=0.5)'

## Early Stopping

In [None]:
stopper = EarlyStopping(patience=6, score_function=lambda engine: engine.state.metrics['Graph F1'],
                        trainer=trainer)
stopper_handle = evaluator.add_event_handler(Events.COMPLETED, stopper)

## Checkpointing

In [None]:
# /tmp/run
checkpointer = ModelCheckpoint(wandb_logger.run.dir, 'Attentionist_Opener_En',
                               score_name='Graph F1', n_saved=2, require_empty=False)
# How can I do it every 2 evaluator epochs ?
evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model})

<ignite.engine.events.RemovableEventHandle at 0x7f7ab1884610>

## Run

In [None]:
@trainer.on(Events.EPOCH_COMPLETED)
def evaluate():
  evaluator.run(dev_dataloader)
  train_evaluator.run(train_dataloader)

In [None]:
wandb_logger.config['random_seed'] = 41
trainer.run(train_dataloader,max_epochs=30 )

# Inference

## Load Best Model

In [None]:
!wandb login --cloud <API_KEY>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
dataset_names_map = {'opener_en': 'Opener_En', 'opener_es': 'Opener_Es', 'norec': 'Norec',
                     'darmstadt_unis': 'Darmstadt_unis', 'multibooked_eu': 'Multibooked_eu',
                     'multibooked_ca': 'Multibooked_ca', 'mpqa': 'MPQA'}

In [None]:
api = wandb.Api()
runs = api.runs('sadra-barikbin/ABSA', order='summary_metrics.evaluation/SF1',
                filters={'tags': dataset_names_map[DATASET_NAME]})
model_path = list(list(runs)[0].files())[1].download()
state_dict = torch.load(model_path.name)

In [None]:
list(list(runs)[0].files())

[<File Attentionist_Darmstadt_unis_model_SF1=0.5069.pt (application/vnd.snesdev-page-table) 473.9MiB>,
 <File Attentionist_Darmstadt_unis_model_SF1=0.5193.pt (application/vnd.snesdev-page-table) 473.9MiB>,
 <File config.yaml () 514.0B>,
 <File output.log (text/plain; charset=utf-8) 104.0B>,
 <File requirements.txt (text/plain; charset=utf-8) 7.1KiB>,
 <File wandb-metadata.json (application/json) 676.0B>,
 <File wandb-summary.json (application/json) 515.0B>]

In [None]:
model = StructuredSentimentPredictor()
model.load_state_dict(state_dict)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.decoder.weight', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.weight', 'roberta.pooler.dense.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

In [None]:
tokenizer = get_tokenizer(DATASET_NAME)

## Dataset, DataLoader and Engine

In [None]:
def test_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:

  texts = [item['text'] for item in batch]
  ptm_input = tokenizer(texts, padding=True, return_tensors='pt')
  ptm_input['input_ids'] = ptm_input['input_ids'].to(DEVICE)
  ptm_input['attention_mask'] = ptm_input['attention_mask'].to(DEVICE)
  if 'token_type_ids' in ptm_input:
    ptm_input['token_type_ids'] = ptm_input['token_type_ids'].to(DEVICE)
  padding_mask = tokenizer(texts, padding=True, return_tensors='pt',
                           add_special_tokens=False)["attention_mask"].to(DEVICE)

  return {"pred_template": batch, "ptm_input": ptm_input, "padding_mask": padding_mask}


In [None]:
test_ds = SemEval2022Task10Dataset(DATASET_NAME, 'test')
test_dataloader = DataLoader(test_ds, batch_size=BATCH_SIZE * 2, collate_fn=test_collate_fn)

In [None]:
@torch.no_grad()
def test_engine_process_function(engine: Engine, batch: Dict[str, Any]) -> List[Dict[str, Any]]:

  model.eval()

  base_output = model.base(batch["ptm_input"])
  base_output = model.novelty['base_pooler'](base_output)
  node_predictor_out = model.novelty['node_predictor'](base_output)
  output: List[SentimentGraph] = model.predict(
      batch, StructuredSentimentPredictor.Output(node_predictor_out, None)
  )

  pred_template = batch['pred_template']

  tokens_char_offsets_batch = tokenizer([item['text'] for item in pred_template],
                                  add_special_tokens=False, return_offsets_mapping=True,
                                  return_length=True)['offset_mapping']
  
  for sent_idx in range(len(output)):
    sentim_nodes = output[sent_idx].nodes.sentiment_nodes
    tgt_nodes = output[sent_idx].nodes.target_nodes
    hld_nodes = output[sent_idx].nodes.holder_nodes

    edges = output[sent_idx].edges

    text = pred_template[sent_idx]['text']

    for i in range(len(sentim_nodes)):
      opinion = {}
      
      node_tokens_char_offsets = tokens_char_offsets_batch[
                                                      sent_idx][sentim_nodes[i].span_in_sentence]
      node_char_idx_begin = node_tokens_char_offsets[0][0]
      node_char_idx_end = node_tokens_char_offsets[-1][1]

      opinion['Polar_expression'] = [[text[node_char_idx_begin:node_char_idx_end]],
                                     [f"{node_char_idx_begin}:{node_char_idx_end}"]]
      
      tgt_neighbors_exprs = []
      tgt_neighbors_char_spans = []
      if edges.sentiment_target_edges is not None:
        for tgt_node_idx in edges.sentiment_target_edges[i].nonzero():
          neighbor_tgt_node = tgt_nodes[tgt_node_idx]

          node_tokens_char_offsets = tokens_char_offsets_batch[
                                                      sent_idx][neighbor_tgt_node.span_in_sentence]

          node_char_idx_begin = node_tokens_char_offsets[0][0]
          node_char_idx_end = node_tokens_char_offsets[-1][1]

          tgt_neighbors_exprs.append(text[node_char_idx_begin:node_char_idx_end])
          tgt_neighbors_char_spans.append(f"{node_char_idx_begin}:{node_char_idx_end}")
      opinion['Target'] = [tgt_neighbors_exprs, tgt_neighbors_char_spans]

      hld_neighbors_exprs = []
      hld_neighbors_char_spans = []
      if edges.sentiment_holder_edges is not None:
        for hld_node_idx in edges.sentiment_holder_edges[i].nonzero():
          neighbor_hld_node = hld_nodes[hld_node_idx]

          node_tokens_char_offsets = tokens_char_offsets_batch[
                                                      sent_idx][neighbor_hld_node.span_in_sentence]

          node_char_idx_begin = node_tokens_char_offsets[0][0]
          node_char_idx_end = node_tokens_char_offsets[-1][1]

          hld_neighbors_exprs.append(text[node_char_idx_begin:node_char_idx_end])
          hld_neighbors_char_spans.append(f"{node_char_idx_begin}:{node_char_idx_end}")
      opinion['Source'] = [hld_neighbors_exprs, hld_neighbors_char_spans]

      opinion['Polarity'] = sentim_nodes[i].tag

      pred_template[sent_idx]['opinions'].append(opinion)

  return pred_template


In [None]:
tester = Engine(test_engine_process_function)

In [None]:
output_store = EpochOutputStore()
output_store.attach(tester, 'outputs')

In [None]:
pbar = ProgressBar()
pbar.attach(tester)

In [None]:
tester.run(test_dataloader)

[1/39]   3%|2          [00:00<?]

State:
	iteration: 39
	epoch: 1
	epoch_length: 39
	max_epochs: 1
	output: <class 'list'>
	batch: <class 'dict'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>
	outputs: <class 'list'>

## Save Predictions

In [None]:
path = f"monolingual/{DATASET_NAME}"
os.makedirs(path)

with open(f"{path}/predictions.json", 'w') as f:
  json.dump(list(itertools.chain(*tester.state.outputs)), f)

In [None]:
!zip -r monolingual.zip monolingual

  adding: monolingual/ (stored 0%)
  adding: monolingual/opener_es/ (stored 0%)
  adding: monolingual/opener_es/predictions.json (deflated 83%)
  adding: monolingual/norec/ (stored 0%)
  adding: monolingual/norec/predictions.json (deflated 80%)
  adding: monolingual/multibooked_eu/ (stored 0%)
  adding: monolingual/multibooked_eu/predictions.json (deflated 84%)
  adding: monolingual/multibooked_ca/ (stored 0%)
  adding: monolingual/multibooked_ca/predictions.json (deflated 83%)


## Out of Dataset Example

In [None]:
@torch.no_grad()
def predict(text: Union[List[str], str]) -> SentimentGraph:

  if type(text) == str:
    text = [text]
  ptm_input = tokenizer(text, padding=True, return_tensors='pt')
  ptm_input["input_ids"] = ptm_input["input_ids"].to(DEVICE)
  ptm_input["attention_mask"] = ptm_input["attention_mask"].to(DEVICE)
  padding_mask = tokenizer(text, padding=True, return_tensors='pt',
                          add_special_tokens=False)["attention_mask"].to(DEVICE)
  inputs = {'ptm_input': ptm_input, 'padding_mask': padding_mask}
  model.eval()
  base_output = model.base(inputs["ptm_input"])
  node_predictor_out = model.novelty['node_predictor'](base_output)
  return model.predict(inputs, StructuredSentimentPredictor.Output(node_predictor_out, None))[0]

In [None]:
predict("I love this book!")

# Sweep

In [None]:
dataset_names_map = {'opener_en': 'Opener_En', 'opener_es': 'Opener_Es', 'norec': 'Norec',
                     'darmstadt_unis': 'Darmstadt_unis', 'multibooked_eu': 'Multibooked_eu',
                     'multibooked_ca': 'Multibooked_ca', 'mpqa': 'MPQA'}

In [None]:
for random_seed in [41, 666, 1567, 4447]:
  torch.manual_seed(random_seed)

  for dataset_name in ['opener_en', 'opener_es', 'mpqa', 'darmstadt_unis', 'norec',
                       'multibooked_ca', 'multibooked_eu']:
    if random_seed in [41, 666, 1567]:
      break
    if (dataset_name in ['opener_en', 'opener_es', 'mpqa',
                         'darmstadt_unis', 'norec', 'multibooked_ca']) and random_seed == 4447:
      continue

    DATASET_NAME = dataset_name

    tokenizer = get_tokenizer(DATASET_NAME)

    train_ds = SemEval2022Task10Dataset(DATASET_NAME, 'train')
    dev_ds = SemEval2022Task10Dataset(DATASET_NAME, 'dev')

    if dataset_name in ['opener_es', 'multibooked_ca', 'multibooked_eu']:
      BATCH_SIZE = 16
    else:
      BATCH_SIZE = 32

    train_dataloader = DataLoader(train_ds, collate_fn=collate_fn, batch_size=BATCH_SIZE)
    dev_dataloader = DataLoader(dev_ds, collate_fn=collate_fn, batch_size= BATCH_SIZE * 2)

    torch.cuda.empty_cache()
    model = StructuredSentimentPredictor(pool_all_heads_attention=False).to(DEVICE)

    optimizer_parameter_groups = [
      {'params': list(model.base.parameters())},
      {'params': list(model.novelty.parameters())}
    ]
    optimizer = torch.optim.AdamW(optimizer_parameter_groups)

    trainer = create_supervised_trainer(
        model, optimizer, functools.partial(loss_fn, seq_label_weighted=SEQ_LABEL_WEIGHTED),
        deterministic=True, device=DEVICE,
        prepare_batch=train_prepare_batch, output_transform=train_output_transform
        )

    evaluator = create_supervised_evaluator(
        model, prepare_batch=evaluate_prepare_batch,
        device=DEVICE, output_transform=evaluate_output_transform
        )

    train_evaluator = create_supervised_evaluator(
        model, prepare_batch=evaluate_prepare_batch,
        device=DEVICE, output_transform=evaluate_output_transform
        )
    
    train_loss = Average(output_transform=loss_output_transform)
    eval_loss = Average(output_transform=loss_metric)
    running_loss = RunningAverage(output_transform=loss_output_transform)
    graph_f1 = F1(GraphMatchLogic(), output_transform=acc_output_transform)

    train_loss.attach(trainer, 'Loss')
    graph_f1.attach(train_evaluator, 'Graph F1')
    eval_loss.attach(evaluator, 'Loss')
    graph_f1.attach(evaluator, 'Graph F1')
    running_loss.attach(trainer, 'Running Loss')

    trainer.add_event_handler(Events.ITERATION_COMPLETED,
                              TerminateOnNan(output_transform=itemgetter('loss')))
    
    wandb_logger = WandBLogger(entity='sadra-barikbin',
                              project='ABSA', group=dataset_names_map[DATASET_NAME],
                              name=f'Attentionist-{dataset_names_map[DATASET_NAME]}',
                              tags=[dataset_names_map[DATASET_NAME], 'Warm-Up', 'StepLR'],
                              notes='SemEval2022 Task 10 Subtask 1')

    wandb_logger.attach_opt_params_handler(trainer, event_name=Events.EPOCH_COMPLETED, 
                                          optimizer=optimizer, sync=True)

    wandb_logger.attach_output_handler(trainer, event_name=Events.EPOCH_COMPLETED,
                                      tag="training", metric_names=["Loss"], sync=True)

    wandb_logger.config['random_seed'] = random_seed
    wandb_logger.config['batch size'] = BATCH_SIZE

    metric_names = ["Graph F1"]

    wandb_logger.attach_output_handler(train_evaluator, event_name=Events.COMPLETED,
                                      global_step_transform=global_step_from_engine(trainer),
                                      tag="training", metric_names=metric_names, sync=True)

    wandb_logger.attach_output_handler(evaluator, event_name=Events.COMPLETED,
                                      global_step_transform=global_step_from_engine(trainer),
                                      tag="evaluation", metric_names=["Loss"] + metric_names,
                                      sync=True)
    @trainer.on(Events.COMPLETED)
    def close_logger():
      wandb_logger.close()
    
    pbar = ProgressBar()
    pbar.attach(trainer,metric_names=['Running Loss'])

    lr_scheduler = StepParamScheduler(optimizer, 'lr', gamma=0.1, step_size=9,
                                      save_history=True, param_group_index=1)
    scheduler_handler = trainer.add_event_handler(
        Events.EPOCH_COMPLETED, lr_scheduler)

    wandb_logger.config['LR_scheduler'] = 'StepLR(gamma=0.5)'

    base_scheduler = PiecewiseLinear(optimizer, "lr",
                                    milestones_values=[(1, 1e-6), (len(train_ds) // BATCH_SIZE, 1e-4)],
                                    param_group_index=0)

    novelty_scheduler = PiecewiseLinear(optimizer, "lr",
                                    milestones_values=[(1, 1e-5), (len(train_ds) // BATCH_SIZE, 1e-3)],
                                    param_group_index=1)
    event_filter = lambda _, event: event < (len(train_ds) // BATCH_SIZE) + 1
    scheduler1_handler = trainer.add_event_handler(Events.ITERATION_STARTED(event_filter),
                                                  base_scheduler)
    scheduler2_handler = trainer.add_event_handler(Events.ITERATION_STARTED(event_filter),
                                                  novelty_scheduler)

    wandb_logger.config['LR_scheduler'] = 'Linear Warm-Up + StepLR(gamma=0.5)'

    stopper = EarlyStopping(patience=6, score_function=lambda engine: engine.state.metrics['Graph F1'],
                            trainer=trainer)
    stopper_handle = evaluator.add_event_handler(Events.COMPLETED, stopper)

    checkpointer = ModelCheckpoint(wandb_logger.run.dir,
                                   f'Attentionist_{dataset_names_map[DATASET_NAME]}',
                                   score_name='Graph F1', n_saved=2, require_empty=False)
    evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model})

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate():
      evaluator.run(dev_dataloader)
      train_evaluator.run(train_dataloader)

    trainer.run(train_dataloader,max_epochs=20)

Some weights of the model checkpoint at setu4993/LaBSE were not used when initializing BertModel: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[34m[1mwandb[0m: Currently logged in as: [33msadra-barikbin[0m (use `wandb login --relogin` to force relogin)


[1/57]   2%|1          [00:00<?]

[1/57]   2%|1          [00:00<?]

[1/57]   2%|1          [00:00<?]

[1/57]   2%|1          [00:00<?]

[1/57]   2%|1          [00:00<?]

[1/57]   2%|1          [00:00<?]

[1/57]   2%|1          [00:00<?]

[1/57]   2%|1          [00:00<?]

[1/57]   2%|1          [00:00<?]

[1/57]   2%|1          [00:00<?]

[1/57]   2%|1          [00:00<?]

2022-03-01 13:36:16,279 ignite.handlers.early_stopping.EarlyStopping INFO: EarlyStopping: Stop training


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
evaluation/Graph F1,▁▆▇▇███████
evaluation/Loss,█▃▃▂▂▃▃▂▁▂▂
lr/group_0,▁▁▁▁▁▁▁▁▁▁▁
lr/group_1,█████████▁▁
training/Graph F1,▁▅▆▆▇▇▇▇█▇█
training/Loss,█▅▄▃▂▂▂▂▁▁▁

0,1
evaluation/Graph F1,0.61943
evaluation/Loss,7.27365
lr/group_0,0.0001
lr/group_1,0.0001
training/Graph F1,0.80496
training/Loss,2.82558
