In [None]:
# os.environ["WANDB_DISABLED"] = "true"

In [None]:
# import logging
# log = logging.getLogger()
# log.handlers.clear()
# log.addHandler(logging.StreamHandler())
# log.setLevel(logging.WARNING)

In [None]:
from pathlib import Path

import datasets
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from datasets import Dataset
from torch.utils.data import DataLoader

# from medcat.cat import CAT
# from foresight.models.lucid_transformers import LucidLM2HF
from transformers import SchedulerType, Trainer, TrainingArguments

# from medcat.cdb import CDB
from foresight.datasets.data_collator import CollataAndPad
from foresight.datasets.data_collator_v2 import (
    DataCollatorForLanguageModelingMaskStaticVariables,
)
from foresight.metrics.next_concept_prediction import (
    ComputePrecisionHF,
    metrics_data2df,
    precision,
)
from foresight.models.custom_GPT2 import CustomGPT2Config, CustomGPT2LMHeadModel
from foresight.tokenizers import PreTrainedTokenizerFastWithPositionIDPadding
from foresight.tokenizers.simple_map_tokenizer import SimpleMapTokenizer
from foresight.utils import pickle

In [None]:
import time

import datasets
from torch.utils.data import DataLoader

In [None]:
OUTPUT_DIR = Path.cwd() / "outputs"
SAVE_TOKENIZER_PATH = OUTPUT_DIR / "tokenizer"
SAVE_ENCODED_DATASET_PATH = OUTPUT_DIR / "encoded_dataset"
MODEL_LOGS_DIR = OUTPUT_DIR / "model_logs" / time.strftime("%Y_%m_%d_%H_%M_%S")
FINAL_MODEL_DIR = MODEL_LOGS_DIR / "final_model"


NUM_STATIC_VARIABLES = 4

In [None]:
encoded_dataset = datasets.load_from_disk(SAVE_ENCODED_DATASET_PATH)
encoded_dataset

In [None]:
tokenizer = PreTrainedTokenizerFastWithPositionIDPadding.from_pretrained(
    SAVE_TOKENIZER_PATH
)
training_data_collator = DataCollatorForLanguageModelingMaskStaticVariables(
    tokenizer=tokenizer, mlm=False, num_static_variables=NUM_STATIC_VARIABLES
)

In [None]:
dataset_train = DataLoader(
    encoded_dataset["train"],
    batch_size=1000,
    shuffle=False,
    collate_fn=training_data_collator,
)

# Create GPT2

In [None]:
# Make a new model
config = CustomGPT2Config(
    vocab_size=tokenizer.vocab_size,
    n_positions=100,
    n_embd=16,
    n_layer=4,
    n_head=4,
    bos_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.pad_token_id,
    sep_token_id=tokenizer.sep_token_id,
)
model = CustomGPT2LMHeadModel(config)
model.generation_config.max_length = 100
model.generation_config.pad_token_id = tokenizer.pad_token_id
model.generation_config

In [None]:
model(**next(iter(dataset_train))).logits.shape

# Trainer

In [None]:
# compute_metrics = ComputePrecisionHF(tokenizer.id2tkn,
#                                      prediction_scope='time_range',
#                                      topk=1,
#                                      start=0,
#                                      return_all_metrics=False,
#                                      batch_size=1000,
#                                      select_token_types=all_types,
#                                      type_data=test_set_to_use['token_type'],
#                                      token_type2tokens=tokenizer.token_type2tokens,
#                                      time_data=test_set_to_use['time'],
#                                      time_range=30*24*60*60,
#                                      ignore_label_status=False,
#                                      min_time_left=24*60*60)

In [None]:
from foresight.metrics.timeline import TimelineMetrics

timeline_metrics = TimelineMetrics(tokenizer)
compute_metrics = lambda eval_preds: timeline_metrics.compute_micro_precision_recall_f1(
    eval_preds
)

In [None]:
MODEL_LOGS_DIR.mkdir(parents=True, exist_ok=True)

