# RLCT estimation

In [1]:
import torch

from devinterp.slt.sampler import Sampler, SamplerConfig
from devinterp.evals import SamplerEvaluator
from devinterp.ops.storage import CheckpointerConfig
from devinterp.optim.optimizers import OptimizerConfig

from icl.config import get_config
from icl.evals import ICLEvaluator
from icl.utils import set_seed




In [2]:
config = get_config(task_config={"num_tasks": 64})

import yaml
print(yaml.dump(config.model_dump()))

batch_size: 256
checkpointer_config:
  bucket_name: devinterp
  device: cpu
  local_root: null
  project_dir: icl/ntasks-64-task-a0acbd-opt-f7c569-sched-c4766a
criterion: cross_entropy
device: cpu
eval_batch_size: 2048
logger_config:
  entity: null
  metrics: null
  out_file: null
  project: null
  run_id: null
  stdout: false
  use_df: false
num_steps: 500000
num_training_samples: 128000000
optimizer_config:
  betas: !!python/tuple
  - 0.9
  - 0.999
  elasticity: null
  lr: 0.001
  momentum: null
  noise_level: null
  num_samples: null
  optimizer_type: Adam
  temperature: null
  weight_decay: 0.0
scheduler_config:
  T_max: null
  anneal_strategy: linear
  cycle_momentum: false
  div_factor: 249999.0
  eta_min: null
  final_div_factor: 249999.0
  gamma: null
  last_epoch: -1
  lr_lambda: null
  max_lr: 0.001
  milestones: null
  pct_start: 0.5
  scheduler_type: OneCycleLR
  step_size: null
  total_steps: 500000
task_config:
  embed_size: 128
  max_examples: 16
  mlp_size: 128
  model_



In [None]:
# Get the checkpoints
checkpointer = config.checkpointer_config.factory().sync()
checkpointer.file_ids = sorted([int(x) for x in checkpointer.file_ids])
checkpointer.file_ids

In [3]:
# initialise model
model = config.task_config.model_factory().to(config.device)

# initialise 'pretraining' data source (for training on fixed task set)
pretrain_dist = config.task_config.pretrain_dist_factory().to(config.device)

# initialise 'true' data source (for evaluation, including unseen tasks)
true_dist = config.task_config.true_dist_factory().to(config.device)

# initialise evaluations
evaluator = ICLEvaluator(
    pretrain_dist=pretrain_dist,
    true_dist=true_dist,
    max_examples=config.task_config.max_examples,
    eval_batch_size=config.eval_batch_size,
)

# Load model
# model.load_state_dict(checkpointer[-1]["model"])

evaluator(model)

