In [2]:
from utils import  load_data
from linear_format import encode_data, extract_labels, NeuralNetwork, get_batches
import params
import random, torch, pickle
from torch import nn
import numpy as np
from transformers import BertForSequenceClassification
from tqdm import tqdm

In [3]:
device = torch.device(params.device)

test_data = load_data(params.data_path + "test_data.json", map_relations=params.map_relations)

input_text, labels, positions = encode_data(test_data, test = True)
input_ids, attention_masks, token_type_ids, tokens, position_ids = extract_labels(input_text, labels, positions, token = True)

position_ids = [position_id.clone().detach().to(device) for position_id in position_ids]

model_path = params.model_path + params.bert_name + '.pth'
print(model_path)

embedder = BertForSequenceClassification.from_pretrained(
    params.model_name, 
    output_attentions = False,
    output_hidden_states = True, attention_probs_dropout_prob=0, hidden_dropout_prob=0
)

checkpoint = torch.load(model_path, map_location=params.device)
embedder.load_state_dict(checkpoint['model_state_dict'])
embedder.to(device)
print('Loaded finetuned Bert model')

Loading data: data/stac_data/test_data.json
109 dialogs, 1156 edus, 1126 relations, 8 backward relations
77 edus have multiple parents
models/stac/bert_finetuned.pth


Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

Loaded finetuned Bert model


In [4]:
meta_data = []
for i in range(len(labels)):
    meta_data.append([[lb[2], lb[2]-lb[1]] for lb in labels[i]]) 

model_path = params.model_path + params.linear_name + '.pth'
linear = NeuralNetwork().to(device)
checkpoint = torch.load(model_path, map_location=params.device)
linear.load_state_dict(checkpoint['model_state_dict'])
linear.to(device)
linear.eval()

# consider only 10 preceding edus
indices = [max(len(input_id)-10, 0) for input_id in input_ids]
input_ids = [input_id[-10:] for input_id in input_ids]
attention_masks = [attention_mask[-10:] for attention_mask in attention_masks]
token_type_ids = [token_type_id[-10:] for token_type_id in token_type_ids]
position_ids = [position_id[-10:] for position_id in position_ids]

batches = get_batches(len(input_ids), params.batch_size_linear)


In [5]:
predictions = []
targets = []
pred_dict = {}
for batch in tqdm(batches) : 

    for i in batch:
        output = embedder(input_ids[i].to(device), 
                             token_type_ids=token_type_ids[i].to(device), 
                             attention_mask=attention_masks[i].to(device),
                             position_ids = position_ids[i].to(device),
                             return_dict=True)
        H_embed = torch.stack([torch.cat((output.hidden_states[-1][cand][0],torch.tensor(meta_data[i][cand]).to(device)),0) for cand in range(len(output.hidden_states[0]))])           

        H_embed = H_embed.to(device)
        logits = linear(H_embed).unsqueeze(0) 
        m = nn.Sigmoid()

        mod =(m(logits)).squeeze(-1).cpu().tolist()[0]
        xs = [i for i in range(len(mod)) if mod[i] > 0.81]  # 0.95
        if len(xs) == 0 : xs += [ np.argmax([float(p[0]) for p in logits[0].cpu()])]

        for pred_x in xs :
            pred_y = labels[i][pred_x][2]
            dialog_id = labels[i][0][0]

            if str(dialog_id) not in pred_dict.keys() : 
                pred_dict[str(dialog_id)] = [[pred_x + indices[i], pred_y]]
            else :
                pred_dict[str(dialog_id)] += [[pred_x + indices[i], pred_y]]

        predictions = [pred[1] for pred in pred_dict.items()]




100%|██████████| 66/66 [00:16<00:00,  4.00it/s]


In [6]:
test_pred = predictions
# compute the f1 score
test_truth = []
for dialogue in test_data:
    truth = []
    if len(dialogue['edus'])==1:
        continue
    for edu in dialogue['relations']:
        tup = (edu['x'],edu['y'], edu['type'])
        truth += [tup]
    test_truth += [truth]
    
cnt_pred = []
for dialog in test_pred:
    cnt_pred += [len(dialog)+1]

cnt_pred = sum(cnt_pred)
print('nb of total predictions : ', cnt_pred)

cnt_golden = []
for i,dialog in enumerate(test_data):  
    cnt_g = len(dialog['relations'])
    cnt = [0] * len(dialog["edus"])
    for r in dialog["relations"]:
        cnt[r["y"]] += 1
    for j in range(len(dialog["edus"])):
        if cnt[j] == 0:
            cnt_g += 1
    cnt_golden += [cnt_g]
cnt_golden = sum(cnt_golden)
print('nb of relations : ',cnt_golden) 

cnt_correct = []
for i, dialog_pred in enumerate(test_pred):
    val = 0
    truth = [j[:-1] for j in test_truth[i]]
    for pred in dialog_pred : 
        if tuple(pred) in truth:
            val += 1
    cnt_correct += [val+1]
cnt_correct = sum(list(cnt_correct))
print('nb of correct predictions : ', cnt_correct)

precision = cnt_correct*1.0/cnt_pred*1.0
recall = cnt_correct*1.0/cnt_golden*1.0
print('Precition :' , precision, '   Recall : ', recall)
print('F1 score : ' , 2*( precision * recall / (precision + recall)))

nb of total predictions :  1200
nb of relations :  1239
nb of correct predictions :  891
Precition : 0.7425    Recall :  0.7191283292978208
F1 score :  0.7306273062730627


In [6]:
with open(params.data_path + 'linear_pred_stac.pkl', 'wb') as f:
    pickle.dump(predictions, f)