In [1]:
from transformers import RobertaModel, RobertaTokenizerFast

tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

In [2]:
import json


all_spans = {}

with open("../../mfc_v4.0/spans_with_context.json", "r") as f:
    all_spans = json.load(f)
    
articles = all_spans["articles"]
spans = all_spans["spans"]

a_deleted = 0
for article in list(articles.keys()):
    if len(tokenizer(articles[article])['input_ids']) >= 512:
        del articles[article]
        a_deleted += 1
        
s_deleted = 0
filtered_spans = []
for span in spans:
    if span[-1] not in articles:
        s_deleted += 1
        continue
    filtered_spans.append(span)
    
spans = filtered_spans
labels = set(e[0] for e in spans)

print("Deleted {} articles and {} spans to fit bert input size".format(a_deleted, s_deleted))

Token indices sequence length is longer than the specified maximum sequence length for this model (3016 > 512). Running this sequence through the model will result in indexing errors


Deleted 321 articles and 8147 spans to fit bert input size


In [3]:
import torch
import torch.nn as NN
from torch.nn.functional import normalize


class FullContextSpanClassifier(NN.Module):
    def __init__(self, labels, reporting=False):
        
        super().__init__()
        self.transformer = RobertaModel.from_pretrained("roberta-base")
        for params in self.transformer.parameters():
            params.requires_grad = False
        self.transformer.eval()
        self.fc = NN.Linear(768, len(labels))
        self.logits = NN.Softmax()
        self.labels = labels
        self.reporting=reporting
    
    def forward(self, x):
        tokens = x[0]
        indices = x[1]
        dims = list(indices.shape)
        indices = torch.flatten(indices)
        
        self.report("Data unpacked. running bigbird...")
        
        x = self.transformer(**tokens).last_hidden_state
        
        self.report("bigbird run. applying mask and summing...")
        
        x = torch.reshape(x, (dims[0]*dims[1], 768))
        self.report("mask shape:", indices.shape, "data shape:", x.shape)
        
        x = (x.t()*indices).t()
        self.report("after masking, data is of shape", x.shape)
        x = torch.reshape(x, (dims[0], dims[1], 768))
        
        x = torch.sum(x, dim=1)
        self.report("after summing, data is of shape", x.shape)
        
        x = normalize(x, dim=1)
        
        x = self.fc(x)
        
        x = self.logits(x)
        
        self.report("classifier run.")
        
        return x
    
    def report(self,*args):
        if self.reporting:
            print("(FullContextSpanClassifier): ", " ".join([str(x) for x in args]))

In [4]:
import random

In [5]:
def calc_annotation_mask(offset_mapping, batch_bounds):
    token_spans = []
    for i, inp in enumerate(offset_mapping):

        start_idx = -1
        end_idx = -1

        for j, span in enumerate(inp):
            tok_start = span[0]
            tok_end = span[1]
            annotation_start = batch_bounds[i][0]
            annotation_end = batch_bounds[i][1]
            if tok_end > annotation_start and start_idx == -1:
                start_idx = j
            if tok_end > annotation_end:
                end_idx = j
                break
        token_spans.append([1 if i >= start_idx and i < end_idx else 0 for i in range(len(inp))])
    return token_spans

In [None]:
import time


subset_size = 5000

data = spans[:subset_size]

batch_size = 16

keys = [float(i+1) for i in range(15)]

data = [ [[1 if i == d[0] else 0 for i in keys],d[1],d[2],d[3]] for d in data]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

random.shuffle(data)

loss_fn = torch.nn.MSELoss(reduction='sum')
learning_rate = 5e-4

model = FullContextSpanClassifier(list(labels)).to(device)
model.train()

train_losses = []
test_losses = []

optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)


test_x = [d[1:] for d in data[:batch_size]]
test_x_tokens = tokenizer([articles[x[-1]] for x in test_x], padding=True, return_offsets_mapping=True, return_tensors='pt')
test_annotation_mask = calc_annotation_mask(test_x_tokens["offset_mapping"], test_x)
test_slice_tensor = torch.tensor(test_annotation_mask, dtype=torch.float)
del test_x_tokens["offset_mapping"]
test_model_input = [test_x_tokens.to(device), test_slice_tensor.to(device)]

