In [1]:
!pip install transformers
!pip install jsonlines
!pip install python-levenshtein
!pip install datasets



In [2]:
import Levenshtein as Lev
import torch
from torch.utils.data import Dataset as TorchDataset
from typing import Sequence, Dict, Any, List
import json
from tqdm.notebook import tqdm
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
import torch.nn.functional as F
import torch.nn as nn
from datasets import load_dataset

In [3]:
from transformers import AutoTokenizer
from transformers import BertLMHeadModel
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
maskedlm = BertLMHeadModel.from_pretrained('bert-base-uncased')

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertLMHeadModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertLMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertLMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
train_dataset = load_dataset('trec', split='train')
train_dataset = train_dataset.map(
    lambda e: tokenizer(e['text'], truncation=True, padding='do_not_pad'),
    batched=True
    )
test_dataset = load_dataset('trec', split='test')
test_dataset = test_dataset.map(
    lambda e: tokenizer(e['text'], truncation=True, padding='do_not_pad'),
    batched=True
    )

def collate_fn(batch):
    attention_mask, input_ids, label_coarse = [], [], []
    for b in batch:
        attention_mask.append(torch.tensor(b['attention_mask']))
        input_ids.append(torch.tensor(b['input_ids']))
        label_coarse.append(torch.tensor(b['label-coarse']))
    return {'attention_mask':nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0.0),
            'input_ids':nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0.0),
            'label-coarse':label_coarse}

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, drop_last=True, collate_fn=collate_fn, shuffle=True)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=64, drop_last=True, collate_fn=collate_fn, shuffle=True)

Using custom data configuration default
Reusing dataset trec (/root/.cache/huggingface/datasets/trec/default/1.1.0/1902c380fe66cc215f989888b1b35e8da7e79a3a97520f00dce753fd1f8f5c48)
Loading cached processed dataset at /root/.cache/huggingface/datasets/trec/default/1.1.0/1902c380fe66cc215f989888b1b35e8da7e79a3a97520f00dce753fd1f8f5c48/cache-f4a1d4d2393fdd96.arrow
Using custom data configuration default
Reusing dataset trec (/root/.cache/huggingface/datasets/trec/default/1.1.0/1902c380fe66cc215f989888b1b35e8da7e79a3a97520f00dce753fd1f8f5c48)
Loading cached processed dataset at /root/.cache/huggingface/datasets/trec/default/1.1.0/1902c380fe66cc215f989888b1b35e8da7e79a3a97520f00dce753fd1f8f5c48/cache-b27f26d24927a79a.arrow


In [5]:
class Deep_lev(torch.nn.Module):

    def __init__(self, vocab_size=30522, embedding_dim=128, hidden_dim=128) :
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.encoder = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.linear = nn.Linear(hidden_dim * 3, 1)


    def encode_sequence(self, sequence):
        embedded_sequence = self.embeddings(sequence)
        out, (ht, ct) = self.encoder(embedded_sequence)
        return ht[-1]

    def forward(self, sequence_a, sequence_b, distance = None):
        embedded_sequence_a = self.encode_sequence(sequence_a)
        embedded_sequence_b = self.encode_sequence(sequence_b)
        diff = torch.abs(embedded_sequence_a - embedded_sequence_b)
        representation = torch.cat([embedded_sequence_a, embedded_sequence_b, diff], dim=-1)

        approx_distance = self.linear(representation)

        return approx_distance


class TextCNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, 
                 dropout, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        self.convs = nn.ModuleList([nn.Conv1d(in_channels=embedding_dim, out_channels=n_filters, kernel_size=fs) for fs in filter_sizes])
        self.fc = nn.Linear(len(filter_sizes)*n_filters, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, text):
        embs = self.embedding(text)
        embs = embs.permute(0, 2, 1)
        out = [F.relu(c(embs)) for c in self.convs]
        out_pool = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in out]
        cat = self.dropout(torch.cat(out_pool, dim=1))
        final = self.fc(cat)
        return final




In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
deep_lev = Deep_lev()
deep_lev.load_state_dict(torch.load('deep_levenstein.pt'))
deep_lev = deep_lev.to(device)

In [7]:
classifier = TextCNN(vocab_size=tokenizer.vocab_size,
                embedding_dim=100,
                n_filters=8,
                filter_sizes=[3,4,5],
                output_dim=6,
                dropout=0.1,
                pad_idx=tokenizer.pad_token_id)

classifier = classifier.to(device)
classifier.eval()
classifier.load_state_dict(torch.load('textcnn_trec.pt'))
maskedlm.to(device)

BertLMHeadModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

In [8]:
def dilma_loss(preds, approx_distance, beta=1.0):
    '''
    log((1-Classifier(x'))) - beta * (1-DL(x, x'))**2 
    '''
    global device
    pred = torch.tensor([torch.softmax(preds, dim=1)[i, :].max() for i in range(64)])
    one1 = torch.ones_like(pred, device=device)
    one2 = torch.ones_like(approx_distance, device=device)
    return (-((torch.log(one1 - pred.to(device))).mean() + beta * ((one2 - approx_distance) ** 2))).sum()

In [9]:
import numpy as np
import random 

