In [1]:
from transformers import AutoTokenizer, BertModel, BertConfig
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
def idx_to_tag(idx, entities, type_labels):
    
    for entity in entities:
        if idx in entity['span']:
            return type_labels.index(entity['type'])
    
    if idx == None:
        return -100
        
    if set_none:
        return type_labels.index('None')
    else:
        return -100

def entity_to_tag(encoding, entities, type_labels):
    
    y = encoding.word_ids[1:]
    y = [idx_to_tag(index, entities, type_labels) for index in y]
            
    return y
        

In [60]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import pandas as pd
import json


class MyData(Dataset):
    
    def __init__(self, filename):
        self.df = None
        self.type_labels = None
        self.intent_labels = None

        with open(filename) as file:
            data = json.load(file)
            self.df = pd.DataFrame(data)
            types = pd.json_normalize(data, record_path='entities')
            _, type_labels = pd.factorize(types['type'])
            self.type_labels = type_labels.to_numpy().tolist()

        self.df['intent'], intent_labels = pd.factorize(self.df['intent'])
        self.intent_labels = intent_labels.to_numpy().tolist()
        if set_none:
            self.type_labels.append('None')
        
        
        self.y_label = self.df['intent']
        self.x = self.df['sentence']
        self.y = self.df['entities']
        
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
        
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self,idx):
        return self.x[idx], (self.y_label[idx], self.y[idx])
    

class MyCollate(object):
    
    def __init__(self, type_labels, tokenizer):
        self.type_labels = type_labels
        self.tokenizer = tokenizer
        
    def __call__(self, batch):
        
        x = [entry[0] for entry in batch]
        y = [entry[1][1] for entry in batch]
        y_label = [entry[1][0] for entry in batch]
        
        out_y = []
        
        out_x = self.tokenizer(x, padding=True, return_tensors='pt')
        for index in range(len(x)):
            out_y.append(entity_to_tag(out_x[index], y[index], self.type_labels))
        
        y_label = torch.tensor(y_label)
        out_y = torch.tensor(out_y)
        return out_x, (y_label, out_y)    

myData = MyData('./archive/slurp/train.json')
myCollate = MyCollate(myData.type_labels, myData.tokenizer)

train_dataloader = DataLoader(myData, batch_size=64, shuffle=True, collate_fn=myCollate)

In [61]:
class Model(nn.Module):
    
    def __init__(self, model_name, num_intent_classes, num_tag_classes):
        super(Model, self).__init__()
              
        self.bert = BertModel.from_pretrained(model_name)
        self.config = BertConfig.from_pretrained(model_name)       
        self.intent_classifier = nn.Linear(in_features=self.config.hidden_size, out_features=num_intent_classes)   
        
        torch.nn.init.xavier_normal_(self.intent_classifier.weight.data)
        torch.nn.init.uniform_(self.intent_classifier.bias.data)
        
        
    def forward(self,inputs):
        output = self.bert(**inputs)
        intent = self.intent_classifier(output.last_hidden_state[:,0])
        intent = F.sigmoid(intent)

        return intent

In [6]:
count_y_label = [0] * len(myData.intent_labels)
count_y_tag = [0] * len(myData.type_labels)

for train_x, (train_y_label, train_y_tag) in train_dataloader:
    for idx in range(len(train_y_label)):
        count_y_label[train_y_label[idx]] += 1 
        for k in range(len(train_y_tag[idx])):
            if train_y_tag[idx][k] == -100:
                continue
            count_y_tag[train_y_tag[idx][k]] += 1

In [62]:
learning_rate = 1e-6
num_epoches = 5
epoch = 0


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


myModel = Model("bert-base-cased", len(myData.intent_labels), len(myData.type_labels))

intent_criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(params=myModel.parameters(), lr=learning_rate, weight_decay=1e-5)

intent_error = []


cuda:0


Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [63]:
myModel.train()
myModel.to(device)

while epoch < num_epoches:
     
    print('Epoch : %d' % (epoch+1))
    intent_error.append(0)
    num_iter = 0
    
    for train_x, (train_y_label, train_y_tag) in train_dataloader:
        
        train_x = train_x.to(device)
        train_y_label = train_y_label.to(device)
        
        
        intent = myModel(train_x)       
        
        intent_loss = intent_criterion(intent, train_y_label)       
        
        
        optimizer.zero_grad()
        intent_loss.backward()
        optimizer.step()
        
        num_iter+=1
        
        intent_error[-1] += intent_loss.item()        
    
    intent_error[-1] /= num_iter
    print('intent error: %f' % (intent_error[-1]))
          
    epoch+=1
    