test_y = torch.tensor([d[0] for d in data[:batch_size]], dtype=torch.float).to(device)

train_data = data[batch_size:]

for epoch in range(5):
    
    print("Starting training epoch", epoch)
    
    random.shuffle(train_data)

    train_x = [d[1:] for d in train_data]
    train_y = [d[0] for d in train_data]

    x_batches = []
    y_batches = []
    
    print("constructing batches...")
    
    for i in range(len(train_data)//batch_size):
        batch_x = train_x[i*batch_size:(i+1)*batch_size]
        batch_y = torch.tensor(train_y[i*batch_size:(i+1)*batch_size], dtype=torch.float).to(device)
        x_batches.append(batch_x)
        y_batches.append(batch_y)
    
    t1 = time.time()
    print("batches constructed. starting training steps...")

    for step in range(len(x_batches)):
        
        #print("tokenizing batch {} of {}...".format(step, len(x_batches)))
        
        batch_tokens = tokenizer([articles[x[-1]] for x in x_batches[step]], padding=True, truncation=True, return_offsets_mapping=True, return_tensors='pt')
        
        #print("tokenized. calculating annotation mask...")
        
        annotation_mask = calc_annotation_mask(batch_tokens["offset_mapping"], x_batches[step])
        slice_tensor = torch.tensor(annotation_mask, dtype=torch.float)
        del batch_tokens["offset_mapping"]
        model_input = [batch_tokens.to(device), slice_tensor.to(device)]
        
        #print("mask calculated. running model...")
        
#         with torch.no_grad():
#             embeddings = model(model_input).detach().cpu().numpy().astype('float32')
#         all_embeddings.append(embeddings)
        y_train_pred = model(model_input)
        train_loss = loss_fn(y_train_pred, y_batches[step])
        
        train_losses.append(train_loss.item())

        print("Training Loss at step",step,":",train_loss.item())

        optimizer.zero_grad()

        train_loss.backward()

        optimizer.step()

        if step % 10 == 0:
            y_test_pred = model(test_model_input)
            test_loss = loss_fn(y_test_pred, test_y).item()
            test_losses.append(test_loss)
            print("Test Loss",test_loss,"\n")
            
            diff = time.time() - t1
            proj_end = (diff/(step+1)) * len(x_batches)
            print("done with step {} of {}".format(step, len(x_batches)))
            print("current time:", diff/60, "est epoch finish: ", proj_end/60)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Starting training epoch 0
constructing batches...
batches constructed. starting training steps...
Training Loss at step 0 : 14.959721565246582
Test Loss 14.867717742919922 

done with step 0 of 311
current time: 0.011228454113006592 est epoch finish:  3.4920492291450502
Training Loss at step 1 : 14.894336700439453
Training Loss at step 2 : 14.858057975769043
Training Loss at step 3 : 14.871723175048828
Training Loss at step 4 : 14.857890129089355
Training Loss at step 5 : 14.851085662841797
Training Loss at step 6 : 14.791114807128906
Training Loss at step 7 : 14.797819137573242
Training Loss at step 8 : 14.750685691833496
Training Loss at step 9 : 14.730365753173828
Training Loss at step 10 : 14.79826545715332
Test Loss 14.77617073059082 

done with step 10 of 311
current time: 0.07496475378672282 est epoch finish:  2.1194580388791633
Training Loss at step 11 : 14.66836166381836
Training Loss at step 12 : 14.718295097351074
Training Loss at step 13 : 14.719673156738281
Training Loss a

Training Loss at step 138 : 14.141304969787598
Training Loss at step 139 : 13.964757919311523
Training Loss at step 140 : 13.989887237548828
Test Loss 14.484991073608398 

done with step 140 of 311
current time: 0.9137767990430196 est epoch finish:  2.015493507109072
Training Loss at step 141 : 14.272119522094727
Training Loss at step 142 : 14.026771545410156
Training Loss at step 143 : 14.262874603271484
Training Loss at step 144 : 13.904277801513672
Training Loss at step 145 : 13.968494415283203
Training Loss at step 146 : 14.079180717468262
Training Loss at step 147 : 14.188714981079102
Training Loss at step 148 : 14.207623481750488
Training Loss at step 149 : 14.084989547729492
Training Loss at step 150 : 13.898887634277344
Test Loss 14.469326972961426 

done with step 150 of 311
current time: 0.9755862792332967 est epoch finish:  2.0093200850434125
Training Loss at step 151 : 14.138829231262207
Training Loss at step 152 : 13.954305648803711
Training Loss at step 153 : 14.029413223

Training Loss at step 275 : 13.759088516235352
Training Loss at step 276 : 13.976300239562988
Training Loss at step 277 : 14.11009693145752
Training Loss at step 278 : 13.846817016601562
Training Loss at step 279 : 13.529170989990234
Training Loss at step 280 : 14.044897079467773
Test Loss 14.435081481933594 

done with step 280 of 311
current time: 1.819424311319987 est epoch finish:  2.0136688997171386
Training Loss at step 281 : 13.88992977142334
Training Loss at step 282 : 13.429948806762695
Training Loss at step 283 : 13.854198455810547
Training Loss at step 284 : 12.915481567382812
Training Loss at step 285 : 13.7476806640625
Training Loss at step 286 : 14.205766677856445
Training Loss at step 287 : 13.832663536071777
Training Loss at step 288 : 13.262690544128418
Training Loss at step 289 : 13.662324905395508
Training Loss at step 290 : 13.929732322692871
Test Loss 14.420700073242188 

done with step 290 of 311
current time: 1.8861392339070637 est epoch finish:  2.01577079637490

Training Loss at step 101 : 13.081175804138184
Training Loss at step 102 : 13.608415603637695


In [9]:
from matplotlib import pyplot as plt

train_domain = list(range(len(train_losses)))
test_domain = ([i*10 for i in range(len(test_losses))])

plt.plot(train_domain, train_losses, label="train")
plt.plot(test_domain, test_losses, label="test")

plt.legend()

plt.show()

True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True


199

In [7]:
torch.save(model.state_dict(), "./distilberta-mfc-with-context.pt")

In [8]:
lengths = dict({})
for i in range(14):
    lengths[2**i] = 0

for n, title in enumerate(articles.keys()):
    if n % 1000 == 0:
        print("Done with {} of {} articles".format(n, len(articles)))
    length = len(tokenizer(articles[title])['input_ids'])
    for i in range(1,15):
        if length < 2**i:
            lengths[2**(i-1)] += 1
            break
            

Done with 0 of 32014 articles
Done with 1000 of 32014 articles
Done with 2000 of 32014 articles
Done with 3000 of 32014 articles
Done with 4000 of 32014 articles
Done with 5000 of 32014 articles
Done with 6000 of 32014 articles
Done with 7000 of 32014 articles
Done with 8000 of 32014 articles
Done with 9000 of 32014 articles
Done with 10000 of 32014 articles
Done with 11000 of 32014 articles
Done with 12000 of 32014 articles
Done with 13000 of 32014 articles
Done with 14000 of 32014 articles
Done with 15000 of 32014 articles
Done with 16000 of 32014 articles
Done with 17000 of 32014 articles
Done with 18000 of 32014 articles
Done with 19000 of 32014 articles
Done with 20000 of 32014 articles
Done with 21000 of 32014 articles
Done with 22000 of 32014 articles
Done with 23000 of 32014 articles
Done with 24000 of 32014 articles
Done with 25000 of 32014 articles
Done with 26000 of 32014 articles
Done with 27000 of 32014 articles
Done with 28000 of 32014 articles
Done with 29000 of 32014 ar

In [9]:
lengths

{1: 0,
 2: 0,
 4: 0,
 8: 0,
 16: 0,
 32: 0,
 64: 20,
 128: 2092,
 256: 29581,
 512: 240,
 1024: 75,
 2048: 6,
 4096: 0,
 8192: 0}