In [1]:
from pathlib import Path

import numpy as np
import pandas as pd

import torch
from transformers import LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM

from dicl import dicl
%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from huggingface_hub import login
login(token="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")

In [None]:
llm_model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
#llm_model = "Qwen/Qwen2-1.5B-Instruct"
#llm_model = "gpt2"
#llm_model = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-3.2-1B",
    use_fast=True,
)
print(tokenizer)
model = AutoModelForCausalLM.from_pretrained(
    llm_model,
    device_map="auto",
)
model.eval()

PreTrainedTokenizerFast(name_or_path='meta-llama/Llama-3.2-1B', vocab_size=128000, model_max_length=131072, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|begin_of_text|>', 'eos_token': '<|end_of_text|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	128000: AddedToken("<|begin_of_text|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128001: AddedToken("<|end_of_text|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128002: AddedToken("<|reserved_special_token_0|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128003: AddedToken("<|reserved_special_token_1|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128004: AddedToken("<|finetune_right_pad_id|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128005: AddedToken("<|reserved_special_token_2|>", rst

Downloading shards: 100%|██████████| 8/8 [04:18<00:00, 32.28s/it]
Loading checkpoint shards:  50%|█████     | 4/8 [00:27<00:27,  6.86s/it]

In [None]:
env_name = "HalfCheetah"
n_actions = 6  # number of actions in the HalfCheetah system
n_observations = 17  # number of observations in the HalfCheetah system
data_label = "expert"
data_path = Path("src") / "dicl" / "data" / f"D4RL_{env_name}_{data_label}.csv"

# ICL parameters
context_length = 300
rescale_factor = 7.0
up_shift = 1.5

In [None]:
# to use DICL-(s) or vICL, set include_actions to False.
# to use DICL-(s,a), set include_actions to True
include_actions = True
if include_actions:
    n_features = n_observations + n_actions
else:
    n_features = n_observations

In [None]:
# load data to get a sample episode
X = pd.read_csv(data_path, index_col=0)
X = X.values.astype("float")

# find episodes beginnings. the restart column is equal to 1 at the start of
# an episode, 0 otherwise.
restart_index = n_observations + n_actions + 1
restarts = X[:, restart_index]
episode_starts = np.where(restarts)[0]

# sample an episode and extract time series
episode = np.random.choice(episode_starts)
time_series = X[episode : episode + context_length, :n_features]

print(
    f"The in-context trajectory shape is (n_timestamps:{time_series.shape[0]}, "
    f"n_features:{time_series.shape[1]})"
)

In [None]:
# to use vICL, set vanilla_icl to True.
# to use DICL-(s,a) or DICL-(s), set vanilla_icl to False
vanilla_icl = False
if vanilla_icl:
    DICL = dicl.vICL(
        n_features=n_features,
        model=model,
        tokenizer=tokenizer,
        rescale_factor=rescale_factor,
        up_shift=up_shift,
    )
else:
    DICL = dicl.DICL_PCA(
        n_features=n_features,
        n_components=int(n_features / 2),
        model=model,
        tokenizer=tokenizer,
        rescale_factor=rescale_factor,
        up_shift=up_shift,
    )

In [None]:
DICL.fit_disentangler(X=time_series)

In [None]:
mean, mode, lb, ub = DICL.predict_single_step(X=time_series)


In [None]:
DICL.__dict__.keys()

In [None]:
len(DICL.icl_object[3].PDF_list)

In [None]:


# print metrics
burnin = 0
single_step_metrics = DICL.compute_metrics(burnin=burnin)
print(f"after burnin steps: {burnin}")
print(
    "average_agg_squared_error: "
    f"{single_step_metrics['average_agg_squared_error']:.3f}"
)
print(f"agg_ks: {single_step_metrics['agg_ks']}")

In [None]:
DICL.plot_single_step(xlim=[0, context_length])

In [None]:
DICL.plot_calibration(burnin=0)

In [None]:
prediction_horizon = 20
mean, mode, lb, ub = DICL.predict_multi_step(
    X=time_series,
    prediction_horizon=prediction_horizon,
    stochastic=True,
    if_true_mean_else_mode=True,
)

In [None]:
# print metrics
burnin = context_length - prediction_horizon  # only keeps multi-step prediction
single_step_metrics = DICL.compute_metrics(burnin=burnin)
print(f"after burnin steps: {burnin}")
print(
    "average_agg_squared_error: "
    f"{single_step_metrics['average_agg_squared_error']:.3f}"
)
print(f"agg_ks: {single_step_metrics['agg_ks']}")

In [None]:
DICL.plot_multi_step(xlim=[context_length - 50, context_length])