In [None]:
from collections import defaultdict
from dataclasses import dataclass
import os
import random
import time
from typing import Callable, Dict, List, Generator, Tuple
from data_pre_process import *
from model import *
from data_loader import *
from validation import *
import gc

import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm

import torch
from torch import nn, optim
import torch.cuda.amp 
from pathlib import Path
from torch.cuda.amp import GradScaler as scaler

from torch.utils.data import Dataset, Subset, DataLoader

from transformers import BertTokenizer, AdamW, BertModel, get_linear_schedule_with_warmup, BertPreTrainedModel

In [None]:
init_start_time = time.time()

bert_model = 'bert-base-uncased'
do_lower_case = 'uncased' in bert_model
device = torch.device('cuda')

data_dir_t = Path('data_2/v1.0/train')
data_path_t = data_dir_t/'nq-train-02.jsonl.gz'

data_dir_v = Path('data_2/v1.0/dev')
data_path_v = data_dir_v/'nq-dev-00.jsonl.gz'

# data_dir_t = Path('data')
# data_path_t = data_dir_t/'v1.0_train.jsonl.gz'

# data_dir_v = Path('data')
# data_path_v = data_dir_v/'v1.0_dev.jsonl.gz'

In [None]:
"""
Hyperparameters to convert a specific examples into multiple examples and the function that does it. 
Chunksize is the number of examples. Max sequence length is the size of a specific example will be broken down, and
the overall content will be broken down in strides of 128.
"""

chunksize = 1000
max_seq_len = 384
max_question_len = 64
doc_stride = 128
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case='uncased' in 'bert-base-uncased')

convert_func = functools.partial(convert_data,
                                 tokenizer=tokenizer,
                                 max_seq_len=max_seq_len,
                                 max_question_len=max_question_len,
                                 doc_stride=doc_stride,
                                 val=True) #this is to use document_html

In [None]:
"""
Opens a specific file and loads in all the examples which are per line, and calls jsonlreader which is an iterable 
used later on to train the model
"""

def open_file(data_path_t):
    start = time.time()
    with gzip.open(data_path_t, "rb") as f:
        data = f.read()
    x = data.splitlines()
    data_reader = JsonlReader(x, convert_func, chunksize=chunksize)
    end = time.time()
    train_size = len(x)
    print("Loading Data:", end - start, "seconds")
    return data_reader, train_size

In [None]:
"""
Hyperparameters that will be training my model on. Was only able to have a batchsize of 16 because low vram
"""

num_labels = 5
n_epochs = 1
lr = 2e-5
warmup = 0.05
batch_size = 16
accumulation_steps = 4

In [None]:
"""
Initilization of model, paramters, optimizer and schedular
"""
desired_data_train_files = 6 #number of files you want to train on

model = BertForQuestionAnswering.from_pretrained(bert_model, num_labels=5)
model = model.to(device)

average_file_size = 6000
train_size = average_file_size*desired_data_train_files
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
train_optimization_steps = int(n_epochs * train_size / batch_size / accumulation_steps)
warmup_steps = int(train_optimization_steps * warmup)

optimizer = AdamW(optimizer_grouped_parameters, lr=lr, correct_bias=False)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=train_optimization_steps)

model.zero_grad()
model = model.train()

In [None]:
"""
Training Loop that goes through multiple directories of files provided, then calls open_file, then iterates through
each chunksize and loads the data which is broken to batches. The files loop through randomly, and the x_batch is a
random batch of examples that were subsets of the same larger example.
"""

def train_per_file(data_reader, train_size):
    running_loss = 0.0
    global_step = 0
    for examples in tqdm(data_reader, total=int(np.floor(train_size/chunksize))):
        examples = list(filter(lambda example: len(example) != 0, examples))
        train_dataset = TextDataset(examples)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
        for x_batch, y_batch in train_loader:
            x_batch, attention_mask, token_type_ids = x_batch
            y_batch = (y.to(device) for y in y_batch)

            y_pred = model(x_batch.to(device),
                           attention_mask=attention_mask.to(device),
                           token_type_ids=token_type_ids.to(device))

            loss = loss_fn(y_pred, y_batch)
            loss.backward()
            running_loss += loss.item()
            
            if (global_step + 1) % accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                model.zero_grad()
            global_step += 1

    torch.save(model.state_dict(), 'bert_pytorch_all_t.bin')
    torch.save(optimizer.state_dict(), 'bert_pytorch_optimizer_all_t.bin')
    
    del examples, train_dataset, train_loader
    return running_loss/train_size

In [None]:
num_data_train_files = 0
loss_per_file = []
for child in data_dir_t.iterdir():
    data_dir_t = child
    data_reader, train_size = open_file(data_dir_t)
    if train_size == 0:
        continue
    num_data_train_files += 1
    loss = train_per_file(data_reader, train_size)
    loss_per_file.append(loss)
    if num_data_train_files >= desired_data_train_files:
        break
        
print(loss_per_file)
x = gc.collect()

In [None]:
"""
Plotting Loss Graph
"""
import matplotlib.pyplot as plt

if len(loss_per_file) == 0:
    loss_per_file = list(range(desired_data_train_files)) #so it will continue to next cell;

plt.plot(list(range(desired_data_train_files)), loss_per_file)
plt.xlabel("Number of Train Files")
plt.ylabel("Running Loss")
plt.title("Loss vs File")
plt.savefig("Loss.png", dpi=300)
plt.show()

In [None]:
"""
Loads a previous model if needed.
"""
load = False
if load:
    model.load_state_dict(torch.load("bert_pytorch.bin"))
    optimizer.load_state_dict(torch.load("bert_pytorch_optimizer.bin"))

In [None]:
data_reader_v, val_size = open_file(data_path_v)

In [None]:
eval_start_time = time.time()

convert_func = functools.partial(convert_data,
                                 tokenizer=tokenizer,
                                 max_seq_len=max_seq_len,
                                 max_question_len=max_question_len,
                                 doc_stride=doc_stride,
                                 val=True)

valid_data = next(data_reader_v)
valid_data = list(itertools.chain.from_iterable(valid_data))
valid_dataset = Subset(valid_data, range(15000))
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, collate_fn=eval_collate_fn)
valid_scores = eval_model(model, valid_loader, device=device)
print(f'calculate validation score done in {(time.time() - eval_start_time) / 60:.1f} minutes.')

In [None]:
long_score, long_p, long_recall = valid_scores['long_score']
short_score, short_p, short_recall = valid_scores['short_score']
overall_score = valid_scores['overall_score']
print('validation scores:')
print(f'\tlong score    : {long_score:.4f}')
print(f'\tlong precision    : {long_p:.4f}')
print(f'\tlong_recall    : {long_recall:.4f}')
print(f'\tshort score   : {short_score:.4f}')
print(f'\toverall score : {overall_score:.4f}')
print(f'all process done in {(time.time() - init_start_time) / 3600:.1f} hours.')

In [None]:
valid_dataset = Subset(valid_data, range(len(valid_data)))
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=eval_collate_fn)
valid_scores = eval_model(model, valid_loader, device=device)

In [None]:
long_score, long_p, long_recall = valid_scores['long_score']
short_score, short_p, short_recall = valid_scores['short_score']
overall_score = valid_scores['overall_score']
print('validation scores:')
print(f'\tlong score    : {long_score:.4f}')
print(f'\tlong precision    : {long_p:.4f}')
print(f'\tlong_recall    : {long_recall:.4f}')
print(f'\tshort score   : {short_score:.4f}')
print(f'\toverall score : {overall_score:.4f}')
print(f'all process done in {(time.time() - init_start_time) / 3600:.1f} hours.')