In [1]:
import torch
assert torch.__version__>='1.2.0', 'Expect PyTorch>=1.2.0 but get {}'.format(torch.__version__)
from torch import nn
import torch.nn.functional as F

import numpy as np

import time
import os
import sys
import pickle

imp_dir = '../Implementations'
sys.path.insert(1, imp_dir)
data_dir = '../Data/criteo'
sys.path.insert(1, data_dir)

import logging
import importlib
importlib.reload(logging)

log_path = 'DCN_notebook.log'
if os.path.isfile(log_path): os.remove(log_path)
    
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

formatter = logging.Formatter('%(asctime)s %(levelname)-s: %(message)s', datefmt='%H:%M:%S')

fh = logging.FileHandler(log_path)
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
logger.addHandler(fh)

sh = logging.StreamHandler(sys.stdout)
sh.setLevel(logging.INFO)
sh.setFormatter(formatter)
logger.addHandler(sh)

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info('Device in Use: {}'.format(DEVICE))
torch.cuda.empty_cache()
t = torch.cuda.get_device_properties(DEVICE).total_memory/1024**3
c = torch.cuda.memory_cached(DEVICE)/1024**3
a = torch.cuda.memory_allocated(DEVICE)/1024**3
logger.info('CUDA Memory: Total {:.2f} GB, Cached {:.2f} GB, Allocated {:.2f} GB'.format(t,c,a))

20:32:54 INFO: Device in Use: cuda
20:32:54 INFO: CUDA Memory: Total 11.17 GB, Cached 0.00 GB, Allocated 0.00 GB


## Prepare Data

### Load Dict

In [3]:
embedding_map_dict_pkl_path = os.path.join(data_dir, 'criteo_feature_dict_artifact/categorical_feature_map_dict.pkl')
with open(embedding_map_dict_pkl_path, 'rb') as f:
    embedding_map_dict = pickle.load(f)

### List all available files

In [4]:
np_artifact_dir = os.path.join(data_dir, 'criteo_train_numpy_artifact')
index_artifact = sorted(list(filter(lambda x: x.split('-')[1]=='index', os.listdir(np_artifact_dir))), key = lambda x: int(x.split('.')[0].split('-')[-1]))
value_artifact = sorted(list(filter(lambda x: x.split('-')[1]=='value', os.listdir(np_artifact_dir))), key = lambda x: int(x.split('.')[0].split('-')[-1]))
label_artifact = sorted(list(filter(lambda x: x.split('-')[1]=='label', os.listdir(np_artifact_dir))), key = lambda x: int(x.split('.')[0].split('-')[-1]))

In [5]:
start = time.time()

train_data = (
    np.vstack([np.load(os.path.join(np_artifact_dir, f)) for f in index_artifact[:10]]),
    np.vstack([np.load(os.path.join(np_artifact_dir, f)) for f in value_artifact[:10]]),
    np.vstack([np.load(os.path.join(np_artifact_dir, f)) for f in label_artifact[:10]]),
)

logger.info('Training data loaded after {:.2f}s'.format(time.time()-start))

start = time.time()

test_data = (
    np.vstack([np.load(os.path.join(np_artifact_dir, f)) for f in index_artifact[10:]]),
    np.vstack([np.load(os.path.join(np_artifact_dir, f)) for f in value_artifact[10:]]),
    np.vstack([np.load(os.path.join(np_artifact_dir, f)) for f in label_artifact[10:]]),
)

logger.info('Test data loaded after {:.2f}s'.format(time.time()-start))

21:33:32 INFO: Training data loaded after 3634.80s
21:33:41 INFO: Test data loaded after 8.91s


## Create Model

In [6]:
import execution
import DCN_BinClf_Torch

<module 'DCN_BinClf_Torch' from '../Implementations/DCN_BinClf_Torch.py'>

In [7]:
DCN = DCN_BinClf_Torch.DCN_Layer(len(embedding_map_dict)+60,
                                 20,
                                 26,
                                 13,
                                 [1024 for _ in range(2)],
                                 [0.5 for _ in range(3)],
                                 6).to(DEVICE)

## Sample Run

In [8]:
cwd = os.getcwd()
checkpoint_dir = os.path.join(cwd, 'DCN_artifact')
checkpoint_prefix = 'DCN'

In [None]:
execution.train_model_separate_inp(DCN, 
                                   train_data, 
                                   test_data, 
                                   F.binary_cross_entropy_with_logits, 
                                   torch.optim.Adam(DCN.parameters()), 
                                   DEVICE, 
                                   checkpoint_dir, 
                                   checkpoint_prefix,
                                   logger=logger
                                  )