In [1]:
#SETTINGS.PY

import os
from os.path import abspath, dirname, join


PROJ_DIR = join(abspath('./cs145-pst'))
DATA_DIR = join(PROJ_DIR, "data")
OUT_DIR = join(PROJ_DIR, "out")

os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(OUT_DIR, exist_ok=True)
DATA_TRACE_DIR = DATA_DIR


In [2]:
#UTILS.PY

from os.path import join
import json
import numpy as np
import pickle
from collections import defaultdict as dd
from bs4 import BeautifulSoup
from datetime import datetime

import logging

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')  # include timestamp


def load_json(rfdir, rfname):
    logger.info('loading %s ...', rfname)
    with open(join(rfdir, rfname), 'r', encoding='utf-8') as rf:
        data = json.load(rf)
        logger.info('%s loaded', rfname)
        return data


def dump_json(obj, wfdir, wfname):
    logger.info('dumping %s ...', wfname)
    with open(join(wfdir, wfname), 'w', encoding='utf-8') as wf:
        json.dump(obj, wf, indent=4, ensure_ascii=False)
    logger.info('%s dumped.', wfname)


def serialize_embedding(embedding):
    return pickle.dumps(embedding)


def deserialize_embedding(s):
    return pickle.loads(s)


def find_bib_context(xml, dist=125):
    bs = BeautifulSoup(xml, "xml")
    bib_to_context = dd(list)
    bibr_strs_to_bid_id = {}
    for item in bs.find_all(type='bibr'):
        if "target" not in item.attrs:
            continue
        bib_id = item.attrs["target"][1:]
        item_str = "<ref type=\"bibr\" target=\"{}\">{}</ref>".format(item.attrs["target"], item.get_text())
        bibr_strs_to_bid_id[item_str] = bib_id

    for item_str in bibr_strs_to_bid_id:
        bib_id = bibr_strs_to_bid_id[item_str]
        cur_bib_context_pos_start = [ii for ii in range(len(xml)) if xml.startswith(item_str, ii)]
        for pos in cur_bib_context_pos_start:
            bib_to_context[bib_id].append(xml[pos - dist: pos + dist].replace("\n", " ").replace("\r", " ").strip())
    return bib_to_context


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


class Log:
    def __init__(self, file_path):
        self.file_path = file_path
        self.f = open(file_path, 'w+')

    def log(self, s):
        self.f.write(str(datetime.now()) + "\t" + s + '\n')
        self.f.flush()

In [3]:
import os
from os.path import join
from tqdm import tqdm
from collections import defaultdict as dd
from bs4 import BeautifulSoup
from fuzzywuzzy import fuzz
import numpy as np
import torch
from transformers import AutoTokenizer
from transformers import BertForSequenceClassification, get_linear_schedule_with_warmup
from transformers.optimization import AdamW
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from tqdm import trange
from sklearn.metrics import classification_report, precision_recall_fscore_support, average_precision_score
import logging





In [4]:


logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)


MAX_SEQ_LENGTH=512

In [5]:



