#Config

In [6]:
class config:
    
    device = None
    seed = 27
    num_workers = 2
    prefetch_factor = 2
    fp16 = True
    warm_up = 0.1
    weight_decay = 0.01
    train_batch_size = 1
    eval_batch_size = 1
    train_epochs = 10
    gradient_accumulation_steps = 6
    adam_epsilon= 1e-6
    adam_betas = (0.9, 0.98)
    learning_rate= 1e-6
    max_grad_norm=0.0
    writer=False
    save_steps=773
    logging_steps=100
    max_step=1000000

    max_seq_length = 256
    load_examples_num_workers = 2

    # pretrained path
    pretrained_model_name_or_path = 'roberta-large'
    pretrained_model_name_or_path_cache = 'pretrained'

    # local paths
    train_data_path = '/content/dataset/train.json'
    val_data_path = '/content/dataset/val.json'
    test_data_path = '/content/dataset/test.json'
    output_path = 'content/output'
    tensor_cache_path = 'content/tensor/'

#Installs and Imports

In [7]:
import nltk
nltk.download('averaged_perceptron_tagger')
nltk.download('wordnet')
nltk.download('omw-1.4')
nltk.download('punkt')

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [8]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [9]:
# !git lfs install
# !git clone https://huggingface.co/roberta-large

In [10]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [11]:
import shutil
 
# Full path of
# the archive file
filename = "/content/drive/MyDrive/Thesis/reclor_dataset.zip"
 
# Target directory
extract_dir = "/content/"
 
# Format of archive file
archive_format = "zip"
 
# Unpack the archive file
shutil.unpack_archive(filename, extract_dir, archive_format)

print("Archive file unpacked successfully.")

Archive file unpacked successfully.


In [12]:
import os
import json
import torch
import logging
import random
import numpy as np
import sys


import torch
from torch import nn, Tensor
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, RandomSampler, TensorDataset, SequentialSampler
from torch.cuda.amp import GradScaler
from transformers import AutoTokenizer, get_linear_schedule_with_warmup, AdamW, PreTrainedTokenizer
from transformers.modeling_outputs import MultipleChoiceModelOutput
from transformers.tokenization_utils_base import PaddingStrategy, TruncationStrategy
from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaPreTrainedModel, RobertaConfig, RobertaLMHead

from collections import Counter
from functools import partial
from multiprocessing import Pool
from typing import Dict, List
from nltk import sent_tokenize
from tqdm import tqdm
from abc import ABC

if config.writer:
    from torch.utils.tensorboard import SummaryWriter
import copy
from typing import Optional, Any, Union, Callable

import torch
from torch import Tensor
from torch.nn import functional as F
from torch.nn.modules import Module
from torch.nn.modules.activation import MultiheadAttention
from torch.nn.modules.container import ModuleList
from torch.nn.init import xavier_uniform_
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm

logging.basicConfig(level=logging.INFO)


from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from torch.types import _dtype as DType
else:
    # The JIT doesn't understand Union, nor torch.dtype here
    DType = int

# Model Dependencies

In [13]:
def get_accuracy(logits, labels):
    assert logits.size()[:-1] == labels.size()

    _, pred = logits.max(dim=-1)
    true_label_num = (labels != -1).sum().item()
    correct = (pred == labels).sum().item()
    if true_label_num == 0:
        return 0, 0
    acc = correct * 1.0 / true_label_num
    return acc, true_label_num


class AverageMeter(object):
    """Computes and stores the average and current value."""

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        if isinstance(val, torch.Tensor):
            val = val.item()
        if isinstance(n, torch.Tensor):
            n = n.item()

        self.val = val
        self.sum += val * n
        self.count += n
        if self.count > 0:
            self.avg = self.sum / self.count
        else:
            self.avg = 0

    def save(self):
        return {
            'val': self.val,
            'avg': self.avg,
            'sum': self.sum,
            'count': self.count
        }

    def load(self, value: dict):
        if value is None:
            self.reset()
        self.val = value['val'] if 'val' in value else 0
        self.avg = value['avg'] if 'avg' in value else 0
        self.sum = value['sum'] if 'sum' in value else 0
        self.count = value['count'] if 'count' in value else 0
        
class LogMetric(object):
    """
    Record all metrics for logging.
    """

    def __init__(self, *metric_names):

        self.metrics = {
            key: AverageMeter() for key in metric_names
        }

    def update(self, metric_name, val, n=1):

        self.metrics[metric_name].update(val, n)

    def reset(self, metric_name=None):
        if metric_name is None:
            for key in self.metrics.keys():
                self.metrics[key].reset()
            return

        self.metrics[metric_name].reset()

    def get_log(self):

        log = {
            key: self.metrics[key].avg for key in self.metrics
        }
        return log

class LogMixin:
    eval_metrics: LogMetric = None

    def init_metric(self, *metric_names):
        self.eval_metrics = LogMetric(*metric_names)

    def get_eval_log(self, reset=False):
        if self.eval_metrics is None:
            print("The `eval_metrics` attribute hasn't been initialized.")

        results = self.eval_metrics.get_log()

        _eval_metric_log = '\t'.join([f"{k}: {v}" for k, v in results.items()])

        if reset:
            self.eval_metrics.reset()

        return _eval_metric_log, results

In [14]:
def _canonical_mask(
        mask: Optional[Tensor],
        mask_name: str,
        other_type: Optional[DType],
        other_name: str,
        target_type: DType,
        check_other: bool = True,
) -> Optional[Tensor]:

    if mask is not None:
        _mask_dtype = mask.dtype
        _mask_is_float = torch.is_floating_point(mask)
        if _mask_dtype != torch.bool and not _mask_is_float:
            raise AssertionError(
                f"only bool and floating types of {mask_name} are supported")
        if check_other and other_type is not None:
            if _mask_dtype != other_type:
                warnings.warn(
                    f"Support for mismatched {mask_name} and {other_name} "
                    "is deprecated. Use same type for both instead."
                )
        if not _mask_is_float:
            mask = (
                torch.zeros_like(mask, dtype=target_type)
                .masked_fill_(mask, float("-inf"))
            )
    return mask

def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
    if input is None:
        return None
    elif isinstance(input, torch.Tensor):
        return input.dtype
    raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")

# PowerNorm

In [15]:
# ref:https://github.com/sIncerass/powernorm/blob/master/fairseq/modules/norms/mask_powernorm.py
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : MaskPowerNorm.py
# Distributed under MIT License.


import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

__all__ = ['MaskPowerNorm']

def _sum_ft(tensor):
    """sum over the first and last dimention"""
    return tensor.sum(dim=0).sum(dim=-1)

class GroupScaling1D(nn.Module):
    r"""Scales inputs by the second moment for the entire layer.
    """
    def __init__(self, eps=1e-5, group_num=4):
        super(GroupScaling1D, self).__init__()
        self.eps = eps
        self.group_num = group_num

    def extra_repr(self):
        return f'eps={self.eps}, group={self.group_num}'

    def forward(self, input):
        # calculate second moment
        # different group use different mean
        T, B, C = input.shape[0], input.shape[1], input.shape[2]
        Cg = C // self.group_num
        gn_input = input.contiguous().reshape(T, B, self.group_num, Cg)
        moment2 = torch.repeat_interleave( torch.mean(gn_input * gn_input, dim=3, keepdim=True), \
            repeats=Cg, dim=-1).contiguous().reshape(T, B, C)
        # divide out second moment
        return input / torch.sqrt(moment2 + self.eps)

