In [None]:
from utils import  load_data
from bert_format import input_format, position_ids_compute, format_time, flat_accuracy
import params
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from transformers import BertForSequenceClassification
import torch, pickle
import numpy as np
from sklearn.metrics import classification_report, ConfusionMatrixDisplay, confusion_matrix
from torch import nn   
from multitask_format import MultiTaskModel, Task

device = torch.device(params.device)

test_data = load_data(params.data_path + "test_data.json", map_relations=params.map_relations)

num_labels = 17 
if params.data_path == "data/stac_squished_data/"  :
    num_labels = 18

attach_task = Task(id = 0, name = 'attach prediction', type = "seq_classification", num_labels=2)
relation_task = Task(id = 1, name = 'relation prediction', type = "seq_classification", num_labels = num_labels)
tasks = [attach_task, relation_task]

model = MultiTaskModel(params.model_name, tasks)
output_model = params.model_path + 'bert_multitask.pth'
print(output_model)
checkpoint = torch.load(output_model, map_location='cuda')
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
print('loaded')

### Prediction on gold attachments

In [None]:
input_ids, attention_masks, token_type_ids, tokens, labels_relation, labels, raw = input_format(test_data, relations=True)
position_ids = position_ids_compute(input_ids, raw, labels)
task_ids = torch.Tensor([1 for i in range(len(input_ids))])

prediction_data = TensorDataset(input_ids, attention_masks, token_type_ids, position_ids, labels_relation, task_ids)
prediction_sampler = SequentialSampler(prediction_data)
prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, batch_size=params.batch_size_bert)

print('Predicting labels for {:,} test sentences...'.format(len(input_ids)))

model.eval()

predictions , true_labels = [], []

for batch in prediction_dataloader:
    batch = tuple(t.to(device) for t in batch)
    b_input_ids, b_input_mask, b_token_types, b_position_ids, b_labels, b_task_ids = batch

    with torch.no_grad():
        outputs, embed = model(b_input_ids, 
                     token_type_ids=b_token_types, 
                     attention_mask=b_input_mask,
                     position_ids = b_position_ids,
                     task_ids = b_task_ids)
    logits = outputs[0]

    logits = logits.detach().cpu().numpy()
    label_ids = b_labels.to('cpu').numpy()

    predictions.append(logits)
    true_labels.append(label_ids)

print('    DONE.')

flat_predictions = np.concatenate(predictions, axis=0)
flat_predictions = np.argmax(flat_predictions, axis=1).flatten()
flat_true_labels = np.concatenate(true_labels, axis=0)
print(classification_report(flat_true_labels,flat_predictions))

cm = confusion_matrix(flat_true_labels,flat_predictions)
ConfusionMatrixDisplay(cm).plot()

### Relation prediction on the predicted attachments

In [None]:
with open(params.data_path + 'linear_pred_stac.pkl', 'rb') as f:
    test_pred = pickle.load(f)
input_ids, attention_masks, token_type_ids, tokens, labels_attach, labels, raw = input_format(test_data, relations=True, attach_preds=test_pred)
position_ids = position_ids_compute(input_ids, raw, labels)
task_ids = torch.Tensor([1 for i in range(len(input_ids))])

prediction_data = TensorDataset(input_ids, attention_masks, token_type_ids, position_ids, task_ids)
prediction_sampler = SequentialSampler(prediction_data)
prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, batch_size=params.batch_size_bert)


model.eval()

predictions , true_labels = [], []

for batch in prediction_dataloader:
    batch = tuple(t.to(device) for t in batch)

    b_input_ids, b_input_mask, b_token_types, b_position_ids, b_task_ids = batch

    with torch.no_grad():
        outputs, embed = model(b_input_ids, 
                     token_type_ids=b_token_types, 
                     attention_mask=b_input_mask,
                     position_ids = b_position_ids,
                     task_ids = b_task_ids)
    logits = outputs[0]
    logits = logits.detach().cpu().numpy()

    predictions.append(logits)

print('    DONE.')

flat_prediction = np.concatenate(predictions, axis=0)
flat_predictions = np.argmax(flat_prediction, axis=1).flatten()

In [None]:
counter = 0
for dialog in test_pred: 
    for pair in dialog:
        pair += [flat_predictions[counter]]
        counter += 1

# compute the f1 score
test_truth = []
for dialogue in test_data:
    truth = []
    if len(dialogue['edus'])==1:
        continue
    for edu in dialogue['relations']:
        tup = (edu['x'],edu['y'], edu['type'])
        truth += [tup]
    test_truth += [truth]

cnt_pred = []
for dialog in test_pred:
    cnt_pred += [len(dialog)+1]

cnt_pred = sum(cnt_pred)
print('nb of total predictions : ', cnt_pred)

cnt_golden = []
for i,dialog in enumerate(test_data):  
    cnt_g = len(dialog['relations'])
    cnt = [0] * len(dialog["edus"])
    for r in dialog["relations"]:
        cnt[r["y"]] += 1
    for j in range(len(dialog["edus"])):
        if cnt[j] == 0:
            cnt_g += 1
    cnt_golden += [cnt_g]
cnt_golden = sum(cnt_golden)
print('nb of relations : ',cnt_golden) 

cnt_correct = []
for i, dialog_pred in enumerate(test_pred):
    val = 0
    dialog_pred = [j for j in dialog_pred]
    truth = [j for j in test_truth[i]]
    for pred in dialog_pred : 
        if tuple(pred) in truth:
            val += 1
    cnt_correct += [val+1]
cnt_correct = sum(list(cnt_correct))
print('nb of correct predictions : ', cnt_correct)

precision = cnt_correct*1.0/cnt_pred*1.0
recall = cnt_correct*1.0/cnt_golden*1.0
print('precision : ', precision )
print('recall : ', recall)
print('F1 score : ' , 2*( precision * recall / (precision + recall)))