# Evaluating the results of the model
---

## Imports

In [14]:
import sys
import os
import pandas as pd
from collections import namedtuple

In [28]:
__MODEL_SLUG__ = 'biomed'
sys.path.insert(0, os.path.join(BASE_PATH))

In [29]:
import torch
from transformers import RobertaTokenizer, RobertaModel, RobertaConfig
from models.linear_classifier import LinearClassifier
from util import get_dataloaders

## Constants

In [60]:
BASE_PATH = os.path.join(os.pardir)
assert os.path.exists(BASE_PATH)
MODELS_PATH = os.path.join(os.pardir, os.pardir, os.pardir, os.pardir, 'biomedical-datasets', 'models', 'biomed-roberta')
assert os.path.exists(MODELS_PATH)
DATASET_PATH = os.path.join(os.pardir, 'datasets', 'preprocessed')
assert os.path.exists(DATASET_PATH)

In [64]:
Args = namedtuple('Args',
 [
    'max_seq_length',
    'num_classes',
    'model_name',
    'evidence_retrieval',
    'dataset',
    'dataset_suffix',
    'combine',
    'data_folder',
    'encoder_ckpt',
    'epochs',
    'batch_size',
    'learning_rate',
    'weight_decay',
    'l1_regularization',
    'gradient_accumulation_steps',
    'alpha',
    'temp',
    'ckpt',
    'model_path',
    'save_folder',
    'print_freq',
    'evaluation_metric',
    'evaluate_state'
  ], defaults=(
      128,
      2,
      'biomed',
      False,
      'nli4ct',
      None,
      False,
      DATASET_PATH,
      None,
      3,
      16,
      1e-05,
      0.01,
      0.1,
      1,
      1.,
      0.05,
      None,
      None,
      None,
      100,
      'f1',
      'validation'
    )
  )

args = Args()

In [65]:
args = args._replace(model_path= os.path.join(MODELS_PATH, '{}_models'.format(args.dataset)) if args.dataset_suffix is None else os.path.join(models_path, '{}_models'.format(args.dataset, args.dataset_suffix)))
args = args._replace(model_name = '{}_{}L_len_{}_lr_{}_w_decay_{}_bsz_{}_temp_{}{}'.format(args.model_name, args.num_classes, args.max_seq_length, args.learning_rate, args.weight_decay, args.batch_size, args.temp, '_er' if args.evidence_retrieval else ''))
args = args._replace(save_folder = os.path.join(args.model_path, args.model_name))

In [38]:
config = RobertaConfig.from_pretrained("allenai/biomed_roberta_base", output_hidden_states=True, output_attentions=True)
model = RobertaModel.from_pretrained("allenai/biomed_roberta_base", config=config)
classifier = LinearClassifier(model, num_classes=args.num_classes)
tokenizer = RobertaTokenizer.from_pretrained("allenai/biomed_roberta_base")
classifier_ckpt = torch.load(os.path.join(args.model_path, args.model_name, 'classifier_best.pth'))

In [57]:
_state = {key[7:]:classifier_ckpt['models'][key] for key in list(classifier_ckpt['models'].keys()) if key != "module.encoder.embeddings.position_ids"}
classifier.load_state_dict(_state)

<All keys matched successfully>

In [66]:
dataloader_struct = get_dataloaders(args.dataset, args.data_folder, tokenizer,
                                    args.batch_size, 1, args.max_seq_length,
                                    args.num_classes)

Loading NLI Dataset


Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

In [77]:
dataloader_struct.keys()

dict_keys(['loader', 'iids', 'trials', 'orders', 'genres', 'types'])