## Model training
Notebook to test model training

### Imports

In [8]:
from reddit.utils import (load_tfrecord, pad_and_stack,
                          split_dataset)
from reddit.models import BatchTransformer
from reddit.losses import TripletLossBase
from reddit.training import Trainer
from transformers import TFDistilBertModel
import glob
from pathlib import Path

In [9]:
METRICS_PATH = Path('..') / 'logs' / 'sample_output'
METRICS_PATH.mkdir(parents=True, exist_ok=True)

### Strategy

In [None]:
gpus = tf.config.list_physical_devices('GPU')
print("Num GPUs Available: ", len(gpus))

In [None]:
  try:
    tf.config.experimental.set_visible_devices(gpus[:2], 'GPU')
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
  except RuntimeError as e:
    print(e)

In [None]:
strategy = tf.distribute.MirroredStrategy(devices=logical_gpus)

### Dataset
Load dataset, pad to desired length, batch and distribute

In [2]:
ds_params = {'n_anchor': 20,
             'n_pos': 1,
             'n_neg': 1,
             'batch_size': 8}

In [4]:
fs = glob.glob('../reddit/data/datasets/triplet/*')
ds = load_tfrecord(fs)
ds = pad_and_stack(ds, pad_to=[ds_params['n_anchor'], 
                               ds_params['n_pos'], 
                               ds_params['n_neg']]).batch(ds_params['batch_size'])

In [6]:
ds_train, ds_val, ds_test = split_dataset(ds, 
                                          size=1000,
                                          perc_train=.1, 
                                          perc_val=.01,
                                          perc_test=.02)

In [None]:
ds_train_distributed = strategy.experimental_distribute_dataset(ds_train)
ds_test_distributed = strategy.experimental_distribute_dataset(ds_test)

### Initialize training parametes

In [None]:
train_params = {'weights': 'distilbert-base-uncased',
                'model': TFDistilBertModel,
                'optimizer_learning_rate': 2e-5,
                'optimizer_n_train_steps': 100,
                'optimizer_n_warmup_steps': 10,
                'loss_margin': 1,
                'n_epochs': 1,
                'steps_per_epoch': 100,
                'train_vars': ['losses','metrics', 
                               'dist_pos', 'dist_neg', 
                               'dist_anchor'],
                'test_vars': ['test_losses', 'test_metrics',
                              'test_dist_pos', 'test_dist_neg',
                              'test_dist_anchor'],
                'log_every': 50}

### Initialize optimizer, model, loss, and trainer object

In [None]:
optimizer = create_optimizer(train_params['optimizer_learning_rate'],
                             num_train_steps=train_params['optimizer_n_train_steps'], 
                             num_warmup_steps=train_params['optimizer_n_warmup_steps'])
model = BatchTransformer(train_params['model'], 
                         train_params['weights'])
loss = TripletLossBase(train_params['loss_margin'],
                       n_pos=ds_params['n_pos'],
                       n_neg=ds_params['n_neg'])

In [None]:
trainer = Trainer(model,
                  loss,
                  optimizer,
                  strategy=strategy, 
                  n_epochs=train_params['n_epoch'], 
                  steps_per_epoch=train_params['steps_per_epoch'], 
                  log_every=train_params['log_every'],
                  train_vars=train_params['train_vars'], 
                  test_vars=train_params['test_vars'], 
                  log_path=str(METRICS_PATH),
                  checkpoint_device=None,
                  distributed=True)

### Train!

In [None]:
trainer.train(dataset_train=ds_train_distributed, 
              dataset_test=ds_test_distributed)