In [1]:
from transformers import AutoTokenizer, BertGenerationDecoder, BertGenerationConfig, EncoderDecoderModel
from transformers import AutoTokenizer, BertModel, BertConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from torch.utils.data import DataLoader
from torch.utils.data import Dataset


In [10]:
model_name = "bert-base-uncased"
learning_rate = 1e-6
batch_size = 64
epoches = 5
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [11]:
class MyData(Dataset):
    
    def __init__(self,dataframe):
        self.df = dataframe
        self.x = df[df.columns[0]]
        self.y = torch.tensor(df[df.columns[1]].values)
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self,idx):
        return self.x[idx],self.y[idx]


In [12]:
#df = pd.read_json('./archive/is_train.json')
df = pd.read_json('./archive/slurp/train.json')
df[df.columns[1]], labels = pd.factorize(df[df.columns[1]])

num_classes = len(labels)
train_dataloader = DataLoader(MyData(df), batch_size=batch_size, shuffle=True)

In [13]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

class BertClassification(nn.Module):
    def __init__(self, model_name, num_classes):
        super(BertClassification, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.config = BertConfig.from_pretrained(model_name)
        self.classifier = nn.Linear(in_features=self.config.hidden_size, out_features=num_classes)
        
        torch.nn.init.xavier_normal_(self.classifier.weight.data)
        torch.nn.init.uniform_(self.classifier.bias.data)
        
    def forward(self,inputs):
        output = self.bert(**inputs)
        output = self.classifier(output.last_hidden_state[:,0])
        output = F.sigmoid(output)
        
        return output
        
        

In [None]:
classifier = torch.load('./classifier.pt')

In [14]:
classifier = BertClassification(model_name, num_classes)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [15]:
classifier.to(device)
classifier.train()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=classifier.parameters(), lr=learning_rate, weight_decay=1e-5)

train_loss_avg = []

print(device)

for epoch in range(epoches):
    train_loss_avg.append(0)
    num_batches = 0
    
    for train_x, train_y in train_dataloader:
        inputs = tokenizer(train_x, padding=True, return_tensors='pt').to(device)
        train_y = train_y.to(device)
        
        output = classifier(inputs)
    
        loss = criterion(output, train_y)
    
        optimizer.zero_grad()
        loss.backward()
    
        optimizer.step()
        
        train_loss_avg[-1] += loss.item()
        num_batches += 1
    
        
    train_loss_avg[-1] /= num_batches
    print('Epoch [%d / %d] average cross entropy error: %f' % (epoch+1, epoches, train_loss_avg[-1]))
    torch.save(classifier, './classifier.pt')
    
classifier.to('cpu')

cuda:0
Epoch [1 / 5] average cross entropy error: 4.185834
Epoch [2 / 5] average cross entropy error: 4.172719
Epoch [3 / 5] average cross entropy error: 4.166778
Epoch [4 / 5] average cross entropy error: 4.163884
Epoch [5 / 5] average cross entropy error: 4.164282


BertClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

In [38]:
torch.cuda.empty_cache()

In [43]:
torch.save(classifier, './classifier.pt')

In [19]:
classifier.eval()
classifier.to('cpu')
text = "what time is it"
output = torch.argmax(classifier(tokenizer(text, return_tensors='pt'))).item()
print(labels)
print(output, labels[output])

Index(['calendar_set', 'audio_volume_up', 'iot_hue_lightup', 'weather_query',
       'iot_coffee', 'audio_volume_mute', 'lists_remove', 'email_query',
       'alarm_set', 'alarm_query', 'qa_stock', 'play_music',
       'recommendation_events', 'qa_definition', 'alarm_remove',
       'play_podcasts', 'social_query', 'email_addcontact', 'news_query',
       'calendar_query', 'music_likeness', 'general_quirky', 'qa_factoid',
       'takeaway_order', 'play_audiobook', 'iot_cleaning', 'general_greet',
       'transport_query', 'transport_taxi', 'email_sendemail', 'general_joke',
       'qa_maths', 'social_post', 'transport_ticket', 'cooking_recipe',
       'music_settings', 'calendar_remove', 'iot_wemo_on',
       'iot_hue_lightchange', 'play_radio', 'email_querycontact',
       'transport_traffic', 'qa_currency', 'datetime_query',
       'iot_hue_lightoff', 'takeaway_query', 'lists_createoradd',
       'music_query', 'recommendation_locations', 'lists_query',
       'recommendation_movies'