def prepare_bert_input():
    x_train = []
    y_train = []
    x_valid = []
    y_valid = []
    print()
    data_dir = join(DATA_TRACE_DIR, "PST")
    papers = load_json(data_dir, "paper_source_trace_train_ans.json")




    chain_of_thought = "Task: identify whether each reference to the paper is either"




    # with open(join(data_dir, "bib_context_train.txt"), "w", encoding="utf-8") as f:
    #         f.write(chain_of_thought + "\n")

    n_papers = len(papers)
    papers = sorted(papers, key=lambda x: x["_id"])
    n_train = int(n_papers * 2 / 3)
    # n_valid = n_papers - n_train

    papers_train = papers[:n_train]
    papers_valid = papers[n_train:]

    pids_train = {p["_id"] for p in papers_train}
    pids_valid = {p["_id"] for p in papers_valid}

    in_dir = join(data_dir, "paper-xml")
    files = []
    for f in os.listdir(in_dir):
        if f.endswith(".xml"):
            files.append(f)

    pid_to_source_titles = dd(list)
    for paper in tqdm(papers):
        pid = paper["_id"]
        for ref in paper["refs_trace"]:
            pid_to_source_titles[pid].append(ref["title"].lower())

    # print(pids_train)

    # files = sorted(files)
    # for file in tqdm(files):
    for cur_pid in tqdm(pids_train | pids_valid):
        # cur_pid = file.split(".")[0]
        # if cur_pid not in pids_train and cur_pid not in pids_valid:
            # continue
        f = open(join(in_dir, cur_pid + ".xml"), encoding='utf-8')
        xml = f.read()
        bs = BeautifulSoup(xml, "xml")

        source_titles = pid_to_source_titles[cur_pid]
        if len(source_titles) == 0:
            continue

        references = bs.find_all("biblStruct")
        bid_to_title = {}
        n_refs = 0
        for ref in references:
            if "xml:id" not in ref.attrs:
                continue
            bid = ref.attrs["xml:id"]
            if ref.analytic is None:
                continue
            if ref.analytic.title is None:
                continue
            bid_to_title[bid] = ref.analytic.title.text.lower()
            b_idx = int(bid[1:]) + 1
            if b_idx > n_refs:
                n_refs = b_idx

        flag = False

        cur_pos_bib = set()

        for bid in bid_to_title:
            cur_ref_title = bid_to_title[bid]
            for label_title in source_titles:
                if fuzz.ratio(cur_ref_title, label_title) >= 80:
                    flag = True
                    cur_pos_bib.add(bid)

        cur_neg_bib = set(bid_to_title.keys()) - cur_pos_bib


        if not flag:
            continue

        if len(cur_pos_bib) == 0 or len(cur_neg_bib) == 0:
            continue

        bib_to_contexts = find_bib_context(xml)

        n_pos = len(cur_pos_bib)
        n_neg = n_pos * 10
        cur_neg_bib_sample = np.random.choice(list(cur_neg_bib), n_neg, replace=True)

        is_train = False

        if cur_pid in pids_train:
            cur_x = x_train
            cur_y = y_train
            is_train = True
        elif cur_pid in pids_valid:
            cur_x = x_valid
            cur_y = y_valid
        else:
            continue
            # raise Exception("cur_pid not in train/valid/test")
        label_bids = np.random.choice(list(cur_pos_bib), 3,replace=True)

        for bib in cur_pos_bib:
            cur_context = " ".join(bib_to_contexts[bib])
            if bib in label_bids and is_train:
              cur_context+="- this ref was cited by "+ cur_pid +"as a good paper-source,\
              this ref, " + bib + " is a major source of inspiration for " + cur_pid + ". " + cur_pid + " either used the same core concepts, same investigative or algorithmic methods to achieve, or was inspired by ideas as " + bib + ". Without this, " + cur_pid + " would not exist"
            cur_x.append(cur_context)
            cur_y.append(1)


            # print(cur_context)
            # print("good")




        # print()
        # print(cur_pid)
        for bib in cur_neg_bib_sample:
            cur_context = " ".join(bib_to_contexts[bib])
            if bib in label_bids and is_train:
              cur_context+="- this ref is a bad ref-source, it is likely true that this ref,"+bib+"did not inspire, or contribute greatly in terms of its data, to" + cur_pid+ " and there is a very small relationship between the methodologies, algorithms, or concepts used between the two refs "

            cur_x.append(cur_context)
            cur_y.append(0)
            # print(cur_context)
            # print("bad")

    print("len(x_train)", len(x_train), "len(x_valid)", len(x_valid))



    with open(join(data_dir, "bib_context_train.txt"), "w", encoding="utf-8") as f:
        for line in x_train:
            f.write(line + "\n")

    with open(join(data_dir, "bib_context_valid.txt"), "w", encoding="utf-8") as f:
        for line in x_valid:
            f.write(line + "\n")

    with open(join(data_dir, "bib_context_train_label.txt"), "w", encoding="utf-8") as f:
        for line in y_train:
            f.write(str(line) + "\n")

    with open(join(data_dir, "bib_context_valid_label.txt"), "w", encoding="utf-8") as f:
        for line in y_valid:
            f.write(str(line) + "\n")
# prepare_bert_input()

In [6]:


class BertInputItem(object):
    """An item with all the necessary attributes for finetuning BERT."""

    def __init__(self, text, input_ids, input_mask, segment_ids, label_id):
        self.text = text
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id



In [7]:

def convert_examples_to_inputs(example_texts, example_labels, max_seq_length, tokenizer, verbose=0):
    """Loads a data file into a list of `InputBatch`s."""

    input_items = []
    examples = zip(example_texts, example_labels)
    for (ex_index, (text, label)) in enumerate(examples):

        # Create a list of token ids
        input_ids = tokenizer.encode(f"[CLS] {text} [SEP]")
        if len(input_ids) > max_seq_length:
            input_ids = input_ids[:max_seq_length]

        # All our tokens are in the first input segment (id 0).
        segment_ids = [0] * len(input_ids)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)

        # Zero-pad up to the sequence length.
        padding = [0] * (max_seq_length - len(input_ids))
        input_ids += padding
        input_mask += padding
        segment_ids += padding

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length

        label_id = label

        input_items.append(
            BertInputItem(text=text,
                          input_ids=input_ids,
                          input_mask=input_mask,
                          segment_ids=segment_ids,
                          label_id=label_id))

    return input_items


In [8]:


def get_data_loader(features, max_seq_length, batch_size, shuffle=True):

    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
    data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)

    dataloader = DataLoader(data, shuffle=shuffle, batch_size=batch_size)
    return dataloader



In [9]:

def evaluate(model, dataloader, device, criterion):
    model.eval()

    eval_loss = 0
    nb_eval_steps = 0
    predicted_labels, correct_labels = [], []

    for step, batch in enumerate(tqdm(dataloader, desc="Evaluation iteration")):
        batch = tuple(t.to(device) for t in batch)
        input_ids, input_mask, segment_ids, label_ids = batch

        with torch.no_grad():
            r = model(input_ids, attention_mask=input_mask,
                                          token_type_ids=segment_ids, labels=label_ids)
            # tmp_eval_loss = r[0]
            logits = r[1]
            # print("logits", logits)
            tmp_eval_loss = criterion(logits, label_ids)

        outputs = np.argmax(logits.to('cpu'), axis=1)
        label_ids = label_ids.to('cpu').numpy()

        predicted_labels += list(outputs)
        correct_labels += list(label_ids)

        eval_loss += tmp_eval_loss.mean().item()
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps

    correct_labels = np.array(correct_labels)
    predicted_labels = np.array(predicted_labels)

    return eval_loss, correct_labels, predicted_labels



In [10]:

def train(year=2023, model_name="scibert"):
    print("model name", model_name)
    train_texts = []
    dev_texts = []
    train_labels = []
    dev_labels = []
    data_year_dir = join(DATA_TRACE_DIR, "PST")
    print("data_year_dir", data_year_dir)

    with open(join(data_year_dir, "bib_context_train.txt"), "r", encoding="utf-8") as f:
        for line in f:
            train_texts.append(line.strip())
    with open(join(data_year_dir, "bib_context_valid.txt"), "r", encoding="utf-8") as f:
        for line in f:
            dev_texts.append(line.strip())

    with open(join(data_year_dir, "bib_context_train_label.txt"), "r", encoding="utf-8") as f:
        for line in f:
            train_labels.append(int(line.strip()))
    with open(join(data_year_dir, "bib_context_valid_label.txt"), "r", encoding="utf-8") as f:
        for line in f:
            dev_labels.append(int(line.strip()))


    print("Train size:", len(train_texts))
    print("Dev size:", len(dev_texts))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    class_weight = len(train_labels) / (2 * np.bincount(train_labels))
    class_weight = torch.Tensor(class_weight).to(device)
    print("Class weight:", class_weight)

    if model_name == "bert":
        BERT_MODEL = "bert-base-uncased"
    elif model_name == "scibert":
        BERT_MODEL = "allenai/scibert_scivocab_uncased"
    else:
        raise NotImplementedError
    tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL)

    model = BertForSequenceClassification.from_pretrained(BERT_MODEL, num_labels = 2)
    model.to(device)

    criterion = torch.nn.CrossEntropyLoss(weight=class_weight)

    '''
    ##### Sampling start
    import random

    # # Set your desired sample size
    SAMPLE_SIZE = 100

    # # Randomly select a subset of your data
    train_texts_sample = random.sample(train_texts, SAMPLE_SIZE) # train_texts sampling instead
    train_labels_sample = random.sample(train_labels, SAMPLE_SIZE)

    train_features = convert_examples_to_inputs(train_texts_sample, train_labels_sample, MAX_SEQ_LENGTH, tokenizer, verbose=0)
    dev_features = convert_examples_to_inputs(dev_texts, dev_labels, MAX_SEQ_LENGTH, tokenizer)

    BATCH_SIZE = 16
    train_dataloader = get_data_loader(train_features, MAX_SEQ_LENGTH, BATCH_SIZE, shuffle=True)
    dev_dataloader = get_data_loader(dev_features, MAX_SEQ_LENGTH, BATCH_SIZE, shuffle=False)

    #### Sampling end
    '''
    # train_features_sample = convert_examples_to_inputs(train_texts_sample, train_labels_sample, MAX_SEQ_LENGTH, tokenizer, verbose=0)
    # train_dataloader_sample = get_data_loader(train_features_sample, MAX_SEQ_LENGTH, BATCH_SIZE, shuffle=True)

    ### OLD CODE:

    train_features = convert_examples_to_inputs(train_texts, train_labels, MAX_SEQ_LENGTH, tokenizer, verbose=0)
    dev_features = convert_examples_to_inputs(dev_texts, dev_labels, MAX_SEQ_LENGTH, tokenizer)

    BATCH_SIZE = 16
    train_dataloader = get_data_loader(train_features, MAX_SEQ_LENGTH, BATCH_SIZE, shuffle=True)
    dev_dataloader = get_data_loader(dev_features, MAX_SEQ_LENGTH, BATCH_SIZE, shuffle=False)

    ####

    GRADIENT_ACCUMULATION_STEPS = 1
    NUM_TRAIN_EPOCHS = 20
    LEARNING_RATE = 5e-5
    WARMUP_PROPORTION = 0.1
    MAX_GRAD_NORM = 5

    num_train_steps = int(len(train_dataloader.dataset) / BATCH_SIZE / GRADIENT_ACCUMULATION_STEPS * NUM_TRAIN_EPOCHS)
    num_warmup_steps = int(WARMUP_PROPORTION * num_train_steps)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=LEARNING_RATE, correct_bias=False)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps)

    OUTPUT_DIR = join(OUT_DIR, "kddcup", model_name)
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    MODEL_FILE_NAME = "pytorch_model.bin"
    PATIENCE = 5

    loss_history = []
    no_improvement = 0
    for _ in trange(int(NUM_TRAIN_EPOCHS), desc="Epoch"):
        model.train()
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(tqdm(train_dataloader, desc="Training iteration")):
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch

            outputs = model(input_ids, attention_mask=input_mask, token_type_ids=segment_ids, labels=label_ids)
            # loss = outputs[0]
            logits = outputs[1]

            loss = criterion(logits, label_ids)

            if GRADIENT_ACCUMULATION_STEPS > 1:
                loss = loss / GRADIENT_ACCUMULATION_STEPS

            loss.backward()
            tr_loss += loss.item()

            if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)

                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()

        dev_loss, _, _ = evaluate(model, dev_dataloader, device, criterion)

        print("Loss history:", loss_history)
        print("Dev loss:", dev_loss)

        if len(loss_history) == 0 or dev_loss < min(loss_history):
            no_improvement = 0
            model_to_save = model.module if hasattr(model, 'module') else model
            output_model_file = os.path.join(OUTPUT_DIR, MODEL_FILE_NAME)
            torch.save(model_to_save.state_dict(), output_model_file)
        else:
            no_improvement += 1

        if no_improvement >= PATIENCE:
            print("No improvement on development set. Finish training.")
            break

        loss_history.append(dev_loss)