training_args = TrainingArguments(
    output_dir=MODEL_LOGS_DIR,  # output directory
    num_train_epochs=20,  # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=16,  # batch size for evaluation
    # weight_decay=1e-2,               # strength of weight decay
    # logging_dir='./logs',            # directory for storing logs
    # warmup_ratio=0.01,
    learning_rate=2e-03,
    # eval_accumulation_steps=1,
    # gradient_accumulation_steps=16,
    do_eval=True,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    metric_for_best_model="eval_loss",
    warmup_ratio=0.1,
    load_best_model_at_end=True,
    # lr_scheduler_type=SchedulerType.LINEAR,
    # use_cpu=True
)

In [None]:
# import wandb

In [None]:
# wandb.init(project='timecat', entity='wish', name=RUN_NAME + '-gpt-16-16_1day_no_base_data')

In [None]:
trainer = Trainer(
    model=model,  # the instantiated 🤗 Transformers model to be trained
    args=training_args,  # training arguments, defined above
    train_dataset=encoded_dataset["train"],  # training dataset
    eval_dataset=encoded_dataset["test"],  # evaluation dataset
    # compute_metrics=compute_metrics,
    data_collator=training_data_collator,
    # prediction_loss_only=True
    # tokenizer=None,
)

In [None]:
# trainer.train()
# trainer.save_model(FINAL_MODEL_DIR)

In [None]:
# model = CustomGPT2LMHeadModel.from_pretrained(FINAL_MODEL_DIR)
model = CustomGPT2LMHeadModel.from_pretrained(
    "./outputs/model_logs/2024_02_07_16_18_57/final_model"
)
# model.to("cuda")
print(model)

In [None]:
training_data_collator()

In [None]:
sample = {k: torch.tensor([v]) for k, v in encoded_dataset["test"][0].items()}

In [None]:
from transformers import DataCollatorWithPadding

tokenizer.padding_side = "left"
inference_data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
batch = inference_data_collator(
    encoded_dataset["test"][:2],
)
print([(k, v.shape) for k, v in batch.items()])

In [None]:
batch["input_ids"]

In [None]:
sep_mask = (batch["input_ids"] == tokenizer.sep_token_id).long()
sep_count = sep_mask.cumsum(-1) - sep_mask
sep_count[batch["input_ids"] == tokenizer.sep_token_id]

In [None]:
output_tokens = model.generate(**batch)
output_tokens

In [None]:
tokenizer.convert_ids_to_tokens(output_tokens[0])

In [None]:
pred_labels = torch.argmax(logits.logits, dim=-1)
pred_labels.shape
pred_labels

In [None]:
pred_labels[3]

In [None]:
tokenizer.decode([0, 1, 2, 3])

In [None]:
model(**(encoded_dataset["test"][0]))

# Hyperparameter search

In [None]:
import ray
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import PopulationBasedTraining

In [None]:
compute_metrics = ComputePrecisionHF(
    id2tkn, id2type, prediction_scope="age", topk=1, start=0, batch_size=2000
)

In [None]:
NUM_TRIALS = 20
N_GPU_PER_TRIAL = 1
METRIC_TO_OPTIMIZE = "eval_precision"

In [None]:
def get_model(params):
    torch.cuda.empty_cache()
    if params is None:
        params = {}

    config = GPT2Config(
        vocab_size=len(embeddings),
        n_positions=MAX_SEQ_LEN + 1,
        n_ctx=MAX_SEQ_LEN + 1,
        n_embd=params.get("n_embd", 300),
        n_layer=params.get("n_layer", 1),
        n_head=params.get("n_head", 1),
        bos_token_id=tkn2id["<PAD>"],
        eos_token_id=tkn2id["<PAD>"],
    )
    model = GPT2LMHeadModel(config)

    if params.get("load_weights", 0):
        model.transformer.wte.load_state_dict(
            {"weight": torch.tensor(embeddings, dtype=torch.float32)}
        )
        model.transformer.wte.weight.requires_grad = True

    return model

