In [1]:
from transformers import BigBirdModel, RobertaTokenizerFast

tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

In [2]:
import json


all_spans = {}

with open("../../mfc_v4.0/spans_with_context.json", "r") as f:
    all_spans = json.load(f)
    
articles = all_spans["articles"]
spans = all_spans["spans"]

labels = set(e[0] for e in spans)

spans[0]

[13.0, 34, 149, 'climate_change1.0-1']

In [3]:
import torch
import torch.nn as NN
from torch.nn.functional import normalize


class FullContextSpanClassifier(NN.Module):
    def __init__(self, labels, reporting=False):
        
        super().__init__()
        self.transformer = BigBirdModel.from_pretrained("google/bigbird-roberta-base")
        for params in self.transformer.parameters():
            params.requires_grad = False
        self.transformer.eval()
#         self.fc = NN.Linear(768, len(labels))
#         self.logits = NN.Softmax()
#         self.labels = labels
        self.reporting=reporting
    
    def forward(self, x):
        tokens = x[0]
        indices = x[1]
        dims = list(indices.shape)
        indices = torch.flatten(indices)
        
        self.report("Data unpacked. running bigbird...")
        
        x = self.transformer(**tokens).last_hidden_state.detach()
        
        self.report("bigbird run. applying mask and summing...")
        
        x = torch.reshape(x, (dims[0]*dims[1], 768)).detach()
        self.report("mask shape:", indices.shape, "data shape:", x.shape)
        
        x = (x.t()*indices).t().detach()
        self.report("after masking, data is of shape", x.shape)
        x = torch.reshape(x, (dims[0], dims[1], 768))
        
        x = torch.sum(x, dim=1).detach()
        self.report("after summing, data is of shape", x.shape)
        
        x = normalize(x, dim=1).detach()
        
#         self.report("mask applied. running classifier...")
        
#         x = self.fc(x)
        
#         x = self.logits(x)
        
#         self.report("classifier run.")
        
        return x.detach()
    
    def report(self,*args):
        if self.reporting:
            print("(FullContextSpanClassifier): ", " ".join([str(x) for x in args]))

In [4]:
import random

In [5]:
def calc_annotation_mask(offset_mapping, batch_bounds):
    token_spans = []
    for i, inp in enumerate(offset_mapping):

        start_idx = -1
        end_idx = -1

        for j, span in enumerate(inp):
            tok_start = span[0]
            tok_end = span[1]
            annotation_start = batch_bounds[i][0]
            annotation_end = batch_bounds[i][1]
            if tok_end > annotation_start and start_idx == -1:
                start_idx = j
            if tok_end > annotation_end:
                end_idx = j
                break
        token_spans.append([1 if i >= start_idx and i < end_idx else 0 for i in range(len(inp))])
    return token_spans

In [6]:
import time

data = spans

batch_size = 6

keys = [float(i+1) for i in range(15)]

data = [ [[1 if i == d[0] else 0 for i in keys],d[1],d[2],d[3]] for d in data]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

random.shuffle(data)

loss_fn = torch.nn.MSELoss(reduction='sum')
learning_rate = 5e-4

model = FullContextSpanClassifier(list(labels)).to(device)
model.eval()

train_losses = []
test_losses = []

optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)


# test_x = [d[1:] for d in data[:batch_size]]
# test_x_tokens = tokenizer([articles[x[-1]] for x in test_x], padding=True, return_offsets_mapping=True, return_tensors='pt')
# test_annotation_mask = calc_annotation_mask(test_x_tokens["offset_mapping"], test_x)
# test_slice_tensor = torch.tensor(test_annotation_mask, dtype=torch.float)
# del test_x_tokens["offset_mapping"]
# test_model_input = [test_x_tokens.to(device), test_slice_tensor.to(device)]

test_y = torch.tensor([d[0] for d in data[:batch_size]], dtype=torch.float).to(device)

train_data = data[batch_size:]

all_embeddings = []

