In [12]:
from afa_rl.models import ShimEmbedderClassifier, ShimEmbedder
from afa_rl.afa_env import AFAMDP
from afa_rl.models import ReadProcessEncoder, MLPClassifier
from afa_rl.agents import ShimQAgent
from afa_rl.datasets import CubeDataset, get_dataset_fn
import torch
from torchrl.collectors import SyncDataCollector
from tqdm import tqdm

In [None]:
# Load a pretrained embedder and classifier
from torch import nn

checkpoint = torch.load(
    "../checkpoints/shim_embedder_classifier.ckpt", weights_only=True
)

n_features = 20

embedder_and_classifier = ShimEmbedderClassifier(
    embedder=ShimEmbedder(
        encoder=ReadProcessEncoder(
            feature_size=n_features + 1,  # state contains one value and one index
            output_size=16,
            reading_block_cells=[32, 32],
            writing_block_cells=[32, 32],
            memory_size=16,
            processing_steps=5,
        ),
    ),
    classifier=MLPClassifier(16, 8, [32, 32]),
    lr=1e-4,
)
embedder_and_classifier.load_state_dict(checkpoint["state_dict"])
device = torch.device("cuda")
embedder_and_classifier = embedder_and_classifier.to(device)

In [3]:
# Prepare cube dataset for the format that AFAMDP expects
dataset = CubeDataset(n_features=n_features, data_points=100_000, sigma=0.01, seed=42)
dataset_fn = get_dataset_fn(dataset.features, dataset.labels)

In [4]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = AFAMDP(
    dataset_fn=dataset_fn,
    embedder=embedder_and_classifier.embedder,
    task_model=embedder_and_classifier.classifier,
    loss_fn=nn.CrossEntropyLoss(reduction="none"),
    acquisition_costs=torch.ones((n_features,), dtype=torch.float32, device=device)
    / n_features,
    device=device,
    batch_size=torch.Size((4,)),
)

In [5]:
td = env.reset()
# print(td)
td = env.rand_step(td)
# print(td)
# td = env.rollout(10)
# print(td)

In [6]:
agent = ShimQAgent(
    embedding_size=16,
    action_spec=env.action_spec,
    lr=1e-3,
    update_tau=1e-3,
    eps=0.1,
    device=device,
)

In [7]:
td = env.reset()
td = agent.policy(td)
td = env.step(td)
print(td)

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([4]), device=cuda:0, dtype=torch.int64, is_shared=True),
        action_value: Tensor(shape=torch.Size([4, 21]), device=cuda:0, dtype=torch.float32, is_shared=True),
        all_features: Tensor(shape=torch.Size([4, 20]), device=cuda:0, dtype=torch.float32, is_shared=True),
        chosen_action_value: Tensor(shape=torch.Size([4, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
        done: Tensor(shape=torch.Size([4, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
        embedding: Tensor(shape=torch.Size([4, 16]), device=cuda:0, dtype=torch.float32, is_shared=True),
        fa_reward: Tensor(shape=torch.Size([4, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
        feature_indices: Tensor(shape=torch.Size([4, 20]), device=cuda:0, dtype=torch.bool, is_shared=True),
        feature_values: Tensor(shape=torch.Size([4, 20]), device=cuda:0, dtype=torch.float32, is_shared=True),
        label: Te

In [8]:
collector = SyncDataCollector(
    env, agent.policy, frames_per_batch=2, total_frames=1_000, device=device
)



In [13]:
for batch in tqdm(collector):
    agent.optim.zero_grad()

    loss = agent.loss_module(batch)
    loss["loss"].backward()

    # Clip gradients?

    agent.optim.step()

    # Update target network
    agent.updater.step()

  0%|          | 0/250 [00:00<?, ?it/s]