{'pretrain/mse': 8.867645263671875,
 'pretrain/delta_dmmse': tensor(7.3911),
 'pretrain/delta_ridge': tensor(5.8186),
 'pretrain/token/0': 8.977218627929688,
 'pretrain/token/1': 8.977317810058594,
 'pretrain/token/2': 8.686665534973145,
 'pretrain/token/3': 8.426115989685059,
 'pretrain/token/4': 8.815361022949219,
 'pretrain/token/5': 8.985779762268066,
 'pretrain/token/6': 8.918869018554688,
 'pretrain/token/7': 9.33285140991211,
 'pretrain/token/8': 8.752054214477539,
 'pretrain/token/9': 8.697751998901367,
 'pretrain/token/10': 8.631465911865234,
 'pretrain/token/11': 8.951107025146484,
 'pretrain/token/12': 8.75244140625,
 'pretrain/token/13': 9.560930252075195,
 'pretrain/token/14': 8.623815536499023,
 'pretrain/token/15': 8.792564392089844,
 'true/mse': 8.72768783569336,
 'true/delta_dmmse': tensor(5.5109),
 'true/delta_ridge': tensor(5.7320),
 'true/token/0': 8.675100326538086,
 'true/token/1': 8.348591804504395,
 'true/token/2': 8.403005599975586,
 'true/token/3': 8.807651519

In [4]:
sampler = SamplerConfig(
    optimizer_config=OptimizerConfig(
        optimizer_type="SGLD",
        lr=1e-5,
        noise_level=1.,
        temperature="adaptive",
        num_samples=len(evaluator.pretrain_xs),
        elasticity=1.,
    ),
    criterion="mse_loss",
    num_burnin_steps=0,
    num_draws_per_chain=100,
    num_steps_bw_draws=100,
    num_chains=5,
).factory(
    model, 
    torch.utils.data.TensorDataset(evaluator.pretrain_xs, evaluator.pretrain_ys),    
)

rlct_evaluator = SamplerEvaluator.create_rlct_evaluator(sampler)
evals = rlct_evaluator(model, None, None)
evals

TypeError: forward() missing 1 required positional argument: 'ys'

In [2]:
from icl.tasks import DiscreteTaskDistribution

dist = DiscreteTaskDistribution(8, 4)
dist.sample_tasks(20).shape

torch.Size([20, 8])

In [7]:
# Importing required modules
import torch
from torch.utils import data
import time

# Defining classes for the two methods

# Method 1: Generate tasks on-the-fly
class Method1:
    def __init__(self, task_size: int, num_tasks: int, device='cuda'):
        self.task_size = task_size
        self.num_tasks = num_tasks
        self.device = device
        self.generator = torch.Generator(device=self.device)

    def sample_task(self, idx: int):
        self.generator.manual_seed(idx)
        return torch.normal(
            mean=0.,
            std=1.,
            size=(self.task_size,),
            generator=self.generator,
            device=self.device,
        )

    def sample_tasks(self, n: int):
        task_selection = torch.randint(
            high=self.num_tasks,
            size=(n,),
            device=self.device,
        )
        return torch.stack([
            self.sample_task(int(i))
            for i in task_selection
        ])

# Method 2: Pre-generate all tasks
class Method2:
    def __init__(self, task_size: int, num_tasks: int, device='cuda'):
        self.task_size = task_size
        self.num_tasks = num_tasks
        self.device = device
        self.tasks = torch.normal(
            mean=0.,
            std=1.,
            size=(self.num_tasks, self.task_size),
            device=self.device,
        )

    def sample_tasks(self, n: int):
        task_selection = torch.randint(
            high=self.num_tasks,
            size=(n,),
            device=self.device,
        )
        return self.tasks[task_selection]


# Save pre-generated tasks to disk
class Method3_Save:
    def __init__(self, task_size: int, num_tasks: int, filename='tasks.pt', device='cpu'):
        self.task_size = task_size
        self.num_tasks = num_tasks
        self.device = device
        self.filename = filename
        self.tasks = torch.normal(
            mean=0.,
            std=1.,
            size=(self.num_tasks, self.task_size),
            device=self.device,
        )
        torch.save(self.tasks, self.filename)

# DataLoader for reading tasks from disk
class DiskTaskDataset(data.Dataset):
    def __init__(self, filename):
        self.tasks = torch.load(filename)

    def __len__(self):
        return len(self.tasks)

    def __getitem__(self, index):
        return self.tasks[index]

# Parameters
task_size = 1000
num_tasks = 10000
n_sample = 128
n_trials = 100
device = 'cpu'

# Initialize objects
method1 = Method1(task_size, num_tasks, device)
method2 = Method2(task_size, num_tasks, device)

# Timing Method 1
start_time = time.time()
for _ in range(n_trials):
    tasks1 = method1.sample_tasks(n_sample)
end_time = time.time()
time_method1 = end_time - start_time

# Timing Method 2
start_time = time.time()
for _ in range(n_trials):
    tasks2 = method2.sample_tasks(n_sample)
end_time = time.time()
time_method2 = end_time - start_time

# Initialize and save tasks to disk
method3_save = Method3_Save(task_size, num_tasks//10)

# Create DataLoader
filename = 'tasks.pt'
batch_size = n_sample  # Number of random samples in one batch

# Timing Method 3
start_time = time.time()
for _ in range(n_trials):
    torch.load(filename)

end_time = time.time()
time_method3 = end_time - start_time

time_method1, time_method2, time_method3


(0.41481709480285645, 0.021686553955078125, 0.13994932174682617)

In [1]:
from typing import List
import wandb

from icl.config import ICLConfig


# Example usage:
from pprint import pp


SWEEP_ID = "devinterp/icl/qoe1mlpn"
api = wandb.Api()
sweep = api.sweep(SWEEP_ID)
sweep.config




{'entity': 'devinterp',
 'method': 'grid',
 'name': 'icl-config-sweep',
 'parameters': {'eval_batch_size': {'value': 2048},
  'task_config': {'parameters': {'embed_size': {'value': 128},
    'max_examples': {'value': 16},
    'mlp_size': {'value': 128},
    'model_seed': {'value': 0},
    'noise_variance': {'value': 0.25},
    'num_heads': {'value': 2},
    'num_layers': {'value': 8},
    'num_tasks': {'values': [1,
      2,
      4,
      8,
      16,
      32,
      64,
      128,
      256,
      512,
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072,
      262144,
      524288,
      1048576]},
    'pretrain_seed': {'value': 1},
    'sampling_seed': {'value': 3},
    'task_size': {'value': 8},
    'true_seed': {'value': 2}}}},
 'program': 'icl',
 'project': 'icl'}