In [1]:
!pip install pytorch-lightning ray[tune]



In [2]:
!pip install -U ipywidgets

Collecting ipywidgets
  Downloading ipywidgets-8.1.5-py3-none-any.whl.metadata (2.3 kB)
Collecting widgetsnbextension~=4.0.12 (from ipywidgets)
  Downloading widgetsnbextension-4.0.13-py3-none-any.whl.metadata (1.6 kB)
Collecting jupyterlab-widgets~=3.0.12 (from ipywidgets)
  Downloading jupyterlab_widgets-3.0.13-py3-none-any.whl.metadata (4.1 kB)
Downloading ipywidgets-8.1.5-py3-none-any.whl (139 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.8/139.8 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jupyterlab_widgets-3.0.13-py3-none-any.whl (214 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m214.4/214.4 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading widgetsnbextension-4.0.13-py3-none-any.whl (2.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m47.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: widgetsnbextension, jupyterlab-widgets, ipywidge

In [3]:
import os
FILE_NAMES = []
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        FILE_NAMES.append(os.path.join(dirname, filename))

In [4]:
# print(FILE_NAMES)

In [5]:
# import torch
# torch.cuda.is_available()

In [6]:
# !nvidia-smi

In [7]:
import os
from pathlib import Path
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import LearningRateMonitor, EarlyStopping
from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback
from ray import train, tune
from ray.tune import Tuner, with_resources
# from ray_lightning.tune import TuneReportCallback
from ray.train import RunConfig

In [8]:
# Set global parameters
scale_factor = 11.888623072966611
input_dim = 3072
# sample_rate = 0.1 # 10% of data
sample_size = 8192 # file_size * sample_rate
num_epochs = 5  # Only one epoch over all data

# Dataset with Subsampling
class ActivationDataset(Dataset):
    def __init__(self, f_type, test_fraction=0.01, scale_factor=1.0, batch_size=2048, seed=42):
        self.test_fraction = test_fraction
        self.scale_factor = scale_factor
        self.batch_size = batch_size
        self.multi = sample_size // batch_size
        self.seed = seed
        self.file_names = FILE_NAMES
        if f_type not in ["train", "test"]:
            raise ValueError("f_type must be 'train' or 'test'")
        if f_type == "train":
            self.file_names = self.file_names[:int(len(self.file_names)*(1 - self.test_fraction))]
        else:
            self.file_names = self.file_names[int(len(self.file_names)*(1 - self.test_fraction)):]
        self.f_type = f_type

    def __len__(self):
        return len(self.file_names)*self.multi

    def __getitem__(self, idx):
        # Load a single file
        f_ix = idx // self.multi # sample the file 4 times
        file_path = self.file_names[f_ix]
        activations = np.load(file_path)[:, :-3]  # Remove metadata columns

        # Normalize
        activations = activations / self.scale_factor * np.sqrt(activations.shape[1])

        # Random subsampling to sample_size
        np.random.seed(self.seed + idx)  # Change seed per file for reproducibility
        subsample_indices = np.random.choice(activations.shape[0], sample_size, replace=False)
        activations = activations[subsample_indices]

        # Get batch 
        batch_i = idx % self.multi
        start = batch_i*self.batch_size
        end = (batch_i+1)*self.batch_size
        activations = activations[start:end]

        # Convert to tensor
        return torch.tensor(activations, dtype=torch.float32)

# Model Definition
class SparseAutoencoder(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim, l1_lambda, lr):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = nn.Linear(input_dim, hidden_dim)
        self.decoder = nn.Linear(hidden_dim, input_dim)
        self.criterion = nn.MSELoss()
        self.l1_lambda = l1_lambda
        self.lr = lr

    def forward(self, x):
        encoded = torch.relu(self.encoder(x))
        decoded = self.decoder(encoded)
        return decoded, encoded

    def compute_loss(self, batch, decoded, encoded):
        mse_loss = self.criterion(decoded, batch)
        decoder_weight_norms = torch.norm(self.decoder.weight, p=2, dim=0)
        l1_terms = encoded * decoder_weight_norms.unsqueeze(0)
        l1_loss = torch.mean(torch.sum(l1_terms, dim=1))
        return mse_loss, l1_loss

    def training_step(self, batch, batch_idx):
        batch = batch.to("cuda")
        decoded, encoded = self(batch)
        mse_loss, l1_loss = self.compute_loss(batch, decoded, encoded)
        total_loss = mse_loss + self.l1_lambda * l1_loss

        # Compute active features
        active_features = (encoded > 0).any(dim=0).float().mean().item() * 100

        # Log metrics
        self.log("train_loss", total_loss, on_step=True, on_epoch=True)
        self.log("train_mse_loss", mse_loss, on_step=True, on_epoch=True)
        self.log("train_l1_loss", l1_loss, on_step=True, on_epoch=True)
        self.log("active_features", active_features, on_step=True, on_epoch=True)
        self.log("val_loss", 0, on_step=True, on_epoch=True)
        # train.report({"loss": total_loss, "active_features": active_features})
        return total_loss

    def validation_step(self, batch, batch_idx):
        batch = batch.to("cuda")
        decoded, encoded = self(batch)
        mse_loss, l1_loss = self.compute_loss(batch, decoded, encoded)
        total_loss = mse_loss + self.l1_lambda * l1_loss

        # Compute active features
        active_features = (encoded > 0).any(dim=0).float().mean().item() * 100

        # Log metrics
        self.log("val_loss", total_loss, on_step=False, on_epoch=True)
        self.log("val_mse_loss", mse_loss, on_step=False, on_epoch=True)
        self.log("val_l1_loss", l1_loss, on_step=False, on_epoch=True)
        self.log("val_active_features", active_features, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

# DataLoader Creation
def create_data_loaders(batch_size):
    train_dataset = ActivationDataset("train", 0.01, scale_factor, batch_size, 42)
    val_dataset = ActivationDataset("test", 0.01, scale_factor, batch_size, 42)
    # train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)#, num_workers=3, pin_memory=True, persistent_workers=True)
    # val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    train_loader = DataLoader(
        train_dataset,
        batch_size=1,  # Keep 1 as outer batch size for per-file sampling
        shuffle=False,
        num_workers=3,  # Adjust based on CPU availability
        pin_memory=False,
        persistent_workers=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=1,  # Smaller for validation
        pin_memory=False,
    )
    return train_loader, val_loader

# Training Function with Ray Tune
def train_model(config):
    train_loader, val_loader = create_data_loaders(config["HB"]["batch_size"])
    model = SparseAutoencoder(input_dim, hidden_dim=config["HB"]["hidden_dim"], l1_lambda=config["l1_lambda"], lr=config["lr"])

    logger = TensorBoardLogger("tb_logs", name="SparseAutoencoder")

    trainer = pl.Trainer(
        max_epochs=num_epochs,
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        logger=logger,
        profiler="advanced",
        # val_check_interval=0.25,  # Check validation 4 times per epoch
        # max_time="00:30:00",  # Stop after 30 minutes
        enable_progress_bar=True, # Show progress bar
        callbacks=[
            LearningRateMonitor(logging_interval="step"),
            EarlyStopping(monitor="val_loss", patience=3, mode="min"),
            TuneReportCheckpointCallback(
                {
                    "train_loss": "train_loss",
                    "train_mse_loss": "train_mse_loss",
                    "train_l1_loss": "train_l1_loss",
                    "active_features": "active_features",
                    "val_loss": "val_loss",
                },
                filename="none",  # Do not save checkpoints
                save_checkpoints = False,
                on="train_batch_end",
            ),
            TuneReportCheckpointCallback(
                {
                    "val_loss": "val_loss",
                    "val_mse_loss": "val_mse_loss",
                    "val_l1_loss": "val_l1_loss",
                    "val_active_features": "val_active_features",
                },
                on="validation_end",
            ),
            # RayTrainReportCallback(),
        ],
        # strategy=RayDDPStrategy(), # Use Ray for distributed training, DDP stands for Distributed Data Parallel
        # callbacks=[RayTrainReportCallback()], # Report metrics to Ray
        # plugins=[RayLightningEnvironment()], # Use Ray for distributed training
    )
    trainer.fit(model, train_loader, val_loader)

# Ray Tune Hyperparameter Search
def tune_hyperparameters():

    # possible_hidden_dims = [4096, 8192, 16384, 20000, 32768]
    # possible_batch_sizes = [512, 1024, 2048, 4096, 8192]
    # possible_hidden_dims = [8192, 16384, 20000, 32768]
    # possible_batch_sizes = [1024, 2048, 4096, 8192]
    possible_hidden_dims = [65536]
    possible_batch_sizes = [2048]

    valid_hb_pairs = []
    for hidden_dim in possible_hidden_dims:
        for batch_size in possible_batch_sizes:
            if hidden_dim * batch_size <= 441_000_000: # VRAM limit
                valid_hb_pairs.append({"hidden_dim": hidden_dim, "batch_size": batch_size})


    # search_space = {
    #     "hidden_dim": tune.choice([4096, 8192, 16384, 20000, 32768]),
    #     "batch_size": tune.choice([512, 1024, 2048, 4096, 8192]),
    #     "l1_lambda": tune.loguniform(1e-4, 1e-2),
    #     "lr": tune.loguniform(1e-4, 1e-2),
    # }

    # search_space = {
    #     "HB": tune.choice(valid_hb_pairs),
    #     "l1_lambda": tune.loguniform(1e-4, 1e-2),
    #     "lr": tune.loguniform(1e-4, 1e-2),
    # }

    search_space = {
        "HB": tune.choice(valid_hb_pairs),
        "l1_lambda": tune.choice([0.00597965]),
        "lr": tune.choice([2.5011e-05]),
    }

    # scheduler_asha = tune.schedulers.ASHAScheduler(
    #     max_t=num_epochs,
    #     grace_period=1,
    #     reduction_factor=2,
    # )

    # os.environ["RAY_CHDIR_TO_TRIAL_DIR"] = "0" # Allows relative paths, but trials are not isolated

    trainable_with_resources = with_resources(
        train_model,
        {"cpu": 4, "gpu": 1}  # Adjust based on your available resources
    )

    tuner = Tuner(
        trainable=trainable_with_resources,
        param_space=search_space,
        tune_config=tune.TuneConfig(
            metric="val_loss",
            mode="min",
            num_samples=1, # Number of hyperparameter sets to try
            max_concurrent_trials=1, # Number of trials to run concurrently
            # scheduler=scheduler_asha,
        ),
        run_config=RunConfig(
            name="hyperparameter_search",
            storage_path=str(Path("./results").resolve()),
        ),
    )
    results = tuner.fit()
    best_result = results.get_best_result(metric="val_loss", mode="min")
    print("Best Hyperparameters Found:")
    print(best_result.config)
    return results

# # Run Hyperparameter Search
# if __name__ == "__main__":
#     tune_hyperparameters()

In [9]:
import multiprocessing
multiprocessing.set_start_method("spawn", force=True)


In [10]:
results = tune_hyperparameters()

0,1
Current time:,2024-12-04 18:55:50
Running for:,01:59:07.48
Memory:,7.3/31.4 GiB

Trial name,status,loc,HB,l1_lambda,lr,iter,total time (s)
train_model_b91fe_00000,TERMINATED,172.19.2.2:376,{'hidden_dim': _5100,0.00597965,2.5011e-05,2036,7088.79


[36m(train_model pid=376)[0m /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
[36m(train_model pid=376)[0m   self.pid = os.fork()


Sanity Checking: |          | 0/? [00:00<?, ?it/s]
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  1.31it/s]
Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:05<00:00,  0.35it/s]
                                                                           
Epoch 0:   0%|          | 0/508 [00:00<?, ?it/s] 
Epoch 0:   0%|          | 1/508 [00:14<2:06:08,  0.07it/s, v_num=0]
Epoch 0:   0%|          | 2/508 [00:15<1:06:22,  0.13it/s, v_num=0]
Epoch 0:   1%|          | 3/508 [00:16<46:27,  0.18it/s, v_num=0]
Epoch 0:   1%|          | 4/508 [00:20<43:18,  0.19it/s, v_num=0]
Epoch 0:   1%|          | 5/508 [00:27<46:16,  0.18it/s, v_num=0]
Epoch 0:   1%|          | 6/508 [00:28<39:37,  0.21it/s, v_num=0]
Epoch 0:   1%|▏         | 7/508 [00:29<34:52,  0.24it/s, v_num=0]
Epoch 0:   2%|▏         | 8/508 [00:35<36:29,  0.23it/s, v_num=0]
Epoch 0:   2%|▏         | 9/508 [00:38<35:42,  0.23it/s, v_num=0]
Epoch

[36m(train_model pid=376)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/kaggle/working/results/hyperparameter_search/train_model_b91fe_00000_0_HB=hidden_dim_65536_batch_size_2048,l1_lambda=0.0060,lr=0.0000_2024-12-04_16-56-42/checkpoint_000000)


[36m(train_model pid=376)[0m 
Epoch 0: 100%|██████████| 508/508 [29:21<00:00,  0.29it/s, v_num=0]




Epoch 1:   0%|          | 0/508 [00:00<?, ?it/s, v_num=0]
Epoch 1:   0%|          | 1/508 [00:13<1:58:14,  0.07it/s, v_num=0]
Epoch 1:   0%|          | 2/508 [00:14<1:02:26,  0.14it/s, v_num=0]
Epoch 1:   1%|          | 3/508 [00:15<43:48,  0.19it/s, v_num=0]  
Epoch 1:   1%|          | 3/508 [00:15<43:49,  0.19it/s, v_num=0]
Epoch 1:   1%|          | 4/508 [00:20<42:04,  0.20it/s, v_num=0]
Epoch 1:   1%|          | 5/508 [00:25<43:09,  0.19it/s, v_num=0]
Epoch 1:   1%|          | 6/508 [00:28<40:13,  0.21it/s, v_num=0]
Epoch 1:   1%|▏         | 7/508 [00:29<35:22,  0.24it/s, v_num=0]
Epoch 1:   2%|▏         | 8/508 [00:32<33:34,  0.25it/s, v_num=0]
Epoch 1:   2%|▏         | 9/508 [00:41<38:09,  0.22it/s, v_num=0]
Epoch 1:   2%|▏         | 10/508 [00:42<34:56,  0.24it/s, v_num=0]
Epoch 1:   2%|▏         | 11/508 [00:42<32:19,  0.26it/s, v_num=0]
Epoch 1:   2%|▏         | 12/508 [00:46<32:19,  0.26it/s, v_num=0]
Epoch 1:   3%|▎         | 13/508 [00:52<33:00,  0.25it/s, v_num=0]
Epoch 1:

[36m(train_model pid=376)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/kaggle/working/results/hyperparameter_search/train_model_b91fe_00000_0_HB=hidden_dim_65536_batch_size_2048,l1_lambda=0.0060,lr=0.0000_2024-12-04_16-56-42/checkpoint_000001)


[36m(train_model pid=376)[0m 
Epoch 1: 100%|██████████| 508/508 [29:16<00:00,  0.29it/s, v_num=0]
Epoch 2:   0%|          | 0/508 [00:00<?, ?it/s, v_num=0]
Epoch 2:   0%|          | 1/508 [00:14<2:00:40,  0.07it/s, v_num=0]
Epoch 2:   0%|          | 2/508 [00:15<1:03:39,  0.13it/s, v_num=0]
Epoch 2:   1%|          | 3/508 [00:15<44:38,  0.19it/s, v_num=0]
Epoch 2:   1%|          | 4/508 [00:20<43:05,  0.19it/s, v_num=0]
Epoch 2:   1%|          | 5/508 [00:24<41:17,  0.20it/s, v_num=0]
Epoch 2:   1%|          | 6/508 [00:27<38:50,  0.22it/s, v_num=0]
Epoch 2:   1%|▏         | 7/508 [00:28<34:10,  0.24it/s, v_num=0]
Epoch 2:   2%|▏         | 8/508 [00:31<32:34,  0.26it/s, v_num=0]
Epoch 2:   2%|▏         | 9/508 [00:39<36:15,  0.23it/s, v_num=0]
Epoch 2:   2%|▏         | 9/508 [00:39<36:15,  0.23it/s, v_num=0]
Epoch 2:   2%|▏         | 10/508 [00:41<34:13,  0.24it/s, v_num=0]
Epoch 2:   2%|▏         | 11/508 [00:42<31:40,  0.26it/s, v_num=0]
Epoch 2:   2%|▏         | 12/508 [00:45<31:3

[36m(train_model pid=376)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/kaggle/working/results/hyperparameter_search/train_model_b91fe_00000_0_HB=hidden_dim_65536_batch_size_2048,l1_lambda=0.0060,lr=0.0000_2024-12-04_16-56-42/checkpoint_000002)


[36m(train_model pid=376)[0m 
Epoch 2: 100%|██████████| 508/508 [28:49<00:00,  0.29it/s, v_num=0]
Epoch 3:   0%|          | 0/508 [00:00<?, ?it/s, v_num=0]
Epoch 3:   0%|          | 1/508 [00:12<1:41:31,  0.08it/s, v_num=0]
Epoch 3:   0%|          | 2/508 [00:14<59:42,  0.14it/s, v_num=0]
Epoch 3:   1%|          | 3/508 [00:14<41:59,  0.20it/s, v_num=0]
Epoch 3:   1%|          | 4/508 [00:18<38:21,  0.22it/s, v_num=0]
Epoch 3:   1%|          | 5/508 [00:26<44:42,  0.19it/s, v_num=0]
Epoch 3:   1%|          | 6/508 [00:27<38:19,  0.22it/s, v_num=0]
Epoch 3:   1%|▏         | 7/508 [00:28<33:45,  0.25it/s, v_num=0]
Epoch 3:   2%|▏         | 8/508 [00:32<33:43,  0.25it/s, v_num=0]
Epoch 3:   2%|▏         | 9/508 [00:39<36:45,  0.23it/s, v_num=0]
Epoch 3:   2%|▏         | 10/508 [00:40<33:41,  0.25it/s, v_num=0]
Epoch 3:   2%|▏         | 11/508 [00:41<31:10,  0.27it/s, v_num=0]
Epoch 3:   2%|▏         | 12/508 [00:45<31:13,  0.26it/s, v_num=0]
Epoch 3:   3%|▎         | 13/508 [00:50<32:08

[36m(train_model pid=376)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/kaggle/working/results/hyperparameter_search/train_model_b91fe_00000_0_HB=hidden_dim_65536_batch_size_2048,l1_lambda=0.0060,lr=0.0000_2024-12-04_16-56-42/checkpoint_000003)


[36m(train_model pid=376)[0m 
Epoch 3: 100%|██████████| 508/508 [28:45<00:00,  0.29it/s, v_num=0]
Epoch 3: 100%|██████████| 508/508 [29:23<00:00,  0.29it/s, v_num=0]


2024-12-04 18:55:50,094	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/kaggle/working/results/hyperparameter_search' in 0.0062s.
2024-12-04 18:55:50,099	INFO tune.py:1041 -- Total run time: 7153.70 seconds (7147.47 seconds for the tuning loop).


Best Hyperparameters Found:
{'HB': {'hidden_dim': 65536, 'batch_size': 2048}, 'l1_lambda': 0.00597965, 'lr': 2.5011e-05}
