In [11]:
import tensorflow as tf
import torch
from transformers import BertTokenizer
import pickle
import os
import keras_bert
import zipfile
from sklearn.datasets import load_files
import chardet
from keras.preprocessing.sequence import pad_sequences

In [12]:
# def standardize_to_utf8(encoding):
#     """
#     standardize to utf-8 if necessary.
#     NOTE: mainly used to use utf-8 if ASCII is detected, as
#     BERT performance suffers otherwise.
#     """
#     encoding = 'utf-8' if encoding.lower() in ['ascii', 'utf8', 'utf-8'] else encoding
#     return encoding

In [13]:
train_test_names = ["train", "test"]
datadir = "dataset_up"
classes = ["class0", "class1"]
train_str = train_test_names[0]
train_b = load_files(os.path.join(datadir, train_str), shuffle=False, categories=classes)
test_str = train_test_names[1]
test_b = load_files(os.path.join(datadir,  test_str), shuffle=False, categories=classes)
x_train = train_b.data
y_train = train_b.target
x_test = test_b.data
y_test = test_b.target

In [14]:
# lst = [chardet.detect(doc)['encoding'] for doc in x_train[:32]]
# encoding = max(set(lst), key=lst.count)
# encoding = standardize_to_utf8(encoding)
# x_train = [x.decode(encoding) for x in x_train]
# x_test = [x.decode(encoding) for x in x_test]

In [15]:
# len(x_train)

In [16]:
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [17]:
# def get_token_ids(x_train, x_test):
    
#     token_tr = []
#     token_tst = []
#     count = 0
#     for sent in x_train :
#         tokens = tokenizer.encode(sent, add_special_tokens = True, max_length=512)
#         token_tr.append(tokens)
#         count+=1
#         if(count%1000==0):
#             print(count)
    
#     for sent1 in x_test :
#         tokens1 = tokenizer.encode(sent1, add_special_tokens = True, max_length=512)
#         token_tst.append(tokens1)
#         count+=1
#         if(count%1000==0):
#             print(count)
            
#     return token_tr, token_tst 

In [18]:
# xtr_token, xtst_token = get_token_ids(x_train, x_test)

In [19]:
# xtr_token = pad_sequences(xtr_token, maxlen=512, dtype="long", 
#                           value=0, truncating="post", padding="post")
# xtst_token = pad_sequences(xtst_token, maxlen=512, dtype="long", 
#                           value=0, truncating="post", padding="post")

In [20]:
# attention_mask_tr = []
# attention_mask_tst = []
# for sent in xtr_token:
#     att_mask = [int(token_id > 0) for token_id in sent]
#     attention_mask_tr.append(att_mask)

In [21]:
# for sent in xtst_token:
#     att_mask = [int(token_id > 0) for token_id in sent]
#     attention_mask_tst.append(att_mask)

In [22]:
# pickle.dump(xtr_token, open("pickles/train_input_tokens.pkl", "wb"))
# pickle.dump(xtst_token, open("pickles/test_input_tokens.pkl", "wb"))
# pickle.dump(attention_mask_tr, open("pickles/attention_mask_train.pkl" , "wb"))
# pickle.dump(attention_mask_tst, open("pickles/attention_mask_test.pkl", "wb"))

In [23]:
xtr_token = pickle.load(open("pickles/train_input_tokens.pkl", "rb"))
xtst_token = pickle.load(open("pickles/test_input_tokens.pkl", "rb"))
attention_mask_tr = pickle.load(open("pickles/attention_mask_train.pkl", "rb"))
attention_mask_tst = pickle.load(open("pickles/attention_mask_test.pkl", "rb"))

In [24]:
train_input = torch.tensor(xtr_token)
test_input = torch.tensor(xtst_token)

train_label = torch.tensor(y_train)
test_label = torch.tensor(y_test)

train_mask = torch.tensor(attention_mask_tr)
test_mask = torch.tensor(attention_mask_tst)

In [25]:
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
batch_size = 32

train_data = TensorDataset(train_input, train_mask, train_label)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

# Create the DataLoader for our validation set.
test_data = TensorDataset(test_input, test_mask, test_label)
test_sampler = SequentialSampler(test_data)
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)

In [26]:
import torch.nn as nn
from transformers import BertModel

class ContextEmbeddings(nn.Module):

    def __init__(self, freeze_bert = True):
        super(ContextEmbeddings, self).__init__()
        self.bert_layer = BertModel.from_pretrained('bert-base-uncased')
        
        if freeze_bert:
            for p in self.bert_layer.parameters():
                p.requires_grad = False
                
        self.embedding_layer = nn.Linear(768, 256)
        self.cls_layer = nn.Linear(256, 1)

    def forward(self, seq, attn_masks):
        '''
        Inputs:
            -seq : Tensor of shape [B, T] containing token ids of sequences
            -attn_masks : Tensor of shape [B, T] containing attention masks to be used to avoid contibution of PAD tokens
        '''

        #Feeding the input to BERT model to obtain contextualized representations
        cont_reps, _ = self.bert_layer(seq, attention_mask = attn_masks)

        #Obtaining the representation of [CLS] head
        embedding_rep = cont_reps[:, 0]
        embeds = self.embedding_layer(embedding_rep)
#         cls_rep = embeds[:,0]
        #Feeding cls_rep to the classifier layer
        logits = self.cls_layer(embeds)

        return logits

In [27]:
net = ContextEmbeddings(freeze_bert = True)

In [28]:
import torch.nn as nn
import torch.optim as optim

criterion = nn.BCEWithLogitsLoss()
opti = optim.Adam(net.parameters(), lr = 2e-5)

In [29]:
def get_accuracy_from_logits(logits, labels):
    probs = torch.sigmoid(logits.unsqueeze(-1))
    soft_probs = (probs > 0.5).long()
    acc = (soft_probs.squeeze() == labels).float().mean()
    return acc

In [30]:
def train(net, criterion, opti, train_loader, test_loader, max_eps):

    for ep in range(max_eps):
        
        for it, (seq, attn_masks, labels) in enumerate(train_loader):
            #Clear gradients
            opti.zero_grad()  
            #Converting these to cuda tensors
#             seq, attn_masks, labels = seq.cuda(), attn_masks.cuda(), labels.cuda()

            #Obtaining the logits from the model
            logits = net(seq, attn_masks)

            #Computing loss
            loss = criterion(logits.squeeze(-1), labels.float())

            #Backpropagating the gradients
            loss.backward()

            #Optimization step
            opti.step()

            if (it + 1) % 1 == 0:
                acc = get_accuracy_from_logits(logits, labels)
                print("Iteration {} of epoch {} complete. Loss : {} Accuracy : {}".format(it+1, ep+1, loss.item(), acc))

In [32]:
train(net, criterion, opti, train_dataloader, test_dataloader, 1)

KeyboardInterrupt: 

In [17]:
torch.save(net.state_dict(), "models/model_state_dict")

In [19]:
torch.save(net, "models/model")

  "type " + obj.__name__ + ". It won't be checked "
