## Protein Design with Guided Discrete Diffusion

In [None]:
from omegaconf import OmegaConf
import hydra
from cortex.logging import wandb_setup

with hydra.initialize(config_path="./hydra"):
    cfg = hydra.compose(config_name="4_guided_diffusion")
    OmegaConf.set_struct(cfg, False)

wandb_setup(cfg)

In [None]:
from cortex.data.dataset import TAPEFluorescenceDataset


dataset = TAPEFluorescenceDataset(
    root='./.cache',
    download=True,
    train=True,
)

med_idx = len(dataset) // 2

init_df = dataset._data.sort_values("log_fluorescence").iloc[med_idx : med_idx + 1]
init_df = init_df.sample(n=cfg.optim.max_num_solutions, replace=True)


In [None]:
import lightning as L

# set random seed
L.seed_everything(seed=cfg.random_seed, workers=True)

# instantiate model
model = hydra.utils.instantiate(cfg.tree)
model.build_tree(cfg, skip_task_setup=False)

# instantiate trainer, set logger
trainer = hydra.utils.instantiate(cfg.trainer)

In [None]:
trainer.fit(
    model,
    train_dataloaders=model.get_dataloader(split="train"),
    val_dataloaders=model.get_dataloader(split="val"),
)

In [None]:
# construct guidance objective
initial_solution = init_df["tokenized_seq"].values
acq_fn_runtime_kwargs = hydra.utils.call(
    cfg.guidance_objective.runtime_kwargs, model=model, candidate_points=initial_solution
)
acq_fn = hydra.utils.instantiate(cfg.guidance_objective.static_kwargs, **acq_fn_runtime_kwargs)

In [None]:
tokenizer_transform = model.root_nodes["protein_seq"].eval_transform
tokenizer = tokenizer_transform[0].tokenizer

tok_idxs = tokenizer_transform(initial_solution)
is_mutable = tokenizer.get_corruptible_mask(tok_idxs)
is_mutable


In [None]:
import torch
with torch.inference_mode():
    tree_output = model.call_from_str_array(initial_solution, corrupt_frac=0.0)
    init_obj_vals = acq_fn.get_objective_vals(tree_output)
init_obj_vals

In [None]:

optimizer = hydra.utils.instantiate(
    cfg.optim,
    params=tok_idxs,
    is_mutable=is_mutable,
    model=model,
    objective=acq_fn,
    constraint_fn=None,
)
for _ in range(cfg.num_steps):
    optimizer.step()


In [None]:
new_designs = optimizer.get_best_solutions()

In [None]:
with torch.inference_mode():
    tree_output = model.call_from_str_array(new_designs["protein_seq"].values, corrupt_frac=0.0)
    final_obj_vals = acq_fn.get_objective_vals(tree_output)
final_obj_vals

In [None]:
history = optimizer._buffer

med_obj_val = history.groupby("iteration").obj_val.median()

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme(style="whitegrid", font_scale=1.75)

plt.plot(med_obj_val)
plt.xlabel("Diffusion Iteration")
plt.ylabel("Median Acq. Value")

In [None]:
sns.kdeplot(final_obj_vals.view(-1), fill=True, alpha=0.5, cut=0)
ylim = plt.ylim()
plt.vlines(init_obj_vals[0], *ylim, color="black", linestyle="--", label="Initial Value")
plt.xlabel("Predicted Log Fluorescence")
plt.legend()