Epoch : 1
intent error: 4.476026
Epoch : 2
intent error: 4.291128
Epoch : 3
intent error: 4.182344
Epoch : 4
intent error: 4.107389
Epoch : 5
intent error: 4.047712


In [64]:
myModel.eval()

test_input = "what time is it?"
test_input_2 = "who are you"
test_input = myData.tokenizer(test_input, padding=True, return_tensors='pt').to(device)
test_input_2 = myData.tokenizer(test_input_2, padding=True, return_tensors='pt').to(device)

intent = myModel(test_input)
intent2 = myModel(test_input_2)
print(myData.intent_labels[intent.argmax()], myData.intent_labels[intent2.argmax()])

datetime_query calendar_query


In [38]:
def fill_slot(tokenizer, token_tag, tokens):  
    start_idx = None
    prev_tag = 'None'
    ret = []  
    
    for idx in range(len(token_tag)):
        
        if token_tag[idx] == prev_tag:
            continue
        
        elif prev_tag == 'None':
            start_idx = idx
        
        elif prev_tag != 'None':
            
            if token_tag[idx] == 'None':
                ret.append({prev_tag : tokenizer.convert_tokens_to_string(tokens[start_idx:idx])})
            
            else:
                ret.append({prev_tag : tokenizer.convert_tokens_to_string(tokens[start_idx:idx])})
                start_idx = idx
        
        prev_tag = token_tag[idx]
        
        
    return ret   

In [40]:
myModel.eval()
myModel.to(device)

test_input = "who are you asdffe"

test_input = myData.tokenizer(test_input, padding=True, return_tensors='pt').to(device)


intent = myModel(test_input)
print(intent)
print(myData.intent_labels[intent[i].argmax()])


"""
out_label=[]
out_tag=[]
for i in range(intent.shape[0]):
    out_label.append(myData.intent_labels[intent[i].argmax()])
    out_tag_token = [myData.type_labels[k.argmax()] for k in tag[i]]
    print(out_tag_token)
    out_tag.append(fill_slot(myData.tokenizer, out_tag_token, test_input[i].tokens[1:]))
    
    #out_tag.append([myData.type_labels[k.argmax()] for k in tag[i]])
    
print(out_label, out_tag)
"""

tensor([[ 1.6555e+00,  1.1889e-03, -6.2424e-01,  1.3413e+00, -1.8278e-01,
         -2.9840e-01,  3.1567e-01,  1.1413e+00,  3.0547e-01, -1.1709e-01,
          2.3932e-01,  1.6241e+00,  5.2523e-01,  7.2117e-01, -6.0465e-01,
          2.1642e-01, -2.6634e-01, -8.9576e-01,  1.2211e+00,  1.5316e+00,
         -1.0550e-01,  1.4576e+00,  1.4376e+00, -2.3883e-01,  2.4447e-01,
         -5.1494e-01, -1.8882e+00,  6.1722e-01, -7.7090e-02,  6.4463e-01,
         -6.3426e-01, -2.7871e-01,  8.3683e-01, -1.4347e-01,  5.0025e-01,
         -9.2164e-01,  9.4688e-01, -6.7938e-01, -2.0005e-01,  4.9957e-01,
         -5.6478e-01, -1.3447e-01,  4.3729e-02,  1.0152e+00,  3.2196e-01,
         -1.9983e-01,  2.0087e-01,  1.3855e-01,  8.7111e-03,  2.1781e-01,
         -7.3992e-01, -6.3566e-01, -1.6820e+00, -2.7653e+00, -8.6095e-01,
         -2.1530e-01, -2.1369e+00, -1.0667e+00, -7.2223e-01, -1.2805e+00,
         -3.2830e+00, -2.3454e+00, -3.4593e+00, -3.8096e+00, -2.9950e+00,
         -4.1624e+00, -3.0887e+00, -2.

'\nout_label=[]\nout_tag=[]\nfor i in range(intent.shape[0]):\n    out_label.append(myData.intent_labels[intent[i].argmax()])\n    out_tag_token = [myData.type_labels[k.argmax()] for k in tag[i]]\n    print(out_tag_token)\n    out_tag.append(fill_slot(myData.tokenizer, out_tag_token, test_input[i].tokens[1:]))\n    \n    #out_tag.append([myData.type_labels[k.argmax()] for k in tag[i]])\n    \nprint(out_label, out_tag)\n'