In [1]:
# edited by Dongyu Zhang
from os import makedirs
from os.path import join, basename
import logging
import numpy as np
import torch
import random
from args import define_new_main_parser
import json

from transformers import Trainer, TrainingArguments, EarlyStoppingCallback

from dataset.ibm import IBMDataset
from dataset.ibm_time_static import IBMWithTimePosAndStaticSplitDataset
from dataset.ibm_time_pos import IBMWithTimePosDataset
from dataset.ibm_static import IBMWithStaticSplitDataset
from models.modules import TabFormerBertLM, TabFormerBertForClassification, TabFormerBertModel, TabStaticFormerBert, \
    TabStaticFormerBertLM, TabStaticFormerBertClassification
from misc.utils import ordered_split_dataset, compute_cls_metrics
from dataset.datacollator import *
from main_ibm import main

logger = logging.getLogger(__name__)
log = logger
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)

import os
os.environ["WANDB_DISABLED"] = "true"

In [7]:
data="credit_card"
dt="IBM"
time_pos_type="regular_position"
fextension=False
fname="card_transaction_train"
val_fname="card_transaction_val"
test_fname="card_transaction_test"
preload_fextension="preload-test"
bs=32
nb=10
save_steps=1000
eval_steps=1000
resample_method="downsample"
resample_ratio=10
resample_seed=100
external_val=False
output_dir=""
checkpoint=None

In [14]:
arg_str = f"--do_train \
    --do_eval \
    --cls_task \
    --long_and_sort \
    --pad_seq_first \
    --get_rids \
    --field_ce \
    --lm_type bert \
    --field_hs 64 \
    --data_type {dt} \
    --stride 5 \
    --data_root ./data/{data}/ \
    --train_batch_size {bs} \
    --eval_batch_size {bs} \
    --save_steps {save_steps} \
    --eval_steps {eval_steps} \
    --nbatches {nb} \
    --data_fname {fname} \
    --data_val_fname {val_fname} \
    --data_test_fname {test_fname} \
    --vocab_cached \
    --user_level_cached \
    --preload_fextension {preload_fextension} \
    --output_dir {output_dir} \
    --time_pos_type {time_pos_type} \
    --resample_ratio {resample_ratio} \
    --resample_seed {resample_seed} \
    "
if fextension:
    arg_str += f"--fextension {fextension} \
    --external_vocab_path ./data/{data}/vocab_ob_{fextension}"
else:
    arg_str += f"--external_vocab_path ./data/{data}/vocab_ob"
if resample_method is not None:
    arg_str += f"\
    --resample_method {resample_method}"
if external_val:
    arg_str += f"\
    --external_val"
if checkpoint is not None:
    arg_str += f"\
    --checkpoint {checkpoint}"

In [15]:
parser = define_new_main_parser(data_type_choices=["IBM", "IBM_time_pos", "IBM_time_static", "IBM_static"])
opts = parser.parse_args(arg_str.split())

In [None]:
opts.log_dir = join(opts.output_dir, "logs")
makedirs(opts.output_dir, exist_ok=True)
makedirs(opts.log_dir, exist_ok=True)

file_handler = logging.FileHandler(
    join(opts.log_dir, 'output.log'), 'w', 'utf-8')
logger.addHandler(file_handler)

opts.cls_exp_task = opts.cls_task or opts.export_task

if opts.data_type in ["IBM_time_pos", "IBM_time_static"]:
    assert opts.time_pos_type == 'time_aware_sin_cos_position'
elif opts.data_type in ["IBM", "IBM_static"]:
    assert opts.time_pos_type in ['sin_cos_position', 'regular_position']

if opts.mlm and opts.lm_type == "gpt2":
    raise Exception(
        "Error: GPT2 doesn't need '--mlm' option. Please re-run with this flag removed.")

if (not opts.mlm) and (not opts.cls_exp_task) and opts.lm_type == "bert":
    raise Exception(
        "Error: Bert needs either '--mlm', '--cls_task' or '--export_task' option. Please re-run with this flag "
        "included.")

main(opts)