## Import Libraries

In [None]:
import json
import torch
from transformers import BertTokenizer

from model.dpr.biencoder_dataset import BiEncoderDataset
from model.dpr.evidence_dataset import EvidenceDataset
from model.dpr.biencoder import BiEncoder
from model.dpr.encoder.bert_encoder import BertEncoder
from model.dpr.biencoder_trainer import BiEncoderTrainer

from model.classifier.bert_classifier import BertClassifier
from model.classifier.bert_classifier_dataset import BertClassifierDataset
from model.classifier.bert_classifier_trainer import BertClassifierTrainer

## Setup

### Check Resource

In [None]:
torch.cuda.is_available()

### Create Tokenizer

In [None]:
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

## Information Retriever

### Load Data

In [None]:
# Training dataset
biencoder_train_dataset = BiEncoderDataset(claim_file_path="./data/train-claims.json",
                                           data_type="train",
                                           evidence_file_path="./data/evidence.json",
                                           tokenizer=bert_tokenizer,
                                           lower_case=False,
                                           max_padding_length=128,
                                           evidence_num=6,
                                           rand_seed=1)

# Testing dataset
biencoder_test_dataset = BiEncoderDataset(claim_file_path="./data/test-claims-unlabelled.json",
                                          data_type="predict",
                                          evidence_file_path="./data/evidence.json",
                                          tokenizer=bert_tokenizer,
                                          lower_case=False,
                                          max_padding_length=128,
                                          evidence_num=0,
                                          rand_seed=1)

#Evidence dataset
evidence_dataset = EvidenceDataset(evidence_file_path="./data/evidence.json",
                                   tokenizer=bert_tokenizer,
                                   lower_case=False,
                                   max_padding_length=128,
                                   rand_seed=1)

### Create BiEncoder Model

In [None]:
query_encoder = BertEncoder(add_pooling_layer=True)

biencoder = BiEncoder(query_model=query_encoder)

### Train BiEncoder Model

In [None]:
biencoder_trainer = BiEncoderTrainer(model=biencoder,
                                     batch_size=16)

In [None]:
biencoder_history = biencoder_trainer.train(train_data=biencoder_train_dataset,
                                            shuffle=True,
                                            max_epoch=5,
                                            loss_func_type="nll_loss",
                                            similarity_func_type="dot",
                                            optimizer_type="adam",
                                            learning_rate=0.0001)

### Retrieve Information

In [None]:
# Encode
evid_embed_data = biencoder.get_evidence_embed(evidence_dataset=evidence_dataset,
                                               batch_size=1000,
                                               output_file_path="./data/output/embed-evidence.json")

In [None]:
# Read from file
raw_embed_evid = json.load(open("./data/output/embed-evidence.json"))

evid_tags = []
evid_embeds = []
for tag, embed in raw_embed_evid.items():
    evid_embeds.append(embed)
    evid_tags.append(tag)

embed_evid_data = {"tag": evid_tags, "evidence": evid_embeds}

In [None]:
retrieve_output = biencoder.retrieve(biencoder_test_dataset,
                                     embed_evid_data=embed_evid_data,
                                     k=5,
                                     batch_size =128,
                                     predict_output_path="./data/output/retrieve-outcome.json")

## Classifier

### Load Data

In [None]:
classifier_train_dataset = BertClassifierDataset(claim_file_path="./data/train-claims.json",
                                                 data_type="train",
                                                 evidence_file_path="./data/evidence.json",
                                                 tokenizer=bert_tokenizer,
                                                 lower_case=False,
                                                 max_padding_length=512,
                                                 rand_seed=1)

classifier_test_dataset = BertClassifierDataset(claim_file_path="./data/test-claims-unlabelled.json",
                                                data_type="predict",
                                                evidence_file_path="./data/evidence.json",
                                                predict_evidence_file_path="./data/output/retrieval-claim-prediction.json",
                                                tokenizer=bert_tokenizer,
                                                lower_case=False,
                                                max_padding_length=512,
                                                rand_seed=1)

### Create Classifier

In [None]:
bert_classifier = BertClassifier()

### Train Classifier

In [None]:
bert_classifier_trainer = BertClassifierTrainer(bert_classifier,
                                                batch_size=4)

In [None]:
ber_classifier_history = bert_classifier_trainer.train(train_dataset=classifier_train_dataset,
                                                       shuffle=True, 
                                                       max_epoch=10,
                                                       loss_func_type="cross_entropy",
                                                       optimizer_type="adam",
                                                       learning_rate=0.0001)

### Predict

In [None]:
classification_output= bert_classifier.predict(classifier_test_dataset,
                                               batch_size=32,
                                               output_file_path="./data/output/class-prediction.json")