In [1]:
from transformers import AutoTokenizer, BertModel, BertConfig
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
def idx_to_tag(idx, entities, type_labels):
    
    for entity in entities:
        if idx in entity['span']:
            return type_labels.index(entity['type'])
    
    if idx == None:
        return -100
        
    if set_none:
        return type_labels.index('None')
    else:
        return -100

def entity_to_tag(encoding, entities, type_labels):
    
    y = encoding.word_ids[1:]
    y = [idx_to_tag(index, entities, type_labels) for index in y]
            
    return y
        

In [32]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import pandas as pd
import json


class MyData(Dataset):
    
    def __init__(self, filename):
        self.df = None
        self.type_labels = None
        self.intent_labels = None

        with open(filename) as file:
            data = json.load(file)
            self.df = pd.DataFrame(data)
            types = pd.json_normalize(data, record_path='entities')
            _, type_labels = pd.factorize(types['type'])
            self.type_labels = type_labels.to_numpy().tolist()

        self.df['intent'], intent_labels = pd.factorize(self.df['intent'])
        self.intent_labels = intent_labels.to_numpy().tolist()
        if set_none:
            self.type_labels.append('None')
        
        
        self.y_label = self.df['intent']
        self.x = self.df['sentence']
        self.y = self.df['entities']
        
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
        
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self,idx):
        return self.x[idx], (self.y_label[idx], self.y[idx])
    

class MyCollate(object):
    
    def __init__(self, type_labels, tokenizer):
        self.type_labels = type_labels
        self.tokenizer = tokenizer
        
    def __call__(self, batch):
        
        x = [entry[0] for entry in batch]
        y = [entry[1][1] for entry in batch]
        y_label = [entry[1][0] for entry in batch]
        
        out_y = []
        
        out_x = self.tokenizer(x, padding=True, return_tensors='pt')
        for index in range(len(x)):
            out_y.append(entity_to_tag(out_x[index], y[index], self.type_labels))
        
        y_label = torch.tensor(y_label)
        out_y = torch.tensor(out_y)
        return out_x, (y_label, out_y)    

myData = MyData('./archive/slurp/train.json')
myCollate = MyCollate(myData.type_labels, myData.tokenizer)

train_dataloader = DataLoader(myData, batch_size=64, shuffle=True, collate_fn=myCollate)

In [34]:
class Model(nn.Module):
    
    def __init__(self, model_name, num_intent_classes, num_tag_classes):
        super(Model, self).__init__()
              
        self.bert = BertModel.from_pretrained(model_name)
        self.config = BertConfig.from_pretrained(model_name)       
        self.intent_classifier = nn.Linear(in_features=self.config.hidden_size, out_features=num_intent_classes)   
        
        torch.nn.init.xavier_normal_(self.intent_classifier.weight.data)
        torch.nn.init.uniform_(self.intent_classifier.bias.data)
        
        
    def forward(self,inputs):
        output = self.bert(**inputs)
        intent = self.intent_classifier(output.last_hidden_state[:,0])
        intent = F.sigmoid(intent)

        return intent

In [6]:
count_y_label = [0] * len(myData.intent_labels)
count_y_tag = [0] * len(myData.type_labels)

for train_x, (train_y_label, train_y_tag) in train_dataloader:
    for idx in range(len(train_y_label)):
        count_y_label[train_y_label[idx]] += 1 
        for k in range(len(train_y_tag[idx])):
            if train_y_tag[idx][k] == -100:
                continue
            count_y_tag[train_y_tag[idx][k]] += 1

In [35]:
learning_rate = 1e-3
num_epoches = 5
epoch = 0


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


myModel = Model("bert-base-uncased", len(myData.intent_labels), len(myData.type_labels))

intent_criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(params=myModel.parameters(), lr=learning_rate, weight_decay=1e-5)

intent_error = []


cuda:0


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
myModel.train()
myModel.to(device)

