In [0]:
! pip install allennlp

In [0]:
import logging
from typing import Dict, Union
import tempfile

from allennlp.data.dataset_readers import SnliReader
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import Field
from allennlp.data.fields import LabelField
from allennlp.data.fields import TextField
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import TokenIndexer
from allennlp.data.tokenizers import Tokenizer
from allennlp.data.vocabulary import Vocabulary
from allennlp.models.model import Model
from allennlp.modules.token_embedders.bert_token_embedder import PretrainedBertModel
from allennlp.nn.initializers import InitializerApplicator
from allennlp.training.metrics import CategoricalAccuracy
from allennlp.common import Params
from allennlp.commands.train import train_model
from overrides import overrides

import torch
from pytorch_pretrained_bert.modeling import BertModel

# Prerequisites

## Downloading data

In [0]:
!wget -O snli_1.0.zip https://nlp.stanford.edu/projects/snli/snli_1.0.zip
!unzip snli_1.0.zip
!rm snli_1.0.zip

## Implementing SoftTripleLoss

In [0]:
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.nn import init

class SoftTriple(nn.Module):
    def __init__(self, la, gamma, tau, margin, dim, cN, K):
        super(SoftTriple, self).__init__()
        self.la = la
        self.gamma = 1./gamma
        self.tau = tau
        self.margin = margin
        self.cN = cN
        self.K = K
        self.fc = Parameter(torch.Tensor(dim, cN*K))
        self.weight = torch.zeros(cN*K, cN*K, dtype=torch.bool).cuda()
        for i in range(0, cN):
            for j in range(0, K):
                self.weight[i*K+j, i*K+j+1:(i+1)*K] = 1
        init.kaiming_uniform_(self.fc, a=math.sqrt(5))
        return

    def forward(self, input, target):
        centers = F.normalize(self.fc, p=2, dim=0)
        simInd = input.matmul(centers)
        simStruc = simInd.reshape(-1, self.cN, self.K)
        prob = F.softmax(simStruc*self.gamma, dim=2)
        simClass = torch.sum(prob*simStruc, dim=2)
        marginM = torch.zeros(simClass.shape).cuda()
        marginM[torch.arange(0, marginM.shape[0]), target] = self.margin
        lossClassify = F.cross_entropy(self.la*(simClass-marginM), target)
        if self.tau > 0 and self.K > 1:
            simCenter = centers.t().matmul(centers)
            reg = torch.sum(torch.sqrt(2.0+1e-5-2.*simCenter[self.weight]))/(self.cN*self.K*(self.K-1.))
            return lossClassify+self.tau*reg
        else:
            return lossClassify

## Implementing BertSnliReader

In [0]:
@DatasetReader.register("bert_snli")
class BertSnliReader(SnliReader):
    """
    Reads a file from the Stanford Natural Language Inference (SNLI) dataset.  This data is
    formatted as jsonl, one json-formatted instance per line.  The keys in the data are
    "gold_label", "sentence1", and "sentence2".  We convert these keys into fields named "label",
    and "tokens".
    Parameters
    ----------
    tokenizer : ``Tokenizer``, optional (default=``SpacyTokenizer()``)
        We use this ``Tokenizer`` for both the premise and the hypothesis.  See :class:`Tokenizer`.
    token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``)
        We similarly use this for both the premise and the hypothesis.  See :class:`TokenIndexer`.
    """
    def __init__(
        self,
        tokenizer: Tokenizer = None,
        token_indexers: Dict[str, TokenIndexer] = None,
        lazy: bool = False,
    ) -> None:
        super(BertSnliReader, self).__init__(tokenizer, token_indexers, lazy)

    @overrides
    def text_to_instance(
        self,  # type: ignore
        premise: str,
        hypothesis: str,
        label: str = None,
    ) -> Instance:

        fields: Dict[str, Field] = {}
        premise_tokens = self._tokenizer.tokenize(premise)
        hypothesis_tokens = self._tokenizer.tokenize(hypothesis)
        # Here, we join the premise with the hypothesis, dropping the CLS token from the hypothesis.        
        # This gives us our desired inputs: "[CLS] premise [SEP] hypothesis [SEP]"
        tokens = premise_tokens + hypothesis_tokens[1:]
        fields["tokens"] = TextField(tokens, self._token_indexers)
        if label:
            fields["label"] = LabelField(label)

        return Instance(fields)

## Implementing BERT with SoftTripleLoss