In [11]:

def eval_test_papers_bert(year=2023, model_name="scibert"):
    print("model name", model_name)
    np.random.seed()
    torch.manual_seed(0)
    data_dir = join(DATA_TRACE_DIR, "PST")
    papers_test = load_json(data_dir, "paper_source_trace_train_ans.json")
    pids_test = {p["_id"] for p in papers_test}

    in_dir = join(data_dir, "paper-xml")
    files = []
    for f in os.listdir(in_dir):
        cur_pid = f.split(".")[0]
        if f.endswith(".xml") and cur_pid in pids_test:
            files.append(f)

    truths = papers_test
    pid_to_source_titles = dd(list)
    for paper in tqdm(truths):
        pid = paper["_id"]
        for ref in paper["refs_trace"]:
            pid_to_source_titles[pid].append(ref["title"].lower())

    if model_name == "bert":
        BERT_MODEL = "bert-base-uncased"
    elif model_name == "scibert":
        BERT_MODEL = "allenai/scibert_scivocab_uncased"
    else:
        raise NotImplementedError
    tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("device", device)
    torch.cuda.seed()
    model = BertForSequenceClassification.from_pretrained(BERT_MODEL, num_labels = 2)
    model.load_state_dict(torch.load(join(OUT_DIR, "kddcup", model_name, "pytorch_model.bin")))
    model.to(device)
    model.eval()

    BATCH_SIZE = 16
    metrics = []
    f_idx = 0

    xml_dir = join(data_dir, "paper-xml")

    for paper in tqdm(papers_test):
        cur_pid = paper["_id"]
        file = join(xml_dir, cur_pid + ".xml")
        f = open(file, encoding='utf-8')

        xml = f.read()
        bs = BeautifulSoup(xml, "xml")
        f.close()

        source_titles = pid_to_source_titles[cur_pid]
        if len(source_titles) == 0:
            continue

        references = bs.find_all("biblStruct")
        bid_to_title = {}
        n_refs = 0
        for ref in references:
            if "xml:id" not in ref.attrs:
                continue
            bid = ref.attrs["xml:id"]
            if ref.analytic is None:
                continue
            if ref.analytic.title is None:
                continue
            bid_to_title[bid] = ref.analytic.title.text.lower()
            b_idx = int(bid[1:]) + 1
            if b_idx > n_refs:
                n_refs = b_idx

        bib_to_contexts = find_bib_context(xml)
        bib_sorted = sorted(bib_to_contexts.keys())

        for bib in bib_sorted:
            cur_bib_idx = int(bib[1:])
            if cur_bib_idx + 1 > n_refs:
                n_refs = cur_bib_idx + 1

        y_true = [0] * n_refs
        y_score = [0] * n_refs

        flag = False
        for bid in bid_to_title:
            cur_ref_title = bid_to_title[bid]
            for label_title in source_titles:
                if fuzz.ratio(cur_ref_title, label_title) >= 80:
                    flag = True
                    b_idx = int(bid[1:])
                    y_true[b_idx] = 1

        if not flag:
            continue

        contexts_sorted = [" ".join(bib_to_contexts[bib]) for bib in bib_sorted]

        test_features = convert_examples_to_inputs(contexts_sorted, y_score, MAX_SEQ_LENGTH, tokenizer)
        test_dataloader = get_data_loader(test_features, MAX_SEQ_LENGTH, BATCH_SIZE, shuffle=False)

        predicted_scores = []
        num_votes = 2

        for step, batch in enumerate(test_dataloader):
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch

            with torch.no_grad():
                r = model(input_ids, attention_mask=input_mask,
                                            token_type_ids=segment_ids, labels=label_ids)
                tmp_eval_loss = r[0]
                logits = r[1]

            cur_pred_scores = logits[:, 1].to('cpu').numpy()
            predicted_scores.extend(cur_pred_scores)
        try:
            for ii in range(len(predicted_scores)):
                bib_idx = int(bib_sorted[ii][1:])
                # print("bib_idx", bib_idx)
                y_score[bib_idx] = predicted_scores[ii]
        except IndexError as e:
            metrics.append(0)
            continue
        print()
        print(pid)
        print(y_true)
        print(y_score)

        cur_map = average_precision_score(y_true, y_score)
        print(cur_map)
        metrics.append(cur_map)
        f_idx += 1
        if f_idx % 20 == 0:
            print("map until now", np.mean(metrics), len(metrics), cur_map)

    print("bert average map", np.mean(metrics), len(metrics))
