## Import Libraries

In [1]:
import importlib
import torch
from transformers import BertTokenizer

import model.biencoder_dataset
import model.biencoder
import model.bert_encoder
import model.biencoder_trainer

importlib.reload(model.biencoder_dataset)
importlib.reload(model.biencoder)
importlib.reload(model.bert_encoder)
importlib.reload(model.biencoder_trainer)

from model.biencoder_dataset import BiEncoderDataset
from model.biencoder import BiEncoder
from model.bert_encoder import BertEncoder
from model.biencoder_trainer import BiEncoderTrainer

  from .autonotebook import tqdm as notebook_tqdm


## Check Resource

In [2]:
torch.cuda.is_available()

False

## Load Training Data

In [2]:
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

train_dataset = BiEncoderDataset(claim_file_path="./data/train-claims.json",
                                 evidence_file_path="./data/evidence.json",
                                 tokenizer=bert_tokenizer,
                                 max_padding_length=200,
                                 neg_evidence_num=2,
                                 rand_seed=1)

## Create Model

In [3]:
query_encoder = BertEncoder(add_pooling_layer=True)

biencoder = BiEncoder(query_model=query_encoder)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertEncoder: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Train Model

In [4]:
biencoder_trainer = BiEncoderTrainer(model=biencoder,
                                     batch_size=8)

biencoder_history = biencoder_trainer.train(train_data=train_dataset,
                                            max_epoch=2,
                                            loss_func_type="nll_loss",
                                            similarity_func_type="dot",
                                            optimizer_type="adam",
                                            learning_rate=0.001)