In [0]:
@Model.register("bert_for_classification_softtripleloss")
class BertForClassificationSortTripleLoss(Model):
    """
    An AllenNLP Model that runs pretrained BERT,
    takes the pooled output, and adds a Linear layer on top.
    If you want an easy way to use BERT for classification, this is it.
    Note that this is a somewhat non-AllenNLP-ish model architecture,
    in that it essentially requires you to use the "bert-pretrained"
    token indexer, rather than configuring whatever indexing scheme you like.
    See `allennlp/tests/fixtures/bert/bert_for_classification.jsonnet`
    for an example of what your config might look like.
    Parameters
    ----------
    vocab : ``Vocabulary``
    bert_model : ``Union[str, BertModel]``
        The BERT model to be wrapped. If a string is provided, we will call
        ``BertModel.from_pretrained(bert_model)`` and use the result.
    num_labels : ``int``, optional (default: None)
        How many output classes to predict. If not provided, we'll use the
        vocab_size for the ``label_namespace``.
    index : ``str``, optional (default: "bert")
        The index of the token indexer that generates the BERT indices.
    label_namespace : ``str``, optional (default : "labels")
        Used to determine the number of classes if ``num_labels`` is not supplied.
    trainable : ``bool``, optional (default : True)
        If True, the weights of the pretrained BERT model will be updated during training.
        Otherwise, they will be frozen and only the final linear layer will be trained.
    initializer : ``InitializerApplicator``, optional
        If provided, will be used to initialize the final linear layer *only*.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 bert_model: Union[str, BertModel],
                 dropout: float = 0.0,
                 num_labels: int = None,
                 index: str = "bert",
                 label_namespace: str = "labels",
                 trainable: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator()) -> None:
        super().__init__(vocab)

        if isinstance(bert_model, str):
            self.bert_model = PretrainedBertModel.load(bert_model)
        else:
            self.bert_model = bert_model

        self.bert_model.requires_grad = trainable

        in_features = self.bert_model.config.hidden_size

        if num_labels:
            out_features = num_labels
        else:
            out_features = vocab.get_vocab_size(label_namespace)

        self._dropout = torch.nn.Dropout(p=dropout)

        self._classification_layer = torch.nn.Linear(in_features, out_features)

        self._accuracy = CategoricalAccuracy()

        # Utilising SoftTripleLoss
        self._loss = SoftTriple(la = 20.0, gamma = 0.1, tau = 0.2, margin = 0.01, dim = 3, cN = 98, K = 10)

        self._index = index
        initializer(self._classification_layer)

    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            From a ``TextField`` (that has a bert-pretrained token indexer)
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            unnormalized log probabilities of the label.
        probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            probabilities of the label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        input_ids = tokens[self._index]
        token_type_ids = tokens[f"{self._index}-type-ids"]
        input_mask = (input_ids != 0).long()

        _, pooled = self.bert_model(input_ids=input_ids,
                                    token_type_ids=token_type_ids,
                                    attention_mask=input_mask)

        pooled = self._dropout(pooled)

        # apply classification layer
        logits = self._classification_layer(pooled)

        probs = torch.nn.functional.softmax(logits, dim=-1)

        output_dict = {"logits": logits, "probs": probs}

        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(logits, label)

        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does a simple argmax over the probabilities, converts index to string label, and
        add ``"label"`` key to the dictionary with the result.
        """
        predictions = output_dict["probs"]
        if predictions.dim() == 2:
            predictions_list = [predictions[i] for i in range(predictions.shape[0])]
        else:
            predictions_list = [predictions]
        classes = []
        for prediction in predictions_list:
            label_idx = prediction.argmax(dim=-1).item()
            label_str = self.vocab.get_token_from_index(label_idx, namespace="labels")
            classes.append(label_str)
        output_dict["label"] = classes
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics = {'accuracy': self._accuracy.get_metric(reset)}
        return metrics

# Putting all together

## Setting params

In [0]:
bert_model = "bert-base-uncased"
bert_stl_model = "bert_for_classification_softtripleloss"

all_params = Params({
    "dataset_reader": {
        "lazy": False,
        "type": "bert_snli",
        "tokenizer": {
            "type": "pretrained_transformer",
            "model_name": bert_model,
            "do_lowercase": True
        },
        "token_indexers": {
            "bert": {
                "type": "bert-pretrained",
                "pretrained_model": bert_model,
            }
        }
    },
    "train_data_path": "snli_1.0/snli_1.0_train.jsonl",
    "validation_data_path": "snli_1.0/snli_1.0_dev.jsonl",
    "model": {
        "type": bert_stl_model,
        "bert_model": bert_model,
        "dropout": 0.1,
        "num_labels": 3,
    },
    "iterator": {
        "type": "bucket",
        "sorting_keys": [["tokens", "num_tokens"]],
        "batch_size": 32
    },
    "trainer": {
        "optimizer": {
            "type": "bert_adam",
            "lr": 2e-5
        },
        "validation_metric": "+accuracy",
        "num_serialized_models_to_keep": 1,
        "num_epochs": 4,
        "grad_norm": 1.0,
        "cuda_device": 0
    }
})

## Training

In [0]:
serialization_dir = tempfile.mkdtemp()
model = train_model(all_params, serialization_dir)

0it [00:00, ?it/s]
23351it [00:10, 2335.05it/s]
48185it [00:20, 2377.65it/s]
73019it [00:30, 2382.17it/s]
98015it [00:40, 2416.22it/s]
123011it [00:50, 2425.25it/s]
147477it [01:01, 2397.77it/s]
174746it [01:11, 2487.84it/s]
202015it [01:22, 2445.99it/s]
225551it [01:33, 2381.18it/s]
252366it [01:43, 2463.96it/s]
279224it [01:53, 2526.55it/s]
306082it [02:05, 2419.83it/s]
332795it [02:15, 2490.15it/s]
359508it [02:28, 2354.73it/s]
386499it [02:38, 2448.43it/s]
413490it [02:48, 2516.51it/s]
440402it [02:58, 2565.17it/s]
467266it [03:11, 2365.15it/s]
493922it [03:21, 2447.91it/s]
520853it [03:31, 2516.63it/s]
547784it [03:41, 2552.07it/s]
549367it [03:42, 2470.05it/s]

0it [00:00, ?it/s]
9842it [00:08, 1210.22it/s]

0it [00:00, ?it/s]
559209it [00:02, 203983.01it/s]

  0%|          | 0/407873900 [00:00<?, ?B/s]
  0%|          | 17408/407873900 [00:00<1:26:56, 78185.22B/s]
  0%|          | 52224/407873900 [00:00<1:13:59, 91866.41B/s]
  0%|          | 87040/407873900 [00:00<1:04:56, 104654