# eval_test_papers_bert(model_name="scibert")

# 5d64ff713a55acf547f20de0
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1]
# [-0.4278647, 0.90334725, -0.42784128, -0.4278638, -0.427869, -0.42787018, -0.4278634, -0.42786795, -0.4278641, -0.4278637, -0.42786762, -0.42786834, -0.42786673, 0, -0.42786798, -0.42786488, -0.42786017, -0.42786357, -0.4278686, -0.4278579, -0.42778033, -0.4278597, 0.93695736, -0.42786908, -0.42786768, -0.42785612, -0.42786503, -0.42785165, -0.42786703, -0.4278665, -0.4278632, -0.42786083, -0.42785737, -0.42786503, -0.42786512, -0.42785725]
# 0.1369047619047619
#   5%|▍         | 38/788 [00:52<20:40,  1.65s/it]
# 5d64ff713a55acf547f20de0
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
# [-0.4278641, -0.4278689, -0.4278667, -0.4278675, -0.42786357, -0.42786267, 0.9295018, -0.42786026, -0.42786428, -0.42786786, -0.42786148, -0.4278604, -0.42786232, -0.42785642, -0.42786238, -0.4278625, -0.4278623, -0.4278668, -0.4278668, -0.42786846, -0.42786422, -0.42786428, -0.42786404, -0.4278691, -0.42786577, -0.42786565, -0.42786655, -0.4278649, -0.42787036, -0.42785692, -0.42786875, -0.42786264, -0.4278636, -0.42785904, -0.42785874, -0.42785788, -0.42785254, -0.42786062, -0.42786333, -0.4278675, -0.42786503, -0.42786565, -0.42786527]
# 0.025
#   5%|▍         | 39/788 [00:54<20:54,  1.68s/it]
# 5d64ff713a55acf547f20de0
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
# [-0.42786792, -0.4278617, -0.42786357, -0.42786372, -0.42786276, -0.42786008, -0.42786318, -0.42784268, -0.4278629, -0.4278631, -0.42786226, -0.42786106, -0.4278614, -0.42786056, 0, -0.42786267, -0.42786226, -0.42786205, -0.427862, -0.4278546, -0.42786208, -0.42785814, -0.42785802, -0.4278607, -0.42786166, -0.42786142, 0.9328686, -0.42786363, -0.4278657, -0.42786095, -0.42786294, -0.42786387, -0.42786157, -0.42786455, -0.42786342, -0.42786512, -0.42786393, -0.42786592, -0.4278669, -0.42786554, -0.42786416, -0.42786455, -0.42786884, -0.42786983, -0.42786354, -0.42786333]
# 0.03125
#   5%|▌         | 40/788 [00:55<19:32,  1.57s/it]
# 5d64ff713a55acf547f20de0
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]
# [-0.42786705, -0.42785317, -0.42786416, -0.42786226, -0.42784348, -0.42786118, -0.42786083, -0.42786524, -0.42786404, -0.4278606, -0.42786357, -0.42786172, -0.42786714, -0.427869, -0.42784533, -0.42785648, -0.42786285, -0.42786732, -0.4278643, -0.4278582, -0.4278653, -0.42786318, -0.42786482, -0.4278584, -0.42786312, -0.4278637, -0.42786145, -0.42786247, -0.42786458, -0.4278674, -0.42786643, -0.4278582, -0.42786565, -0.42786163, -0.4278613, 0.9137402, -0.4278613, -0.42786562, -0.42786422]
# 0.6428571428571428
# map until now 0.30086186532183345 40 0.6428571428571428
#   5%|▌         | 41/788 [00:57<20:42,  1.66s/it]
# 5d64ff713a55acf547f20de0
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
# [-0.42786184, -0.42785436, -0.42786497, -0.4278639, -0.42786363, -0.4278668, -0.42785478, -0.4278653, -0.4278662, -0.4278629, -0.42786625, -0.4278662, -0.42784128, -0.42786628, -0.42786375, 0.9282429, -0.42786157, -0.42786124, -0.42786247, -0.42785513, -0.42784622, -0.42786673, -0.42786387, -0.42786548, -0.42786312, -0.4278602, -0.42786297, -0.42785898, -0.4278649, -0.42786685, -0.42786133, 0.9111156, 0.9354466, -0.42786112, -0.42785728, -0.42786348, -0.42786467, -0.42785886, -0.42777646, -0.42783096, -0.42785648, 0, -0.4278538, -0.4278632, -0.42786378, -0.42786476]
# 0.021739130434782608

