# COVID 19: Using Novel Language Models to Effectively Identify Articles related to Therapeutics and Vaccines
* Team: MD-Lab, ASU
* Author: Ashwin Karthik Ambalavanan, Email: aambalav@asu.edu, Kaggle ID: ashwinambal96
* Team Members: Rishab Banerjee, Hong Guan, Jitesh Pabla, Mihir Parmar, Murthy Devarakonda
* Email ID: loccapollo@gmail.com, hguan6@asu.edu, jpabla1@asu.edu, mparmar3@asu.edu, Murthy.Devarakonda@asu.edu
* Kaggle ID: loccapollo, hongguan, jiteshpabla, mihir3031, murthydevarakonda
* This is a Team Submission
* Here are the links to our teams Kernels: <br>
[Using Novel Language Models and Web scraping to Effectively Identify Articles related to Therapeutics and Vaccines](https://www.kaggle.com/jiteshpabla/scoring-cord-19-using-google-training-on-scibert) <br>
[Using Novel Language Models and elasticsearch to Effectively Identify Articles related to Therapeutics and Vaccines](https://www.kaggle.com/jiteshpabla/classifying-cord-19-articles-using-elasticbert/) <br>
[COVID-19: BERT-based STS Method to Effectively Identify Articles related to Therapeutics and Vaccines](https://www.kaggle.com/mihir3031/bert-sts-for-searching-relevant-research-papers) <br>
[Using embeddings from BERTModel, BioBERT, BertForSequenceClassification to classify articles related to Vaccines and Therapeutics](https://www.kaggle.com/loccapollo/lexicon-based-similarity-scoring-with-bert-biobert) <br>
[Micro-scorers for COVID-19 Open Challenge](https://www.kaggle.com/hongguan/micro-scorers-for-covid-19-open-challenge/) <br>
The final ensembling that combines everything together: [Ensemble model for COVID-19 Open Challenge](https://www.kaggle.com/hongguan/ensemble-model-for-covid-19-open-challenge/) <br>

# Introduction

### On March 19, 2020, the White House Office of Science and Technology Policy (WH-OSTP) issued a statement announcing the release of an extensive machine-readable collection of scientific articles about COVID-19, SARS-CoV-2, and the coronavirus group, jointly by several institutions including National Library of Medicine and Allen Institute for AI, and WH-OSTP: 
> .. [joined] the institutions in issuing a call to action to the Nation’s artificial intelligence experts to develop new text and data mining techniques that can help the science community answer high-priority scientific questions related to COVID-19
### The dataset called COVID-19 Open Research Dataset (CORD-19) presently has nearly 59,000 articles (extracted from various archives), with more than 35,000 of which have full text. The institutions further compiled a series of questions to be answered. For example, some questions related to COVID-19 vaccines and therapeutics are:
* Effectiveness of drugs being developed and tried to treat COVID-19 patients.
* Exploration of use of best animal models and their predictive value for a human vaccine.
* Efforts targeted at a universal coronavirus vaccine.
* Efforts to develop prophylaxis clinical studies and prioritize in healthcare workers

![Vaccine and Therapeutics for COVID-19](https://qtxasset.com/fiercebiotech/1584710738/Screen%20Shot%202020-03-13%20at%2010.09.06%20AM.png/Screen%20Shot%202020-03-13%20at%2010.09.06%20AM.png?afFuGG0s3SKhTTkKAPfwWsMUeTddeDna)

# Problems Addressed And Their Solution

#### *Problem 1*: The process of screening clinically relevant publications remains a formidable challenge despite earlier efforts to improve accuracy of existing search engines themselves. 
#### *Problem 2*: The COVID-19 related questions are complex and the answers may be anywhere in the article's full text. 
#### *Solution*: We approach these problems in 2 steps:
#### Step 1: We use the new generation of neural network, BERT (Bi-directional Encoder Representations from Transformers), pre-trained on biomedical corpus. Instead of using one model for the task, we will use a “crowd” of models each fine-tuned on a different but related task using existing publicly available training datasets. These models independently score an article for relevance to Treatment and Vaccine. Our hypothesis is that the “wisdom of the crowd” is more effective than a single approach.
#### Step 2: We will mine relevant articles for specific information, such as the effectiveness, by identifying key passages that contain the information. In order to do so, we define a novel concept called semantic information availability (SIA) in a passage p relative to the question q, we propose to develop a dataset to train the BERT model to score SIA from a variety of passages and queries. Top-scored passages from the most relevant articles provide the answers to the question, as an extracted summary. This approach generalizes beyond COVID-19 dataset and can be used to automate information gathering for systematic reviews and other meta-analysis. (Aim to develop for June 16 deadline)

#### ***Note*:** In this Kernel, we will tackle part of Step 1 by forming 2 powerful scorers which will contribute to the Crowd Wisdom mentioned above. The Crowd Wisdom which accumulates multiple scorers created by each member of our team (including the ones in this kernel) is available in this link (TBD).

### All Code, Models, Data and Results mentioned in this Kernel are well documented and available in this [GitHub Link for Scorer 1](https://github.com/md-labs/covid19-kaggle/tree/master/document_nsp_text_sim) and this [GitHub Link for Scorer 2](https://github.com/md-labs/covid19-kaggle/tree/master/clinical_hedges_classification)

#### *Note*: The code below had run on high performance GPU's and may take several hours to execute.

![BERT NSP Methodology](https://pytorch.org/tutorials/_images/bert1.png)

# Scorer 1- Next Sentence Prediction for Article Identification:
1. We model the problem using the Next Sentence Prediction (NSP) property that is used as a pretraining strategy in the novel transformer Language model- BERT ([paper](https://arxiv.org/abs/1810.04805)). We particularly use SciBERT ([link](https://github.com/allenai/scibert)) which is a  type of BERT that is pre-trained on a huge corpus of scientific articles from semantic scholar. 
2. NSP in BERT is pre-trained using the following format for input: `[CLS] + Text_A + [SEP] + Text_B + [SEP]`
3. We use hand-picked sentences from BioMedical Papers from renowned conferences and use these as the query (Text_A). These queries are formulated such that they signify strongly Therapeutics and Vaccines separately (one for each). 
4. We then use the title + abstract + journal_id of the papers in the COVID dataset ([link](https://www.kaggle.com/allen-institute-for-ai/CORD-19-research-challenge)) as Text_B.
5. This combined text (title + abstract + journal_id) is input along with the query of Vaccine and Therapeutics separately and a Feed Forward Network acts as the NSP head. The NSP Head takes as input, the embedding corresponding to the [CLS] token generated by SciBERT. Furthermore, with the help of a softmax layer we get the text similarity (probability) scores between the two text elements. 
6. A threshold of 0.999 is set and if the similarity score exceed this, we classify it as relevant to Treatment or Vaccine according to which query it is most similar to.  


#### Next we will describe our manually formulated Vaccine and Therapeutic Queries formed by manually picking sentences from scientific articles from Elsevier and other medical publishers. 

#### Therapeutics Query:
   Therapeutics is the branch of medicine concerned with the treatment of disease and the action of remedial agents. There is no specific antiviral therapy and treatment given by doctors is largely supportive, consisting of supplemental oxygen and conservative fluid administration. Drugs like Chloroquine, Hydroxychloroquine, Lopinavir, Ritonavir, Azithromycin and Tocilizumab are being prescribed by doctors in ICU testing. The drug Remdesivir has shown promise against other coronaviruses in animal models. Patients with respiratory failure require intubation. Patients in shock require urgent fluid resuscitation and administration of empiric antimicrobial therapy. Corticosteroid therapy is not recommended for viral pneumonia; however, use may be considered for patients with refractory shock or acute respiratory distress syndrome
    
#### Vaccine Query:
   Vaccine is a substance used to stimulate the production of antibodies and provide immunity against diseases. They are treated to act as an antigen without inducing the disease. When the virulent version of an agent comes along, the immune system is prepared to respond due to the generation of B cells (memory and plasma cells), which will generate antibodies that will bind to pathogens and destroy them. Vaccine researchers are working on the development of a vaccine candidate expressing the viral spike protein of SARS-CoV-2 using a messenger RNA vaccine. Scientists are also focusing on the development of a chimpanzee adenovirus-vectored vaccine candidate against COVID-19. In addition, scientists are also working to see if vaccines developed for SARS coronavirus are effective against COVID-19.
   
### Note that the first sentence is a definition of the word therapeutics and vaccine in the respective queries. All subsequent sentences are formulated by keeping in mind that treatment and vaccine related to corona virus and covid-19 are emphasized

#### The code below is tested using 4 variants of SciBERT FineTuning which are described below:
1. Pretrained SciBERT model (provided by AllenAI) with NSP (without fine tuning) (Yes! The NSP head is randomly initialized in this step)

2. Fine Tuned SciBERT Model using MLM on the Abstract / Title Text of the COVID Dataset and then use this for NSP (This is used for FineTuning the model using the Masked Language Model and Next Sentence Prediction Training Methods employed in the BERT paper. Intermediate FineTuning seems to help as mentioned in this work: [ULMFiT paper](https://arxiv.org/abs/1801.06146)). Refer this GitHub [link](https://github.com/md-labs/covid19-kaggle/tree/master/document_nsp_text_sim/src/lm_finetuning) for more details. (The NSP head is randomly initialized in this variant too)

3. Fine Tuning the Pretrained SciBERT model in Variant 1 with the *Opioid NSP Question-Answer dataset and then use this model for NSP
4. Fine Tuning the Fine-Tuned SciBERT model in Variant 2 with the *Opioid NSP Question-Answer dataset and then use this model for NSP

Each of these models can be found [here](https://www.dropbox.com/sh/ko0d8jayaapb7xq/AABZ1yPVCLFuKUrPoBXBfjD0a?dl=0)

Labels for the folders are self explanatory about which models depict which variant

*The Opioid NSP Question-Answer dataset mentioned above is a weak and noisy dataset that is scraped from Reddit Opioid Forums. The dataset contains Questions asked my opioid addicts and the answers they have received in the forum. This data is used to fine tune the SciBERT model in order to initialize the weights of the NSP head.

***Note***: Although the Opioid NSP QA Dataset doesn't directly complement the prediction time COVID-19 Data, our intuition is that it helps weakly initialize the NSP head to give more robust Similarity scores than randomly initializing it as in the first two variants. Better datasets like the Clinical STS are available (which complement the Prediction time COVID data) to use but we have another scorer that uses this dataset in a similar manner and hence we wanted to distinctly separate the two scorers to get different perspectives when Ensembling the scores as described above.

# Scorer 2- Classification for Article Identification:
1. #### We model the problem as a text classification problem using the SciBERT model.

2. #### Classification in BERT uses the following format for input: `[CLS] + Text_A + [SEP]`

3. #### The [Clinical Hedges dataset](https://hiru.mcmaster.ca/hiru/HIRU_Hedges_home.aspx) is a set of articles which are manually annotated for a bunch of categories like Format (Original, Review, etc), Purpose (Treatment, Etiology, etc), Rigor (How closely the methodology in the paper correspond to the purpose of the paper) and whether the papers are related to Human Health Care (HHC). 
 
4. #### We use the Purpose column of this dataset (as described above) and label documents as Treatment and Not Treatment based on the annotations. Our model is then trained on this training dataset.
 
5. #### Now, assuming our nsp-text-sim model (described above) performs perfectly, we use the text documents categorized as Treatment or Vaccine by that model as the input to the Clinical Hedges model (here) at prediction time. Our Clinical Hedges model categorizes the text as Treatment or Not Treatment and if our dataset contains only Treatment and Vaccine related documents, our model will effectively categorize it as Treatment and Not Treatment (or Vaccine) related documents. 

6. #### Our architecture consists of the SciBERT model (Fine Tuned on Clinical Hedges Data) with a FFN on top to give the probability scores for classification. Our Text Input to the model is the Title + Abstract + Journal as used previously

![](https://yashuseth.files.wordpress.com/2019/06/fig1-1.png)

### Error Analysis and Future Work:
* Articles of different languages are included in the dataset and may have to be removed so that articles of interest are generalized to the English Domain
* In the NSP modeling, fine tuning on the Opioid Dataset might be adding excessive noise to the model and hence might be contributing negatively in the Ensemble. Training with Clinical STS might show improvement in this work.
* The Classification model is completely dependent on the performance of the NSP models and cannot be used to separate Vaccine and Therapeutics from Other articles on its own
* FineTuning BERT with a lot of other related Clinical datasets would be a possible Future work to improve results significantly

In [None]:
"""
Source: https://github.com/md-labs/covid19-kaggle/blob/master/document_nsp_text_sim/src/create_nsp_data.py
Code to create the input to the BERT Next Sentence Prediction Model (run_nsp.py)
Input:  vaccine_and_therapeutics_query.json, metadata.csv
Output: NSP Formatted Data in the form: [CLS] Query [SEP] (Title + Abstract) Text [SEP]
"""
import os
import json
import csv
import shutil
import sys
from document_nsp_text_sim.utils import ReadData

csv.field_size_limit(sys.maxsize)

path_to_data = os.path.abspath("../data")


def WriteDataNSP(dataset, path, directory):
    if os.path.exists(os.path.join(path_to_data, directory)):
        shutil.rmtree(os.path.join(path_to_data, directory))
    os.mkdir(os.path.join(path_to_data, directory))
    with open(os.path.join(path, directory, 'Test_Data.tsv'), 'w', newline='', encoding='utf-8') as fp:
        writer = csv.writer(fp, delimiter='\t')
        writer.writerow(['ID', 'Text_A', 'Text_B', 'DEF'])
        for row in dataset:
            writer.writerow(row)

RetCovData = ReadData(path_to_data + '/raw_data', "metadata.csv")
vctdict = json.load(open(os.path.join(path_to_data, "vaccine_and_therapeutics_query.json")))
dataset = []
for i, row in enumerate(RetCovData):
    if i == 0:
        continue
    if row[2] == '' or row[3] == '':
        continue
    dataset.append([row[3] + 'V', vctdict['vaccine'], ' '.join((row[3] + ' ' + row[8] + ' ' + row[11]).split('\n')),
                    'VC'])
    dataset.append([row[3] + 'T', vctdict['therapeutics'], ' '.join((row[3] + ' ' + row[8] + ' ' + row[11]).split('\n')),
                    'TR'])
WriteDataNSP(dataset, path_to_data, 'COVID_NSP_Data')

#### The run_nsp code below, is run using the following command:
`python src/run_nsp.py --model_file=models/scibert_scivocab_uncased --bert_model=models/scibert_scivocab_uncased --do_lower_case --task_name=covid --data_dir=data/COVID_NSP_Data --learning_rate=2e-5 --num_train_epochs=10 --output_dir=models/scibert_scivocab_uncased/ --cache_dir=./BERT_CACHE --eval_batch_size=16 --max_seq_length=512 --train_batch_size=16 --do_eval`
#### For more details visit this GitHub ([link](https://github.com/md-labs/covid19-kaggle/tree/master/document_nsp_text_sim))

In [0]:
"""
Source: https://github.com/md-labs/covid19-kaggle/blob/master/document_nsp_text_sim/src/run_nsp.py
BERT finetuning runner for Next Sentence Prediction
Input: [CLS] Query [SEP] (Title + Abstract) Text [SEP]
Output: Probability of similarity between Query and Text
"""

from __future__ import absolute_import, division, print_function

import argparse
import csv
import logging
import os
import random
import sys
import copy

import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForNextSentencePrediction, BertConfig, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)


class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, text_b=None, label=None):
        """Constructs a InputExample.
        Args:
            guid: Unique id for the example.
            text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
            text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
            label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_id):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id


class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        with open(input_file, "r") as f:
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
            lines = []
            for i, line in enumerate(reader):
                if i == 0:
                    continue
                if sys.version_info[0] == 2:
                    line = list(unicode(cell, 'utf-8') for cell in line)
                lines.append(line)
            return lines


class COVIDProcessor(DataProcessor):
    """Processor for the CLPsych data set."""

    def get_train_examples(self, data_dir):
        """See base class."""
        logger.info("LOOKING AT {}".format(os.path.join(data_dir, "Train_Data.tsv")))
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "Train_Data.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "Dev_Data.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "Test_Data.tsv")), "test")

    def get_labels(self):
        """See base class."""
        return ["VC", "TR"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        random.seed(42)
        req = list()
        for i in range(0, len(lines)):
            req.append(i)
        req_final = random.sample(req, len(lines))	
        for i in req_final:
            # print(lines[i])
            guid = lines[i][0]
            text_a = lines[i][1]
            text_b = lines[i][2]
            label = lines[i][-1]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples   


def convert_examples_to_features(examples, label_list, max_seq_length,
                                 tokenizer, output_mode="classification"):
    """Loads a data file into a list of `InputBatch`s."""

    label_map = {label : i for i, label in enumerate(label_list)}

    features = []
    for (ex_index, example) in enumerate(examples):
        if ex_index % 10000 == 0:
            logger.info("Writing example %d of %d" % (ex_index, len(examples)))

        tokens_a = tokenizer.tokenize(example.text_a)

        tokens_b = None
        if example.text_b:
            tokens_b = tokenizer.tokenize(example.text_b)
            # Modifies `tokens_a` and `tokens_b` in place so that the total
            # length is less than the specified length.
            # Account for [CLS], [SEP], [SEP] with "- 3"
            _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
        else:
            # Account for [CLS] and [SEP] with "- 2"
            if len(tokens_a) > max_seq_length - 2:
                tokens_a = tokens_a[:(max_seq_length - 2)]

        # The convention in BERT is:
        # (a) For sequence pairs:
        #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
        #  type_ids: 0   0  0    0    0     0       0 0    1  1  1  1   1 1
        # (b) For single sequences:
        #  tokens:   [CLS] the dog is hairy . [SEP]
        #  type_ids: 0   0   0   0  0     0 0
        #
        # Where "type_ids" are used to indicate whether this is the first
        # sequence or the second sequence. The embedding vectors for `type=0` and
        # `type=1` were learned during pre-training and are added to the wordpiece
        # embedding vector (and position vector). This is not *strictly* necessary
        # since the [SEP] token unambiguously separates the sequences, but it makes
        # it easier for the model to learn the concept of sequences.
        #
        # For classification tasks, the first vector (corresponding to [CLS]) is
        # used as as the "sentence vector". Note that this only makes sense because
        # the entire model is fine-tuned.
        tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
        segment_ids = [0] * len(tokens)

        if tokens_b:
            tokens += tokens_b + ["[SEP]"]
            segment_ids += [1] * (len(tokens_b) + 1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)

        # Zero-pad up to the sequence length.
        padding = [0] * (max_seq_length - len(input_ids))
        input_ids += padding
        input_mask += padding
        segment_ids += padding

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length

        if output_mode == "classification":
            label_id = label_map[example.label]
        elif output_mode == "regression":
            label_id = float(example.label)
        else:
            raise KeyError(output_mode)

        if ex_index < 5:
            logger.info("*** Example ***")
            logger.info("guid: %s" % (example.guid))
            logger.info("tokens: %s" % " ".join(
                    [str(x) for x in tokens]))
            logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
            logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
            logger.info(
                    "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
            logger.info("label: %s (id = %d)" % (example.label, label_id))

        features.append(
                InputFeatures(input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              label_id=label_id))
    return features



def _truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""

    # This is a simple heuristic which will always truncate the longer sequence
    # one token at a time. This makes more sense than truncating an equal percent
    # of tokens from each, since if one sequence is very short then each token
    # that's truncated likely contains more information than a longer sequence.
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()


def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--req_pretrained", default="True", type=str, required=True)
    parser.add_argument("--model_dir", default=None, type=str, required=True)
    parser.add_argument("--data_dir",
                        default="",
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--vocab_dir", default="", type=str,
                        required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default="",
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument("--output_dir",
                        default="",
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--cache_dir",
                        default="",
                        type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length",
                        default=512,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    processors = {
		"covid": COVIDProcessor,
    }

    num_labels_task = {
        "covid": 2,
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")
        

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    print(processor)
    num_labels = num_labels_task[task_name]
    print(num_labels)
    label_list = processor.get_labels()
    print(label_list)
    
    tokenizer = BertTokenizer.from_pretrained(args.vocab_dir, do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank))
    model = BertForNextSentencePrediction.from_pretrained(args.model_dir,
              cache_dir=cache_dir)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(
            train_examples, label_list, args.max_seq_length, tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        for ep in trange(int(args.num_train_epochs), desc="Epoch"):
            model.train()
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(train_dataloader):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            eval_examples = processor.get_dev_examples(args.data_dir)
            eval_features = convert_examples_to_features(
                eval_examples, label_list, args.max_seq_length, tokenizer)
            print("\n")
            print("Running evaluation for epoch: {}".format(ep))
            all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
            all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
            all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
            all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
            eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
            # Run prediction for full data
            eval_sampler = SequentialSampler(eval_data)
            eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
    					
            model.eval()
            eval_loss, eval_accuracy = 0, 0
            nb_eval_steps, nb_eval_examples = 0, 0
    
            for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)
    
                with torch.no_grad():
                    tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
                    logits = model(input_ids, segment_ids, input_mask)
    
                logits = logits.detach().cpu().numpy()
                label_ids = label_ids.to('cpu').numpy()
                tmp_eval_accuracy = accuracy(logits, label_ids)
    
                eval_loss += tmp_eval_loss.mean().item()
                eval_accuracy += tmp_eval_accuracy
    
                nb_eval_examples += input_ids.size(0)
                nb_eval_steps += 1
    
            eval_loss = eval_loss / nb_eval_steps
            eval_accuracy = eval_accuracy / nb_eval_examples
            loss = tr_loss/nb_tr_steps if args.do_train else None
            result = {'eval_loss': eval_loss,
                      'eval_accuracy': eval_accuracy,
                      'global_step': global_step,
                      'loss': loss}
    
            for key in sorted(result.keys()):
                print(key, str(result[key]))
            print()
            
    if args.do_train and args.do_eval:
        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())
        # Load a trained model and config that you have fine-tuned
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        config = BertConfig(output_config_file)
        model = BertForNextSentencePrediction(config)
        model.load_state_dict(torch.load(output_model_file))
    elif args.do_train:
        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())	
    else:
        # Load a trained model and config that you have fine-tuned
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        config = BertConfig(output_config_file)
        if args.req_pretrained == "True":
            model = BertForNextSentencePrediction.from_pretrained(args.model_dir,
                                                                  cache_dir=cache_dir)
        else:
            model = BertForNextSentencePrediction(config)
            model.load_state_dict(torch.load(output_model_file))

    model.to(device)

    if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_test_examples(args.data_dir)
        eval_features = convert_examples_to_features(
            eval_examples, label_list, args.max_seq_length, tokenizer)
        complete_user_ids = list()
        for example in eval_examples:
            complete_user_ids.append(example.guid)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        complete_label_ids = list()
        complete_outputs = list()
        complete_probs = list()
        for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)
            with torch.no_grad():
                tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
                logits = model(input_ids, segment_ids, input_mask)

            last_layer_op = copy.deepcopy(logits)
            logits = logits.detach().cpu().numpy()
            print(logits)
            sm = torch.nn.Softmax()
            probabilities = sm(last_layer_op)
            probabilities = probabilities.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            tmp_eval_accuracy = accuracy(logits, label_ids)
            outputs = np.argmax(logits, axis=1)
            complete_outputs.extend(outputs)
            complete_label_ids.extend(label_ids)
            complete_probs.extend(probabilities[:, 1])

        outcsv = open(os.path.join(args.output_dir, "Reqd_Labels.csv"),'w', encoding = 'utf8', newline='')
        writer = csv.writer(outcsv,quotechar = '"')
        writer.writerow(["ID", "Probs"])
        for user,true, pred, probs in zip(complete_user_ids, complete_label_ids, complete_outputs, complete_probs):
            writer.writerow([user, probs])

In [None]:
"""
Source: https://github.com/md-labs/covid19-kaggle/blob/master/document_nsp_text_sim/src/analyze_nsp_results.py
Code to read the output csv from run_nsp.py and metadata.csv. It filters papers based on a threshold set as 0.999 for
either Treatment or Vaccine Queries. It assigns label based on whichever is higher beyond 0.999 and O if it isn't beyond
the threshold.
Input: metadata.csv, Reqd_Labels.csv from run_nsp.py
Output: Threshold Filtered Results formatted as 'Title', 'Text', 'Label', 'Prob_T', 'Prob_V'
"""


from document_nsp_text_sim.utils import ReadData
import os
import csv
import sys

path_to_data = os.path.abspath("../data")
path_to_results = os.path.abspath("../results")
csv.field_size_limit(sys.maxsize)


def main():
    RetCovData = ReadData(path_to_data + '/raw_data', "metadata.csv")
    RetCovDict = dict()
    text = []
    for i, row in enumerate(RetCovData):
        if i == 0:
            continue
        RetCovDict[row[3]] = ' '.join((row[3] + ' ' + row[8] + ' ' + row[11]).split('\n'))
        text.append(' '.join((row[3] + ' ' + row[8] + ' ' + row[11]).split('\n')))
    labels = ReadData(os.path.join(path_to_results, "nsp_results_final"), "Reqd_Labels_Before_FineTuning.csv")
    label_dict = dict()
    for i, row in enumerate(labels):
        if i == 0:
            continue
        if row[0][:-1] not in label_dict:
            label_dict[row[0][:-1]] = dict()
        label_dict[row[0][:-1]][row[0][-1]] = float(row[1])
    count = 0
    final_labels = []
    for key in label_dict.keys():
        if key == 'paper_id':
            continue
        dictionary = label_dict[key]
        if 'T' not in dictionary or 'V' not in dictionary:
            continue
        th = 0.999
        if dictionary['T'] > th or dictionary['V'] > th:
            label = 'TR' if dictionary['T'] > dictionary['V'] else 'VC'
            final_labels.append([key, RetCovDict[key], label, dictionary['T'], dictionary['V']])
            count += 1
        else:
            final_labels.append([key, RetCovDict[key], 'O', dictionary['T'], dictionary['V']])
    final_labels = list(sorted(final_labels, key=lambda x: x[3], reverse=True))
    with open(os.path.join(path_to_results, 'filtered_results_final', 'Filter_Before_FineTuning.tsv'), 'w', newline='',
              encoding='utf-8') as fp:
        writer = csv.writer(fp, delimiter='\t')
        writer.writerow(['Title', 'Text', 'Label', 'Prob_T', 'Prob_V'])
        for row in final_labels:
            writer.writerow(row)

In [None]:
import pandas as pd
pd.set_option('display.max_colwidth', -1)
import os
import csv

In [None]:
metadata = dict()
with open("/kaggle/input/cord19metadata/metadata.csv") as fp:
    reader = csv.reader(fp)
    for row in reader:
        metadata[row[3]] = [row[0], row[8], row[11]]

# Results From Scorer 1 After Filtering Results from each of the 4 Variants by setting a threshold similarity score of 0.999:

# **Variant 1- SciBERT Before FineTuning:**

In [None]:
df = pd.read_csv('/kaggle/input/covid-nsp-filtered-results/Filter_Before_FineTuning.tsv', delimiter='\t')
df = df[['ID', 'Text', 'Label']]
df.columns = ['Title', 'Abstract', 'Label']
df['CORD_ID'] = df.apply(lambda row: metadata[row[0]][0] , axis=1)
df['Abstract'] = df.apply(lambda row: metadata[row[0]][1] , axis=1)
df['Journal'] = df.apply(lambda row: metadata[row[0]][2] , axis=1)
df = df.set_index('CORD_ID')

In [None]:
# TREATMENT CLASSIFIED ARTICLES
df[df['Label'] == 'TR'].head()

In [None]:
# VACCINE CLASSIFIED ARTICLES
df.loc[df['Label'] == 'VC'].head()

# **Variant 2- SciBERT After FineTuning:**

In [None]:
df = pd.read_csv('/kaggle/input/covid-nsp-filtered-results/Filter_After_FineTuning.tsv', delimiter='\t')
df = df[['ID', 'Text', 'Label']]
df.columns = ['Title', 'Abstract', 'Label']
df['CORD_ID'] = df.apply(lambda row: metadata[row[0]][0] , axis=1)
df['Abstract'] = df.apply(lambda row: metadata[row[0]][1] , axis=1)
df['Journal'] = df.apply(lambda row: metadata[row[0]][2] , axis=1)
df = df.set_index('CORD_ID')

In [None]:
# TREATMENT CLASSIFIED ARTICLES
df[df['Label'] == 'TR'].head()

In [None]:
# VACCINE CLASSIFIED ARTICLES
df[df['Label'] == 'VC'].head()

# **Variant 3: Pretrained SciBERT After Fine Tuning with Opioid NSP**

In [None]:
df = pd.read_csv('/kaggle/input/covid-nsp-filtered-results/Filter_Before_FineTuning_Classification.tsv', delimiter='\t')
df = df[['ID', 'Text', 'Label']]
df.columns = ['Title', 'Abstract', 'Label']
df['CORD_ID'] = df.apply(lambda row: metadata[row[0]][0] , axis=1)
df['Abstract'] = df.apply(lambda row: metadata[row[0]][1] , axis=1)
df['Journal'] = df.apply(lambda row: metadata[row[0]][2] , axis=1)
df = df.set_index('CORD_ID')

In [None]:
# TREATMENT CLASSIFIED ARTICLES
df[df['Label'] == 'TR'].head()

In [None]:
# VACCINE CLASSIFIED ARTICLES
df[(df['Label'] == 'VC') & (df['Abstract'] != '')].head()

# **Variant 4: FineTuned SciBERT After Fine Tuning with Opioid NSP**

In [None]:
df = pd.read_csv('/kaggle/input/covid-nsp-filtered-results/Filter_After_FineTuning_Classification.tsv', delimiter='\t')
df = df[['ID', 'Text', 'Label']]
df.columns = ['Title', 'Abstract', 'Label']
df['CORD_ID'] = df.apply(lambda row: metadata[row[0]][0] , axis=1)
df['Abstract'] = df.apply(lambda row: metadata[row[0]][1] , axis=1)
df['Journal'] = df.apply(lambda row: metadata[row[0]][2] , axis=1)
df = df.set_index('CORD_ID')

In [None]:
# TREATMENT CLASSIFIED ARTICLES
df[(df['Label'] == 'TR') & (df['Abstract'] != '')].head()

In [None]:
# VACCINE CLASSIFIED ARTICLES
df[(df['Label'] == 'VC') & (df['Abstract'] != '')].head()

# Inference from Results of Scorer 1
### The results above show that FineTuning on the Opioid Dataset seems to have added excessive noise to the model thereby generalizing it to domains other than Coronavirus. On the other hand, doing the intermediate Fine Tuning (MLM + NSP) of the SciBERT model on the Title + Abstract text before the prediction seems to have helped the model predict Treatment and Vaccine related documents more accurately

#### Although not perfect, these methods help to filter relevant articles from thousands of irrelevant ones

In [None]:
"""
Source: https://github.com/md-labs/covid19-kaggle/blob/master/clinical_hedges_classification/src/combine_nsp_results.py
Code to combine all text which is classified as Vaccine or Therapeutics by the Query-Document NSP Model and prepare
input to the Clinical Hedges classifier (run_classifier.py)
Input: Directory of filtered results from NSP model
Output: Data prepped in format required for input to Clinical Hedges Classification Model
Output File Header Format: ["ID", "Text", "Label", "Prob"]
"""

import csv
import os

path_to_results = os.path.abspath('../../document_nsp_text_sim/results/filtered_results_final')
dirListing = os.listdir(path_to_results)

combined_data_dict = dict()
for file in dirListing:
    with open(os.path.join(path_to_results, file)) as fp:
        reader = csv.reader(fp, delimiter='\t')
        for i, row in enumerate(reader):
            if i == 0:
                continue
            combined_data_dict[row[0]] = row[1:]


with open(os.path.abspath("../data/Pred_Data_COVID/Test_Data.tsv"), 'w') as fp:
    writer = csv.writer(fp, delimiter='\t')
    writer.writerow(["ID", "Text", "Label", "Prob"])
    for key in combined_data_dict.keys():
        if combined_data_dict[key][1] == 'O':
            continue
        writer.writerow([key] + combined_data_dict[key])

#### The run_classifier code below, is run using the following command:
`python src/run_classifier.py  --task_num=3 --model_file=models/scibert_scivocab_uncased --vocab_dir=models/scibert_scivocab_uncased --do_lower_case --task_name=clinicalhedges --data_dir=data/Pred_Data_COVID --learning_rate=2e-5 --num_train_epochs=10 --output_dir=models/SciBERT_Trained_Treatment_Model/ --eval_batch_size=16 --max_seq_length=400 --train_batch_size=16 --do_eval`
#### For more details visit this GitHub ([link](https://github.com/md-labs/covid19-kaggle/tree/master/clinical_hedges_classification))

In [None]:
"""
Source: https://github.com/md-labs/covid19-kaggle/edit/master/clinical_hedges_classification/src/run_classifier.py
BERT finetuning runner for Text Classification
Input: [CLS] (Title + Abstract) Text [SEP]
Output: Classification Labels of Treatment or Other for the Text Input
"""

from __future__ import absolute_import, division, print_function

import argparse
import csv
import logging
import os
import random
import sys
import shutil
import copy

import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam

from sklearn.metrics import classification_report

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)


class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, text_b=None, label=None):
        """Constructs a InputExample.
        Args:
            guid: Unique id for the example.
            text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
            text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
            label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_id):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id


class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self, index):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        with open(input_file, "r") as f:
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
            lines = []
            for line in reader:
                if sys.version_info[0] == 2:
                    line = list(unicode(cell, 'utf-8') for cell in line)
                lines.append(line)
            return lines


class InputProcessor(DataProcessor):
    """Processor for the Clinical Hedges data set."""

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "Train_Data.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "Dev_Data.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "Test_Data.tsv")), "test")

    def get_labels(self, index):
        """See base class."""
        index = int(index)
        if(index == 0):
            return ["NA", "O"]
        elif(index == 1):
            return ["F", "T"]
        elif(index == 2):
            return ["NA", "TR"]
        elif(index == 3):
            return ["F", "T"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        req = list()
        for i in range(0, len(lines)):
            req.append(i)
        req_final = random.sample(req, len(lines))
        for i in req_final:
            if i == 0:
                continue
            guid = lines[i][0]
            text_a = lines[i][1]
            text_b = None
            label = lines[i][2]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples


def accuracy(out, labels):
    outputs = np.argmax(out, axis=1)
    return np.sum(outputs == labels)


def get_tp_fp_fn(logits, labels):
  assert labels.shape[1] == 1
  labels = labels.squeeze()
  predictions = np.argmax(logits, axis=1)
  labels, predictions = labels.astype(int), predictions.astype(int)
  tp = np.sum(np.logical_and(predictions == 1, labels == 1))
  fp = np.sum(np.logical_and(predictions == 1, labels == 0))
  fn = np.sum(np.logical_and(predictions == 0, labels == 1))
  return tp, fp, fn


def compute_metrics(tp, fp, fn):
  precision = tp / (tp + fp + np.finfo(float).eps)
  recall = tp / (tp + fn + np.finfo(float).eps)
  f1 = 2 * precision * recall / (precision + recall + np.finfo(float).eps)
  return precision, recall, f1


def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
    """Loads a data file into a list of `InputBatch`s."""

    label_map = {label : i for i, label in enumerate(label_list)}

    features = []
    for (ex_index, example) in enumerate(examples):
        tokens_a = tokenizer.tokenize(example.text_a)

        tokens_b = None
        if example.text_b:
            tokens_b = tokenizer.tokenize(example.text_b)
            # Modifies `tokens_a` and `tokens_b` in place so that the total
            # length is less than the specified length.
            # Account for [CLS], [SEP], [SEP] with "- 3"
            _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
        else:
            # Account for [CLS] and [SEP] with "- 2"
            if len(tokens_a) > max_seq_length - 2:
                tokens_a = tokens_a[:(max_seq_length - 2)]

        # The convention in BERT is:
        # (a) For sequence pairs:
        #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
        #  type_ids: 0   0  0    0    0     0       0 0    1  1  1  1   1 1
        # (b) For single sequences:
        #  tokens:   [CLS] the dog is hairy . [SEP]
        #  type_ids: 0   0   0   0  0     0 0
        #
        # Where "type_ids" are used to indicate whether this is the first
        # sequence or the second sequence. The embedding vectors for `type=0` and
        # `type=1` were learned during pre-training and are added to the wordpiece
        # embedding vector (and position vector). This is not *strictly* necessary
        # since the [SEP] token unambigiously separates the sequences, but it makes
        # it easier for the model to learn the concept of sequences.
        #
        # For classification tasks, the first vector (corresponding to [CLS]) is
        # used as as the "sentence vector". Note that this only makes sense because
        # the entire model is fine-tuned.
        tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
        segment_ids = [0] * len(tokens)

        if tokens_b:
            tokens += tokens_b + ["[SEP]"]
            segment_ids += [1] * (len(tokens_b) + 1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)

        # Zero-pad up to the sequence length.
        padding = [0] * (max_seq_length - len(input_ids))
        input_ids += padding
        input_mask += padding
        segment_ids += padding

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length

        label_id = label_map[example.label]
        #if ex_index < 5:
        #    logger.info("*** Example ***")
        #    logger.info("guid: %s" % (example.guid))
        #    logger.info("tokens: %s" % " ".join(
        #            [str(x) for x in tokens]))
        #    logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
        #    logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
        #    logger.info(
        #            "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
        #    logger.info("label: %s (id = %d)" % (example.label, label_id))

        features.append(
                InputFeatures(input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              label_id=label_id))
    return features


def _truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""

    # This is a simple heuristic which will always truncate the longer sequence
    # one token at a time. This makes more sense than truncating an equal percent
    # of tokens from each, since if one sequence is very short then each token
    # that's truncated likely contains more information than a longer sequence.
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()

def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--data_dir",
                        default="",
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--model_dir", default="", type=str,
                        required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                        "bert-base-multilingual-cased, bert-base-chinese or any pretrained model directory with model.bin and config file")
    parser.add_argument("--vocab_dir", default="", type=str,
                        required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default="",
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument("--output_dir",
                        default="",
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    parser.add_argument("--task_num",
                        default=-1,
                        type=int,
                        required=True,
                        help="The task number of Clinical Hedges Tasks to run")

    ## Other parameters
    parser.add_argument("--cache_dir",
                        default="",
                        type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length",
                        default=512,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    processors = {
        "clinicalhedges": InputProcessor,
    }

    num_labels_task = {
        "clinicalhedges": [2, 2, 2, 2],
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")


    task_name = args.task_name.lower()
    task_num = args.task_num
    if task_name not in processors:
        raise ValueError("Task not found: %s" % task_name)

    processor = processors[task_name]()
    print(processor)
    num_labels = num_labels_task[task_name][task_num-1]
    print(num_labels)
    label_list = processor.get_labels(task_num-1)
    print(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.vocab_dir, do_lower_case=args.do_lower_case)
    file = open(os.path.join(args.output_dir, "Classification_Reports_Task_{}.txt".format(task_num)), 'w')

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank))
    model = BertForSequenceClassification.from_pretrained(args.model_dir,
              cache_dir=cache_dir,
              num_labels = num_labels)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(
            train_examples, label_list, args.max_seq_length, tokenizer)
        logger.info("***** Running training for Task {}*****".format(task_num))
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        for ep in trange(int(args.num_train_epochs), desc="Epoch"):
            model.train()
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(train_dataloader):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            eval_examples = processor.get_dev_examples(args.data_dir)
            eval_features = convert_examples_to_features(
                eval_examples, label_list, args.max_seq_length, tokenizer)
            print("\n")
            print("Running evaluation for epoch: {}".format(ep))
            all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
            all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
            all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
            all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
            eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
            # Run prediction for full data
            eval_sampler = SequentialSampler(eval_data)
            eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

            model.eval()
            eval_loss, eval_accuracy = 0, 0
            nb_eval_steps, nb_eval_examples = 0, 0

            for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)

                with torch.no_grad():
                    tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
                    logits = model(input_ids, segment_ids, input_mask)

                logits = logits.detach().cpu().numpy()
                label_ids = label_ids.to('cpu').numpy()
                tmp_eval_accuracy = accuracy(logits, label_ids)

                eval_loss += tmp_eval_loss.mean().item()
                eval_accuracy += tmp_eval_accuracy

                nb_eval_examples += input_ids.size(0)
                nb_eval_steps += 1

            eval_loss = eval_loss / nb_eval_steps
            eval_accuracy = eval_accuracy / nb_eval_examples
            loss = tr_loss/nb_tr_steps if args.do_train else None
            result = {'eval_loss': eval_loss,
                      'eval_accuracy': eval_accuracy,
                      'global_step': global_step,
                      'loss': loss}

            for key in sorted(result.keys()):
                print(key, str(result[key]))
            print()

    if args.do_train:
        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        if(os.path.exists(os.path.join(args.output_dir, "Model_Task_{}".format(task_num)))):
            shutil.rmtree(os.path.join(args.output_dir, "Model_Part_Task_{}".format(task_num)))
        os.mkdir(os.path.join(args.output_dir, "Model_Part_Task_{}".format(task_num)))
        output_model_file = os.path.join(args.output_dir, "Model_Part_Task_{}".format(task_num), WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, "Model_Part_Task_{}".format(task_num), CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())
    if args.do_eval:
        # Load a trained model and config that you have fine-tuned
        output_model_file = os.path.join(args.output_dir, "Model_Part_Task_{}".format(task_num), WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, "Model_Part_Task_{}".format(task_num), CONFIG_NAME)
        config = BertConfig(output_config_file)
        model = BertForSequenceClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file, map_location='cpu'))
    model.to(device)

    if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_test_examples(args.data_dir)
        eval_features = convert_examples_to_features(
            eval_examples, label_list, args.max_seq_length, tokenizer)
        complete_user_ids = list()
        for example in eval_examples:
            complete_user_ids.append(example.guid)
        logger.info("***** Running Test for Task {}*****".format(task_num))
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        complete_label_ids = list()
        complete_outputs = list()
        complete_probs = list()
        for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)
            with torch.no_grad():
                tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
                logits = model(input_ids, segment_ids, input_mask)

            last_layer_op = copy.deepcopy(logits)
            logits = logits.detach().cpu().numpy()
            sm = torch.nn.Softmax()
            probabilities = sm(last_layer_op)
            probabilities = probabilities.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            tmp_eval_accuracy = accuracy(logits, label_ids)
            outputs = np.argmax(logits, axis=1)
            complete_outputs.extend(outputs)
            complete_label_ids.extend(label_ids)
            complete_probs.extend(probabilities[:,1])

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        outcsv = open(os.path.join(args.output_dir, "Reqd_Labels_Task_{}.csv".format(task_num)),'w', encoding = 'utf8', newline='')
        writer = csv.writer(outcsv, quotechar = '"')
        writer.writerow(["ID", "True", "Pred"])
        for user, true, pred, prob in zip(complete_user_ids, complete_label_ids, complete_outputs, complete_probs):
            writer.writerow([user,true,pred, prob])
        outcsv.close()
        eval_loss = eval_loss / nb_eval_steps
        eval_loss = eval_loss / nb_eval_steps


        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples
        loss = tr_loss/nb_tr_steps if args.do_train else None
        result = {'eval_loss': eval_loss,
                  'eval_accuracy': eval_accuracy,
                  'global_step': global_step,
                  'loss': loss}
        print(result)

        file.write("\nClassification Report\n\n" + classification_report(complete_label_ids, complete_outputs) + "\n\n\n")
    file.close()

if __name__ == "__main__":
    main()


# Results of Scorer 2:

In [None]:
df = pd.read_csv('/kaggle/input/covid-ch-classification-results/COVID_Labels_After_Classification_CH.tsv', delimiter='\t')
df = df[['ID', 'Text', 'Label']]
df.columns = ['Title', 'Abstract', 'Label']
df['CORD_ID'] = df.apply(lambda row: metadata[row[0]][0] , axis=1)
df['Abstract'] = df.apply(lambda row: metadata[row[0]][1] , axis=1)
df['Journal'] = df.apply(lambda row: metadata[row[0]][2] , axis=1)
df = df.set_index('CORD_ID')

In [None]:
# TREATMENT CLASSIFIED ARTICLES
df[(df['Label'] == 'TR') & (df['Abstract'] != '')].head()

In [None]:
# VACCINE CLASSIFIED ARTICLES
df[(df['Label'] == 'VC') & (df['Abstract'] != '')].head()

# Inference from Results of Scorer 2
### The performance of this method is completely dependent on the performance efficiency of our NSP models. Although, the model predicts with high accuracy the Treatment related articles they are in small numbers.