def _unsqueeze_ft(tensor):
    """add new dimensions at the front and the tail"""
    return tensor.unsqueeze(0).unsqueeze(-1)

class PowerFunction(torch.autograd.Function):
    @staticmethod

    def forward(ctx, x, weight, bias, running_phi, eps, afwd, abkw, ema_gz, \
                debug, warmup_iters, current_iter, mask_x):
        ctx.eps = eps
        ctx.debug = debug
        current_iter = current_iter.item()
        ctx.current_iter = current_iter
        ctx.warmup_iters = warmup_iters
        ctx.abkw = abkw
        rmax = 1
        N, C, H, W = x.size()
        x2 = (mask_x * mask_x).mean(dim=0)

       	var = x2.reshape(1, C, 1, 1)
        if current_iter <= warmup_iters:
            z = x /(var + eps).sqrt()
        else:
            z = x /(running_phi + eps).sqrt()
            
        y = z
        ctx.save_for_backward(z, var, weight, ema_gz)

        if current_iter < warmup_iters:
            running_phi.copy_(running_phi * (current_iter-1)/current_iter + var.mean(dim=0, keepdim=True)/current_iter)
        running_phi.copy_(afwd*running_phi + (1-afwd)*var.mean(dim=0, keepdim=True))
        y = weight.reshape(1,C,1,1) * y + bias.reshape(1,C,1,1)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps
        debug = ctx.debug
        current_iter = ctx.current_iter
        warmup_iters = ctx.warmup_iters
        abkw = ctx.abkw

        N, C, H, W = grad_output.size()
        z, var, weight, ema_gz = ctx.saved_variables

        y = z
        g = grad_output * weight.reshape(1, C, 1, 1)
        g = g * 1

        gz = (g * z).mean(dim=3).mean(dim=2).mean(dim=0)
        
        approx_grad_g = (g - (1 - abkw) * ema_gz * z)
        ema_gz.add_((approx_grad_g * z).mean(dim=3, keepdim=True).mean(dim=2, keepdim=True).mean(dim=0, keepdim=True))

        gx = 1. / torch.sqrt(var + eps) * approx_grad_g 
        return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(dim=0), \
         None, None, None, None, None, None, None, None, None, None

class MaskPowerNorm(nn.Module):
    """
    An implementation of masked batch normalization, used for testing the numerical
    stability.
    """

    def __init__(self, num_features, eps=1e-5, alpha_fwd=0.9, alpha_bkw=0.9, \
                affine=True, warmup_iters=10000, group_num=1):
        super().__init__()

        self.num_features = num_features
        self.eps = eps
        self.affine = affine

        self.register_parameter('weight', nn.Parameter(torch.ones(num_features)))
        self.register_parameter('bias', nn.Parameter(torch.zeros(num_features)))
        self.register_buffer('running_phi', torch.ones(1,num_features,1,1))
        self.register_buffer('ema_gz', torch.zeros(1,num_features,1,1))
        self.register_buffer('iters', torch.zeros(1).type(torch.LongTensor))

        self.afwd = alpha_fwd
        self.abkw = alpha_bkw

        self.eps = eps
        self.debug = False
        self.warmup_iters = warmup_iters
        self.gp = GroupScaling1D(group_num=group_num)
        self.group_num = group_num

    def extra_repr(self):
        return '{num_features}, eps={eps}, alpha_fwd={afwd}, alpha_bkw={abkw}, ' \
               'affine={affine}, warmup={warmup_iters}, group_num={group_num}'.format(**self.__dict__)

    def forward(self, input, pad_mask=None, is_encoder=False):
        """
        input:  T x B x C -> B x C x T
             :  B x C x T -> T x B x C
        pad_mask: B x T (padding is True)
        """
        shaped_input = (len(input.shape) == 2)
        if shaped_input:
            input = input.unsqueeze(0)
        T, B, C = input.shape
        input = self.gp(input)

        # construct the mask_input, size to be (BxL) x C: L is the real length here
        if pad_mask is None:
            mask_input = input.clone()
        else:
            # Transpose the bn_mask (B x T -> T x B)
            bn_mask = ~pad_mask
            bn_mask = bn_mask.transpose(0, 1)

        if pad_mask is not None:
            pad_size = (~bn_mask).sum()
            mask_input = input[bn_mask, :]
        else:
            mask_input = input.clone()

        mask_input = mask_input.reshape(-1, self.num_features)

        input = input.permute(1, 2, 0).contiguous()
        input_shape = input.size()
        input = input.reshape(input.size(0), self.num_features, -1)
        input = input.unsqueeze(-1)

        if self.training:
            self.iters.copy_(self.iters + 1)
            output = PowerFunction.apply(input, self.weight, self.bias, self.running_phi, self.eps, \
                        self.afwd, self.abkw, self.ema_gz, self.debug, self.warmup_iters, self.iters, mask_input)
            
        else:
            N, C, H, W = input.size()
            var = self.running_phi
            output = input / (var + self.eps).sqrt()
            output = self.weight.reshape(1,C,1,1) * output + self.bias.reshape(1,C,1,1)

        output = output.reshape(input_shape)
        output = output.permute(2, 0, 1).contiguous()
        # Reshape it.
        if shaped_input:
            output = output.squeeze(0)

        return output


#Transformer Model

In [16]:
import copy
from typing import Optional, Any, Union, Callable

import torch
from torch import Tensor
from torch.nn import functional as F
from torch.nn import Module
from torch.nn import MultiheadAttention
from torch.nn import ModuleList
from torch.nn.init import xavier_normal_
from torch.nn import Dropout
from torch.nn import Linear
from torch.nn import LayerNorm

__all__ = ['Transformer', 'TransformerEncoder', 'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer']


