In [None]:
import wandb
from train import DEFAULT_CONFIG, TRANSFORMERS_LIB, train, test
from data.datasets import get_data_loaders
from utils.utils import seed_everything
from datetime import datetime

In [None]:
WANDB_CONFIG = {
  "entity": "j-getzner",
  "project": "Trends & Innovations Classifier",
  "disabled": False
}

In [None]:
model_name = "distilbert-base-uncased"

In [None]:
current_time = datetime.strftime(datetime.now(), format="%Y.%m.%d-%H:%M:%S")
current_config = DEFAULT_CONFIG.copy()
current_config["seed"] = current_config["initial_seed"]
if model_name:
    current_config["model_name"] = model_name
for seed in range(current_config["num_seeds"]):
    # seed
    current_config["seed"] = current_config["initial_seed"] + seed
    seed_everything(current_config["seed"])
    # change save model dir
    current_config["save_model_dir"] = f"{current_config['save_model_dir']}/{current_config['model_name']}/seed_{current_config['seed']}"
    # init model
    current_model = TRANSFORMERS_LIB[current_config["model_name"]].from_pretrained(
        current_config["model_name"],
        num_labels=current_config["num_labels"]
    ).to(current_config["device"])

    # load data
    data_loaders = get_data_loaders(current_config)

    wandb.init(
        entity=WANDB_CONFIG["entity"],
        project=WANDB_CONFIG["project"],
        config=current_config,
        mode="disabled" if WANDB_CONFIG["disabled"] else "online",
        group=f"{current_time}-{current_config['model_name']}",
        job_type="train",
        name="seed_"+str(current_config["seed"])
    )

    train(current_model, data_loaders["train"], data_loaders["val"], current_config)
    test(current_model, data_loaders["test"], current_config)