In [None]:
# edited by Dongyu Zhang
from os import makedirs
from os.path import join, basename
import logging
import numpy as np
import torch
import random
from args import define_new_main_parser
import json

from transformers import Trainer, TrainingArguments, EarlyStoppingCallback

from dataset.amazon import AmazonDataset
from dataset.amazon_time_static import AmazonWithTimePosAndStaticSplitDataset
from dataset.amazon_time_pos import AmazonWithTimePosDataset
from dataset.amazon_static import AmazonWithStaticSplitDataset
from models.modules import TabFormerBertLM, TabFormerBertForClassification, TabFormerBertModel, TabStaticFormerBert, \
    TabStaticFormerBertLM, TabStaticFormerBertClassification
from misc.utils import ordered_split_dataset, compute_cls_metrics
from dataset.datacollator_mask_label import *
from main_amazon import main

logger = logging.getLogger(__name__)
log = logger
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)

import os
os.environ["WANDB_DISABLED"] = "true"

In [None]:
data="amazon_movie"
dt="amazon"
time_pos_type="regular_position"
fextension=False
fname="Movies_and_TV_train"
val_fname="Movies_and_TV_val"
test_fname="Movies_and_TV_test"
preload_fextension="preload-test"
bs=32
nb=10
num_train_epochs=3
save_steps=1000
eval_steps=1000
pretrained_dir=""
external_val=True
output_dir=""
checkpoint=None

In [None]:
arg_str = f"--do_train \
    --do_eval \
    --cls_task \
    --long_and_sort \
    --pad_seq_first \
    --get_rids \
    --field_ce \
    --lm_type bert \
    --field_hs 64 \
    --data_type {dt} \
    --stride 1 \
    --data_root ./data/{data}/ \
    --train_batch_size {bs} \
    --eval_batch_size {bs} \
    --save_steps {save_steps} \
    --eval_steps {eval_steps} \
    --nbatches {nb} \
    --num_train_epochs {num_train_epochs} \
    --data_fname {fname} \
    --data_val_fname {val_fname} \
    --data_test_fname {test_fname} \
    --user_level_cached \
    --vocab_cached \
    --preload_fextension {preload_fextension} \
    --pretrained_dir {pretrained_dir} \
    --output_dir {output_dir} \
    --time_pos_type {time_pos_type} \
    "
if fextension:
    arg_str += f"--fextension {fextension} \
    --external_vocab_path ./data/{data}/vocab_ob_{fextension}"
else:
    arg_str += f"--external_vocab_path ./data/{data}/vocab_ob"
if external_val:
    arg_str += f"\
    --external_val"
if checkpoint is not None:
    arg_str += f"\
    --checkpoint {checkpoint}"

In [None]:
parser = define_new_main_parser(data_type_choices=["amazon", "amazon_time_pos", "amazon_time_static", "amazon_static"])
opts = parser.parse_args(arg_str.split())

In [None]:
opts.log_dir = join(opts.output_dir, "logs")
makedirs(opts.output_dir, exist_ok=True)
makedirs(opts.log_dir, exist_ok=True)

file_handler = logging.FileHandler(
    join(opts.log_dir, 'output.log'), 'w', 'utf-8')
logger.addHandler(file_handler)

opts.cls_exp_task = opts.cls_task or opts.export_task

if opts.data_type in ["amazon_time_pos", "amazon_time_static"]:
    assert opts.time_pos_type == 'time_aware_sin_cos_position'
elif opts.data_type in ["amazon", "amazon_static"]:
    assert opts.time_pos_type in ['sin_cos_position', 'regular_position']

if opts.mlm and opts.lm_type == "gpt2":
    raise Exception(
        "Error: GPT2 doesn't need '--mlm' option. Please re-run with this flag removed.")

if (not opts.mlm) and (not opts.cls_exp_task) and opts.lm_type == "bert":
    raise Exception(
        "Error: Bert needs either '--mlm', '--cls_task' or '--export_task' option. Please re-run with this flag "
        "included.")

main(opts)