class Transformer(Module):
    r"""A transformer model. User is able to modify the attributes as needed. The architecture
    is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
    Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
    Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
    Processing Systems, pages 6000-6010.

    Args:
        d_model: the number of expected features in the encoder/decoder inputs (default=512).
        nhead: the number of heads in the multiheadattention models (default=8).
        num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
        num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of encoder/decoder intermediate layer, can be a string
            ("relu" or "gelu") or a unary callable. Default: relu
        custom_encoder: custom encoder (default=None).
        custom_decoder: custom decoder (default=None).
        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
        norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before
            other attention and feedforward operations, otherwise after. Default: ``False`` (after).

    Examples::
        >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
        >>> src = torch.rand((10, 32, 512))
        >>> tgt = torch.rand((20, 32, 512))
        >>> out = transformer_model(src, tgt)

    Note: A full example to apply nn.Transformer module for the word language model is available in
    https://github.com/pytorch/examples/tree/master/word_language_model
    """

    def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")

        if custom_encoder is not None:
            self.encoder = custom_encoder
        else:
            encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
                                                    activation, layer_norm_eps, batch_first, norm_first,
                                                    **factory_kwargs)
            encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
            self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        if custom_decoder is not None:
            self.decoder = custom_decoder
        else:
            decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
                                                    activation, layer_norm_eps, batch_first, norm_first,
                                                    **factory_kwargs)
            decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
            self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)

        self._reset_parameters()

        self.d_model = d_model
        self.nhead = nhead

        self.batch_first = batch_first


    def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Take in and process masked source/target sequences.

        Args:
            src: the sequence to the encoder (required).
            tgt: the sequence to the decoder (required).
            src_mask: the additive mask for the src sequence (optional).
            tgt_mask: the additive mask for the tgt sequence (optional).
            memory_mask: the additive mask for the encoder output (optional).
            src_key_padding_mask: the Tensor mask for src keys per batch (optional).
            tgt_key_padding_mask: the Tensor mask for tgt keys per batch (optional).
            memory_key_padding_mask: the Tensor mask for memory keys per batch (optional).

        Shape:
            - src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or
              `(N, S, E)` if `batch_first=True`.
            - tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
              `(N, T, E)` if `batch_first=True`.
            - src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`.
            - tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`.
            - memory_mask: :math:`(T, S)`.
            - src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
            - tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`.
            - memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.

            Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
            positions. If a BoolTensor is provided, positions with ``True``
            are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
            is provided, it will be added to the attention weight.
            [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
            the attention. If a BoolTensor is provided, the positions with the
            value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.

            - output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
              `(N, T, E)` if `batch_first=True`.

            Note: Due to the multi-head attention architecture in the transformer model,
            the output sequence length of a transformer is same as the input sequence
            (i.e. target) length of the decoder.

            where S is the source sequence length, T is the target sequence length, N is the
            batch size, E is the feature number

        Examples:
            >>> # xdoctest: +SKIP
            >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
        """

        is_batched = src.dim() == 3
        if not self.batch_first and src.size(1) != tgt.size(1) and is_batched:
            raise RuntimeError("the batch number of src and tgt must be equal")
        elif self.batch_first and src.size(0) != tgt.size(0) and is_batched:
            raise RuntimeError("the batch number of src and tgt must be equal")

        if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
            raise RuntimeError("the feature number of src and tgt must be equal to d_model")

        memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
                              tgt_key_padding_mask=tgt_key_padding_mask,
                              memory_key_padding_mask=memory_key_padding_mask)
        return output



    @staticmethod
    def generate_square_subsequent_mask(sz: int, device='cpu') -> Tensor:
        r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
            Unmasked positions are filled with float(0.0).
        """
        return torch.triu(torch.full((sz, sz), float('-inf'), device=device), diagonal=1)


    def _reset_parameters(self):
        r"""Initiate parameters in the transformer model."""

        for p in self.parameters():
            if p.dim() > 1:
                xavier_normal_(p)




class TransformerEncoder(Module):
    r"""TransformerEncoder is a stack of N encoder layers. Users can build the
    BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.

    Args:
        encoder_layer: an instance of the TransformerEncoderLayer() class (required).
        num_layers: the number of sub-encoder-layers in the encoder (required).
        norm: the layer normalization component (optional).
        enable_nested_tensor: if True, input will automatically convert to nested tensor
            (and convert back on output). This will improve the overall performance of
            TransformerEncoder when padding rate is high. Default: ``True`` (enabled).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
        >>> src = torch.rand(10, 32, 512)
        >>> out = transformer_encoder(src)
    """
    __constants__ = ['norm']

    def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True):
        super().__init__()
        torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
        self.enable_nested_tensor = enable_nested_tensor
        self.mask_check = mask_check


    def forward(
            self,
            src: Tensor,
            mask: Optional[Tensor] = None,
            src_key_padding_mask: Optional[Tensor] = None,
            is_causal: Optional[bool] = None) -> Tensor:
        r"""Pass the input through the encoder layers in turn.

        Args:
            src: the sequence to the encoder (required).
            mask: the mask for the src sequence (optional).
            is_causal: If specified, applies a causal mask as mask (optional)
                and ignores attn_mask for computing scaled dot product attention.
                Default: ``False``.
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        src_key_padding_mask = _canonical_mask(
            mask=src_key_padding_mask,
            mask_name="src_key_padding_mask",
            other_type= _none_or_dtype(mask),
            other_name="mask",
            target_type=src.dtype
        )

        output = src
        convert_to_nested = False
        first_layer = self.layers[0]
        src_key_padding_mask_for_layers = src_key_padding_mask
        why_not_sparsity_fast_path = ''
        str_first_layer = "self.layers[0]"
        if not isinstance(first_layer, torch.nn.TransformerEncoderLayer):
            why_not_sparsity_fast_path = f"{str_first_layer} was not TransformerEncoderLayer"
        elif first_layer.norm_first :
            why_not_sparsity_fast_path = f"{str_first_layer}.norm_first was True"
        elif first_layer.training:
            why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
        elif not first_layer.self_attn.batch_first:
            why_not_sparsity_fast_path = f" {str_first_layer}.self_attn.batch_first was not True"
        elif not first_layer.self_attn._qkv_same_embed_dim:
            why_not_sparsity_fast_path = f"{str_first_layer}.self_attn._qkv_same_embed_dim was not True"
        elif not first_layer.activation_relu_or_gelu:
            why_not_sparsity_fast_path = f" {str_first_layer}.activation_relu_or_gelu was not True"
        elif not (first_layer.norm1.eps == first_layer.norm2.eps) :
            why_not_sparsity_fast_path = f"{str_first_layer}.norm1.eps was not equal to {str_first_layer}.norm2.eps"
        elif not src.dim() == 3:
            why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
        elif not self.enable_nested_tensor:
            why_not_sparsity_fast_path = "enable_nested_tensor was not True"
        elif src_key_padding_mask is None:
            why_not_sparsity_fast_path = "src_key_padding_mask was None"
        elif (((not hasattr(self, "mask_check")) or self.mask_check)
                and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())):
            why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned"
        elif output.is_nested:
            why_not_sparsity_fast_path = "NestedTensor input is not supported"
        elif mask is not None:
            why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied"
        elif first_layer.self_attn.num_heads % 2 == 1:
            why_not_sparsity_fast_path = "num_head is odd"
        elif torch.is_autocast_enabled():
            why_not_sparsity_fast_path = "autocast is enabled"

        if not why_not_sparsity_fast_path:
            tensor_args = (
                src,
                first_layer.self_attn.in_proj_weight,
                first_layer.self_attn.in_proj_bias,
                first_layer.self_attn.out_proj.weight,
                first_layer.self_attn.out_proj.bias,
                first_layer.norm1.weight,
                first_layer.norm1.bias,
                first_layer.norm2.weight,
                first_layer.norm2.bias,
                first_layer.linear1.weight,
                first_layer.linear1.bias,
                first_layer.linear2.weight,
                first_layer.linear2.bias,
            )

            if torch.overrides.has_torch_function(tensor_args):
                why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
            elif not (src.is_cuda or 'cpu' in str(src.device)):
                why_not_sparsity_fast_path = "src is neither CUDA nor CPU"
            elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
                why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
                                              "input/output projection weights or biases requires_grad")

            if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None):
                convert_to_nested = True
                output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
                src_key_padding_mask_for_layers = None

        # Prevent type refinement
        make_causal = (is_causal is True)

        if is_causal is None:
            if mask is not None:
                sz = mask.size(0)
                causal_comparison = torch.triu(
                    torch.ones(sz, sz, device=mask.device) * float('-inf'), diagonal=1
                ).to(mask.dtype)

                if torch.equal(mask, causal_comparison):
                    make_causal = True

        is_causal = make_causal

        for mod in self.layers:
            output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)

        if convert_to_nested:
            output = output.to_padded_tensor(0.)

        if self.norm is not None:
            output = self.norm(output)

        return output




class TransformerDecoder(Module):
    r"""TransformerDecoder is a stack of N decoder layers

    Args:
        decoder_layer: an instance of the TransformerDecoderLayer() class (required).
        num_layers: the number of sub-decoder-layers in the decoder (required).
        norm: the layer normalization component (optional).

    Examples::
        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
        >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
        >>> memory = torch.rand(10, 32, 512)
        >>> tgt = torch.rand(20, 32, 512)
        >>> out = transformer_decoder(tgt, memory)
    """
    __constants__ = ['norm']

    def __init__(self, decoder_layer, num_layers, norm=None):
        super().__init__()
        torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm


    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the inputs (and mask) through the decoder layer in turn.

        Args:
            tgt: the sequence to the decoder (required).
            memory: the sequence from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        output = tgt

        for mod in self.layers:
            output = mod(output, memory, tgt_mask=tgt_mask,
                         memory_mask=memory_mask,
                         tgt_key_padding_mask=tgt_key_padding_mask,
                         memory_key_padding_mask=memory_key_padding_mask)

        if self.norm is not None:
            output = self.norm(output)

        return output



class TransformerEncoderLayer(Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of the intermediate layer, can be a string
            ("relu" or "gelu") or a unary callable. Default: relu
        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
        norm_first: if ``True``, layer norm is done prior to attention and feedforward
            operations, respectively. Otherwise it's done after. Default: ``False`` (after).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)

    Alternatively, when ``batch_first`` is ``True``:
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
        >>> src = torch.rand(32, 10, 512)
        >>> out = encoder_layer(src)

    Fast path:
        forward() will use a special optimized implementation described in
        `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following
        conditions are met:

        - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
          argument ``requires_grad``
        - training is disabled (using ``.eval()``)
        - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
        - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
        - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
        - if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
          nor ``src_key_padding_mask`` is passed
        - the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
          unless the caller has manually modified one without modifying the other)

        If the optimized implementation is in use, a
        `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
        passed for ``src`` to represent padding more efficiently than using a padding
        mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
        returned, and an additional speedup proportional to the fraction of the input that
        is padding can be expected.

        .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
         https://arxiv.org/abs/2205.14135

    """
    __constants__ = ['batch_first', 'norm_first']

    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
                                            **factory_kwargs)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)

        self.norm_first = norm_first
        # breakpoint()
        self.norm1 = LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        # Legacy string support for activation function.
        if isinstance(activation, str):
            activation = _get_activation_fn(activation)

        # We can't test self.activation in forward() in TorchScript,
        # so stash some information about it instead.
        if activation is F.relu or isinstance(activation, torch.nn.ReLU):
            self.activation_relu_or_gelu = 1
        elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
            self.activation_relu_or_gelu = 2
        else:
            self.activation_relu_or_gelu = 0
        self.activation = activation

    def __setstate__(self, state):
        super().__setstate__(state)
        if not hasattr(self, 'activation'):
            self.activation = F.relu



    def forward(
            self,
            src: Tensor,
            src_mask: Optional[Tensor] = None,
            src_key_padding_mask: Optional[Tensor] = None,
            is_causal: bool = False) -> Tensor:
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            is_causal: If specified, applies a causal mask as src_mask.
              Default: ``False``.
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        src_key_padding_mask = _canonical_mask(
            mask=src_key_padding_mask,
            mask_name="src_key_padding_mask",
            other_type=_none_or_dtype(src_mask),
            other_name="src_mask",
            target_type=src.dtype
        )

        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
        why_not_sparsity_fast_path = ''
        if not src.dim() == 3:
            why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
        elif self.training:
            why_not_sparsity_fast_path = "training is enabled"
        elif not self.self_attn.batch_first :
            why_not_sparsity_fast_path = "self_attn.batch_first was not True"
        elif not self.self_attn._qkv_same_embed_dim :
            why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
        elif not self.activation_relu_or_gelu:
            why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
        elif not (self.norm1.eps == self.norm2.eps):
            why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
        elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None):
            why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input"
        elif self.self_attn.num_heads % 2 == 1:
            why_not_sparsity_fast_path = "num_head is odd"
        elif torch.is_autocast_enabled():
            why_not_sparsity_fast_path = "autocast is enabled"
        if not why_not_sparsity_fast_path:
            tensor_args = (
                src,
                self.self_attn.in_proj_weight,
                self.self_attn.in_proj_bias,
                self.self_attn.out_proj.weight,
                self.self_attn.out_proj.bias,
                self.norm1.weight,
                self.norm1.bias,
                self.norm2.weight,
                self.norm2.bias,
                self.linear1.weight,
                self.linear1.bias,
                self.linear2.weight,
                self.linear2.bias,
            )

            # We have to use list comprehensions below because TorchScript does not support
            # generator expressions.
            if torch.overrides.has_torch_function(tensor_args):
                why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
            elif not all((x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args):
                why_not_sparsity_fast_path = "some Tensor argument is neither CUDA nor CPU"
            elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
                why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
                                              "input/output projection weights or biases requires_grad")

            if not why_not_sparsity_fast_path:
                merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src)
                return torch._transformer_encoder_layer_fwd(
                    src,
                    self.self_attn.embed_dim,
                    self.self_attn.num_heads,
                    self.self_attn.in_proj_weight,
                    self.self_attn.in_proj_bias,
                    self.self_attn.out_proj.weight,
                    self.self_attn.out_proj.bias,
                    self.activation_relu_or_gelu == 2,
                    self.norm_first,
                    self.norm1.eps,
                    self.norm1.weight,
                    self.norm1.bias,
                    self.norm2.weight,
                    self.norm2.bias,
                    self.linear1.weight,
                    self.linear1.bias,
                    self.linear2.weight,
                    self.linear2.bias,
                    merged_mask,
                    mask_type,
                )


        x = src
        dt_fix = 0.01646
        dt_fix_bool = True
        if dt_fix_bool:
            x = x + self._sa_block(x, src_mask, src_key_padding_mask)
            x = x + self._ff_block(x * dt_fix)
        else:
            x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
            x = self.norm2(x + self._ff_block(x))

        return x


    # self-attention block
    def _sa_block(self, x: Tensor,
                  attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
        x = self.self_attn(x, x, x,
                           attn_mask=attn_mask,
                           key_padding_mask=key_padding_mask,
                           need_weights=False)[0]
        return self.dropout1(x)

    # feed forward block
    def _ff_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout2(x)




class TransformerDecoderLayer(Module):
    r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
    This standard decoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of the intermediate layer, can be a string
            ("relu" or "gelu") or a unary callable. Default: relu
        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
        norm_first: if ``True``, layer norm is done prior to self attention, multihead
            attention and feedforward operations, respectively. Otherwise it's done after.
            Default: ``False`` (after).

    Examples::
        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
        >>> memory = torch.rand(10, 32, 512)
        >>> tgt = torch.rand(20, 32, 512)
        >>> out = decoder_layer(tgt, memory)

    Alternatively, when ``batch_first`` is ``True``:
        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
        >>> memory = torch.rand(32, 10, 512)
        >>> tgt = torch.rand(32, 20, 512)
        >>> out = decoder_layer(tgt, memory)
    """
    __constants__ = ['batch_first', 'norm_first']

    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
                                            **factory_kwargs)
        self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
                                                 **factory_kwargs)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)

        self.norm_first = norm_first
        # breakpoint()
        self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)
        self.dropout3 = Dropout(dropout)

        # Legacy string support for activation function.
        if isinstance(activation, str):
            self.activation = _get_activation_fn(activation)
        else:
            self.activation = activation

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super().__setstate__(state)


    def forward(
        self,
        tgt: Tensor,
        memory: Tensor,
        tgt_mask: Optional[Tensor] = None,
        memory_mask: Optional[Tensor] = None,
        tgt_key_padding_mask: Optional[Tensor] = None,
        memory_key_padding_mask: Optional[Tensor] = None,
        tgt_is_causal: bool = False,
        memory_is_causal: bool = False,
    ) -> Tensor:
        r"""Pass the inputs (and mask) through the decoder layer.

        Args:
            tgt: the sequence to the decoder layer (required).
            memory: the sequence from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).
            tgt_is_causal: If specified, applies a causal mask as tgt mask.
                Mutually exclusive with providing tgt_mask. Default: ``False``.
            memory_is_causal: If specified, applies a causal mask as tgt mask.
                Mutually exclusive with providing memory_mask. Default: ``False``.
        Shape:
            see the docs in Transformer class.
        """
        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf

        x = tgt
        if self.norm_first:
            x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)
            x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal)
            x = x + self._ff_block(self.norm3(x))
        else:
            x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
            x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))
            x = self.norm3(x + self._ff_block(x))

        return x


    # self-attention block
    def _sa_block(self, x: Tensor,
                  attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
        x = self.self_attn(x, x, x,
                           attn_mask=attn_mask,
                           key_padding_mask=key_padding_mask,
                           is_causal=is_causal,
                           need_weights=False)[0]
        return self.dropout1(x)

    # multihead attention block
    def _mha_block(self, x: Tensor, mem: Tensor,
                   attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
        x = self.multihead_attn(x, mem, mem,
                                attn_mask=attn_mask,
                                key_padding_mask=key_padding_mask,
                                is_causal=is_causal,
                                need_weights=False)[0]
        return self.dropout2(x)

    # feed forward block
    def _ff_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout3(x)



def _get_clones(module, N):
    # FIXME: copy.deepcopy() is not defined on nn.module
    return ModuleList([copy.deepcopy(module) for i in range(N)])


def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu

    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))

#Roberta Model

In [17]:
class RobertaForMultipleChoice(RobertaPreTrainedModel, LogMixin, ABC):
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def __init__(self, config: RobertaConfig,
                 re_init_cls: bool = False,
                 fs_checkpoint: bool = False,
                 fs_checkpoint_offload_to_cpu: bool = False,
                 fs_checkpoint_maintain_forward_counter: bool = False,
                 freeze_encoder: bool = False,
                 no_pooler: bool = False):
        super().__init__(config)

        self.roberta = RobertaModel(config)
        self.lm_head = RobertaLMHead(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.re_init_cls = re_init_cls
        if self.re_init_cls:
            self.classifier_i = nn.Linear(config.hidden_size, 1)
        self.classifier = nn.Linear(config.hidden_size, 1)
        self.no_pooler = no_pooler
        self.freeze_encoder = freeze_encoder
        print(self.freeze_encoder)
        if freeze_encoder:
            for param in self.roberta.parameters():
                param.requires_grad = False

        self.init_weights()
        self.max_norm = -sys.maxsize - 1
        self.init_metric("loss", "acc")
        self.transformer = TransformerEncoder(
            TransformerEncoderLayer(d_model=config.hidden_size, nhead=8),num_layers=2)

    @staticmethod
    def fold_tensor(x: Tensor):
        if x is None:
            return x
        return x.reshape(-1, x.size(-1))

    def forward(
            self,
            input_ids: Tensor,
            attention_mask: Tensor = None,
            token_type_ids: Tensor = None,
            labels: Tensor = None,
            sentence_index: Tensor = None,
            sentence_mask: Tensor = None,
            sent_token_mask: Tensor = None,
            mlm_labels: Tensor = None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
    ):

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        num_choices = input_ids.shape[1]

        input_ids = self.fold_tensor(input_ids)
        attention_mask = self.fold_tensor(attention_mask)
        token_type_ids = self.fold_tensor(token_type_ids)

        outputs = self.roberta(
            input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        if self.no_pooler:
            pooled_output = outputs[0][:, 0]
        else:
            pooled_output = outputs[1]
        
        pooled_output = self.transformer(pooled_output)

        pooled_output = self.dropout(pooled_output)
        self.max_norm = max(torch.max(pooled_output), self.max_norm)
        # print("#$#"*12)
        # print(self.max_norm)
        # print("#$#"*12)
        if self.re_init_cls:
            logits = self.classifier_i(pooled_output)
        else:
            logits = self.classifier(pooled_output)
        reshaped_logits = logits.view(-1, num_choices)

        choice_mask = (attention_mask.sum(dim=-1) == 0).reshape(-1, num_choices)
        reshaped_logits = reshaped_logits + choice_mask * -10000.0

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-1)
            loss = loss_fct(reshaped_logits, labels)

            if mlm_labels is not None:
                mlm_scores = self.lm_head(outputs[0])
                mlm_loss = loss_fct(mlm_scores.reshape(-1, self.config.vocab_size), mlm_labels.reshape(-1))
                loss += mlm_loss

            if not self.training:
                acc, true_label_num = get_accuracy(reshaped_logits, labels)
                self.eval_metrics.update("acc", val=acc, n=true_label_num)
                self.eval_metrics.update("loss", val=loss.item(), n=true_label_num)

        if not return_dict:
            output = (reshaped_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return MultipleChoiceModelOutput(
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

#Data Providers

In [18]:
def get_sep_tokens(_tokenizer):
    return [_tokenizer.sep_token] * (_tokenizer.max_len_single_sentence - _tokenizer.max_len_sentences_pair)


def is_bpe(_tokenizer: PreTrainedTokenizer):
    return _tokenizer.__class__.__name__ in [
        "RobertaTokenizer",
        "LongformerTokenizer",
        "BartTokenizer",
        "RobertaTokenizerFast",
        "LongformerTokenizerFast",
        "BartTokenizerFast",
    ]


def load_dataset(config, tokenizer, split='train'):
    if split == 'train':
        file_path = config.train_data_path
    elif split == 'val':
        file_path = config.val_data_path
    elif split == 'test':
        file_path = config.test_data_path
    else:
        raise Exception(split)

    examples, features, tensors = convert_examples_into_features(file_path=file_path,
                                                                 tokenizer=tokenizer,
                                                                 max_seq_length=config.max_seq_length,
                                                                 num_workers=config.load_examples_num_workers,
                                                                suffix=split)
    dataset = TensorDataset(*tensors)
    return dataset, features


def collator(batch):
    if len(batch[0]) == 5:
        input_ids, attention_mask, token_type_ids, labels, sentence_spans = list(zip(*batch))
    elif len(batch[0]) == 4:
        input_ids, attention_mask, labels, sentence_spans = list(zip(*batch))
        token_type_ids = None
    else:
        raise RuntimeError()

    input_ids = torch.stack(input_ids, dim=0)
    attention_mask = torch.stack(attention_mask, dim=0)
    labels = torch.stack(labels, dim=0)
    sentence_spans = torch.stack(sentence_spans, dim=0)

    batch, option_num, _, _ = sentence_spans.size()
    # [batch, option_num, max_sent_num]
    max_sent_len = (sentence_spans[:, :, :, 1] - sentence_spans[:, :, :, 0]).max().item()
    # [batch, option_num, max_sent_num]
    sent_mask = (sentence_spans[:, :, :, 0] != -1)
    # [batch, option_num]
    sent_num = sent_mask.sum(dim=2)
    b_max_sent_num = sent_num.max().item()
    sentence_spans = sentence_spans[:, :, :b_max_sent_num]
    sent_mask = sent_mask[:, :, :b_max_sent_num]

    sentence_index = torch.zeros(batch, option_num, b_max_sent_num, max_sent_len, dtype=torch.long)
    sent_token_mask = torch.zeros(batch, option_num, b_max_sent_num, max_sent_len, dtype=torch.long)
    for b_id, b_spans in enumerate(sentence_spans):
        for op_id, op_spans in enumerate(b_spans):
            for sent_id, span in enumerate(op_spans):
                s, e = span[0].item(), span[1].item()
                if s == -1:
                    break
                _len = e - s
                sentence_index[b_id, op_id, sent_id, :_len] = torch.arange(s, e, dtype=torch.long)
                sent_token_mask[b_id, op_id, sent_id, :_len] = 1

    outputs = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
        "sentence_index": sentence_index,
        "sentence_mask": sent_mask,
        "sent_token_mask": sent_token_mask
    }
    if token_type_ids is not None:
        outputs["token_type_ids"] = torch.stack(token_type_ids, dim=0)

    return outputs

def read_examples(file_path: str):
    data = json.load(open(file_path, 'r'))

    examples = []
    for sample in data:
        _context = sample["context"]
        _question = sample["question"]
        if "label" not in sample:
            _label = -1
        else:
            _label = sample["label"]
        examples.append({
            "context": _context,
            "question": _question,
            "options": sample["answers"],
            "label": _label
        })

    print(f"{len(examples)} examples are loaded from {file_path}.")
    return examples


def _convert_example_to_features(example, tokenizer, max_seq_length):
    context = example["context"]
    question = example["question"]
    context_sentences = [sent for sent in sent_tokenize(context) if sent]

    context_tokens = []
    for _sent_id, _sent in enumerate(context_sentences):
        _sent_tokens = tokenizer.tokenize(_sent)
        context_tokens.extend([(_sent_id, _tok) for _tok in _sent_tokens])

    _q_sent_id_offset = len(context_sentences)
    question_tokens = [(_q_sent_id_offset, _tok) for _tok in tokenizer.tokenize(question)]

    features = []
    for option in example["options"]:
        sep_tokens = get_sep_tokens(tokenizer)
        _op_sent_id_offset = _q_sent_id_offset + 1
        opt_tokens = [(_op_sent_id_offset, _tok) for _tok in tokenizer.tokenize(option)]

        lens_to_remove = len(context_tokens) + len(question_tokens) + len(opt_tokens) + len(sep_tokens) + (
                tokenizer.model_max_length - tokenizer.max_len_sentences_pair) - max_seq_length

        tru_c_tokens, tru_q_o_tokens, _ = tokenizer.truncate_sequences(context_tokens,
                                                                       question_tokens + sep_tokens + opt_tokens,
                                                                       num_tokens_to_remove=lens_to_remove,
                                                                       truncation_strategy=TruncationStrategy.LONGEST_FIRST)

        c_tokens, q_op_tokens = [], []
        sent_id_map = Counter()

        for _sent_id, _tok in tru_c_tokens:
            sent_id_map[_sent_id] += 1
            c_tokens.append(_tok)

        for _tok in tru_q_o_tokens:
            if isinstance(_tok, tuple):
                _sent_id, _tok = _tok
                q_op_tokens.append(_tok)
                sent_id_map[_sent_id] += 1
            elif isinstance(_tok, str):
                q_op_tokens.append(_tok)
            else:
                raise RuntimeError(_tok)

        sent_span_offset = 1
        sent_spans = []
        for i in range(len(context_sentences) + 2):
            if i == _q_sent_id_offset or i == _op_sent_id_offset:
                sent_span_offset += (tokenizer.max_len_single_sentence - tokenizer.max_len_sentences_pair)
            if i in sent_id_map:
                _cur_len = sent_id_map.pop(i)
                sent_spans.append((sent_span_offset, sent_span_offset + _cur_len))
                sent_span_offset += _cur_len
        assert not sent_id_map

        tokenizer_outputs = tokenizer(tokenizer.convert_tokens_to_string(c_tokens),
                                      text_pair=tokenizer.convert_tokens_to_string(q_op_tokens),
                                      padding=PaddingStrategy.MAX_LENGTH,
                                      max_length=max_seq_length)
        assert len(tokenizer_outputs["input_ids"]) == max_seq_length, (
        len(c_tokens), len(q_op_tokens), len(tokenizer_outputs["input_ids"]))
        features.append({
            "input_ids": tokenizer_outputs["input_ids"],
            "attention_mask": tokenizer_outputs["attention_mask"],
            "token_type_ids": tokenizer_outputs["token_type_ids"] if "token_type_ids" in tokenizer_outputs else None,
            "sentence_spans": sent_spans,
        })

    return {
        "features": features,
        "label": example["label"]
    }


def _data_to_tensors(features):
    data_num = len(features)
    option_num = len(features[0]["features"])

    input_ids = torch.tensor([[op["input_ids"] for op in f["features"]] for f in features])
    attention_mask = torch.tensor([[op["attention_mask"] for op in f["features"]] for f in features], dtype=torch.long)
    if features[0]["features"][0]["token_type_ids"] is not None:
        token_type_ids = torch.tensor([[op["token_type_ids"] for op in f["features"]] for f in features],
                                      dtype=torch.long)
    else:
        token_type_ids = None
    labels = torch.tensor([f["label"] for f in features], dtype=torch.long)

    # List[List[List[Tuple[int, int]]]]
    sentence_spans_ls = [[op["sentence_spans"] for op in f["features"]] for f in features]
    max_sent_num = 0
    for f in sentence_spans_ls:
        f_max_sent_num = max(map(len, f))
        max_sent_num = max(f_max_sent_num, max_sent_num)

    sentence_spans = torch.zeros(data_num, option_num, max_sent_num, 2, dtype=torch.long).fill_(-1)
    for f_id, f in enumerate(sentence_spans_ls):
        for op_id, op in enumerate(f):
            f_op_sent_num = len(op)
            sentence_spans[f_id, op_id, :f_op_sent_num] = torch.tensor(op, dtype=torch.long)

    if token_type_ids is not None:
        return input_ids, attention_mask, token_type_ids, labels, sentence_spans
    else:
        return input_ids, attention_mask, labels, sentence_spans


def convert_examples_into_features(file_path, tokenizer, max_seq_length, num_workers = 16, suffix=''):
    tokenizer_name = tokenizer.__class__.__name__
    tokenizer_name = tokenizer_name.replace('TokenizerFast', '')
    tokenizer_name = tokenizer_name.replace('Tokenizer', '').lower()

    file_suffix = f"{tokenizer_name}_{max_seq_length}_{suffix}"
    cached_file_path = config.tensor_cache_path + file_suffix

    if os.path.exists(cached_file_path):
        print(f"Loading cached file from {cached_file_path}")
        examples, features, tensors = torch.load(cached_file_path)
        return examples, features, tensors

    examples = read_examples(file_path)

    with Pool(num_workers) as p:
        _annotate = partial(_convert_example_to_features, tokenizer=tokenizer, max_seq_length=max_seq_length)
        features = list(tqdm(
            p.imap(_annotate, examples, chunksize=32),
            total=len(examples),
            desc='converting examples to features:'
        ))

    print("Transform features into tensors...")
    tensors = _data_to_tensors(features)

    print(f"Saving processed features into {cached_file_path}.")
    if not os.path.exists(config.tensor_cache_path):
        os.makedirs(config.tensor_cache_path)
    torch.save((examples, features, tensors), cached_file_path)

    return examples, features, tensors

In [19]:
def forward_and_backward(model, inputs, config, scaler):
    if config.fp16 and scaler:
        with torch.cuda.amp.autocast():
            outputs = model(**inputs)
            loss = outputs["loss"]
    else:
        outputs = model(**inputs)
        loss = outputs["loss"]

    if config.gradient_accumulation_steps > 1:
        loss = loss / config.gradient_accumulation_steps

    if scaler:
        scaler.scale(loss).backward()
    else:
        loss.backward()

    return loss.item()

def batch_to_device(batch, device):
    batch_on_device = {}
    for k, v in batch.items():
        batch_on_device[k] = v.to(device)
    return batch_on_device

def train_model(config, train_dataset,val_dataset, model, tokenizer, start_global_step):
    output_path_split = config.output_path.split('/')
    log_dir = '/'.join([output_path_split[0], 'runs'] + output_path_split[1:])
    if config.writer:
        writer = SummaryWriter(log_dir=log_dir)
    else:
        writer = None

    train_loader = DataLoader(dataset=train_dataset,
                              sampler=RandomSampler(train_dataset),
                              batch_size=config.train_batch_size,
                              collate_fn=collator,
                              num_workers=config.num_workers,
                              pin_memory=True,
                              prefetch_factor=config.prefetch_factor)
    
    val_loader = DataLoader(dataset=val_dataset,
                              sampler=RandomSampler(val_dataset),
                              batch_size=config.train_batch_size,
                              collate_fn=collator,
                              num_workers=config.num_workers,
                              pin_memory=True,
                              prefetch_factor=config.prefetch_factor)

    no_decay = ['bias', 'LayerNorm.weight', 'layer_norm.weight']
    grouped_parameters = [
        {
            'params': [p for n, p in model.named_parameters() if
                       (not any(nd in n for nd in no_decay)) and p.requires_grad],
            'weight_decay': config.weight_decay
        },
        {
            'params': [p for n, p in model.named_parameters() if (any(nd in n for nd in no_decay)) and p.requires_grad],
            'weight_decay': 0.0
        }
    ]

    total_steps = len(train_loader) // config.gradient_accumulation_steps * config.train_epochs

    optimizer = AdamW(grouped_parameters,
                      lr=config.learning_rate,
                      eps=config.adam_epsilon,
                      betas=config.adam_betas)
    import transformers
    scheduler = transformers.get_scheduler(
                        "linear",    # Create a schedule with a learning rate that decreases linearly 
                                     # from the initial learning rate set in the optimizer to 0.
                        optimizer = optimizer,
                        num_warmup_steps = 0,
                        num_training_steps = total_steps)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, min_lr = 0.005,patience=1)

    if config.fp16 and config.device.type == 'cuda':
        scaler = GradScaler()
    else:
        scaler = None

    print(optimizer)
    print("-- Start Training --")
    print("  Num examples = ", len(train_dataset))
    print("  Num Epochs = ", config.train_epochs)
    print("  Batch size = ", config.train_batch_size)
    print("  Gradient Accumulation steps = ", config.gradient_accumulation_steps)
    print("  Total optimization steps = ", total_steps)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)

    for epoch in range(config.train_epochs):
        for step, batch in enumerate(train_loader):

            if global_step < start_global_step:
                if (step + 1) % config.gradient_accumulation_steps == 0:
                    scheduler.step()
                    global_step += 1
                continue
                
            model.train()
            batch = batch_to_device(batch, config.device)
            
            loss = forward_and_backward(model, batch, config, scaler)
            tr_loss += loss

            if (step + 1) % config.gradient_accumulation_steps == 0:

                if scaler:
                    scaler.unscale_(optimizer)

                if config.max_grad_norm:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)

                if scaler:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()

                scheduler.step((tr_loss - logging_loss))
                model.zero_grad(set_to_none=True)
                global_step += 1

                

                if config.logging_steps > 0 and global_step % config.logging_steps == 0:
                    if config.writer:
                        writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                        writer.add_scalar('loss', (tr_loss - logging_loss) / config.logging_steps, global_step)
                    else:
                        print('gb_step={}, loss={}'.format(global_step, (tr_loss - logging_loss) / config.logging_steps))
                    logging_loss = tr_loss

                if config.save_steps > 0 and global_step % config.save_steps == 0:
                    output_dir = os.path.join(config.output_path, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                        
                    model.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)
                    print("Saving model checkpoint to ", output_dir)
                    
            if global_step >= config.max_step:
                break
        
        if global_step >= config.max_step:
            break
        model.eval()
        val_loss = 0.0
        val_acc = 0.0
        num_val_steps = 0
        correct_predictions = 0.0
        pred_list = []
        prob_list = []  
        
        for val_batch in (val_loader):
            batch = batch_to_device(val_batch, config.device)
            with torch.cuda.amp.autocast():
              with torch.no_grad():
                  outputs = model(**batch)
                  val_loss += outputs.loss.item()
                  num_val_steps = num_val_steps +1
                  predictions = outputs.logits.argmax(dim=-1)
                  targets = batch["labels"]
                  correct_predictions += (predictions == targets).sum().item()
        val_loss /= num_val_steps
        
        val_acc = correct_predictions / len(val_batch)
        print("VAL ACCURACY",val_acc)
        print("val_loss",val_loss)
    return global_step, tr_loss / global_step

In [20]:
def train_main(start_global_step=0):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    config.device = device
    
    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)

    tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model_name_or_path, cache_dir=config.pretrained_model_name_or_path_cache)
    model = RobertaForMultipleChoice.from_pretrained(config.pretrained_model_name_or_path, cache_dir=config.pretrained_model_name_or_path_cache)

    model.to(config.device)

    train_dataset, features = load_dataset(config, tokenizer=tokenizer, split='train')
    val_dataset, features_val = load_dataset(config, tokenizer=tokenizer, split='val')
    
    step, loss = train_model(config, train_dataset, val_dataset, model, tokenizer, start_global_step)
    print('Train finished, step: {}, loss: {}'.format(step, loss))

In [21]:
%%time
train_main()

Downloading (…)lve/main/config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.43G [00:00<?, ?B/s]

False


Some weights of RobertaForMultipleChoice were not initialized from the model checkpoint at roberta-large and are newly initialized: ['transformer.layers.0.norm1.weight', 'transformer.layers.0.self_attn.in_proj_weight', 'transformer.layers.0.self_attn.out_proj.weight', 'transformer.layers.1.norm1.bias', 'lm_head.decoder.bias', 'transformer.layers.1.norm2.weight', 'transformer.layers.1.linear2.bias', 'transformer.layers.0.self_attn.in_proj_bias', 'transformer.layers.0.linear2.bias', 'transformer.layers.1.linear1.weight', 'transformer.layers.0.self_attn.out_proj.bias', 'transformer.layers.1.self_attn.out_proj.bias', 'transformer.layers.1.self_attn.in_proj_weight', 'transformer.layers.0.linear1.bias', 'transformer.layers.1.self_attn.in_proj_bias', 'transformer.layers.0.norm1.bias', 'classifier.bias', 'transformer.layers.0.linear1.weight', 'classifier.weight', 'transformer.layers.0.norm2.bias', 'transformer.layers.0.linear2.weight', 'transformer.layers.1.self_attn.out_proj.weight', 'transfo

4638 examples are loaded from /content/dataset/train.json.


converting examples to features::   0%|          | 1/4638 [00:00<14:33,  5.31it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
converting exa

Transform features into tensors...
Saving processed features into content/tensor/roberta_256_train.
500 examples are loaded from /content/dataset/val.json.


converting examples to features:: 100%|██████████| 500/500 [00:01<00:00, 311.44it/s]


Transform features into tensors...
Saving processed features into content/tensor/roberta_256_val.
AdamW (
Parameter Group 0
    betas: (0.9, 0.98)
    correct_bias: True
    eps: 1e-06
    initial_lr: 1e-06
    lr: 1e-06
    weight_decay: 0.01

Parameter Group 1
    betas: (0.9, 0.98)
    correct_bias: True
    eps: 1e-06
    initial_lr: 1e-06
    lr: 1e-06
    weight_decay: 0.0
)
-- Start Training --
  Num examples =  4638
  Num Epochs =  10
  Batch size =  1
  Gradient Accumulation steps =  6
  Total optimization steps =  7730




gb_step=100, loss=1.3925063107907771
gb_step=200, loss=1.3863941572606564
gb_step=300, loss=1.3864710664749145
gb_step=400, loss=1.3998775558173657
gb_step=500, loss=1.3927386225759983
gb_step=600, loss=1.3887938837707043
gb_step=700, loss=1.402943127155304
Saving model checkpoint to  content/output/checkpoint-773
VAL ACCURACY 0.322
val_loss 1.3815461189746856
gb_step=800, loss=1.384882756769657
gb_step=900, loss=1.4003713808953762
gb_step=1000, loss=1.3978631748259067
gb_step=1100, loss=1.3919439755380154
gb_step=1200, loss=1.3958239254355431
gb_step=1300, loss=1.3962407697737218
gb_step=1400, loss=1.3805039994418622
gb_step=1500, loss=1.3705406035482883
Saving model checkpoint to  content/output/checkpoint-1546
VAL ACCURACY 0.338
val_loss 1.3449450380802155
gb_step=1600, loss=1.3659130029380322
gb_step=1700, loss=1.348380241766572
gb_step=1800, loss=1.3505318797379733
gb_step=1900, loss=1.3437963265180588
gb_step=2000, loss=1.295428769737482
gb_step=2100, loss=1.3079252706468105
gb_s

In [22]:
n = 5.3338

In [23]:
def evaluate(config, split="val", checkpoint_name=''):
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    config.device = device
    tokenizer = AutoTokenizer.from_pretrained(config.output_path + '/' + checkpoint_name)
    model = RobertaForMultipleChoice.from_pretrained(config.output_path + '/' + checkpoint_name)

    model.to(config.device)

    dataset, features = load_dataset(config, tokenizer, split=split)
    eval_sampler = SequentialSampler(dataset)
    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=config.eval_batch_size,
                                  collate_fn=collator)

    torch.cuda.empty_cache()
    print("***** Running evaluation {} on {} *****".format(split, checkpoint_name))
    print("  Num examples =", len(dataset))
    print("  Batch size =", config.eval_batch_size)

    model.eval()
    pred_list = []
    prob_list = []

    for batch in tqdm(eval_dataloader, desc="Evaluating", dynamic_ncols=True):
        batch = batch_to_device(batch, config.device)
        with torch.cuda.amp.autocast():
            with torch.no_grad():
                outputs = model(**batch)
                probs = outputs["logits"].softmax(dim=-1).detach().float().cpu()
                prob, pred = probs.max(dim=-1)
                pred_list.extend(pred.tolist())
                prob_list.extend(prob.tolist())
  
    metric_log, results = model.get_eval_log(reset=True)
    print("****** Evaluation Results ******")
    print(metric_log)
    
    prediction_file = os.path.join(config.output_path, "{checkpoint_name}_eval_predictions.npy")
    np.save(prediction_file, pred_list)
    json.dump(prob_list, open(os.path.join(config.output_path, "{checkpoint_name}_eval_probs.json"), "w"))
    return results

In [24]:
evaluate(config, checkpoint_name='checkpoint-7730')

False
Loading cached file from content/tensor/roberta_256_val
***** Running evaluation val on checkpoint-7730 *****
  Num examples = 500
  Batch size = 1


Evaluating: 100%|██████████| 500/500 [00:31<00:00, 15.69it/s]

****** Evaluation Results ******
loss: 1.1224138465886935	acc: 0.58





{'loss': 1.1224138465886935, 'acc': 0.58}