In [11]:
import torch
from transformers import BertTokenizer, BertConfig
from torch.utils.data import DataLoader
from torch import nn, optim
from functions import *
from model import JointBert
from utils import *


config = BertConfig.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

train_dev_data = load_data('dataset/ATIS/train.json')
test_data = load_data('dataset/ATIS/test.json')
train_raw, dev_raw, test_raw, y_train, y_dev, y_test = get_splits(train_dev_data, test_data)

corpus = train_raw + dev_raw + test_raw # We do not wat unk labels, 
                                        # however this depends on the research purpose
words = sum([x['utterance'].split() for x in train_raw], []) # No set() since we want to compute
slots = set(sum([line['slots'].split() for line in corpus],[]))
intents = set([line['intent'] for line in corpus])

lang = Lang(words, intents, slots, cutoff=0)
out_slot = len(lang.slot2id)
out_int = len(lang.intent2id)


train_dataset = IntentsAndSlots(train_raw, lang, tokenizer)
dev_dataset = IntentsAndSlots(dev_raw, lang, tokenizer)
test_dataset = IntentsAndSlots(test_raw, lang, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=32, collate_fn=collate_fn,  shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=64, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64, collate_fn=collate_fn)

lr = 2e-5 # learning rate
clip = 1.0 # Clip the gradient
dropout = 0.1 # Dropout rate
model = JointBert(config, out_slot, out_int, dropout=dropout).to(device)
model.bert.resize_token_embeddings(len(tokenizer))

optimizer = optim.AdamW(model.parameters(), lr=lr)
criterion_slots = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
criterion_intents = nn.CrossEntropyLoss() # Because we do not have the pad token



In [12]:
tgt = [17, 21, 25]
path = 'bin/model{}.pth'




In [22]:
from conll import evaluate
import numpy as np
def eval_loop(data, criterion_slots, criterion_intents, model, lang):
    model.eval()
    loss_array = []
    
    ref_intents = []
    hyp_intents = []
    
    ref_slots = []
    hyp_slots = []
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    special_tokens = [tokenizer.pad_token_id, tokenizer.cls_token_id, tokenizer.sep_token_id]
    #softmax = nn.Softmax(dim=1) # Use Softmax if you need the actual probability
    with torch.no_grad(): # It used to avoid the creation of computational graph
        for sample in data:
            input_ids, attention_mask, token_type_ids = sample['input_ids'].to(device), sample['attention_mask'].to(device), sample['token_type_ids'].to(device)
            slot_labels = sample['slot_labels'].to(device)
            intent_labels = sample['intent_labels'].to(device)
            slots, intents = model(input_ids, attention_mask, token_type_ids)
            loss_intent = criterion_intents(intents, intent_labels)
            loss_slot = criterion_slots(slots, slot_labels)
            loss = loss_intent + loss_slot 
            loss_array.append(loss.item())

            # Intent inference
            # Get the highest probable class
            out_intents = [lang.id2intent[x] 
                           for x in torch.argmax(intents, dim=1).tolist()] 
            gt_intents = [lang.id2intent[x] for x in sample['intent_labels'].tolist()]
            ref_intents.extend(gt_intents)
            hyp_intents.extend(out_intents)
            
            # Slot inference 
            output_slots = torch.argmax(slots, dim=1)
            if -1 in output_slots:
                print("Error in slot labels")
                print(output_slots.shape)
                print(output_slots)

            # check this part
            for id_seq, seq in enumerate(output_slots):
                mask = ~np.isin(slot_labels[id_seq].cpu().numpy(), special_tokens)
                indices = list(np.where(mask)[0])
                utt_ids = [input_ids[id_seq][i].item() for i in indices]
                gt_ids = [slot_labels[id_seq][i].item() for i in indices]
                gt_slots = [lang.id2slot[x] for x in gt_ids]
                utterance = tokenizer.convert_ids_to_tokens(utt_ids)
                to_decode = [seq[i].item() for i in indices]

                ref_slots.append([(utterance[id_el], elem) for id_el, elem in enumerate(gt_slots)])
                tmp_seq = []
                for id_el, elem in enumerate(to_decode):
                    tmp_seq.append((utterance[id_el], lang.id2slot[elem]))
                hyp_slots.append(tmp_seq)
                if len(ref_slots[id_seq]) != len(hyp_slots[id_seq]):
                    print("Error in slot labels")

    try:            
        results = evaluate(ref_slots, hyp_slots)
    except Exception as ex:
        # Sometimes the model predicts a class that is not in REF
        print("Warning:", ex)
        ref_s = set([x[1] for x in ref_slots])
        hyp_s = set([x[1] for x in hyp_slots])
        print(hyp_s.difference(ref_s))
        results = {"total":{"f":0}}
        
    report_intent = classification_report(ref_intents, hyp_intents, 
                                          zero_division=False, output_dict=True)
    return results, report_intent, loss_array

In [18]:
from model import JointBert
state_dict = torch.load(path.format(1))['model']
model = JointBert(config, out_slot, out_int, dropout=dropout).to(device)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [36]:
from model import JointBert
F1_scores = []
intent_accuracies = []

for i in range(0, 30):
    model_state_dict, optim = torch.load(path.format(i))
    
    model.eval()
    
    results_test, intent_results, _ = eval_loop(dev_loader, criterion_slots, criterion_intents, model, lang)

    F1_scores.append(results_test['total']['f'])
    intent_accuracies.append(intent_results['accuracy'])

In [38]:
intent_results


{'abbreviation': {'precision': 0.0,
  'recall': 0.0,
  'f1-score': 0.0,
  'support': 15.0},
 'aircraft': {'precision': 0.0,
  'recall': 0.0,
  'f1-score': 0.0,
  'support': 8.0},
 'aircraft+flight+flight_no': {'precision': 0.0,
  'recall': 0.0,
  'f1-score': 0.0,
  'support': 0.0},
 'airfare': {'precision': 0.0,
  'recall': 0.0,
  'f1-score': 0.0,
  'support': 42.0},
 'airfare+flight': {'precision': 0.0,
  'recall': 0.0,
  'f1-score': 0.0,
  'support': 0.0},
 'airfare+flight_time': {'precision': 0.0,
  'recall': 0.0,
  'f1-score': 0.0,
  'support': 0.0},
 'airline': {'precision': 0.0,
  'recall': 0.0,
  'f1-score': 0.0,
  'support': 16.0},
 'airline+flight_no': {'precision': 0.0,
  'recall': 0.0,
  'f1-score': 0.0,
  'support': 0.0},
 'airport': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 2.0},
 'capacity': {'precision': 0.0,
  'recall': 0.0,
  'f1-score': 0.0,
  'support': 2.0},
 'city': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 2.0},
 'distance':

In [37]:
F1_scores , intent_accuracies

([0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902,
  0.011847361895577902],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0])