In [None]:
training_args = TrainingArguments(
    output_dir="./results",  # output directory
    num_train_epochs=5,  # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=128,  # batch size for evaluation
    weight_decay=0.01,  # strength of weight decay
    logging_dir="./logs",  # directory for storing logs
    logging_steps=200,
    eval_steps=200,
    learning_rate=5e-5,
    eval_accumulation_steps=1,
    do_eval=True,
    evaluation_strategy="steps",
    skip_memory_metrics=True,
)

In [None]:
training_args.n_head = 1
training_args.n_layer = 1
training_args.n_embd = 300
training_args.load_weights = 0

In [None]:
tune_dataset = encoded_dataset["train"].train_test_split(test_size=0.1)

In [None]:
tune_train_dataset = tune_dataset["train"]
tune_test_dataset = tune_dataset["test"]

In [None]:
trainer = Trainer(
    #    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,  # training arguments, defined above
    train_dataset=tune_train_dataset,  # training dataset
    eval_dataset=tune_test_dataset,  # evaluation dataset
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
    tokenizer=None,
    model_init=get_model,
)

In [None]:
tune_config = {
    "num_train_epochs": tune.choice([5]),
    "n_head": tune.choice([2, 4, 6]),
}
scheduler = PopulationBasedTraining(
    time_attr="training_iteration",
    metric=METRIC_TO_OPTIMIZE,
    mode="max",
    perturbation_interval=1,
    hyperparam_mutations={
        "weight_decay": tune.uniform(0.0, 0.3),
        "learning_rate": tune.uniform(1e-5, 5e-5),
        "per_device_train_batch_size": [16, 32, 64, 128],
        "n_layer": tune.choice([2, 4, 6, 8]),
        #       "n_embd": tune.choice([256, 512]),
        "load_weights": tune.choice([0, 1]),
        "warmup_steps": tune.choice([20, 40, 60, 100]),
    },
)

In [None]:
import copy


def compute_objective(metrics):
    metrics = copy.deepcopy(metrics)
    eval_precision = metrics.pop("eval_precision")

    return eval_precision

In [None]:
best_model = trainer.hyperparameter_search(
    hp_space=lambda _: tune_config,
    backend="ray",
    n_trials=NUM_TRIALS,
    direction="maximize",
    compute_objective=compute_objective,
    resources_per_trial={"cpu": 1, "gpu": N_GPU_PER_TRIAL},
    scheduler=scheduler,
    keep_checkpoints_num=1,
    checkpoint_score_attr=METRIC_TO_OPTIMIZE,
    stop=None,
    local_dir=RESULTS_HYPERPARAM,
    name="21_May_2021",
    log_to_file=False,
    loggers=None,  # (WandbLogger, ),
)

In [None]:
best_model

# Saliency 

In [None]:
import ecco

In [None]:
lm = ecco.LM(trainer.model, tokenizer, model_name="gpt2")

In [None]:
ind = 49
print(
    "~~".join(
        [tokenizer.id2tkn[id] for id in encoded_dataset["test"][ind]["input_ids"]]
    )
)
text = "~~".join(
    [tokenizer.id2tkn[id] for id in encoded_dataset["test"][ind]["input_ids"][1:-1]]
)

In [None]:
output = lm.generate(text, generate=10, do_sample=True, temperature=1)

In [None]:
output.saliency(style="detailed")

# Probability prediction

In [None]:
from foresight.sight import Sight

In [None]:
_ = model.eval()

In [None]:
sight = Sight(tokenizer=tokenizer, device="cuda", model=model)

In [None]:
cdb.name2cuis["muscle~pain"]

In [None]:
cdb.get_name("pain")

In [None]:
text = "<ETHNICITY>~~White~~<SEX>~~Male~~<AGE>~~23~~49727002~~386661006".split("~~")

In [None]:
# Small with WD
r = sight.next_concepts(
    text, type_ids=["T-11"], n=40, p_new=True, create_position_ids=False
)
print([cdb.get_name(x) for x in text])
for x in r:
    print(x[0], x[1], cdb.get_name(x[0]))