# **SET UP**

## Import

In [1]:
import os
import numpy as np
from openrec.tf1.legacy import ImplicitModelTrainer
from openrec.tf1.legacy.utils import ImplicitDataset
from openrec.tf1.legacy.recommenders import CML, BPR
from openrec.tf1.legacy.utils.evaluators import AUC
from openrec.tf1.legacy.utils.samplers import PairwiseSampler


## Init

In [2]:
seed = 76424236
np.random.seed(seed=seed)

folder_name = f"./Dataset/"

if os.path.exists(folder_name) == False:
    os.makedirs(folder_name)

## Load training dataset

In [3]:
raw_data = dict()
raw_data['train_data'] = np.load(folder_name + "training_arr.npy")
raw_data['val_data'] = np.load(folder_name + "validation_arr.npy")
raw_data['max_user'] = 15401
raw_data['max_item'] = 1001
batch_size = 8000
test_batch_size = 1000
display_itr = 1000

train_dataset = ImplicitDataset(raw_data['train_data'], raw_data['max_user'], raw_data['max_item'], name='Train')
val_dataset = ImplicitDataset(raw_data['val_data'], raw_data['max_user'], raw_data['max_item'], name='Val')

# **TRAIN MODEL**

## Define model

In [None]:
# Avoid tensorflow using cached embeddings
import tensorflow as tf
tf.compat.v1.reset_default_graph()


cml_model = CML(batch_size=batch_size, max_user=train_dataset.max_user(), max_item=train_dataset.max_item(), 
    dim_embed=50, l2_reg=0.001, opt='Adam', sess_config=None)
sampler = PairwiseSampler(batch_size=batch_size, dataset=train_dataset, num_process=4)
model_trainer = ImplicitModelTrainer(batch_size=batch_size, test_batch_size=test_batch_size,
                                     train_dataset=train_dataset, model=cml_model, sampler=sampler,
                                     eval_save_prefix=folder_name + "yahoo",
                                     item_serving_size=500)
auc_evaluator = AUC()

## Train Model

In [None]:
model_trainer.train(num_itr=10001, display_itr=display_itr, eval_datasets=[val_dataset],
                    evaluators=[auc_evaluator], num_negatives=200)

# **SAVE MODEL**

In [None]:
cml_model.save(folder_name + "cml-yahoo",None)