### Load Facet Data

In [173]:
import json
def read_facet_results(file_name):
    with open(file_name, "r") as f:
        lines = f.readlines()
    facet_results = []
    for line in lines:
        facet_results.append(json.loads(line))
    return facet_results
faceted_results = read_facet_results("cs5_abstract-tag.json")



In [174]:
# this is used temporarily for old versions of the facet results
for result in faceted_results:
    result['sent'] = [sent for sent in result['sent'] if "* keywords" not in sent]
    facets = result['gpt_annotation'].replace('.', '').split(';')
    new_facets = []
    for i, facet in enumerate(facets):
        assert str(i) in facet
        new_facets.append(facet.replace(str(i), '').strip())
    try:
        assert len(new_facets) == len(result['sent'])
    except:
        print(len(new_facets), len(result['sent']))
        print(result['sent'])
        print(new_facets)
        
    result['facets'] = new_facets


In [175]:
faceted_results[0]

{'article_id': '1604.01592',
 'sent': [' matrix data sets are common nowadays like in biomedical imaging where the diffusion tensor magnetic resonance imaging ( dt - mri ) modality produces data sets of 3d symmetric positive definite matrices anchored at voxel positions capturing the anisotropic diffusion properties of water molecules in biological tissues . ',
  ' the space of symmetric matrices can be partially ordered using the lwner ordering , and computing extremal matrices dominating a given set of matrices is a basic primitive used in matrix - valued signal processing . in this letter , we design a fast and easy - to - implement iterative algorithm to approximate arbitrarily finely these extremal matrices . ',
  ' finally , we discuss on extensions to matrix clustering .    '],
 'gpt_annotation': '0 Background; 1 Method; 2 Value',
 'prompt_tokens': 367,
 'completion_tokens': 10,
 'total_tokens': 377,
 'facets': ['Background', 'Method', 'Value']}

## Preprocess data

In [97]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")



In [157]:
from torch.utils.data import Dataset, DataLoader
import torch
class FacetedSummaryDataset(Dataset):
    def __init__(self, faceted_results):
        self.faceted_results = faceted_results
        self.facet2idx = {"Background": 0, "Method": 1, "Result": 2,  "Value": 3, "Others": 4}
        
    def __len__(self):
        return len(self.faceted_results)
    
    def __getitem__(self, idx):
        example = self.faceted_results[idx]
        model_input = {'input_ids': None, 'attention_mask': None, 'token_type_ids': None, 'labels': None}
        # tokenize a list of sentences
        tokenized_sentences = [tokenizer.tokenize(sent) + [tokenizer.sep_token] for sent in example['sent']]
        token_type_ids = [len(tokens) for tokens in tokenized_sentences]
        sent_lens = [len(tokens) for tokens in tokenized_sentences]
        token_type_ids = [[i]*length for i, length in enumerate(sent_lens)]
        token_type_ids = [item for sublist in token_type_ids for item in sublist]
        
        # add the cls token
        tokenized_sentences = [tokenizer.cls_token] + [token for sent in tokenized_sentences for token in sent]
        token_type_ids = [0] + token_type_ids

        # convert the tokens to indices
        # model_input["tokenized_sentences"] = tokenized_sentences
        model_input["input_ids"] = tokenizer.convert_tokens_to_ids(tokenized_sentences)
        model_input['attention_mask'] = [1] * len(model_input['input_ids'])
        model_input['token_type_ids'] = token_type_ids
        
        # convert the facets to indices
        facet_indices = []
        for facet in example['facets']:
            facet_indices.append(self.facet2idx[facet])
        model_input['labels'] = facet_indices
        return model_input['input_ids'], model_input['attention_mask'], model_input['token_type_ids'], model_input['labels']

In [158]:
dataset = FacetedSummaryDataset(faceted_results)
# collate function for padding the input
def collate_fn(batch):
    input_ids, attention_mask, token_type_ids, lst_labels = zip(*batch)
    
    max_len = max([len(ids) for ids in input_ids])
    input_ids = [ids + [0] * (max_len - len(ids)) for ids in input_ids]
    attention_mask = [mask + [0] * (max_len - len(mask)) for mask in attention_mask]
    token_type_ids = [ids + [0] * (max_len - len(ids)) for ids in token_type_ids]
    lst_labels = [item for sublist in lst_labels for item in sublist]
    return input_ids, attention_mask, token_type_ids, lst_labels

dataloader = DataLoader(dataset, batch_size=3, collate_fn=collate_fn)


## Building Local Models

In [168]:

from transformers import DistilBertModel, BertModel
import torch.nn as nn
from torch.nn import CrossEntropyLoss
import torch

# init
num_facets = 5 # four facets and "others"
model_name = "bert-base-uncased"
plm = BertModel.from_pretrained(model_name)
dropout = nn.Dropout(0.1) # 0.1 is the dropout rate in the implementation of DistilBertForQuestionAnswering
classifier = nn.Linear(plm.config.hidden_size, num_facets) 



### Train the local model

In [161]:
for batch in dataloader:
    input_ids = torch.tensor(batch[0], dtype=torch.long)
    attention_mask = torch.tensor(batch[1], dtype=torch.long)
    token_type_ids = torch.tensor(batch[2], dtype=torch.long)
    labels = torch.tensor(batch[3], dtype=torch.long)

    break

    

In [171]:
# forward pass
lm_output = plm(
    input_ids=input_ids,
    attention_mask=attention_mask,
    token_type_ids=token_type_ids
)


IndexError: index out of range in self

In [None]:
hidden_states = lm_output[0]  # (bs, max_len, dim)
hidden_states = dropout(hidden_states)  # (bs, max_len, dim)
logits = classifier(hidden_states)  # (bs, max_len, num_facets)

# only calculate the loss on [SEP] tokens inserted at the end of each sentence
sep_indices = (inputs["input_ids"] == tokenizer.sep_token_id).nonzero()
# mask all the other tokens except [SEP] tokens
sep_mask = torch.zeros_like(inputs["input_ids"]).bool()
sep_mask[sep_indices[:, 0], sep_indices[:, 1]] = True


# gather logits for [SEP] tokens
sep_logits = logits[sep_mask].view(-1, num_facets)  # (nu_sep1+num_sep2+..., num_facets)
labels = None