for epoch in range(1):
    
    print("Starting training epoch", epoch)
    
    random.shuffle(train_data)

    train_x = [d[1:] for d in train_data]
    train_y = [d[0] for d in train_data]

    x_batches = []
    y_batches = []
    
    print("constructing batches...")
    
    for i in range(len(train_data)//batch_size):
        batch_x = train_x[i*batch_size:(i+1)*batch_size]
        batch_y = torch.tensor(train_y[i*batch_size:(i+1)*batch_size], dtype=torch.float).to(device)
        x_batches.append(batch_x)
        y_batches.append(batch_y)
    
    t1 = time.time()
    print("batches constructed. starting training steps...")

    for step in range(len(x_batches)):
        
        if step > 0 and step % 100 == 0:
            diff = time.time() - t1
            proj_end = (diff/step) * len(x_batches)
            print("done with step {} of {}".format(step, len(x_batches)))
            print("current time:", diff/60, "est finish: ", proj_end/60)
        
        #print("tokenizing batch {} of {}...".format(step, len(x_batches)))
        
        batch_tokens = tokenizer([articles[x[-1]] for x in x_batches[step]], padding="max_length", truncation=True, max_length=2048, return_offsets_mapping=True, return_tensors='pt')
        
        #print("tokenized. calculating annotation mask...")
        
        annotation_mask = calc_annotation_mask(batch_tokens["offset_mapping"], x_batches[step])
        slice_tensor = torch.tensor(annotation_mask, dtype=torch.float)
        del batch_tokens["offset_mapping"]
        model_input = [batch_tokens.to(device), slice_tensor.to(device)]
        
        #print("mask calculated. running model...")
        
        with torch.no_grad():
            embeddings = model(model_input).detach().cpu().numpy().astype('float32')
        all_embeddings.append(embeddings)
        del model_input
        torch.cuda.empty_cache()
#         y_train_pred = model(model_input)
#         train_loss = loss_fn(y_train_pred, y_batches[step])
        
#         train_losses.append(train_loss.item())

#         print("Training Loss at step",step,":",train_loss.item())

#         optimizer.zero_grad()

#         train_loss.backward()

#         optimizer.step()

#         if step % 10 == 0:
#             y_test_pred = model(test_model_input)
#             test_loss = loss_fn(y_test_pred, test_y).item()
#             test_losses.append(test_loss)
#             print("Test Loss",test_loss,"\n")

Some weights of the model checkpoint at google/bigbird-roberta-base were not used when initializing BigBirdModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BigBirdModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BigBirdModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Starting training epoch 0
constructing batches...
batches constructed. starting training steps...


  * num_indices_to_pick_from


done with step 100 of 94215
current time: 1.8279147187868754 est finish:  1722.1698523050545
done with step 200 of 94215
current time: 3.7025567293167114 est finish:  1744.1819112628698
done with step 300 of 94215
current time: 5.574326856931051 est finish:  1750.6173494191967


KeyboardInterrupt: 

In [9]:
count = 0
for param in model.parameters():
    count += 1
    print(param.requires_grad)
count

True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True


199

In [7]:
torch.cuda.empty_cache()

In [8]:
lengths = dict({})
for i in range(14):
    lengths[2**i] = 0

for n, title in enumerate(articles.keys()):
    if n % 1000 == 0:
        print("Done with {} of {} articles".format(n, len(articles)))
    length = len(tokenizer(articles[title])['input_ids'])
    for i in range(1,15):
        if length < 2**i:
            lengths[2**(i-1)] += 1
            break
            

Done with 0 of 32014 articles
Done with 1000 of 32014 articles
Done with 2000 of 32014 articles
Done with 3000 of 32014 articles
Done with 4000 of 32014 articles
Done with 5000 of 32014 articles
Done with 6000 of 32014 articles
Done with 7000 of 32014 articles
Done with 8000 of 32014 articles
Done with 9000 of 32014 articles
Done with 10000 of 32014 articles
Done with 11000 of 32014 articles
Done with 12000 of 32014 articles
Done with 13000 of 32014 articles
Done with 14000 of 32014 articles
Done with 15000 of 32014 articles
Done with 16000 of 32014 articles
Done with 17000 of 32014 articles
Done with 18000 of 32014 articles
Done with 19000 of 32014 articles
Done with 20000 of 32014 articles
Done with 21000 of 32014 articles
Done with 22000 of 32014 articles
Done with 23000 of 32014 articles
Done with 24000 of 32014 articles
Done with 25000 of 32014 articles
Done with 26000 of 32014 articles
Done with 27000 of 32014 articles
Done with 28000 of 32014 articles
Done with 29000 of 32014 ar

In [9]:
lengths

{1: 0,
 2: 0,
 4: 0,
 8: 0,
 16: 0,
 32: 0,
 64: 20,
 128: 2092,
 256: 29581,
 512: 240,
 1024: 75,
 2048: 6,
 4096: 0,
 8192: 0}