# Baseline
python: 3.8.*

use ```Ctrl + ]``` to collapse all section :)

Download our starter pack (3~5 min)

In [1]:
!gdown --folder 1T6jpOtdf_i6XNYA6F_lqU4mRRh1xYPcl
!mv baseline/* ./

Retrieving folder list
Retrieving folder 15cJ_K2Tm95UkHgnvvPdMqGQy0_At6lRI checkpoints
Retrieving folder 1jIWaLk4VLXEEm7au12mH80bKqFS3-ZKk claim_verification
Retrieving folder 1DVfRginiE0thVjKELz7DCsEMYltRaTDG e20_bs32_7e-05_top5
Processing file 1AGMfElEghEnYWLYU0YURZrBzmskYuovm val_acc=0.4259_model.750.pt
Retrieving folder 19-pRgA1elB6U2UklW6DmUUm9R4qraAu8 sent_retrieval
Retrieving folder 15zhuj7t3s5vG4JtFS_PNRTtE5l9mv0nZ e1_bs64_2e-05_neg0.03_top5
Processing file 1PBBfgu9lV8_hHdmZRT99bAk_Zl51MZcr model.50.pt
Retrieving folder 1r742uPeVnqUm04qUEZpzzpxXZFbj4gNd data
Processing file 1hBMys30E2Tw-QCt8FDKKnu-eF_G1tGio dev_doc5sent5.jsonl
Processing file 1iHSNpooXDLurizn9sxUbY1Fpk-h9La3i dev_doc5sent5.pkl
Processing file 1a0e__D64L8CWhZsvuYZZuXiOZ7a99Bnj hanlp_con_results.pkl
Processing file 1BFeoxCm5n0fauW9rbNqy6Se9dK9RnbEJ hanlp_con_test_results.pkl
Processing file 1rTvOEQAZG7Hs5aVlTvpnZSW89JWNTUzK public_test.jsonl
Processing file 1TbIsMs71WZP2kRpgnn2U383mPavu1G-b public_train.jsonl
Pro

In [None]:
# %pip install -r requirements.txt

notebook1
## PART 1. Document retrieval

Prepare the environment and import all library we need

In [1]:
# built-in libs
import json
import pickle
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Set, Tuple, Union
from tqdm import tqdm
# 3rd party libs
import hanlp
import opencc
import pandas as pd
import wikipedia
from hanlp.components.pipeline import Pipeline
from pandarallel import pandarallel
from transformers import AutoTokenizer,AutoModel
import torch
# our own libs
from utils import (
    generate_evidence_to_wiki_pages_mapping,
    jsonl_dir_to_df,
    load_json,
    load_model,
    save_checkpoint,
    set_lr_scheduler,
)
from torch.multiprocessing import Pool, Process, set_start_method

pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=5)
wikipedia.set_lang("zh")



Data, Model setting

In [2]:
TRAIN_DATA = load_json("data/public_train_all.jsonl")
TEST_DATA = load_json("data/private/public_private_combine_test_data.jsonl")
CONVERTER_T2S = opencc.OpenCC("t2s.json")
CONVERTER_S2T = opencc.OpenCC("s2t.json")

simcse_model=AutoModel.from_pretrained("IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese")
simcse_tok=AutoTokenizer.from_pretrained("IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese")
device = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")

simcse_model.to(device)


Some weights of the model checkpoint at IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese were not used when initializing BertModel: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(21128, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [3]:
wiki_pages = jsonl_dir_to_df("data/wiki-pages")
mapping = generate_evidence_to_wiki_pages_mapping(wiki_pages)
del wiki_pages

Reading and concatenating jsonl files in data/wiki-pages
Generate parse mapping


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=39592), Label(value='0 / 39592')))…

Transform to id to evidence_map mapping


Data class for type hinting

In [4]:
@dataclass
class Claim:
    data: str

@dataclass
class AnnotationID:
    id: int

@dataclass
class EvidenceID:
    id: int

@dataclass
class PageTitle:
    title: str

@dataclass
class SentenceID:
    id: int

@dataclass
class Evidence:
    data: List[List[Tuple[AnnotationID, EvidenceID, PageTitle, SentenceID]]]

### Helper function

For the sake of consistency, we convert traditional to simplified Chinese first before converting it back to traditional Chinese.  This is due to some errors occuring when converting traditional to traditional Chinese.

In [5]:
def do_st_corrections(text: str) -> str:
    simplified = CONVERTER_T2S.convert(text)

    return CONVERTER_S2T.convert(simplified)

We use constituency parsing to separate part of speeches or so called constituent to extract noun phrases.  In the later stages, we will use the noun phrases as the query to search for relevant documents.  

In [6]:
def get_nps_hanlp(
    predictor: Pipeline,
    d: Dict[str, Union[int, Claim, Evidence]],
) -> List[str]:
    claim = d["claim"]
    tree = predictor(CONVERTER_T2S.convert(claim))["con"]
    nps = [
        do_st_corrections("".join(subtree.leaves()))
        for subtree in tree.subtrees(lambda t: ("NP" in t.label() or "obj" in t.label() or "IP" in t.label()))
    ]

    return nps

Precision refers to how many related documents are retrieved.  Recall refers to how many relevant documents are retrieved.  

In [7]:
def calculate_precision(
    data: List[Dict[str, Union[int, Claim, Evidence]]],
    predictions: pd.Series,
) -> None:
    precision = 0
    count = 0

    for i, d in enumerate(data):
        if d["label"] == "NOT ENOUGH INFO":
            continue

        # Extract all ground truth of titles of the wikipedia pages
        # evidence[2] refers to the title of the wikipedia page
        gt_pages = set([
            evidence[2]
            for evidence_set in d["evidence"]
            for evidence in evidence_set
        ])

        predicted_pages = set(predictions.iloc[i][:10])
        hits = predicted_pages.intersection(gt_pages)
        # print(predicted_pages,hits)
        if len(predicted_pages) != 0:
            precision += len(hits) / len(predicted_pages)
            # precision += len(hits) / 5

        count += 1

    print(f"Precision: {precision / count}")


def calculate_recall(
    data: List[Dict[str, Union[int, Claim, Evidence]]],
    predictions: pd.Series,
) -> None:
    recall = 0
    count = 0

    for i, d in enumerate(data):
        if d["label"] == "NOT ENOUGH INFO":
            continue

        gt_pages = set([
            evidence[2]
            for evidence_set in d["evidence"]
            for evidence in evidence_set
        ])
        predicted_pages = set(predictions.iloc[i][:10])

        hits = predicted_pages.intersection(gt_pages)

        recall += len(hits) / len(gt_pages)
        count += 1
  
    print(f"Recall: {recall / count}")

The default amount of documents retrieved is at most five documents.  This `num_pred_doc` can be adjusted based on your objective.  Save data in jsonl format.

In [8]:
def save_doc(
    data: List[Dict[str, Union[int, Claim, Evidence]]],
    predictions: pd.Series,
    mode: str = "train",
    num_pred_doc: int = 5,
) -> None:
    with open(
        f"data/{mode}_doc{num_pred_doc}.jsonl",
        "w",
        encoding="utf8",
    ) as f:
        for i, d in enumerate(data):
            d["predicted_pages"] = list(predictions.iloc[i])
            f.write(json.dumps(d, ensure_ascii=False) + "\n")

### Main function for document retrieval

In [9]:
def get_pred_pages(series_data):
    nps = series_data["hanlp_results"]


    wiki_search_results = [
        do_st_corrections(j) for w in nps for j in wikipedia.search(w)
    ]
    wiki_search_results = list(set(wiki_search_results))
    tmp_wiki_search_results=[]
    sentence=[]
    for page in wiki_search_results:
        try:
            context = mapping[page.replace(" ","_").replace("-", "")].values()
            tt=[]
            for setn in context:
                if "|" not in setn or "=" not in setn :
                    tt.append(setn)
            sentence.append(" ".join(tt))
            tmp_wiki_search_results.append(page)
        except:
            continue
    wiki_search_results = tmp_wiki_search_results
    return [wiki_search_results,sentence]

### Step 1. Get noun phrases from hanlp consituency parsing tree

Setup [HanLP](https://github.com/hankcs/HanLP) predictor (1 min)

In [10]:
predictor = (hanlp.pipeline().append(
    hanlp.load("MSR_TOK_ELECTRA_BASE_CRF"),
    output_key="tok",
).append(
    hanlp.load("CTB9_CON_FULL_TAG_ERNIE_GRAM"),
    output_key="con",
    input_key="tok",
))

Building model [5m[33m...[0m[0m

KeyboardInterrupt: 

We will skip this process which for creating parsing tree when demo on class

In [None]:
hanlp_file = f"data/hanlp_con_results.pkl"
if Path(hanlp_file).exists():
    with open(hanlp_file, "rb") as f:
        hanlp_results = pickle.load(f)
else:
    hanlp_results = [get_nps_hanlp(predictor, d) for d in TRAIN_DATA]
    with open(hanlp_file, "wb") as f:
        pickle.dump(hanlp_results, f)

Get pages via wiki online api

In [11]:
doc_path = f"data/train_doc5_all_page_other_predictor.jsonl"
if Path(doc_path).exists():
    train_df = pd.DataFrame(load_json(doc_path))
else:
    train_df = pd.DataFrame(TRAIN_DATA)
    train_df.loc[:, "hanlp_results"] = hanlp_results
    result = train_df.parallel_apply(get_pred_pages, axis=1)
    predicted_results=[]
    predicted_sentence=[]
    for elmt in result:
        predicted_results.append(elmt[0])
        predicted_sentence.append(elmt[1])
    train_df.loc[:,"all_page"]=predicted_results
    train_df.loc[:,"all_sentences"]=predicted_sentence
    train_df[["id","label", "claim","evidence","hanlp_results", "all_page","all_sentences"]].to_json(
        "data/train_doc5_all_page_other_predictor.jsonl",
        orient="records",
        lines=True,
        force_ascii=False,
    )


In [13]:
cos = torch.nn.CosineSimilarity(dim=1)
simcse_model.eval()
results=[]
batch=16
for i, elmt in tqdm(train_df.iterrows()):
    page_score={}
    tmp=[]
    for title in elmt['all_page']:
        if title in elmt["claim"] or title.replace('·',"") in elmt["claim"]:
            tmp.append(title)
    claim_input=simcse_tok(CONVERTER_T2S.convert(elmt["claim"]),return_tensors="pt").to(device)
    claim_output = simcse_model(**claim_input).pooler_output
    for i in range(0,len(elmt['all_page']),batch):

        branch_1_term_label=[]
        branch_1_term=[]
        branch_1_sentence=[]

        branch_2_term_label=[]
        branch_2_term_sentence=[]
        for term,sentence in zip(elmt['all_page'][i:i+batch],elmt['all_sentences'][i:i+batch]):
            if term in tmp:
                branch_1_term_label.append(term)
                branch_1_term.append(simcse_tok(CONVERTER_T2S.convert(term),max_length=256,truncation=True)["input_ids"])
                branch_1_sentence.append(simcse_tok(CONVERTER_T2S.convert(sentence),max_length=256,truncation=True)["input_ids"])
            else:
                branch_2_term_label.append(term)
                branch_2_term_sentence.append(simcse_tok(CONVERTER_T2S.convert(term+" : "+sentence),max_length=256,truncation=True)["input_ids"])
        if branch_1_term_label!=[]:
            term_max_len= max(map(len,branch_1_term))
            sentences_max_len= max(map(len,branch_1_sentence))
            term_pad=[]
            sentence_pad=[]
            for i in range(len(branch_1_term)):
                if len(branch_1_term[i])<term_max_len:
                    term_pad.append(branch_1_term[i]+[simcse_tok.pad_token_id]*(term_max_len-len(branch_1_term[i])))
                else:
                    term_pad.append(branch_1_term[i])

                if len(branch_1_sentence[i])<sentences_max_len:
                    sentence_pad.append(branch_1_sentence[i]+[simcse_tok.pad_token_id]*(sentences_max_len-len(branch_1_sentence[i])))
                else:
                    sentence_pad.append(branch_1_sentence[i])
            term_pad = torch.tensor(term_pad).to(device)
            term_attention_mask = (term_pad!=simcse_tok.pad_token_id).to(device)

            sentence_pad = torch.tensor(sentence_pad).to(device)
            sen_attention_mask = (sentence_pad!=simcse_tok.pad_token_id).to(device)
            term_output = simcse_model(
                input_ids = term_pad,
                attention_mask = term_attention_mask
            ).pooler_output
            sen_output = simcse_model(
                input_ids = sentence_pad,
                attention_mask = sen_attention_mask
            ).pooler_output
            score = ((cos(claim_output,term_output)+cos(claim_output,sen_output)*1.3)/2).tolist()
            for i,t in enumerate(branch_1_term_label):
                page_score[t]=score[i]

        if branch_2_term_label!=[]:
            combine_pad=[]
            combine_max_len= max(map(len,branch_2_term_sentence))
            for i in range(len(branch_2_term_sentence)):
                if len(branch_2_term_sentence[i])<combine_max_len:
                    combine_pad.append(branch_2_term_sentence[i]+[simcse_tok.pad_token_id]*(combine_max_len-len(branch_2_term_sentence[i])))
                else:
                    combine_pad.append(branch_2_term_sentence[i])
            combine_pad = torch.tensor(combine_pad).to(device)
            combine_attention_mask = (combine_pad!=simcse_tok.pad_token_id).to(device)
            combine_output = simcse_model(
                input_ids = combine_pad,
                attention_mask = combine_attention_mask
            ).pooler_output
            score = cos(claim_output,combine_output).tolist()
            for i,t in  enumerate(branch_2_term_label):
                page_score[t]=score[i]

    results.append([key.replace(" ","_").replace("-", "") for key,_ in sorted(page_score.items(), key=lambda item: item[1],reverse=True)[:10]])
    # break    

3038it [45:57,  1.10it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 48.00 MiB (GPU 1; 10.91 GiB total capacity; 5.83 GiB already allocated; 18.06 MiB free; 8.80 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [23]:
# results[236]
train_df_t = pd.DataFrame(TRAIN_DATA)
train_df_t["predicted_pages"]=results
train_df_t[["id","label", "claim","evidence", "predicted_pages"]].to_json(
    "data/train_doc10_all_method.jsonl",
    orient="records",
    lines=True,
    force_ascii=False,
)

### Step 2. Calculate our results

In [27]:
train_df_t = pd.DataFrame(load_json('data/train_doc10_all_method.jsonl'))
train_df_t.head()

Unnamed: 0,id,label,claim,evidence,predicted_pages
0,2663,refutes,天衛三軌道在天王星內部的磁層，以《 仲夏夜之夢 》作者緹坦妮雅命名。,"[[[4209, 4331, 天衛三, 2]]]","[天衛三, 天王星, 磁層, 天衛二十三, 天衛四, 冥衛三, 軌道, 緹坦妮雅, 天衛二,..."
1,2399,refutes,信天翁科的活動範圍位於北冰洋以及南太平洋，牠的翼展可達到3.7米，是世界上現存的翼展最大的鳥類。,"[[[2719, 2928, 信天翁科, 2]]]","[漂泊信天翁, 北冰洋, 翼展, 信天翁科, 鳥, 大鰹鳥屬, 太平洋, 南太平洋, 牠, ..."
2,8075,NOT ENOUGH INFO,F.I.R. 的 團員有主唱Faye飛 （ 詹雯婷 ） 、 吉他手Real阿沁 （ 黃漢青 ...,"[[7208, None, None, None]]","[F.I.R.飛兒樂團, 樂團, 主唱, 飛_(消歧義), SpeXial, 阿沁, 閃靈樂..."
3,8931,NOT ENOUGH INFO,香港國際機場全年24小時運作，它從2001年起一直躋身世界最佳機場 ， 並8度獲評級爲全宇宙...,"[[8162, None, None, None]]","[香港國際機場, 香港航空業, 香港之最, 啓德機場, 國泰貨運, 香港國際機場旅客捷運系統..."
4,332,NOT ENOUGH INFO,北理工是歷史上最後一批副部級高校，黨委書記和校長列入中央管理的高校 ， 簡稱中管高校 ， 俗...,"[[204, None, None, None]]","[黨委書記和校長列入中央管理的高校, 中央部屬高校, 北京航空航天大學, 中央軍事委員會訓練..."


In [31]:
calculate_precision(TRAIN_DATA, train_df_t['predicted_pages'])
calculate_recall(TRAIN_DATA, train_df_t['predicted_pages'])

Precision: 0.10352898811668668
Recall: 0.9305945664786142


### Step 3. Repeat the same process on test set
Create parsing tree

In [27]:
TEST_DATA = load_json("data/private/public_private_combine_test_data.jsonl")
hanlp_test_file = f"data/hanlp_con_combine_test_results_imp.pkl"
if Path(hanlp_test_file).exists():
    with open(hanlp_test_file, "rb") as f:
        hanlp_results = pickle.load(f)
else:
    hanlp_results = [get_nps_hanlp(predictor, d) for d in TEST_DATA]
    with open(hanlp_test_file, "wb") as f:
        pickle.dump(hanlp_results, f)

Get pages via wiki online api

In [20]:
test_doc_path = f"data/combine_test_doc5_all_page_with_sentences.jsonl"
if Path(test_doc_path).exists():
    test_df = pd.DataFrame(load_json(test_doc_path))
else:
    test_df = pd.DataFrame(TEST_DATA)
    test_df.loc[:, "hanlp_results"] = hanlp_results
    test_results = test_df.parallel_apply(get_pred_pages, axis=1)
    predicted_results=[]
    predicted_sentence=[]
    for elmt in test_results:
        predicted_results.append(elmt[0])
        predicted_sentence.append(elmt[1])
    test_df.loc[:,"all_page"]=predicted_results
    test_df.loc[:,"all_sentences"]=predicted_sentence
    test_df[["id", "claim","hanlp_results", "all_page","all_sentences"]].to_json(
        "data/combine_test_doc5_all_page_with_sentences.jsonl",
        orient="records",
        lines=True,
        force_ascii=False,
    )

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=302), Label(value='0 / 302'))), HB…

In [4]:
cos = torch.nn.CosineSimilarity(dim=1)
simcse_model.eval()
results=[]
batch=16
for i, elmt in tqdm(test_df.iterrows()):

    page_score={}
    tmp=[]
    for title in elmt['all_page']:
        if title in elmt["claim"] or title.replace('·',"") in elmt["claim"]:
            tmp.append(title)
    claim_input=simcse_tok(CONVERTER_T2S.convert(elmt["claim"]),return_tensors="pt").to("cuda:0")
    claim_output = simcse_model(**claim_input).pooler_output
    for i in range(0,len(elmt['all_page']),batch):

        branch_1_term_label=[]
        branch_1_term=[]
        branch_1_sentence=[]

        branch_2_term_label=[]
        branch_2_term_sentence=[]
        for term,sentence in zip(elmt['all_page'][i:i+batch],elmt['all_sentences'][i:i+batch]):
            if term in tmp:
                branch_1_term_label.append(term)
                branch_1_term.append(simcse_tok(CONVERTER_T2S.convert(term),max_length=256,truncation=True)["input_ids"])
                branch_1_sentence.append(simcse_tok(CONVERTER_T2S.convert(sentence),max_length=256,truncation=True)["input_ids"])
            else:
                branch_2_term_label.append(term)
                branch_2_term_sentence.append(simcse_tok(CONVERTER_T2S.convert(term+" : "+sentence),max_length=256,truncation=True)["input_ids"])
        if branch_1_term_label!=[]:
            term_max_len= max(map(len,branch_1_term))
            sentences_max_len= max(map(len,branch_1_sentence))
            term_pad=[]
            sentence_pad=[]
            for i in range(len(branch_1_term)):
                if len(branch_1_term[i])<term_max_len:
                    term_pad.append(branch_1_term[i]+[simcse_tok.pad_token_id]*(term_max_len-len(branch_1_term[i])))
                else:
                    term_pad.append(branch_1_term[i])

                if len(branch_1_sentence[i])<sentences_max_len:
                    sentence_pad.append(branch_1_sentence[i]+[simcse_tok.pad_token_id]*(sentences_max_len-len(branch_1_sentence[i])))
                else:
                    sentence_pad.append(branch_1_sentence[i])
            term_pad = torch.tensor(term_pad).to("cuda:0")
            term_attention_mask = (term_pad!=simcse_tok.pad_token_id).to("cuda:0")

            sentence_pad = torch.tensor(sentence_pad).to("cuda:0")
            sen_attention_mask = (sentence_pad!=simcse_tok.pad_token_id).to("cuda:0")
            term_output = simcse_model(
                input_ids = term_pad,
                attention_mask = term_attention_mask
            ).pooler_output
            sen_output = simcse_model(
                input_ids = sentence_pad,
                attention_mask = sen_attention_mask
            ).pooler_output
            score = ((cos(claim_output,term_output)+cos(claim_output,sen_output)*1.3)/2).tolist()
            for i,t in enumerate(branch_1_term_label):
                page_score[t]=score[i]
        if branch_2_term_label!=[]:
            combine_pad=[]
            combine_max_len= max(map(len,branch_2_term_sentence))
            for i in range(len(branch_2_term_sentence)):
                if len(branch_2_term_sentence[i])<combine_max_len:
                    combine_pad.append(branch_2_term_sentence[i]+[simcse_tok.pad_token_id]*(combine_max_len-len(branch_2_term_sentence[i])))
                else:
                    combine_pad.append(branch_2_term_sentence[i])
            combine_pad = torch.tensor(combine_pad).to("cuda:0")
            combine_attention_mask = (combine_pad!=simcse_tok.pad_token_id).to("cuda:0")
            combine_output = simcse_model(
                input_ids = combine_pad,
                attention_mask = combine_attention_mask
            ).pooler_output
            score = cos(claim_output,combine_output).tolist()
            for i,t in  enumerate(branch_2_term_label):
                page_score[t]=score[i]
    
    results.append([key.replace(" ","_").replace("-", "") for key,val in sorted(page_score.items(), key=lambda item: item[1],reverse=True)[:10]])
    

4090it [1:05:21,  1.04it/s]


KeyboardInterrupt: 

In [8]:
TEST_DATA = load_json("data/private/public_private_combine_test_data.jsonl")
test_df_t = pd.DataFrame(TEST_DATA)
test_df_t.loc[:, "predicted_pages"] = results

In [7]:
test_df_t[["id", "claim","predicted_pages"]].to_json(
    "data/combine_test_doc10.jsonl",
    orient="records",
    lines=True,
    force_ascii=False,
)

notebook2
## PART 2. Sentence retrieval

Import some libs

In [None]:
# built-in libs
from pathlib import Path
from typing import Dict, List, Set, Tuple, Union
from dataclasses import dataclass
# third-party libs
import numpy as np
import pandas as pd
from pandarallel import pandarallel
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import (
    #AutoModelForSequenceClassification,
    AutoTokenizer,
    get_scheduler,
)
from modeling_bert import BertForSequenceClassification
from dataset import BERTDataset, Dataset

# local libs
from utils import (
    generate_evidence_to_wiki_pages_mapping,
    jsonl_dir_to_df,
    load_json,
    load_model,
    save_checkpoint,
    set_lr_scheduler,
)

pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=10)

In [3]:
@dataclass
class Claim:
    data: str

@dataclass
class AnnotationID:
    id: int

@dataclass
class EvidenceID:
    id: int

@dataclass
class PageTitle:
    title: str

@dataclass
class SentenceID:
    id: int

@dataclass
class Evidence:
    data: List[List[Tuple[AnnotationID, EvidenceID, PageTitle, SentenceID]]]

Global variable

In [4]:
SEED = 40

TRAIN_DATA = load_json("data/public_train_all.jsonl")
TEST_DATA = load_json("data/public_test.jsonl")
DOC_DATA = load_json("data/train_doc10_all_method.jsonl")

LABEL2ID: Dict[str, int] = {
    "supports": 0,
    "refutes": 1,
    "NOT ENOUGH INFO": 2,
}
ID2LABEL: Dict[int, str] = {v: k for k, v in LABEL2ID.items()}

_y = [LABEL2ID[data["label"]] for data in TRAIN_DATA]
# GT means Ground Truth
TRAIN_GT, DEV_GT = train_test_split(
    DOC_DATA,
    shuffle=True,
    test_size=0.2,
    random_state=SEED,
    stratify=_y,
)

Preload wiki database (1 min)

In [4]:
wiki_pages = jsonl_dir_to_df("data/wiki-pages")
mapping = generate_evidence_to_wiki_pages_mapping(wiki_pages)
del wiki_pages

Reading and concatenating jsonl files in data/wiki-pages
Generate parse mapping


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=118776), Label(value='0 / 118776')…

Transform to id to evidence_map mapping


### Helper function

Calculate precision for sentence retrieval

In [7]:
def evdi_calculate_precision(
    data: List[Dict[str, Union[int, Claim, Evidence]]],
    predictions: pd.Series,
) -> None:
    precision = 0
    count = 0

    for i, d in enumerate(data):
        if d["label"] == "NOT ENOUGH INFO":
            continue

        # Extract all ground truth of titles of the wikipedia pages
        # evidence[2] refers to the title of the wikipedia page
        gt_pages = set([
            (evidence[2],evidence[3])
            for evidence_set in d["evidence"]
            for evidence in evidence_set
        ])

        predicted_pages = set([(t[0],t[1]) for id,t in enumerate(predictions.iloc[i][0:2])])

        hits = predicted_pages.intersection(gt_pages)

        if len(predicted_pages) != 0:
            precision += len(hits) / len(predicted_pages)
        count += 1

    # Macro precision
    print(f"Precision: {precision / count}")


def evdi_calculate_recall(
    data: List[Dict[str, Union[int, Claim, Evidence]]],
    predictions: pd.Series,
) -> None:
    recall = 0
    count = 0

    for i, d in enumerate(data):
        if d["label"] == "NOT ENOUGH INFO":
            continue

        gt_pages = set([
            (evidence[2],evidence[3])
            for evidence_set in d["evidence"]
            for evidence in evidence_set
        ])
        predicted_pages = set([(t[0],t[1]) for id,t in enumerate(predictions.iloc[i][0:2])])

        hits = predicted_pages.intersection(gt_pages)

        recall += len(hits) / len(gt_pages)
        count += 1
    print(f"Recall: {recall / count}")
    return recall / count

In [7]:
def evidence_macro_precision(
    instance: Dict,
    top_rows: pd.DataFrame,
) -> Tuple[float, float]:
    """Calculate precision for sentence retrieval
    This function is modified from fever-scorer.
    https://github.com/sheffieldnlp/fever-scorer/blob/master/src/fever/scorer.py

    Args:
        instance (dict): a row of the dev set (dev.jsonl) of test set (test.jsonl)
        top_rows (pd.DataFrame): our predictions with the top probabilities

        IMPORTANT!!!
        instance (dict) should have the key of `evidence`.
        top_rows (pd.DataFrame) should have a column `predicted_evidence`.

    Returns:
        Tuple[float, float]:
        [1]: relevant and retrieved (numerator of precision)
        [2]: retrieved (denominator of precision)
    """
    this_precision = 0.0
    this_precision_hits = 0.0
    # print(top_rows)
    # Return 0, 0 if label is not enough info since not enough info does not
    # contain any evidence.
    if instance["label"].upper() != "NOT ENOUGH INFO":
        # e[2] is the page title, e[3] is the sentence index
        all_evi = [[e[2], e[3]]
                   for eg in instance["evidence"]
                   for e in eg
                   if e[3] is not None]
        claim = instance["claim"]
        predicted_evidence = top_rows[top_rows["claim"] ==
                                      claim]["predicted_evidence"].tolist()

        for prediction in predicted_evidence:
            if prediction in all_evi:
                this_precision += 1.0
            this_precision_hits += 1.0

        return (this_precision /
                this_precision_hits) if this_precision_hits > 0 else 1.0, 1.0

    return 0.0, 0.0

Calculate recall for sentence retrieval

In [8]:
def evidence_macro_recall(
    instance: Dict,
    top_rows: pd.DataFrame,
) -> Tuple[float, float]:
    """Calculate recall for sentence retrieval
    This function is modified from fever-scorer.
    https://github.com/sheffieldnlp/fever-scorer/blob/master/src/fever/scorer.py

    Args:
        instance (dict): a row of the dev set (dev.jsonl) of test set (test.jsonl)
        top_rows (pd.DataFrame): our predictions with the top probabilities

        IMPORTANT!!!
        instance (dict) should have the key of `evidence`.
        top_rows (pd.DataFrame) should have a column `predicted_evidence`.

    Returns:
        Tuple[float, float]:
        [1]: relevant and retrieved (numerator of recall)
        [2]: relevant (denominator of recall)
    """
    # We only want to score F1/Precision/Recall of recalled evidence for NEI claims
    if instance["label"].upper() != "NOT ENOUGH INFO":
        # If there's no evidence to predict, return 1
        if len(instance["evidence"]) == 0 or all(
            [len(eg) == 0 for eg in instance]):
            return 1.0, 1.0

        claim = instance["claim"]

        predicted_evidence = top_rows[top_rows["claim"] ==
                                      claim]["predicted_evidence"].tolist()

        for evidence_group in instance["evidence"]:
            evidence = [[e[2], e[3]] for e in evidence_group]
            if all([item in predicted_evidence for item in evidence]):
                # We only want to score complete groups of evidence. Incomplete
                # groups are worthless.
                return 1.0, 1.0
        return 0.0, 1.0
    return 0.0, 0.0

Calculate the scores of sentence retrieval

In [9]:
def evaluate_retrieval(
    probs: np.ndarray,
    df_evidences: pd.DataFrame,
    ground_truths: pd.DataFrame,
    top_n: int = 5,
    cal_scores: bool = True,
    save_name: str = None,
) -> Dict[str, float]:
    """Calculate the scores of sentence retrieval

    Args:
        probs (np.ndarray): probabilities of the candidate retrieved sentences
        df_evidences (pd.DataFrame): the candiate evidence sentences paired with claims
        ground_truths (pd.DataFrame): the loaded data of dev.jsonl or test.jsonl
        top_n (int, optional): the number of the retrieved sentences. Defaults to 2.

    Returns:
        Dict[str, float]: F1 score, precision, and recall
    """
    df_evidences["prob"] = probs
    top_rows = (
        df_evidences.groupby("claim").apply(
        lambda x: x.nlargest(top_n, "prob"))
        .reset_index(drop=True)
    )

    if cal_scores:
        macro_precision = 0
        macro_precision_hits = 0
        macro_recall = 0
        macro_recall_hits = 0

        for i, instance in enumerate(ground_truths):
            macro_prec = evidence_macro_precision(instance, top_rows)
            macro_precision += macro_prec[0]
            macro_precision_hits += macro_prec[1]

            macro_rec = evidence_macro_recall(instance, top_rows)
            macro_recall += macro_rec[0]
            macro_recall_hits += macro_rec[1]
            # break
        pr = (macro_precision /
              macro_precision_hits) if macro_precision_hits > 0 else 1.0
        rec = (macro_recall /
               macro_recall_hits) if macro_recall_hits > 0 else 0.0
        f1 = 2.0 * pr * rec / (pr + rec)
    import json
    if save_name is not None:
        # write doc7_sent5 file
        with open(f"data/{save_name}", "w") as f:
            for instance in ground_truths:
                claim = instance["claim"]
                predicted_evidence = top_rows[
                    top_rows["claim"] == claim
                ]["predicted_evidence"].tolist()
                
                instance["predicted_evidence"] = predicted_evidence
                f.write(json.dumps(instance, ensure_ascii=False) + "\n")

    if cal_scores:
        return {"F1 score": f1, "Precision": pr, "Recall": rec}

Inference script to get probabilites for the candidate evidence sentences

In [10]:
def get_predicted_probs(
    model: nn.Module,
    dataloader: Dataset,
    device: torch.device,
) -> np.ndarray:
    """Inference script to get probabilites for the candidate evidence sentences

    Args:
        model: the one from HuggingFace Transformers
        dataloader: devset or testset in torch dataloader

    Returns:
        np.ndarray: probabilites of the candidate evidence sentences
    """
    model.eval()
    probs = []

    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            logits = outputs.logits
            probs.extend(torch.softmax(logits, dim=1)[:, 1].tolist())
            # probs.extend(torch.argmax(torch.softmax(logits, dim=1),dim=1).tolist())
            # break
    return np.array(probs)

AicupTopkEvidenceBERTDataset class for AICUP dataset with top-k evidence sentences

In [11]:
class SentRetrievalBERTDataset(BERTDataset):
    """AicupTopkEvidenceBERTDataset class for AICUP dataset with top-k evidence sentences."""

    def __getitem__(
        self,
        idx: int,
        **kwargs,
    ) -> Tuple[Dict[str, torch.Tensor], int]:
        item = self.data.iloc[idx]
        sentA = item["claim"]
        sentB = item["text"]
        # print(self.tokenizer(sentA))
        # print(sentB)
        # claim [SEP] text
        concat = self.tokenizer(
            sentA,
            sentB,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
        )
        # print(concat.items())
        concat_ten = {k: torch.tensor(v) for k, v in concat.items()}
        # print("label" in item)
        if "label" in item:
            concat_ten["labels"] = torch.tensor(item["label"])

        return concat_ten

### Main function for sentence retrieval

In [11]:
def pair_with_wiki_sentences(
    mapping: Dict[str, Dict[int, str]],
    df: pd.DataFrame,
    negative_ratio: float,
) -> pd.DataFrame:
    """Only for creating train sentences."""
    claims = []
    sentences = []
    labels = []

    # positive
    for i in range(len(df)):
        # if df["label"].iloc[i] == "NOT ENOUGH INFO":
        #     continue

        claim = df["claim"].iloc[i]
        evidence_sets = df["evidence"].iloc[i]
        for evidence_set in evidence_sets:
            sents = []
            for evidence in evidence_set:
                # evidence[2] is the page title
                # print(evidence)
                if not isinstance(evidence, list):
                    continue
                # print(evidence)
                page = evidence[2].replace(" ", "_")
                # the only page with weird name
                if page == "臺灣海峽危機#第二次臺灣海峽危機（1958）":
                    continue
                # evidence[3] is in form of int however, mapping requires str
                sent_idx = str(evidence[3])
                if len(evidence_set)!=1:
                    claims.append(claim)
                    sentences.append(mapping[page][sent_idx])
                    labels.append(1)

                sents.append(mapping[page][sent_idx])

            whole_evidence = " ".join(sents)

            claims.append(claim)
            sentences.append(whole_evidence)
            labels.append(1)



    # negative
    for i in range(len(df)):
        # if df["label"].iloc[i] == "NOT ENOUGH INFO":
        #     continue
        claim = df["claim"].iloc[i]

        evidence_set = set([(evidence[2], str(evidence[3]))
                            for evidences in df["evidence"][i]
                            for evidence in evidences if isinstance(evidence, list)])
        predicted_pages = df["predicted_pages"][i]
        # predicted_pages = predicted_pages + evidence_set
        for page in predicted_pages:
            page = page.replace(" ", "_")
            try:
                page_sent_id_pairs = [
                    (page, sent_idx) for sent_idx in mapping[page].keys()
                ]
            except KeyError:
                # print(f"{page} is not in our Wiki db.")
                continue

            for pair in page_sent_id_pairs:
                text = mapping[page][pair[1]]
                if pair in evidence_set:
                    continue
                # `np.random.rand(1) <= 0.05`: Control not to add too many negative samples
                if text != "" and ('|' not in text or '=' not in text):
                    if np.random.rand(1) <= negative_ratio:
                        claims.append(claim)
                        sentences.append(text)
                        labels.append(0)


    return pd.DataFrame({"claim": claims, "text": sentences, "label": labels})


def pair_with_wiki_sentences_eval(
    mapping: Dict[str, Dict[int, str]],
    df: pd.DataFrame,
    is_testset: bool = False,
) -> pd.DataFrame:
    """Only for creating dev and test sentences."""
    claims = []
    sentences = []
    evidence = []
    predicted_evidence = []

    # negative
    for i in range(len(df)):
        # if df["label"].iloc[i] == "NOT ENOUGH INFO":
            # continue
        claim = df["claim"].iloc[i]

        predicted_pages = df["predicted_pages"][i]
        for page in predicted_pages:
            page = page.replace(" ", "_")
            try:
                page_sent_id_pairs = [(page, k) for k in mapping[page]]
            except KeyError:
                # print(f"{page} is not in our Wiki db.")
                continue

            for page_name, sentence_id in page_sent_id_pairs:
                text = mapping[page][sentence_id]
                if text != "":
                    claims.append(claim)
                    sentences.append(text)
                    if not is_testset:
                        evidence.append(df["evidence"].iloc[i])
                    predicted_evidence.append([page_name, int(sentence_id)])

    return pd.DataFrame({
        "claim": claims,
        "text": sentences,
        "evidence": evidence if not is_testset else None,
        "predicted_evidence": predicted_evidence,
    })

### Step 1. Setup training environment

Hyperparams

In [16]:
#@title  { display-mode: "form" }

MODEL_NAME = "hfl/chinese-bert-wwm-ext"  #@param {type:"string"}
NUM_EPOCHS = 10  #@param {type:"integer"}
LR = 2e-5  #@param {type:"number"}
TRAIN_BATCH_SIZE = 16  #@param {type:"integer"}
TEST_BATCH_SIZE = 64  #@param {type:"integer"}
NEGATIVE_RATIO = 0.085  #@param {type:"number"}
VALIDATION_STEP = 300  #@param {type:"integer"}
TOP_N = 5  #@param {type:"integer"}

Experiment Directory

In [17]:
EXP_DIR = f"sent_retrieval/e{NUM_EPOCHS}_bs{TRAIN_BATCH_SIZE}_" + f"{LR}_neg{NEGATIVE_RATIO}_top{TOP_N}_all_ext"
LOG_DIR = "logs/" + EXP_DIR
CKPT_DIR = "checkpoints/" + EXP_DIR

if not Path(LOG_DIR).exists():
    Path(LOG_DIR).mkdir(parents=True)

if not Path(CKPT_DIR).exists():
    Path(CKPT_DIR).mkdir(parents=True)

### Step 2. Combine claims and evidences

In [14]:
train_df = pair_with_wiki_sentences(
    mapping,
    pd.DataFrame(TRAIN_GT),
    NEGATIVE_RATIO,
)
counts = train_df["label"].value_counts()
print("Now using the following train data with 0 (Negative) and 1 (Positive)")
print(counts)

dev_evidences = pair_with_wiki_sentences_eval(mapping, pd.DataFrame(DEV_GT))

Now using the following train data with 0 (Negative) and 1 (Positive)
0    45362
1    15619
Name: label, dtype: int64


### Step 3. Start training

Dataloader things

In [15]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_dataset = SentRetrievalBERTDataset(train_df, tokenizer=tokenizer)
val_dataset = SentRetrievalBERTDataset(dev_evidences, tokenizer=tokenizer)

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=TRAIN_BATCH_SIZE,
)
eval_dataloader = DataLoader(val_dataset, batch_size=TEST_BATCH_SIZE)

Save your memory.

Trainer

In [16]:
device = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")
# device = torch.device("cpu")
# model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)

model = BertForSequenceClassification.from_pretrained(MODEL_NAME)
model.load_state_dict(torch.load("checkpoints/sent_retrieval/e10_bs32_2e-05_neg0.085_top5_all_ext/model.1899.pt"))
model.to(device)

optimizer = AdamW(model.parameters(), lr=LR)
num_training_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = set_lr_scheduler(optimizer, num_training_steps)

writer = SummaryWriter(LOG_DIR)

Some weights of the model checkpoint at yechen/bert-large-chinese were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not 

Please make sure that you are using gpu when training (5 min)

In [18]:
progress_bar = tqdm(range(num_training_steps))
current_steps = 0
step_cnt=0
max=0
for epoch in range(10):
    model.train()
    
    for i,batch in enumerate(train_dataloader):
        
        batch = {k: v.to(device) for k, v in batch.items()}
        # print(batch)
        # break
        outputs = model(**batch)
        loss = outputs.loss/4
        loss.backward()
        step_cnt+=1
        if step_cnt==4 or i==len(train_dataloader)-1:
            step_cnt=0
            optimizer.step()
            optimizer.zero_grad()
            lr_scheduler.step()

        progress_bar.update(1)
        progress_bar.set_postfix(loss=loss.item()*4)
        writer.add_scalar("training_loss", loss.item(), current_steps)

        y_pred = torch.argmax(outputs.logits, dim=1).tolist()
        y_true = batch["labels"].tolist()

        current_steps += 1

        if i==len(train_dataloader)-1:#
            print("Start validation")
            probs = get_predicted_probs(model, eval_dataloader, device)

            val_results = evaluate_retrieval(
                probs=probs,
                df_evidences=dev_evidences,
                ground_truths=DEV_GT,
                top_n=TOP_N,
            )
            print(val_results)
            if val_results['F1 score']>max:
                
                val_results = evaluate_retrieval(
                    probs=probs,
                    df_evidences=dev_evidences,
                    ground_truths=DEV_GT,
                    top_n=TOP_N,
                    save_name=f"dev_doc10sent{TOP_N}_large.jsonl",
                )
                max = val_results['F1 score']
            # log each metric separately to TensorBoard
            for metric_name, metric_value in val_results.items():
                writer.add_scalar(
                    f"dev_{metric_name}",
                    metric_value,
                    current_steps,
                )

            save_checkpoint(
                model,val_results, CKPT_DIR, current_steps
                mark=f"val_F1={val_results['F1 score']:.4f}",
            )

print("Finished training!")

  0%|          | 0/38120 [00:00<?, ?it/s]

Start validation


  0%|          | 0/2534 [00:00<?, ?it/s]

{'F1 score': 0.4101963121667645, 'Precision': 0.2748650269945951, 'Recall': 0.8080383923215357}
Start validation


  0%|          | 0/2534 [00:00<?, ?it/s]

{'F1 score': 0.41383052983595964, 'Precision': 0.2775044991001741, 'Recall': 0.8134373125374925}
Start validation


  0%|          | 0/2534 [00:00<?, ?it/s]

{'F1 score': 0.4067793865298869, 'Precision': 0.27354529094180596, 'Recall': 0.7930413917216557}
Start validation


  0%|          | 0/2534 [00:00<?, ?it/s]

{'F1 score': 0.4061312905655683, 'Precision': 0.27246550689861454, 'Recall': 0.7972405518896221}
Start validation


  0%|          | 0/2534 [00:00<?, ?it/s]

{'F1 score': 0.40796349284318956, 'Precision': 0.27390521895620307, 'Recall': 0.7990401919616077}
Start validation


  0%|          | 0/2534 [00:00<?, ?it/s]

{'F1 score': 0.40286000902406266, 'Precision': 0.27078584283142804, 'Recall': 0.7864427114577085}
Start validation


  0%|          | 0/2534 [00:00<?, ?it/s]

{'F1 score': 0.38729683040664004, 'Precision': 0.26070785842830896, 'Recall': 0.7528494301139772}
Start validation


  0%|          | 0/2534 [00:00<?, ?it/s]

{'F1 score': 0.3983884680702607, 'Precision': 0.2680263947210502, 'Recall': 0.7756448710257948}
Start validation


  0%|          | 0/2534 [00:00<?, ?it/s]

{'F1 score': 0.3982023177069879, 'Precision': 0.2677864427114521, 'Recall': 0.7762447510497901}
Start validation


  0%|          | 0/2534 [00:00<?, ?it/s]

{'F1 score': 0.4037980952051185, 'Precision': 0.270785842831428, 'Recall': 0.7936412717456509}
Finished training!


In [19]:
print("Start validation")
probs = get_predicted_probs(model, eval_dataloader, device)

val_results = evaluate_retrieval(
    probs=probs,
    df_evidences=dev_evidences,
    ground_truths=DEV_GT,
    top_n=TOP_N,
)
print(val_results)
if val_results['F1 score']>max:
    
    val_results = evaluate_retrieval(
        probs=probs,
        df_evidences=dev_evidences,
        ground_truths=DEV_GT,
        top_n=TOP_N,
        save_name=f"dev_doc10sent{TOP_N}_large.jsonl",
    )
    max = val_results['F1 score']
# log each metric separately to TensorBoard
for metric_name, metric_value in val_results.items():
    writer.add_scalar(
        f"dev_{metric_name}",
        metric_value,
        current_steps,
    )


Start validation


  0%|          | 0/2534 [00:00<?, ?it/s]

{'F1 score': 0.40026761344894174, 'Precision': 0.26922615476904066, 'Recall': 0.7798440311937612}


In [None]:
%load_ext tensorboard
%tensorboard --logdir logs

Validation part (15 mins)

In [20]:
# ckpt_name = "model.5409.pt"  #@param {type:"string"}
# model = load_model(model, ckpt_name, CKPT_DIR)
model.load_state_dict(torch.load("checkpoints/sent_retrieval/e10_bs32_2e-05_neg0.085_top5_all_ext/model.1899.pt", map_location=device))

print("Start final evaluations and write prediction files.")

train_evidences = pair_with_wiki_sentences_eval(
    mapping=mapping,
    df=pd.DataFrame(TRAIN_GT),
)
train_set = SentRetrievalBERTDataset(train_evidences, tokenizer)
train_dataloader = DataLoader(train_set, batch_size=TEST_BATCH_SIZE)

print("Start calculating training scores")
probs = get_predicted_probs(model, train_dataloader, device)
train_results = evaluate_retrieval(
    probs=probs,
    df_evidences=train_evidences,
    ground_truths=TRAIN_GT,
    top_n=TOP_N,
    save_name=f"train_doc10sent{TOP_N}_all_ext.jsonl",
)
print(f"Training scores => {train_results}")

# print("Start validation")
probs = get_predicted_probs(model, eval_dataloader, device)
val_results = evaluate_retrieval(
    probs=probs,
    df_evidences=dev_evidences,
    ground_truths=DEV_GT,
    top_n=TOP_N,
    save_name=f"dev_doc10sent{TOP_N}_all_ext.jsonl",
)

print(f"Validation scores => {val_results}")

Start final evaluations and write prediction files.
Start calculating training scores


  0%|          | 0/10173 [00:00<?, ?it/s]

Training scores => {'F1 score': 0.43625463261664854, 'Precision': 0.29348739495800735, 'Recall': 0.8494897959183674}


  0%|          | 0/2534 [00:00<?, ?it/s]

Validation scores => {'F1 score': 0.42019263533901646, 'Precision': 0.2813437312537435, 'Recall': 0.829634073185363}


In [22]:
top_rows = pd.DataFrame(load_json("data/train_doc10sent5_all_ext.jsonl"))

evdi_calculate_precision(TRAIN_GT, top_rows['predicted_evidence'])
evdi_calculate_recall(TRAIN_GT, top_rows['predicted_evidence'])

Precision: 0.5774309723889556
Recall: 0.7736191798147838


In [5]:
from transformers import AutoTokenizer,AutoModel
import torch
from torch import optim
model=AutoModel.from_pretrained("IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese")
simcse_tok=AutoTokenizer.from_pretrained("IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese")

TRAIN_GT = load_json("data/train_doc10sent5_all_ext.jsonl")
DEV_GT = load_json("data/dev_doc10sent5_all_ext.jsonl")
train_df = pd.DataFrame(TRAIN_GT)
dev_df = pd.DataFrame(DEV_GT)

Some weights of the model checkpoint at IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese were not used when initializing BertModel: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
class ContrastiveModel(nn.Module):
    def __init__(self,model=None):
        super().__init__()
        self.model = model

        self.transition_energy_net = nn.Linear(768, 768,bias=True)

    def forward(self, input_ids,attention_mask):
        output = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        energy = self.transition_energy_net(output['pooler_output'])

        return energy

simcse_model = ContrastiveModel(model)
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
simcse_model.to(device)
NUM_EPOCHS=20
optimizer = optim.AdamW(simcse_model.parameters(), lr=3e-5)
num_training_steps = NUM_EPOCHS * len(TRAIN_GT)
lr_scheduler = set_lr_scheduler(optimizer, num_training_steps)

myModel(
  (model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
   

In [14]:
from tqdm.auto import tqdm
import math
tqdm.pandas()
import opencc
CONVERTER_T2S = opencc.OpenCC("t2s.json")
CONVERTER_S2T = opencc.OpenCC("s2t.json")
cos = torch.nn.CosineSimilarity(dim=1)
result=[]
step_cnt=0
progress_bar = tqdm(train_df.iterrows())
max_recall=0
current_steps=0
for epoch in range(NUM_EPOCHS):
    simcse_model.train()
    optimizer.zero_grad()
    for _,row in train_df.iterrows():
        # while 1:
        page_score={}
        if row["label"] == "NOT ENOUGH INFO":
            continue

        gt_pages = [
            [evidence[2],evidence[3]]
            for evidence_set in row["evidence"]
            for evidence in evidence_set
        ]
        claim = row["claim"]
        claim_tok = simcse_tok(claim,return_tensors="pt").to(device)
        claim_output = simcse_model(
            input_ids = claim_tok["input_ids"],
            attention_mask = claim_tok["attention_mask"]
        )
        evd_list = row["predicted_evidence"]
        evd_content = []
        simple_term_sentence=[]
        pos_id=[]
        neg_id=[]
        for i,(evd,sent_id) in enumerate(evd_list):
            if [evd,sent_id] in gt_pages:
                pos_id.append(i)
            else:
                neg_id.append(i)
            simple_term_sentence.append(simcse_tok(evd+" : "+mapping[evd][str(sent_id)])["input_ids"])
        if pos_id==[] or neg_id == []:
            continue
        combine_pad=[]
        combine_max_len= max(map(len,simple_term_sentence))
        for i in range(len(simple_term_sentence)):
            if len(simple_term_sentence[i])<combine_max_len:
                combine_pad.append(simple_term_sentence[i]+[simcse_tok.pad_token_id]*(combine_max_len-len(simple_term_sentence[i])))
            else:
                combine_pad.append(simple_term_sentence[i])
        combine_pad = torch.tensor(combine_pad).to(device)
        combine_attention_mask = (combine_pad!=simcse_tok.pad_token_id).to(device)
        simple_output = simcse_model(
            input_ids = combine_pad,
            attention_mask = combine_attention_mask
        )

        # cos_sim = torch.matmul(claim_output,simple_output.T).squeeze(0)
        cos_sim = cos(claim_output,simple_output)

        pos_id = torch.tensor(pos_id)
        neg_id = torch.tensor(neg_id)
        pos_energy = cos_sim[pos_id]
        neg_energy = cos_sim[neg_id]
        rank_loss = torch.tensor([0.0], dtype=torch.float32,requires_grad=True,device=pos_energy.device)

        for i in pos_energy:
            neg_energy_seg = neg_energy[torch.where(neg_energy>i,True,False)]
            if neg_energy_seg.size(0)==0:
                continue
            pos_energy_seg = i.expand_as(neg_energy_seg)

            y = torch.ones_like(neg_energy_seg)  # 正樣本的標籤都是1
            rank_loss_func = torch.nn.MarginRankingLoss(
                math.exp(1/pos_id.size(0))
            )
            # print(neg_energy[:5])
            rank_loss=rank_loss+rank_loss_func(pos_energy_seg, neg_energy_seg, y)
        rank_loss = rank_loss/pos_id.size(0)
        rank_loss = rank_loss/4
        rank_loss.backward()
        step_cnt+=1
        if step_cnt==4:
            step_cnt=0
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
        progress_bar.update(1)
        progress_bar.set_postfix(loss=rank_loss.item()*4)
        current_steps+=1
    # break
    simcse_model.eval()
    result=[]
    for _,row in tqdm(dev_df.iterrows()):
    
        page_score={}

        claim = row["claim"]
        claim_tok = simcse_tok(claim,return_tensors="pt").to(device)
        claim_output = simcse_model(
            input_ids = claim_tok["input_ids"],
            attention_mask = claim_tok["attention_mask"]
        )
        evd_list = row["predicted_evidence"]
        evd_content = []

        for evd,sent_id in evd_list:
            simple_tok = simcse_tok(evd+" : "+mapping[evd][str(sent_id)],return_tensors="pt").to(device)
            simple_output = simcse_model(
                input_ids = simple_tok["input_ids"],
                attention_mask = simple_tok["attention_mask"]
            )
            page_score[evd+" "+str(sent_id)] = cos(claim_output,simple_output).item()
            # page_score[evd+" "+str(sent_id)] = torch.matmul(claim_output,simple_output.T).squeeze(0)
        result.append([[key.split()[0],int(key.split()[1])] for key,_ in sorted(page_score.items(), key=lambda item: item[1],reverse=True)])
    rerank_dev_df= dev_df.copy()
    rerank_dev_df = rerank_dev_df.drop(["predicted_evidence"],axis=1)
    rerank_dev_df.loc[:,"predicted_evidence"]=result
    evdi_calculate_precision(DEV_GT, rerank_dev_df['predicted_evidence'])
    recall = evdi_calculate_recall(DEV_GT, rerank_dev_df['predicted_evidence'])
    if recall>max_recall:
        max_recall=recall
        save_checkpoint(simcse_model, "checkpoints/rank_sent_retrieval", current_steps,mark=f"val_recall={recall:.4f}")


0it [00:00, ?it/s]

0it [00:00, ?it/s]

Precision: 0.5512445095168375
Recall: 0.7443142996583701


0it [00:00, ?it/s]

Precision: 0.5505124450951684
Recall: 0.7438262567105908


In [37]:
from tqdm import tqdm
import opencc
CONVERTER_T2S = opencc.OpenCC("t2s.json")
CONVERTER_S2T = opencc.OpenCC("s2t.json")
cos = torch.nn.CosineSimilarity(dim=1)
result=[]
simcse_model.eval()
for _,row in tqdm(train_df.iterrows()):

    page_score={}

    claim = row["claim"]
    claim_tok = simcse_tok(claim,return_tensors="pt").to(device)
    claim_output = simcse_model(
        input_ids = claim_tok["input_ids"],
        attention_mask = claim_tok["attention_mask"]
    )
    evd_list = row["predicted_evidence"]
    evd_content = []

    for evd,sent_id in evd_list:
        simple_tok = simcse_tok(evd+" : "+mapping[evd][str(sent_id)],return_tensors="pt").to(device)
        simple_output = simcse_model(
            input_ids = simple_tok["input_ids"],
            attention_mask = simple_tok["attention_mask"]
        )
        page_score[evd+" "+str(sent_id)] = cos(claim_output,simple_output).item()

    result.append([[key.split()[0],int(key.split()[1])] for key,_ in sorted(page_score.items(), key=lambda item: item[1],reverse=True)])
rerank_train_top_rows= train_df.copy()
rerank_train_top_rows = rerank_train_top_rows.drop(["predicted_evidence"],axis=1)
rerank_train_top_rows.loc[:,"predicted_evidence"]=result


2it [00:00, 12.37it/s]

[[[6324, 6133, '北平_(消歧義)', 2], [6324, 6133, '北平_(五代)', 0], [6324, 6133, '北平_(五代)', 1], [6324, 6133, '北平_(五代)', 2], [6324, 6133, '北平_(五代)', 4], [6324, 6133, '北平_(五代)', 5]]]
[[[11832, 10509, '抄襲', 3]]]
[[[3346, 3510, '上座部佛教', 8], [3346, 3510, '根本分裂', 0]]]


4it [00:00, 13.13it/s]

[[[9745, 8882, '李天命', 1]]]
[[[13362, 11516, '陰陽', 0], [13362, 11516, '陰陽', 1], [13362, 11516, '陰陽', 4]]]
[[[9184, 8408, '斯登衝鋒槍', 0], [9184, 8408, '斯登衝鋒槍', 1]]]


6it [00:00, 12.90it/s]

[[12496, None, None, None]]





TypeError: 'int' object is not subscriptable

In [None]:
from tqdm import tqdm
import opencc
CONVERTER_T2S = opencc.OpenCC("t2s.json")
CONVERTER_S2T = opencc.OpenCC("s2t.json")
cos = torch.nn.CosineSimilarity(dim=1)
result=[]
simcse_model.eval()
for _,row in tqdm(dev_df.iterrows()):

    page_score={}

    claim = row["claim"]
    claim_tok = simcse_tok(claim,return_tensors="pt").to(device)
    claim_output = simcse_model(
        input_ids = claim_tok["input_ids"],
        attention_mask = claim_tok["attention_mask"]
    )
    evd_list = row["predicted_evidence"]
    evd_content = []

    for evd,sent_id in evd_list:
        simple_tok = simcse_tok(evd+" : "+mapping[evd][str(sent_id)],return_tensors="pt").to(device)
        simple_output = simcse_model(
            input_ids = simple_tok["input_ids"],
            attention_mask = simple_tok["attention_mask"]
        )
        page_score[evd+" "+str(sent_id)] = cos(claim_output,simple_output).item()

    result.append([[key.split()[0],int(key.split()[1])] for key,_ in sorted(page_score.items(), key=lambda item: item[1],reverse=True)])
rerank_dev_top_rows = dev_df.copy()
rerank_dev_top_rows = rerank_dev_top_rows.drop(["predicted_evidence"],axis=1)
rerank_dev_top_rows.loc[:,"predicted_evidence"]=result


In [29]:
evdi_calculate_precision(TRAIN_GT, rerank_train_top_rows['predicted_evidence'])
evdi_calculate_recall(TRAIN_GT, rerank_train_top_rows['predicted_evidence'])
print("=====")
evdi_calculate_precision(DEV_GT, rerank_dev_top_rows['predicted_evidence'])
evdi_calculate_recall(DEV_GT, rerank_dev_top_rows['predicted_evidence'])

### Step 4. Check on our test data
(5 min)

In [16]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
device = torch.device("cuda:2") if torch.cuda.is_available() else torch.device("cpu")
model = BertForSequenceClassification.from_pretrained(MODEL_NAME)
model.load_state_dict(torch.load("checkpoints/sent_retrieval/e10_bs32_2e-05_neg0.085_top5_all_ext/model.1899.pt"))
model.to(device)
test_data = load_json("data/combine_test_doc10_all.jsonl")

test_evidences = pair_with_wiki_sentences_eval(
    mapping,
    pd.DataFrame(test_data),
    is_testset=True,
)
test_set = SentRetrievalBERTDataset(test_evidences, tokenizer)
test_dataloader = DataLoader(test_set, batch_size=256)

print("Start predicting the test data")
probs = get_predicted_probs(model, test_dataloader, device)
evaluate_retrieval(
    probs=probs,
    df_evidences=test_evidences,
    ground_truths=test_data,
    top_n=TOP_N,
    cal_scores=False,
    save_name=f"combine_test_doc10sent{TOP_N}_ext.jsonl",
)

Some weights of the model checkpoint at hfl/chinese-bert-wwm-ext were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkp

Start predicting the test data


  0%|          | 0/9135 [00:00<?, ?it/s]

In [10]:
from tqdm import tqdm
import opencc
test_df = pd.DataFrame(load_json("data/combine_test_doc10sent5_ext.jsonl"))
CONVERTER_T2S = opencc.OpenCC("t2s.json")
CONVERTER_S2T = opencc.OpenCC("s2t.json")
cos = torch.nn.CosineSimilarity(dim=1)
result=[]
simcse_model.eval()
for _,row in tqdm(test_df.iterrows()):

    page_score={}

    claim = row["claim"]
    claim_tok = simcse_tok(claim,return_tensors="pt").to(device)
    claim_output = simcse_model(
        input_ids = claim_tok["input_ids"],
        attention_mask = claim_tok["attention_mask"]
    )
    evd_list = row["predicted_evidence"]
    evd_content = []

    for evd,sent_id in evd_list:
        simple_tok = simcse_tok(evd+" : "+mapping[evd][str(sent_id)],return_tensors="pt").to(device)
        simple_output = simcse_model(
            input_ids = simple_tok["input_ids"],
            attention_mask = simple_tok["attention_mask"]
        )
        page_score[evd+" "+str(sent_id)] = cos(claim_output,simple_output).item()

    result.append([[key.split()[0],int(key.split()[1])] for key,_ in sorted(page_score.items(), key=lambda item: item[1],reverse=True)])
rerank_test_top_rows= test_df.copy()
rerank_test_top_rows = rerank_test_top_rows.drop(["predicted_evidence"],axis=1)
rerank_test_top_rows.loc[:,"predicted_evidence"]=result


9038it [08:44, 17.24it/s]


In [12]:
# test reranking results
rerank_test_top_rows[["id", "claim","predicted_pages","predicted_evidence"]].to_json(
    "data/combine_test_doc10sent5_reranlk_ext.jsonl",
    orient="records",
    lines=True,
    force_ascii=False,
)


notebook3
## PART 3. Claim verification

import libs

In [1]:
import pickle
from pathlib import Path
from typing import Dict, Tuple

import numpy as np
import pandas as pd
from pandarallel import pandarallel
from tqdm.auto import tqdm
import json
import torch
from sklearn.metrics import accuracy_score
from torch.optim import AdamW,Adam
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import (
    #AutoModelForSequenceClassification,
    AutoTokenizer,
    get_scheduler,
)
from modeling_bert import BertForSequenceClassification,BertForMaskedLM,BertForNextSentencePrediction
from dataset import BERTDataset
from utils import (
    generate_evidence_to_wiki_pages_mapping,
    jsonl_dir_to_df,
    load_json,
    load_model,
    save_checkpoint,
    set_lr_scheduler,
)
import opencc
CONVERTER_T2S = opencc.OpenCC("t2s.json")
CONVERTER_S2T = opencc.OpenCC("s2t.json")

pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=20)

Global variables

In [27]:
LABEL2ID: Dict[str, int] = {
    "supports": 0,
    "refutes": 1,
    "NOT ENOUGH INFO": 2,
}
LABEL2TEXT: Dict[str, int] = {
    "supports": "對",
    "refutes": "錯",
    "NOT ENOUGH INFO": "無",
}
TEXT2LABEL: Dict[str, int] = {
    "對": "supports",
    "錯": "refutes",
    "無": "NOT ENOUGH INFO",
}
ID2LABEL: Dict[str, int] = {
    "0": "supports",
    "1": "refutes",
    "2": "NOT ENOUGH INFO",
}
NUMBER2CHINESE: Dict[str, int] = {
    0: "零",
    1: "一",
    2: "二",
    3: "三",
    4: "四",
    5: "五"
}
ID2LABEL: Dict[int, str] = {v: k for k, v in LABEL2ID.items()}

TRAIN_DATA = load_json("data/train_doc10sent5_all_rerank_ext.jsonl")
DEV_DATA = load_json("data/dev_doc10sent5_all_rerank_ext.jsonl")

TRAIN_PKL_FILE = Path("data/train_doc10sent5_all_rerank_ext.pkl")
DEV_PKL_FILE = Path("data/dev_doc10sent5_all_rerank_ext.pkl")

Preload wiki database (same as part 2.)

In [3]:
wiki_pages = jsonl_dir_to_df("data/wiki-pages")
mapping = generate_evidence_to_wiki_pages_mapping(wiki_pages,)
del wiki_pages

Reading and concatenating jsonl files in data/wiki-pages
Generate parse mapping


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=59388), Label(value='0 / 59388')))…

Transform to id to evidence_map mapping


### Helper function

AICUP dataset with top-k evidence sentences.

In [26]:
class AicupTopkEvidenceBERTDataset(BERTDataset):
    """AICUP dataset with top-k evidence sentences."""

    def __getitem__(
        self,
        idx: int,
        **kwargs,
    ) -> Tuple[Dict[str, torch.Tensor], int]:
        item = self.data.iloc[idx]
        claim = item["claim"]
        evidence = item["evidence_list"]
        pad = ["[PAD]"] * (self.topk - len(evidence))
        evidence += pad

        true_evidence=[]
        for i,evd in enumerate(evidence):
            if evd!='[PAD]':
                true_evidence.append(evd)

        # 錯誤 正確 沒有足夠證據
        # 光學顯微鏡是以電磁學原理來將不可見或難見的微小物放大至肉眼可見的儀器。
        # evid1 evid2
        # 你主張 : 光學顯微鏡是以電磁學原理來將不可見或難見的微小物放大至肉眼可見的儀器。 因為證據有 : evid1  evid2 ... 所以我認為 [MASK]   
        evidence_text=""
        # tmp_str=""
        for i,text in enumerate(true_evidence):
            evidence_text+=str(i+1)+". "+text+" "

        if evidence_text=="":
            evidence_text='"無"。'

        if "label" in item: 
            # concat_claim_evidence = "我認為"+'"'+LABEL2TEXT[item["label"]]+'"'+"，因為你認為 : "+claim+"，而證據有 : "+evidence_text
            concat_claim_evidence = "我覺得"+'"'+LABEL2TEXT[item["label"]]+'"'+"，因為你認為 : "+claim+"而證據有 : "+evidence_text
            # concat_claim_evidence = "我認為"+'"'+LABEL2TEXT[item["label"]]+'"'+"，因為結論為 : "+claim+"，而證明該結論有 : "+evidence_text
            # concat_claim_evidence = "這讀起來"+'"'+LABEL2TEXT[item["label"]]+'"'+": "+claim+"，"+evidence_text
            # concat_claim_evidence = '你主張"'+claim+'"，因證據有'+NUMBER2CHINESE[x]+'個: '+evidence_text+"所以我覺得"+'"'+LABEL2TEXT[item["label"]]+'"'
        else:
            # concat_claim_evidence = "我認為"+'"'+"[MASK]"+'"'+"，因為你認為 : "+claim+"，而證據有 : "+evidence_text
            concat_claim_evidence = "我覺得"+'"'+"[MASK]"+'"'+"，因為你認為 : "+claim+"而證據有 : "+evidence_text
            # concat_claim_evidence = "我認為"+'"'+"[MASK]"+'"'+"，因為結論為 : "+claim+"，而證明該結論有 : "+evidence_text
            # concat_claim_evidence = "這讀起來"+'"'+"[MASK]"+'"'+"，結論: "+claim+" 原因: "+evidence_text
            # concat_claim_evidence = '你主張"'+claim+'"，因證據有'+NUMBER2CHINESE[x]+'個: '+evidence_text+"所以我覺得"+'"'+"[MASK]"+'"'

        concat = self.tokenizer.encode(
            concat_claim_evidence,
            max_length=self.max_length,
            truncation=True,
        )
        # prompt後
        # maintain_label = (self.max_length-len(concat))
        # concat = concat + [tokenizer.pad_token_id]*maintain_label
        # concat = torch.tensor(concat)
        # attention_mask = (concat!=tokenizer.pad_token_id)
        # concat_ten={}
        # if "label" in item: 
        #     label = torch.ones_like(concat).detach()*-100
        #     maintain_label = -maintain_label-3
        #     maintain = concat[maintain_label]
        #     label[maintain_label] = maintain
        #     concat[maintain_label]=tokenizer.mask_token_id
        #     concat_ten["labels"] = label.squeeze(0)
        #     target = LABEL2ID[item["label"]] if "label" in item else -1
        #     concat_ten["target"] = torch.tensor(target)
        # concat_ten["input_ids"] = concat
        # concat_ten["attention_mask"] = attention_mask
        
        # prompt前
        concat = concat + [tokenizer.pad_token_id]*(self.max_length-len(concat))
        concat = torch.tensor(concat)
        attention_mask = (concat!=tokenizer.pad_token_id)
        concat_ten={}
        if "label" in item: 
            label = concat.detach().clone()
            maintain = label[5]
            label.masked_fill_(label!=maintain,-100)
            concat[5]=tokenizer.mask_token_id
            concat_ten["labels"] = label.squeeze(0)
            target = LABEL2ID[item["label"]] if "label" in item else -1
            concat_ten["target"] = torch.tensor(target)
        concat_ten["input_ids"] = concat
        concat_ten["attention_mask"] = attention_mask
        
        return concat_ten

Evaluation function

In [5]:
def run_evaluation(model: torch.nn.Module, dataloader: DataLoader, device):
    model.eval()

    loss = 0
    y_true = []
    y_pred = []

    with torch.no_grad():
        progress_bar = tqdm(dataloader)
        for i,batch in enumerate(dataloader):
            y_true.extend(batch["labels"].tolist())

            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss += outputs.loss.item()
            logits = outputs.logits
            y_pred.extend(torch.argmax(logits, dim=1).tolist())
            progress_bar.set_postfix(loss=loss/(i+1),accuracy=accuracy_score(y_true, y_pred))
            progress_bar.update(1)
    acc = accuracy_score(y_true, y_pred)

    return {"val_loss": loss / len(dataloader), "val_acc": acc}

In [5]:
def run_evaluation_prompt(model: torch.nn.Module,tok, dataloader: DataLoader, device):
    model.eval()

    loss = 0
    y_true = []
    y_pred = []

    with torch.no_grad():
        progress_bar = tqdm(dataloader)
        for i,batch in enumerate(dataloader):
            y_true.extend(batch["target"].tolist())
            batch = {k: v.to(device) for k, v in batch.items() if k!="target"}
            outputs = model(**batch)
            mask_token_index = torch.where(batch["input_ids"]==tok.mask_token_id,True,False)
            mask_token_logits = outputs.logits[mask_token_index]
            extract_label = tok.batch_encode_plus(list(TEXT2LABEL.keys()),add_special_tokens=False,return_tensors="pt")['input_ids'].squeeze(1).unsqueeze(0)
            extract_label_repeat = torch.repeat_interleave(extract_label,mask_token_logits.size(0),dim=0)
            extract_label_output = torch.gather(mask_token_logits,1,extract_label_repeat.to(device))
            max_label = torch.argmax(extract_label_output,dim=1).tolist()
            y_pred.extend(max_label)
            loss += outputs.loss.item()

            progress_bar.set_postfix(loss=loss/(i+1),accuracy=accuracy_score(y_true, y_pred))
            progress_bar.update(1)
    acc = accuracy_score(y_true, y_pred)

    return {"val_loss": loss / len(dataloader), "val_acc": acc}

Prediction

In [7]:
def run_predict(model: torch.nn.Module, test_dl: DataLoader, device) -> list:
    model.eval()

    preds = []
    for batch in tqdm(test_dl,
                      total=len(test_dl),
                      leave=False,
                      desc="Predicting"):
        batch = {k: v.to(device) for k, v in batch.items()}
        pred = model(**batch).logits
        pred = torch.argmax(pred, dim=1)
        preds.extend(pred.tolist())
    return preds

In [6]:
def run_predict_prompt(model: torch.nn.Module,tok, test_dl: DataLoader, device) -> list:
    model.eval()

    preds = []
    for batch in tqdm(test_dl,
                      total=len(test_dl),
                      leave=False,
                      desc="Predicting"):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        mask_token_index = torch.where(batch["input_ids"]==tok.mask_token_id,True,False)
        mask_token_logits = outputs.logits[mask_token_index]
        extract_label = tok.batch_encode_plus(list(TEXT2LABEL.keys()),add_special_tokens=False,return_tensors="pt")['input_ids'].squeeze(1).unsqueeze(0)
        extract_label_repeat = torch.repeat_interleave(extract_label,mask_token_logits.size(0),dim=0)
        extract_label_output = torch.gather(mask_token_logits,1,extract_label_repeat.to(device))
        max_label = torch.argmax(extract_label_output,dim=1).tolist()
        preds.extend(max_label)
    return preds

### Main function

In [7]:
def join_with_topk_evidence(
    df: pd.DataFrame,
    mapping: dict,
    mode: str = "train",
    topk: int = 5,
) -> pd.DataFrame:
    """join_with_topk_evidence join the dataset with topk evidence.

    Note:
        After extraction, the dataset will be like this:
               id     label         claim                           evidence            evidence_list
        0    4604  supports       高行健...     [[[3393, 3552, 高行健, 0], [...  [高行健 （ ）江西赣州出...
        ..    ...       ...            ...                                ...                     ...
        945  2095  supports       美國總...  [[[1879, 2032, 吉米·卡特, 16], [...  [卸任后 ， 卡特積極參與...
        停各种战争及人質危機的斡旋工作 ， 反对美国小布什政府攻打伊拉克...

        [946 rows x 5 columns]

    Args:
        df (pd.DataFrame): The dataset with evidence.
        wiki_pages (pd.DataFrame): The wiki pages dataframe
        topk (int, optional): The topk evidence. Defaults to 5.
        cache(Union[Path, str], optional): The cache file path. Defaults to None.
            If cache is None, return the result directly.

    Returns:
        pd.DataFrame: The dataset with topk evidence_list.
            The `evidence_list` column will be: List[str]
    """

    # format evidence column to List[List[Tuple[str, str, str, str]]]
    if "evidence" in df.columns:
        df["evidence"] = df["evidence"].parallel_map(
            lambda x: [[x]] if not isinstance(x[0], list) else [x]
            if not isinstance(x[0][0], list) else x)
    return df

### Step 1. Setup training environment

Hyperparams

In [8]:
#@title  { display-mode: "form" }

MODEL_NAME = "hfl/chinese-bert-wwm-ext"  #@param {type:"string"}
TRAIN_BATCH_SIZE = 4  #@param {type:"integer"}
TEST_BATCH_SIZE = 16  #@param {type:"integer"}
SEED = 42  #@param {type:"integer"}
LR = 7e-5  #@param {type:"number"}
NUM_EPOCHS = 5  #@param {type:"integer"}
MAX_SEQ_LEN = 300  #@param {type:"integer"}
EVIDENCE_TOPK = 5  #@param {type:"integer"}
VALIDATION_STEP = 500  #@param {type:"integer"}


Experiment Directory

In [9]:
OUTPUT_FILENAME = "submission.jsonl"

EXP_DIR = f"claim_verification/e{NUM_EPOCHS}_bs{TRAIN_BATCH_SIZE}_" + f"{LR}_top{EVIDENCE_TOPK}_ext_all_evi2_rerank5"
LOG_DIR = "logs/" + EXP_DIR
CKPT_DIR = "checkpoints/" + EXP_DIR

if not Path(LOG_DIR).exists():
    Path(LOG_DIR).mkdir(parents=True)

if not Path(CKPT_DIR).exists():
    Path(CKPT_DIR).mkdir(parents=True)

### Step 2. Concat claim and evidences
join topk evidence

In [31]:
if not TRAIN_PKL_FILE.exists():
    train_df = join_with_topk_evidence(
        pd.DataFrame(TRAIN_DATA),
        mapping,
        topk=EVIDENCE_TOPK,
    )
    train_df.to_pickle(TRAIN_PKL_FILE, protocol=5)
else:
    with open(TRAIN_PKL_FILE, "rb") as f:
        train_df = pickle.load(f)

if not DEV_PKL_FILE.exists():
    dev_df = join_with_topk_evidence(
        pd.DataFrame(DEV_DATA),
        mapping,
        mode="eval",
        topk=EVIDENCE_TOPK,
    )
    dev_df.to_pickle(DEV_PKL_FILE, protocol=5)
else:
    with open(DEV_PKL_FILE, "rb") as f:
        dev_df = pickle.load(f)

In [32]:
total_list=[]
for i,row in tqdm(enumerate(train_df.iterrows())):
    tmp=[]
    for evi_id, evi_idx in row[1]['predicted_evidence'][:3]:

        if evi_id == "臺灣海峽危機#第二次臺灣海峽危機（1958）":
            continue
        tmp.append(mapping[evi_id][str(evi_idx)])

    total_list.append(tmp)
#train_df = train_df.drop("evidence_list",axis=1)
train_df.loc[:, "evidence_list"] = total_list

total_list=[]
for i,row in tqdm(enumerate(dev_df.iterrows())):
    tmp=[]
    for evi_id, evi_idx in row[1]['predicted_evidence'][:3]:
        if evi_id == "臺灣海峽危機#第二次臺灣海峽危機（1958）":
            continue
        tmp.append(mapping[evi_id][str(evi_idx)])
    total_list.append(tmp)
#dev_df = dev_df.drop("evidence_list",axis=1)
dev_df.loc[:, "evidence_list"] = total_list

0it [00:00, ?it/s]

### Step 3. Training

Prevent CUDA out of memory

In [34]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_dataset = AicupTopkEvidenceBERTDataset(
    train_df,
    tokenizer=tokenizer,
    max_length=MAX_SEQ_LEN,
)
val_dataset = AicupTopkEvidenceBERTDataset(
    dev_df,
    tokenizer=tokenizer,
    max_length=MAX_SEQ_LEN,
)

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=TRAIN_BATCH_SIZE,
)
eval_dataloader = DataLoader(val_dataset, batch_size=TEST_BATCH_SIZE)

In [None]:
# 可以不用執行 check train dataset時使用
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_dataset = AicupTopkEvidenceBERTDataset(
    train_df,
    tokenizer=tokenizer,
    max_length=MAX_SEQ_LEN,
)
train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=TRAIN_BATCH_SIZE,
)



In [25]:
device = torch.device("cuda:2") if torch.cuda.is_available() else torch.device("cpu")

#model = AutoModelForSequenceClassification.from_pretrained(
model = BertForMaskedLM.from_pretrained(
    MODEL_NAME,
    # num_labels=len(LABEL2ID),
    # type_vocab_size=7,
    # ignore_mismatched_sizes=True
)
# model.load_state_dict(torch.load("checkpoints/claim_verification/e20_bs8_7e-05_top5_ext_all_evi3/val_acc=0.7146_model.7500_bs0.577.pt"),strict=False)

model.to(device)

optimizer = AdamW(model.parameters(), lr=LR)
num_training_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = set_lr_scheduler(optimizer, num_training_steps)

writer = SummaryWriter(LOG_DIR)

Some weights of the model checkpoint at hfl/chinese-bert-wwm were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Training (30 mins)

In [26]:
progress_bar = tqdm(range(num_training_steps))
current_steps = 0
step_cnt=0

for epoch in range(NUM_EPOCHS):
    model.train()

    for i,batch in enumerate(train_dataloader):
        batch = {k: v.to(device) for k, v in batch.items() if k!="target"}
        outputs = model(**batch)
        # break
        loss = outputs.loss
        loss = loss/8
        loss.backward()
        step_cnt+=1
        if step_cnt==8 or i==len(train_dataloader)-1:
            step_cnt=0
            optimizer.step()
            optimizer.zero_grad()
            lr_scheduler.step()
        progress_bar.set_postfix(loss=loss.item()*8)
        progress_bar.update(1)
        writer.add_scalar("training_loss", loss.item(), current_steps)


        current_steps += 1

        if current_steps%VALIDATION_STEP==0 or i==len(train_dataloader)-1:
            # print("Start validation")
            # val_results = run_evaluation(model, eval_dataloader, device)
            val_results = run_evaluation_prompt(model,tokenizer, eval_dataloader, device)

            # # log each metric separately to TensorBoard
            for metric_name, metric_value in val_results.items():
                print(current_steps,f"{metric_name}: {metric_value}")
                writer.add_scalar(f"{metric_name}", metric_value, current_steps)

            save_checkpoint(
                model,
                CKPT_DIR,
                current_steps,
                mark=f"val_acc={val_results['val_acc']:.4f}",
            )

print("Finished training!")

  0%|          | 0/11650 [00:00<?, ?it/s]

  0%|          | 0/146 [00:00<?, ?it/s]

500 val_loss: 0.8928324329118206
500 val_acc: 0.6927038626609442


  0%|          | 0/146 [00:00<?, ?it/s]

1000 val_loss: 0.8243427129640971
1000 val_acc: 0.7072961373390558


  0%|          | 0/146 [00:00<?, ?it/s]

1500 val_loss: 0.8726550680521417
1500 val_acc: 0.7030042918454935


  0%|          | 0/146 [00:00<?, ?it/s]

2000 val_loss: 0.8944448342878525
2000 val_acc: 0.711587982832618


  0%|          | 0/146 [00:00<?, ?it/s]

2330 val_loss: 0.8293357702763113
2330 val_acc: 0.7055793991416309


  0%|          | 0/146 [00:00<?, ?it/s]

2500 val_loss: 0.9718346252833328
2500 val_acc: 0.7107296137339055


  0%|          | 0/146 [00:00<?, ?it/s]

3000 val_loss: 1.0134965023153448
3000 val_acc: 0.7188841201716738


  0%|          | 0/146 [00:00<?, ?it/s]

3500 val_loss: 0.9797987083662046
3500 val_acc: 0.7133047210300429


  0%|          | 0/146 [00:00<?, ?it/s]

4000 val_loss: 0.9656425836968096
4000 val_acc: 0.7128755364806867


  0%|          | 0/146 [00:00<?, ?it/s]

4500 val_loss: 1.0100778537253812
4500 val_acc: 0.7064377682403433


  0%|          | 0/146 [00:00<?, ?it/s]

4660 val_loss: 0.8147847758580561
4660 val_acc: 0.6922746781115879


  0%|          | 0/146 [00:00<?, ?it/s]

5000 val_loss: 1.1248167616455522
5000 val_acc: 0.7085836909871245


  0%|          | 0/146 [00:00<?, ?it/s]

5500 val_loss: 0.9105253639082386
5500 val_acc: 0.7103004291845494


  0%|          | 0/146 [00:00<?, ?it/s]

6000 val_loss: 0.9633634368034258
6000 val_acc: 0.6922746781115879


  0%|          | 0/146 [00:00<?, ?it/s]

6500 val_loss: 0.9244728543578762
6500 val_acc: 0.6690987124463519


  0%|          | 0/146 [00:00<?, ?it/s]

6990 val_loss: 1.0046413462251833
6990 val_acc: 0.6742489270386266


  0%|          | 0/146 [00:00<?, ?it/s]

7000 val_loss: 0.9053670584338985
7000 val_acc: 0.6875536480686695


  0%|          | 0/146 [00:00<?, ?it/s]

7500 val_loss: 1.0314801910968676
7500 val_acc: 0.663519313304721


  0%|          | 0/146 [00:00<?, ?it/s]

8000 val_loss: 1.1157817099600622
8000 val_acc: 0.6639484978540773


  0%|          | 0/146 [00:00<?, ?it/s]

8500 val_loss: 0.912447160237456
8500 val_acc: 0.6888412017167382


  0%|          | 0/146 [00:00<?, ?it/s]

9000 val_loss: 0.8141130118133271
9000 val_acc: 0.7090128755364807


  0%|          | 0/146 [00:00<?, ?it/s]

9320 val_loss: 0.8948295680831556
9320 val_acc: 0.7124463519313304


  0%|          | 0/146 [00:00<?, ?it/s]

9500 val_loss: 1.2874118169284847
9500 val_acc: 0.6828326180257511


  0%|          | 0/146 [00:00<?, ?it/s]

10000 val_loss: 0.8175377103767983
10000 val_acc: 0.6909871244635193


### Step 4. Make your submission

In [37]:
# import json
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
TEST_DATA = load_json("data/combine_test_doc10sent5_rerank_ext.jsonl")
# TEST_PKL_FILE = Path("data/test_doc5sent5_imp.pkl")

# # if not TEST_PKL_FILE.exists():
test_df = join_with_topk_evidence(
    pd.DataFrame(TEST_DATA),
    mapping,
    mode="eval",
    topk=EVIDENCE_TOPK,
)
total_list=[]
for i,row in tqdm(enumerate(test_df.iterrows())):

    tmp=[]
    for evi_id, evi_idx in row[1]['predicted_evidence'][:3]:

        if evi_id == "臺灣海峽危機#第二次臺灣海峽危機（1958）":
            continue
        tmp.append(mapping[evi_id][str(evi_idx)])

    # print(tmp)
    total_list.append(tmp)
    # break

test_df.loc[:, "evidence_list"] = total_list

test_dataset = AicupTopkEvidenceBERTDataset(
    test_df,
    tokenizer=tokenizer,
    max_length=MAX_SEQ_LEN,
)
test_dataloader = DataLoader(test_dataset, batch_size=10)

Extracting evidence_list for the eval mode ...


0it [00:00, ?it/s]

Prediction

In [None]:
# ckpt_name = "val_acc=0.8174_model.5000.pt"  #@param {type:"string"}
device = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")
model = BertForMaskedLM.from_pretrained(
    MODEL_NAME,
    # num_labels=len(LABEL2ID),
    # type_vocab_size=7,
    # ignore_mismatched_sizes=True
)
model.load_state_dict(torch.load("checkpoints/claim_verification/e5_bs4_0.002_top5_ext_all_evi2_rerank5/val_acc=0.7189_model.3000.pt", map_location="cpu"))
model.to(device)
# # predicted_label = run_predict(model, test_dataloader, device)
predicted_label = run_predict_prompt(model,tokenizer, test_dataloader, device)

Write files

In [39]:
predict_dataset = test_df.copy()
predict_dataset["predicted_label"] = list(map(ID2LABEL.get, predicted_label))
for i,row in predict_dataset.iterrows():
    if row['predicted_evidence'] == []:
        predict_dataset.at[i,"predicted_label"] = "NOT ENOUGH INFO"
        
predict_dataset[["id", "predicted_label", "predicted_evidence"]].to_json(
    "submission_combine_val_acc=0.7189_model.3000.jsonl",
    orient="records",
    lines=True,
    force_ascii=False,
)