# 1.  PyTorch implementation of BERT

 * pytorch-pretrained-BERT
 - [Link](https://github.com/huggingface/pytorch-pretrained-BERT)

In [None]:
# Bert Config
from pytorch_pretrained_bert import BertConfig
config = BertConfig(vocab_size_or_config_json_file=32000, 
                    hidden_size=768,
                    num_hidden_layers=12, 
                    num_attention_heads=12, 
                    intermediate_size=3072)

In [None]:
# Bert Model
from pytorch_pretrained_bert import BertModel
bert = BertModel(config)

# 2. Extending BERT for SQuAD task

    * see class modeling.BertForQuestionAnswering

### 2.1 define a linear layer

In [None]:
# Define new fully connected layer with two ouputs
import torch
import torch.nn as nn

# test how it works
# inputs
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])

#outputs
qa_outputs = nn.Linear(config.hidden_size, 2)

last_encoding_layer, _ = bert(input_ids, token_type_ids, input_mask, output_all_encoded_layers=False)
logits = qa_outputs(last_encoding_layer)

print(last_encoding_layer)
print(last_encoding_layer.size())
print(logits)
print(logits.size())

In [None]:
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
print(start_logits,'\n',end_logits)

### 2.2 Defining loss

In [None]:
# loss definition

start_positions=torch.LongTensor([1,1])
end_positions=torch.LongTensor([2,2])

start_loss = nn.functional.cross_entropy(start_logits, start_positions)
end_loss = nn.functional.cross_entropy(end_logits, end_positions)

# define the loss as the avarage loss of start and end loss
total_loss = (start_loss + end_loss) / 2

print(start_loss)
print(end_loss)
print(total_loss)

# 3. Fine-Tuning Bert for SQuAD

### 3.1 Make DataLoader for training

In [None]:
# Total batch size for training
train_batch_size = 6
num_train_epochs = 1.0
learning_rate=0.1
seed = 42
torch.cuda.manual_seed_all(seed)
    



In [None]:
import pickle
from squad_example import InputFeatures

global_step = 0
cached_train_features_file = \
'data/tranin-v1.1_sample.json_bert-base-multilingual-cased_128_64_64'

# Load cached features
with open(cached_train_features_file, "rb") as reader:
    train_features = pickle.load(reader)
print("total number of input features loaded : ", len(train_features))

# Define Input Tensors
all_input_ids = torch.tensor([f.input_ids for f in train_features],
                            dtype = torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features],
                             dtype = torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                              dtype = torch.long)
all_start_positions = torch.tensor([f.start_position for f in train_features],
                                  dtype = torch.long)
all_end_positions = torch.tensor([f.end_position for f in train_features],
                               dtype = torch.long)


In [None]:
from torch.utils.data import TensorDataset,\
                              DataLoader, \
                              RandomSampler

train_data = TensorDataset(all_input_ids,
                           all_input_mask,
                           all_segment_ids,
                           all_start_positions,
                           all_end_positions)

train_sampler = RandomSampler(train_data)

train_dataloader = DataLoader(train_data,
                              sampler=train_sampler,
                              batch_size = train_batch_size)




In [None]:
from squad_example import InputFeatures

# The maximum total input sequence length after WordPiece tokenization
max_seq_length = 512
#When splitting up a long document into chunks, how much stride to take between chunks
doc_stride = 128
#The maximum number of tokens for the question.
max_query_length = 64 

output_dir = '/tmp/squad'

num_train_steps = int(len(train_features) 
                      / train_batch_size 
                      * num_train_epochs)

print("Number of train steps : ", num_train_steps)

In [None]:
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE

bert_model = 'bert-base-multilingual-cased'
model = BertForQuestionAnswering.from_pretrained(
            bert_model,
            cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(-1))

model.to("cuda")
model = torch.nn.DataParallel(model)

#prepare optimizer
param_optimizer = list(model.named_parameters())
#hack to remove "pooler"  which is not used
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

In [None]:
from pytorch_pretrained_bert.optimization import BertAdam

no_decay =['bias','LayerNorm.bias','LayerNorm.weight']
optimizer_grouped_parameters =[
    {'params':[ p for n , p in param_optimizer 
               if not any(nd in n for nd in no_decay)],'weight_decay':0.01},
    {'params':[p for n, p in param_optimizer
              if any(nd in n for nd in no_decay)],'weigth_decay':0.0}
    ]

optimizer = BertAdam(optimizer_grouped_parameters,
                    lr=5e-5,
                    warmup=0.1,
                    t_total = num_train_steps)


### Train


In [None]:
model.train()
from tqdm import tqdm, trange

for _ in trange(int(num_train_epochs), desc="Epoch"):
    for step, batch in enumerate(tqdm(train_dataloader,desc="Iteration")):
        
        batch = tuple(t.to("cuda") for t in batch)
        
        input_ids,\
        input_mask,\
        segment_ids,\
        start_positions,\
        end_positions=batch
        
        loss = model(input_ids,
                     segment_ids,
                     input_mask,
                     start_positions,
                     end_positions)
        
        loss = loss.mean()
        
        loss.backward()
        
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate
            optimizer.step()
            optimizer.zero_grad()
            global_step +=1

model_to_save = model.module if hasatt(model,'module') else model
output_model_file = os.path.join(output_dir,"pytorch_model.bin")
torch.save(model_to_save.state_dict(),output_model_file)