In [5]:
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
from pytorch_pretrained_bert.tokenization import BertTokenizer
import torch
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from train import InputExample, convert_examples_to_features
import pandas as pd
import numpy as np
from sklearn.metrics import classification_report

### Argument Classification Using fine tuned BERT Model

train_tweets = pd.read_pickle('/Users/william/Desktop/Thesis_code/train_tweets.pkl')
test_tweets = pd.read_pickle('/Users/william/Desktop/Thesis_code/test_tweets.pkl')


def convert_to_input_examples(data):
    input_examples = []
    target = data['Target'].iloc[0]
    data = data['Processed_Tweet_sw'].tolist()
    for tweet in data:
        input_examples.append(InputExample(text_a=target,text_b=tweet,label="NoArgument"))
    return input_examples

input_examples = convert_to_input_examples(test_tweets)



num_labels = 3
model_path = 'argument_classification_ukp/'
label_list = ["NoArgument", "Argument_against", "Argument_for"]
max_seq_length = 64
eval_batch_size = 8

#Input examples. The model 'bert_output/ukp/bert-base-topic-sentence/all_topics/' expects text_a to be the topic
#and text_b to be the sentence. label is an optional value, only used when we print the output in this script.





tokenizer = BertTokenizer.from_pretrained(model_path, do_lower_case=True)
eval_features = convert_examples_to_features(input_examples, label_list, max_seq_length, tokenizer)

all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)

eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids)
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=eval_batch_size)



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BertForSequenceClassification.from_pretrained(model_path, num_labels=num_labels)
model.to(device)
model.eval()

predicted_labels = []
with torch.no_grad():
    for input_ids, input_mask, segment_ids in eval_dataloader:
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)


        logits = model(input_ids, segment_ids, input_mask)
        logits = logits.detach().cpu().numpy()

        for prediction in np.argmax(logits, axis=1):
            predicted_labels.append(label_list[prediction])



test_tweets['Predicted_Label'] = predicted_labels

test_tweets['Predicted_Label'] = test_tweets['Predicted_Label'].replace(['Argument_against','Argument_for'],'Argumentative')
test_tweets['Predicted_Label'] = test_tweets['Predicted_Label'].replace(['NoArgument'],'Non_Argumentative')



report = classification_report(test_tweets['Opinion Towards'],test_tweets['Predicted_Label'],output_dict=True)


print(report)

loading vocabulary file argument_classification_ukp/vocab.txt
:: Sentences longer than max_sequence_length: 0
:: Num sentences: 1956
loading archive file argument_classification_ukp/
Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "max_position_embeddings": 512,
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "type_vocab_size": 2,
  "vocab_size": 30522
}



{'Argumentative': {'precision': 0.5833333333333334, 'recall': 0.006097560975609756, 'f1-score': 0.01206896551724138, 'support': 1148}, 'Non_Argumentative': {'precision': 0.4130658436213992, 'recall': 0.9938118811881188, 'f1-score': 0.5835755813953489, 'support': 808}, 'accuracy': 0.41411042944785276, 'macro avg': {'precision': 0.4981995884773663, 'recall': 0.4999547210818643, 'f1-score': 0.29782227345629514, 'support': 1956}, 'weighted avg': {'precision': 0.5129978876854587, 'recall': 0.41411042944785276, 'f1-score': 0.24815145305789113, 'support': 1956}}