In [12]:

def gen_kddcup_valid_submission_bert(model_name="scibert", num_votes=3):
    print("model name", model_name)
    data_dir = join(DATA_TRACE_DIR, "PST")
    papers = load_json(data_dir, "paper_source_trace_valid_wo_ans.json")

    if model_name == "bert":
        BERT_MODEL = "bert-base-uncased"
    elif model_name == "scibert":
        BERT_MODEL = "allenai/scibert_scivocab_uncased"
    else:
        raise NotImplementedError
    tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL)

    sub_example_dict = load_json(data_dir, "submission_example_valid.json")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("device", device)
    model = BertForSequenceClassification.from_pretrained(BERT_MODEL, num_labels = 2)
    model.load_state_dict(torch.load(join(OUT_DIR, "kddcup", model_name, "pytorch_model.bin")))

    model.to(device)
    model.eval()

    BATCH_SIZE = 16
    # metrics = []
    # f_idx = 0

    xml_dir = join(data_dir, "paper-xml")
    sub_dict = {}

    for paper in tqdm(papers):
        cur_pid = paper["_id"]
        file = join(xml_dir, cur_pid + ".xml")
        f = open(file, encoding='utf-8')
        xml = f.read()
        bs = BeautifulSoup(xml, "xml")
        f.close()

        references = bs.find_all("biblStruct")
        bid_to_title = {}
        n_refs = 0
        for ref in references:
            if "xml:id" not in ref.attrs:
                continue
            bid = ref.attrs["xml:id"]
            if ref.analytic is None:
                continue
            if ref.analytic.title is None:
                continue
            bid_to_title[bid] = ref.analytic.title.text.lower()
            b_idx = int(bid[1:]) + 1
            if b_idx > n_refs:
                n_refs = b_idx

        bib_to_contexts = find_bib_context(xml)
        # bib_sorted = sorted(bib_to_contexts.keys())
        bib_sorted = ["b" + str(ii) for ii in range(n_refs)]

        y_score = [0] * n_refs

        assert len(sub_example_dict[cur_pid]) == n_refs
        # continue

        contexts_sorted = [" ".join(bib_to_contexts[bib]) for bib in bib_sorted]

        test_features = convert_examples_to_inputs(contexts_sorted, y_score, MAX_SEQ_LENGTH, tokenizer)
        test_dataloader = get_data_loader(test_features, MAX_SEQ_LENGTH, BATCH_SIZE, shuffle=False)

        predicted_scores = []
        for step, batch in enumerate(test_dataloader):
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch

            with torch.no_grad():
                r = model(input_ids, attention_mask=input_mask,
                                            token_type_ids=segment_ids, labels=label_ids)
                tmp_eval_loss = r[0]
                logits = r[1]

            cur_pred_scores = logits[:, 1].to('cpu').numpy()
            predicted_scores.extend(cur_pred_scores)

        for ii in range(len(predicted_scores)):
            bib_idx = int(bib_sorted[ii][1:])
            # print("bib_idx", bib_idx)
            y_score[bib_idx] = float(sigmoid(predicted_scores[ii]))

        sub_dict[cur_pid] = y_score
    dump_json(sub_dict, join(OUT_DIR, "kddcup", model_name), "valid_submission_scibert.json")