while epoch < num_epoches:
     
    print('Epoch : %d' % (epoch+1))
    intent_error.append(0)
    num_iter = 0
    
    for train_x, (train_y_label, train_y_tag) in train_dataloader:
        
        train_x = train_x.to(device)
        train_y_label = train_y_label.to(device)
        
        
        intent = myModel(train_x)
        
        intent_loss = intent_criterion(intent, train_y_label)       
        
        
        optimizer.zero_grad()
        intent_loss.backward()
        optimizer.step()
        
        num_iter+=1
        
        intent_error[-1] += intent_loss.item()        
    
    intent_error[-1] /= num_iter
    print('intent error: %f' % (intent_error[-1]))
          
    epoch+=1
    

Epoch : 1
[8, 29, 33, 7, 5, 16, 22, 61, 22, 3, 32, 47, 3, 0, 45, 55, 7, 42, 3, 12, 27, 44, 8, 24, 42, 3, 13, 27, 50, 32, 27, 28, 11, 48, 15, 18, 21, 4, 7, 7, 3, 7, 13, 15, 3, 22, 12, 43, 49, 20, 32, 25, 43, 7, 39, 58, 48, 13, 12, 16, 19, 4, 29, 29]
[18, 0, 4, 20, 29, 11, 43, 2, 3, 39, 21, 23, 22, 28, 40, 11, 36, 23, 42, 18, 19, 39, 7, 11, 31, 35, 18, 11, 18, 4, 22, 35, 22, 15, 41, 21, 20, 49, 50, 10, 29, 51, 50, 39, 45, 8, 52, 34, 21, 0, 11, 0, 8, 8, 7, 23, 0, 16, 22, 11, 24, 51, 46, 44]
[21, 3, 11, 18, 15, 33, 26, 11, 31, 2, 19, 6, 15, 35, 5, 6, 35, 27, 38, 19, 19, 0, 10, 11, 21, 18, 36, 78, 36, 19, 41, 12, 18, 22, 0, 18, 3, 22, 38, 23, 40, 40, 18, 19, 46, 19, 29, 34, 0, 48, 9, 0, 0, 29, 0, 34, 11, 39, 7, 11, 39, 11, 24, 39]
[40, 13, 8, 13, 10, 12, 39, 40, 32, 29, 11, 21, 8, 11, 36, 43, 19, 29, 0, 11, 27, 44, 29, 32, 57, 29, 22, 15, 12, 28, 3, 32, 58, 18, 29, 43, 8, 55, 7, 22, 15, 41, 11, 22, 36, 7, 39, 13, 32, 11, 46, 0, 39, 31, 22, 5, 26, 55, 45, 43, 27, 43, 32, 21]
[3, 39, 38, 7, 1

[15, 46, 22, 49, 32, 43, 6, 16, 19, 22, 43, 44, 43, 22, 50, 11, 12, 22, 0, 47, 12, 36, 37, 39, 3, 0, 21, 3, 16, 9, 11, 39, 3, 7, 15, 21, 22, 57, 22, 12, 11, 20, 0, 1, 29, 0, 0, 20, 0, 39, 12, 32, 19, 18, 19, 47, 46, 43, 13, 36, 6, 67, 29, 24]
[0, 0, 39, 5, 21, 43, 10, 7, 10, 46, 49, 5, 3, 21, 22, 22, 23, 21, 40, 32, 3, 7, 24, 51, 0, 11, 47, 47, 18, 61, 6, 11, 39, 13, 31, 42, 39, 39, 21, 49, 0, 22, 34, 0, 48, 22, 48, 10, 40, 0, 18, 39, 40, 3, 44, 9, 22, 32, 11, 40, 7, 41, 35, 18]
[29, 42, 43, 30, 10, 18, 3, 36, 23, 52, 12, 7, 22, 19, 19, 32, 33, 43, 13, 35, 39, 29, 21, 48, 8, 43, 18, 21, 39, 18, 18, 9, 28, 32, 15, 43, 23, 18, 20, 22, 0, 32, 33, 11, 18, 23, 2, 0, 1, 32, 21, 22, 32, 36, 17, 3, 43, 34, 10, 7, 28, 7, 0, 25]
[13, 38, 54, 7, 57, 21, 23, 29, 36, 18, 21, 23, 23, 34, 59, 22, 22, 3, 22, 43, 10, 19, 49, 0, 35, 45, 32, 36, 29, 21, 33, 15, 33, 22, 11, 6, 18, 39, 11, 18, 0, 54, 15, 18, 16, 55, 0, 34, 21, 11, 3, 3, 28, 18, 39, 54, 22, 7, 11, 39, 18, 22, 0, 24]
[14, 11, 19, 0, 19, 8, 2

[43, 22, 49, 13, 47, 21, 42, 11, 18, 14, 7, 18, 38, 11, 9, 36, 37, 41, 43, 0, 34, 11, 5, 19, 32, 46, 61, 3, 37, 7, 0, 11, 19, 45, 11, 11, 46, 43, 41, 39, 5, 0, 0, 19, 36, 3, 31, 22, 33, 0, 41, 19, 13, 12, 22, 24, 42, 2, 24, 6, 42, 11, 3, 14]
[11, 55, 39, 11, 26, 3, 3, 21, 19, 19, 23, 34, 0, 33, 39, 29, 18, 84, 7, 43, 32, 54, 3, 39, 19, 15, 12, 0, 15, 3, 3, 19, 4, 50, 0, 46, 1, 45, 41, 0, 3, 20, 6, 0, 23, 21, 3, 36, 3, 3, 25, 21, 6, 3, 30, 6, 34, 18, 43, 24, 19, 13, 25, 36]
[22, 19, 23, 43, 14, 0, 0, 43, 34, 42, 27, 29, 36, 59, 41, 24, 9, 18, 8, 38, 32, 47, 46, 34, 59, 57, 39, 21, 21, 48, 19, 7, 18, 3, 19, 9, 12, 21, 22, 30, 19, 4, 47, 18, 11, 44, 18, 43, 37, 7, 0, 15, 21, 22, 24, 21, 18, 46, 0, 77, 18, 58, 36, 28]
[18, 3, 38, 32, 15, 34, 19, 0, 8, 7, 7, 27, 1, 22, 42, 18, 39, 39, 33, 32, 19, 7, 55, 49, 41, 13, 12, 19, 46, 11, 0, 43, 49, 38, 19, 0, 12, 11, 29, 15, 24, 13, 10, 38, 39, 0, 21, 0, 58, 0, 48, 18, 1, 57, 19, 8, 54, 18, 2, 7, 41, 32, 18, 24]
[31, 20, 55, 43, 18, 18, 22, 22, 43

[7, 3, 48, 42, 30, 0, 32, 4, 27, 23, 24, 19, 18, 41, 21, 21, 3, 13, 24, 10, 3, 22, 3, 44, 18, 13, 11, 22, 36, 44, 36, 49, 22, 7, 19, 31, 22, 9, 71, 58, 21, 11, 42, 21, 18, 31, 12, 51, 28, 54, 0, 66, 55, 3, 49, 38, 18, 11, 18, 37, 25, 12, 8, 18]
[20, 11, 10, 11, 21, 47, 12, 32, 29, 43, 45, 41, 15, 11, 43, 24, 43, 19, 18, 2, 36, 27, 3, 42, 67, 19, 35, 8, 22, 32, 11, 39, 18, 39, 28, 0, 11, 3, 27, 18, 34, 24, 49, 38, 46, 7, 29, 0, 23, 34, 11, 22, 18, 27, 39, 22, 21, 18, 16, 18, 23, 11, 0, 22]
[29, 47, 19, 22, 13, 11, 45, 43, 17, 36, 0, 24, 0, 4, 9, 19, 0, 45, 44, 36, 39, 19, 11, 18, 0, 11, 18, 48, 83, 18, 38, 21, 33, 29, 36, 21, 19, 18, 21, 8, 46, 32, 17, 13, 3, 20, 11, 40, 8, 3, 34, 40, 54, 47, 20, 39, 41, 23, 49, 45, 44, 0, 29, 47]
[47, 66, 34, 11, 47, 39, 11, 3, 39, 76, 29, 23, 0, 18, 39, 45, 29, 2, 32, 8, 3, 47, 5, 34, 46, 36, 21, 28, 3, 21, 44, 11, 8, 21, 39, 0, 48, 21, 10, 42, 22, 17, 31, 21, 22, 22, 77, 18, 0, 32, 7, 44, 25, 29, 29, 7, 18, 34, 38, 27, 48, 0, 50, 44]
[11, 34, 17, 17,

[26, 43, 15, 36, 34, 39, 3, 47, 27, 27, 13, 46, 0, 0, 43, 23, 18, 49, 66, 43, 25, 21, 38, 20, 36, 52, 6, 37, 3, 5, 25, 25, 51, 10, 0, 3, 0, 49, 22, 7, 43, 0, 3, 34, 27, 47, 29, 48, 49, 28, 19, 10, 7, 18, 45, 22, 34, 8, 32, 15, 11, 8, 2, 0]
[22, 21, 21, 27, 49, 48, 11, 3, 4, 21, 9, 15, 19, 49, 18, 49, 43, 19, 11, 22, 39, 51, 21, 30, 40, 45, 18, 22, 19, 9, 5, 4, 11, 29, 18, 29, 38, 34, 54, 14, 43, 0, 16, 3, 11, 11, 27, 11, 20, 24, 0, 7, 22, 44, 15, 22, 11, 36, 5, 19, 31, 29, 9, 51]
[0, 20, 41, 0, 4, 49, 45, 5, 36, 48, 34, 24, 47, 11, 21, 22, 49, 48, 7, 0, 29, 4, 41, 11, 19, 32, 13, 7, 31, 19, 39, 22, 40, 24, 29, 0, 32, 20, 42, 19, 1, 9, 36, 22, 22, 19, 25, 48, 24, 29, 22, 8, 43, 0, 50, 8, 13, 24, 51, 7, 11, 55, 16, 47]
[16, 19, 46, 10, 43, 0, 9, 8, 8, 43, 34, 49, 47, 9, 40, 21, 11, 21, 22, 55, 18, 11, 48, 46, 64, 33, 21, 23, 0, 18, 24, 3, 5, 22, 36, 1, 21, 6, 21, 0, 0, 36, 31, 11, 19, 24, 47, 43, 28, 29, 21, 48, 7, 30, 11, 16, 58, 0, 11, 42, 8, 32, 24, 33]
[5, 19, 3, 18, 32, 0, 28, 49, 2

[22, 21, 55, 59, 32, 28, 43, 39, 28, 44, 29, 7, 27, 19, 12, 43, 18, 46, 8, 18, 49, 22, 9, 3, 6, 49, 19, 7, 48, 43, 29, 28, 19, 49, 49, 52, 42, 22, 57, 3, 11, 55, 39, 3, 21, 0, 19, 40, 34, 51, 5, 19, 34, 11, 34, 23, 0, 3, 30, 0, 22, 36, 5, 44]
[11, 32, 21, 46, 37, 38, 7, 3, 19, 32, 11, 12, 18, 11, 44, 33, 7, 11, 29, 15, 29, 15, 22, 32, 51, 3, 18, 47, 21, 42, 3, 44, 47, 27, 32, 0, 7, 0, 3, 48, 58, 12, 47, 0, 29, 0, 17, 34, 0, 36, 39, 18, 0, 52, 28, 12, 25, 0, 3, 21, 3, 32, 3, 4]
[27, 3, 32, 3, 9, 21, 6, 58, 54, 28, 21, 3, 5, 32, 18, 11, 41, 22, 7, 32, 32, 10, 12, 22, 3, 44, 21, 11, 55, 7, 39, 19, 43, 11, 33, 28, 21, 0, 26, 18, 0, 19, 4, 46, 45, 22, 19, 9, 36, 5, 43, 39, 21, 32, 43, 4, 19, 3, 28, 0, 0, 7, 7, 9]
[8, 19, 40, 22, 42, 8, 18, 52, 0, 21, 55, 34, 23, 40, 21, 18, 27, 48, 3, 0, 11, 34, 8, 8, 7, 11, 21, 0, 6, 73, 13, 1, 18, 8, 18, 29, 18, 7, 23, 37, 36, 32, 15, 42, 32, 45, 3, 18, 18, 12, 17, 23, 41, 29, 13, 21, 0, 21, 29, 22, 0, 34, 29, 44]
[34, 24, 19, 43, 0, 3, 13, 0, 5, 11, 16, 

[8, 15, 36, 43, 0, 38, 38, 36, 25, 43, 21, 39, 11, 7, 50, 29, 0, 45, 0, 43, 8, 13, 11, 18, 43, 38, 31, 27, 4, 3, 18, 6, 35, 0, 50, 40, 13, 85, 6, 32, 36, 24, 36, 30, 13, 22, 37, 11, 26, 19, 21, 18, 59, 21, 36, 43, 19, 32, 11, 34, 8, 0, 22, 11]
[12, 0, 43, 2, 39, 12, 22, 49, 29, 5, 39, 19, 6, 21, 11, 55, 18, 47, 11, 46, 46, 19, 11, 20, 45, 23, 34, 29, 34, 19, 6, 5, 59, 55, 37, 3, 3, 23, 33, 5, 36, 25, 43, 11, 0, 43, 18, 18, 1, 22, 19, 3, 18, 22, 15, 8, 46, 61, 22, 57, 65, 15, 46, 54]
[0, 55, 11, 67, 48, 17, 0, 50, 29, 22, 48, 30, 7, 21, 29, 41, 20, 19, 40, 40, 39, 18, 40, 11, 21, 26, 16, 34, 54, 39, 3, 22, 3, 32, 11, 19, 32, 4, 24, 4, 50, 46, 39, 32, 19, 45, 33, 22, 15, 0, 21, 0, 27, 44, 54, 10, 43, 21, 6, 49, 38, 49, 6, 27]
[11, 22, 32, 36, 22, 13, 0, 11, 36, 56, 44, 21, 29, 0, 48, 50, 36, 34, 21, 43, 43, 21, 49, 22, 24, 21, 11, 18, 18, 21, 59, 78, 18, 10, 27, 19, 13, 22, 21, 33, 44, 31, 48, 31, 0, 32, 0, 43, 7, 11, 46, 50, 10, 27, 36, 7, 19, 6, 52, 7, 6, 0, 34, 8]
[58, 22, 20, 18, 32,

[44, 39, 5, 11, 11, 38, 29, 26, 38, 36, 38, 19, 49, 32, 18, 39, 7, 29, 21, 50, 0, 32, 11, 32, 1, 18, 12, 15, 42, 34, 45, 14, 34, 47, 39, 43, 7, 3, 49, 29, 36, 43, 57, 7, 19, 21, 48, 13, 19, 47, 41, 18, 0, 55, 61, 0, 13, 13, 20, 7, 30, 3, 3, 3]
[11, 33, 49, 11, 40, 11, 19, 0, 21, 7, 56, 48, 18, 7, 21, 37, 42, 23, 22, 21, 48, 3, 36, 39, 34, 15, 29, 18, 18, 54, 51, 7, 3, 32, 7, 37, 21, 12, 18, 55, 0, 69, 11, 0, 11, 30, 39, 29, 21, 49, 19, 49, 7, 11, 34, 7, 32, 40, 19, 13, 32, 29, 0, 18]
[18, 28, 3, 55, 24, 25, 29, 3, 21, 11, 7, 13, 13, 18, 14, 19, 17, 9, 5, 21, 12, 16, 18, 11, 44, 29, 43, 3, 5, 11, 11, 11, 21, 1, 11, 39, 7, 27, 21, 47, 44, 8, 49, 43, 20, 47, 29, 34, 16, 51, 39, 7, 23, 0, 4, 2, 41, 19, 23, 36, 3, 78, 0, 29]
[22, 13, 21, 21, 43, 7, 39, 44, 49, 43, 5, 43, 36, 4, 49, 31, 20, 55, 6, 34, 42, 45, 11, 32, 15, 0, 0, 11, 41, 3, 11, 22, 7, 32, 41, 29, 36, 29, 11, 19, 10, 0, 39, 12, 48, 22, 43, 18, 0, 39, 2, 22, 36, 43, 31, 28, 2, 28, 22, 23, 41, 8, 43, 43]
[22, 4, 42, 57, 1, 43, 18,

In [23]:
myModel.eval()

test_input = "where am i?"
test_input_2 = "who are you"
test_input = myData.tokenizer(test_input, padding=True, return_tensors='pt').to(device)
test_input_2 = myData.tokenizer(test_input_2, padding=True, return_tensors='pt').to(device)

intent = myModel(test_input)
intent2 = myModel(test_input_2)
print(myData.intent_labels[intent.argmax()], myData.intent_labels[intent2.argmax()])

calendar_query calendar_query


In [38]:
def fill_slot(tokenizer, token_tag, tokens):  
    start_idx = None
    prev_tag = 'None'
    ret = []  
    
    for idx in range(len(token_tag)):
        
        if token_tag[idx] == prev_tag:
            continue
        
        elif prev_tag == 'None':
            start_idx = idx
        
        elif prev_tag != 'None':
            
            if token_tag[idx] == 'None':
                ret.append({prev_tag : tokenizer.convert_tokens_to_string(tokens[start_idx:idx])})
            
            else:
                ret.append({prev_tag : tokenizer.convert_tokens_to_string(tokens[start_idx:idx])})
                start_idx = idx
        
        prev_tag = token_tag[idx]
        
        
    return ret   

In [40]:
myModel.eval()
myModel.to(device)

test_input = "who are you asdffe"

test_input = myData.tokenizer(test_input, padding=True, return_tensors='pt').to(device)


intent = myModel(test_input)
print(intent)
print(myData.intent_labels[intent[i].argmax()])


"""
out_label=[]
out_tag=[]
for i in range(intent.shape[0]):
    out_label.append(myData.intent_labels[intent[i].argmax()])
    out_tag_token = [myData.type_labels[k.argmax()] for k in tag[i]]
    print(out_tag_token)
    out_tag.append(fill_slot(myData.tokenizer, out_tag_token, test_input[i].tokens[1:]))
    
    #out_tag.append([myData.type_labels[k.argmax()] for k in tag[i]])
    
print(out_label, out_tag)
"""

tensor([[ 1.6555e+00,  1.1889e-03, -6.2424e-01,  1.3413e+00, -1.8278e-01,
         -2.9840e-01,  3.1567e-01,  1.1413e+00,  3.0547e-01, -1.1709e-01,
          2.3932e-01,  1.6241e+00,  5.2523e-01,  7.2117e-01, -6.0465e-01,
          2.1642e-01, -2.6634e-01, -8.9576e-01,  1.2211e+00,  1.5316e+00,
         -1.0550e-01,  1.4576e+00,  1.4376e+00, -2.3883e-01,  2.4447e-01,
         -5.1494e-01, -1.8882e+00,  6.1722e-01, -7.7090e-02,  6.4463e-01,
         -6.3426e-01, -2.7871e-01,  8.3683e-01, -1.4347e-01,  5.0025e-01,
         -9.2164e-01,  9.4688e-01, -6.7938e-01, -2.0005e-01,  4.9957e-01,
         -5.6478e-01, -1.3447e-01,  4.3729e-02,  1.0152e+00,  3.2196e-01,
         -1.9983e-01,  2.0087e-01,  1.3855e-01,  8.7111e-03,  2.1781e-01,
         -7.3992e-01, -6.3566e-01, -1.6820e+00, -2.7653e+00, -8.6095e-01,
         -2.1530e-01, -2.1369e+00, -1.0667e+00, -7.2223e-01, -1.2805e+00,
         -3.2830e+00, -2.3454e+00, -3.4593e+00, -3.8096e+00, -2.9950e+00,
         -4.1624e+00, -3.0887e+00, -2.

'\nout_label=[]\nout_tag=[]\nfor i in range(intent.shape[0]):\n    out_label.append(myData.intent_labels[intent[i].argmax()])\n    out_tag_token = [myData.type_labels[k.argmax()] for k in tag[i]]\n    print(out_tag_token)\n    out_tag.append(fill_slot(myData.tokenizer, out_tag_token, test_input[i].tokens[1:]))\n    \n    #out_tag.append([myData.type_labels[k.argmax()] for k in tag[i]])\n    \nprint(out_label, out_tag)\n'