def mask_tokens(batch, tokenizer):
    batch_masked, masked_inds = [], []
    for b in batch:
        inds_to_mask = np.random.choice(np.arange(1, b[b!=tokenizer.pad_token_id].shape[0] -1), size=1)
        masked_inds.append(inds_to_mask)
        b_new = b.clone()
        b_new[inds_to_mask] = tokenizer.mask_token_id
        batch_masked.append(b_new)
    return torch.stack(batch_masked), masked_inds

In [10]:
def categorical_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """
    max_preds = preds.argmax(dim=1, keepdim=True) # get the index of the max probability
    correct = max_preds.squeeze(1).eq(y)
    correct = correct.detach().to('cpu')
    return correct.sum() / torch.FloatTensor([y.shape[0]])

In [11]:
optimizer = torch.optim.Adam(maskedlm.parameters(), lr=0.01)

In [12]:
losses = []
maskedlm.train()
classifier.eval()
deep_lev.train()
for epoch in range(10):
    epoch_acc, epoch_adv_acc = [], []
    for batch in trainloader:
        optimizer.zero_grad()
        maskedlm.zero_grad()
        b_input_ids = batch['input_ids'].to(device)
        b_input_mask = batch['attention_mask'].to(device)
        b_masked, masked_inds = mask_tokens(b_input_ids, tokenizer)
        b_masked = b_masked.to(device)
        logits = maskedlm(b_masked, attention_mask=b_input_mask)
        x_adv = torch.argmax(torch.softmax(logits['logits'], dim=2), dim=2)*b_input_mask
        approx_distance = deep_lev(x_adv.to(device), b_input_ids)
        scores = classifier(x_adv.to(device))
        scores_orig = classifier(b_input_ids)
        loss = dilma_loss(torch.softmax(scores, dim=1), approx_distance)
        loss.backward()
        optimizer.step()
        epoch_acc.append(categorical_accuracy(torch.softmax(scores_orig, dim=1), torch.stack(batch['label-coarse']).to(device)).item())
        epoch_adv_acc.append(categorical_accuracy(torch.softmax(scores, dim=1), torch.stack(batch['label-coarse']).to(device)).item())
        losses.append(loss.item() / len(trainloader))
    print(f'Epoch: {epoch}/10, Loss: {np.mean(losses)}, mean original accuracy: {round(np.mean(epoch_acc),3)}, mean adversarial accuracy: {round(np.mean(epoch_adv_acc), 3)}')

Epoch: 0/10, Loss: -19.948984594642088, mean original accuracy: 0.999, mean adversarial accuracy: 0.923
Epoch: 1/10, Loss: -19.94919175936689, mean original accuracy: 0.999, mean adversarial accuracy: 0.926
Epoch: 2/10, Loss: -19.949398850877667, mean original accuracy: 0.999, mean adversarial accuracy: 0.921
Epoch: 3/10, Loss: -19.94950260360348, mean original accuracy: 0.999, mean adversarial accuracy: 0.917
Epoch: 4/10, Loss: -19.949439476778764, mean original accuracy: 0.999, mean adversarial accuracy: 0.919
Epoch: 5/10, Loss: -19.949398470727907, mean original accuracy: 0.999, mean adversarial accuracy: 0.924
Epoch: 6/10, Loss: -19.94927094516575, mean original accuracy: 0.999, mean adversarial accuracy: 0.929
Epoch: 7/10, Loss: -19.94922961805931, mean original accuracy: 0.999, mean adversarial accuracy: 0.922
Epoch: 8/10, Loss: -19.949162673686203, mean original accuracy: 0.999, mean adversarial accuracy: 0.927
Epoch: 9/10, Loss: -19.949167281101317, mean original accuracy: 0.99

In [13]:
losses = []
maskedlm.eval()
classifier.eval()
deep_lev.eval()
epoch_acc, epoch_adv_acc = [], []
for batch in testloader:
    b_input_ids = batch['input_ids'].to(device)
    b_input_mask = batch['attention_mask'].to(device)
    b_masked, masked_inds = mask_tokens(b_input_ids, tokenizer)
    b_masked = b_masked.to(device)
    logits = maskedlm(b_masked, attention_mask=b_input_mask)
    x_adv = torch.argmax(torch.softmax(logits['logits'], dim=2), dim=2)*b_input_mask
    approx_distance = deep_lev(x_adv.to(device), b_input_ids)
    scores = classifier(x_adv.to(device))
    scores_orig = classifier(b_input_ids)
    loss = dilma_loss(torch.softmax(scores, dim=1), approx_distance)
    epoch_acc.append(categorical_accuracy(torch.softmax(scores_orig, dim=1), torch.stack(batch['label-coarse']).to(device)).item())
    epoch_adv_acc.append(categorical_accuracy(torch.softmax(scores, dim=1), torch.stack(batch['label-coarse']).to(device)).item())
    losses.append(loss.item() / len(testloader))
    print(f'mean original accuracy: {round(np.mean(epoch_acc),3)}, mean adversarial accuracy: {round(np.mean(epoch_adv_acc), 3)}')

mean original accuracy: 0.844, mean adversarial accuracy: 0.703
mean original accuracy: 0.852, mean adversarial accuracy: 0.719
mean original accuracy: 0.854, mean adversarial accuracy: 0.688
mean original accuracy: 0.855, mean adversarial accuracy: 0.715
mean original accuracy: 0.85, mean adversarial accuracy: 0.706
mean original accuracy: 0.841, mean adversarial accuracy: 0.719
mean original accuracy: 0.853, mean adversarial accuracy: 0.732