In [None]:
    prepare_bert_input()
    train(model_name="scibert")
    eval_test_papers_bert(model_name="scibert")
    # gen_kddcup_valid_submission_bert(model_name="scibert")




100%|██████████| 788/788 [00:00<00:00, 468345.13it/s]
100%|██████████| 788/788 [10:23<00:00,  1.26it/s]


len(x_train) 7634 len(x_valid) 4037
model name scibert
data_year_dir /content/cs145-pst/data/PST
Train size: 7634
Dev size: 4037
Class weight: tensor([0.5500, 5.5000], device='cuda:0')


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at allenai/scibert_scivocab_uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch:   0%|          | 0/20 [00:00<?, ?it/s]
Training iteration:   0%|          | 0/478 [00:00<?, ?it/s][A
Training iteration:   0%|          | 1/478 [00:01<13:09,  1.65s/it][A
Training iteration:   0%|          | 2/478 [00:02<08:49,  1.11s/it][A
Training iteration:   1%|    

Loss history: []
Dev loss: 1.2533905156354306


Epoch:   5%|▌         | 1/20 [07:28<2:22:09, 448.90s/it]
Training iteration:   0%|          | 0/478 [00:00<?, ?it/s][A
Training iteration:   0%|          | 1/478 [00:00<06:00,  1.32it/s][A
Training iteration:   0%|          | 2/478 [00:01<06:10,  1.28it/s][A
Training iteration:   1%|          | 3/478 [00:02<06:09,  1.28it/s][A
Training iteration:   1%|          | 4/478 [00:03<06:08,  1.29it/s][A
Training iteration:   1%|          | 5/478 [00:03<06:09,  1.28it/s][A
Training iteration:   1%|▏         | 6/478 [00:04<06:09,  1.28it/s][A
Training iteration:   1%|▏         | 7/478 [00:05<06:08,  1.28it/s][A
Training iteration:   2%|▏         | 8/478 [00:06<06:07,  1.28it/s][A
Training iteration:   2%|▏         | 9/478 [00:07<06:07,  1.28it/s][A
Training iteration:   2%|▏         | 10/478 [00:07<06:07,  1.27it/s][A
Training iteration:   2%|▏         | 11/478 [00:08<06:05,  1.28it/s][A
Training iteration:   3%|▎         | 12/478 [00:09<06:03,  1.28it/s][A
Training iteration:   3%|

Loss history: [1.2533905156354306]
Dev loss: 0.9472815813901632


Epoch:  10%|█         | 2/20 [15:00<2:15:12, 450.68s/it]
Training iteration:   0%|          | 0/478 [00:00<?, ?it/s][A
Training iteration:   0%|          | 1/478 [00:00<05:59,  1.33it/s][A
Training iteration:   0%|          | 2/478 [00:01<06:09,  1.29it/s][A
Training iteration:   1%|          | 3/478 [00:02<06:10,  1.28it/s][A
Training iteration:   1%|          | 4/478 [00:03<06:07,  1.29it/s][A
Training iteration:   1%|          | 5/478 [00:03<06:06,  1.29it/s][A
Training iteration:   1%|▏         | 6/478 [00:04<06:06,  1.29it/s][A
Training iteration:   1%|▏         | 7/478 [00:05<06:08,  1.28it/s][A
Training iteration:   2%|▏         | 8/478 [00:06<06:06,  1.28it/s][A
Training iteration:   2%|▏         | 9/478 [00:07<06:06,  1.28it/s][A
Training iteration:   2%|▏         | 10/478 [00:07<06:07,  1.27it/s][A
Training iteration:   2%|▏         | 11/478 [00:08<06:06,  1.27it/s][A
Training iteration:   3%|▎         | 12/478 [00:09<06:05,  1.27it/s][A
Training iteration:   3%|

Loss history: [1.2533905156354306, 0.9472815813901632]
Dev loss: 1.3025968597958917



Training iteration:   0%|          | 0/478 [00:00<?, ?it/s][A
Training iteration:   0%|          | 1/478 [00:00<06:08,  1.30it/s][A
Training iteration:   0%|          | 2/478 [00:01<06:10,  1.29it/s][A
Training iteration:   1%|          | 3/478 [00:02<06:09,  1.28it/s][A
Training iteration:   1%|          | 4/478 [00:03<06:08,  1.29it/s][A
Training iteration:   1%|          | 5/478 [00:03<06:07,  1.29it/s][A
Training iteration:   1%|▏         | 6/478 [00:04<06:09,  1.28it/s][A
Training iteration:   1%|▏         | 7/478 [00:05<06:09,  1.27it/s][A
Training iteration:   2%|▏         | 8/478 [00:06<06:07,  1.28it/s][A
Training iteration:   2%|▏         | 9/478 [00:07<06:06,  1.28it/s][A
Training iteration:   2%|▏         | 10/478 [00:07<06:05,  1.28it/s][A
Training iteration:   2%|▏         | 11/478 [00:08<06:04,  1.28it/s][A
Training iteration:   3%|▎         | 12/478 [00:09<06:04,  1.28it/s][A
Training iteration:   3%|▎         | 13/478 [00:10<06:03,  1.28it/s][A
Training 

Loss history: [1.2533905156354306, 0.9472815813901632, 1.3025968597958917]
Dev loss: 1.421339905375372



Training iteration:   0%|          | 0/478 [00:00<?, ?it/s][A
Training iteration:   0%|          | 1/478 [00:00<06:00,  1.32it/s][A
Training iteration:   0%|          | 2/478 [00:01<06:09,  1.29it/s][A
Training iteration:   1%|          | 3/478 [00:02<06:10,  1.28it/s][A
Training iteration:   1%|          | 4/478 [00:03<06:10,  1.28it/s][A
Training iteration:   1%|          | 5/478 [00:03<06:08,  1.28it/s][A
Training iteration:   1%|▏         | 6/478 [00:04<06:09,  1.28it/s][A
Training iteration:   1%|▏         | 7/478 [00:05<06:07,  1.28it/s][A
Training iteration:   2%|▏         | 8/478 [00:06<06:07,  1.28it/s][A
Training iteration:   2%|▏         | 9/478 [00:07<06:04,  1.29it/s][A
Training iteration:   2%|▏         | 10/478 [00:07<06:06,  1.28it/s][A
Training iteration:   2%|▏         | 11/478 [00:08<06:05,  1.28it/s][A
Training iteration:   3%|▎         | 12/478 [00:09<06:05,  1.28it/s][A
Training iteration:   3%|▎         | 13/478 [00:10<06:04,  1.28it/s][A
Training 

Loss history: [1.2533905156354306, 0.9472815813901632, 1.3025968597958917, 1.421339905375372]
Dev loss: 1.3098287425361015



Training iteration:   0%|          | 0/478 [00:00<?, ?it/s][A
Training iteration:   0%|          | 1/478 [00:00<06:06,  1.30it/s][A
Training iteration:   0%|          | 2/478 [00:01<06:10,  1.29it/s][A
Training iteration:   1%|          | 3/478 [00:02<06:10,  1.28it/s][A
Training iteration:   1%|          | 4/478 [00:03<06:10,  1.28it/s][A
Training iteration:   1%|          | 5/478 [00:03<06:08,  1.28it/s][A
Training iteration:   1%|▏         | 6/478 [00:04<06:08,  1.28it/s][A
Training iteration:   1%|▏         | 7/478 [00:05<06:08,  1.28it/s][A
Training iteration:   2%|▏         | 8/478 [00:06<06:08,  1.28it/s][A
Training iteration:   2%|▏         | 9/478 [00:07<06:09,  1.27it/s][A
Training iteration:   2%|▏         | 10/478 [00:07<06:09,  1.27it/s][A
Training iteration:   2%|▏         | 11/478 [00:08<06:08,  1.27it/s][A
Training iteration:   3%|▎         | 12/478 [00:09<06:06,  1.27it/s][A
Training iteration:   3%|▎         | 13/478 [00:10<06:05,  1.27it/s][A
Training 

Loss history: [1.2533905156354306, 0.9472815813901632, 1.3025968597958917, 1.421339905375372, 1.3098287425361015]
Dev loss: 1.3805538772461796



Training iteration:   0%|          | 0/478 [00:00<?, ?it/s][A
Training iteration:   0%|          | 1/478 [00:00<06:04,  1.31it/s][A
Training iteration:   0%|          | 2/478 [00:01<06:08,  1.29it/s][A
Training iteration:   1%|          | 3/478 [00:02<06:08,  1.29it/s][A
Training iteration:   1%|          | 4/478 [00:03<06:08,  1.29it/s][A
Training iteration:   1%|          | 5/478 [00:03<06:07,  1.29it/s][A
Training iteration:   1%|▏         | 6/478 [00:04<06:09,  1.28it/s][A
Training iteration:   1%|▏         | 7/478 [00:05<06:08,  1.28it/s][A
Training iteration:   2%|▏         | 8/478 [00:06<06:07,  1.28it/s][A
Training iteration:   2%|▏         | 9/478 [00:07<06:07,  1.28it/s][A
Training iteration:   2%|▏         | 10/478 [00:07<06:07,  1.28it/s][A
Training iteration:   2%|▏         | 11/478 [00:08<06:05,  1.28it/s][A
Training iteration:   3%|▎         | 12/478 [00:09<06:05,  1.28it/s][A
Training iteration:   3%|▎         | 13/478 [00:10<06:04,  1.28it/s][A
Training 