In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import csv
import random
import re


def fix_name(name: str) -> str:
    name = re.sub(r"\#\w+", "", name)
    try:
        second_block = re.search(r"/([^/]+)\.sdf$", name).group(1)
        fixed = re.sub(r"/([^/]+)_pocket", f"/{second_block}_pocket", name, count=1)
    except:
        return name
    return fixed

# Input CSVs
SRC_FILES = [
    "/mnt/STORAGE3/sebastian2/DiffSBDD/data/descriptions_generate.csv",
    # add more files here if needed
]
DST_CSV = "/mnt/STORAGE3/sebastian2/DiffSBDD/data/training_descriptions_mixed.csv"

# Possible source columns for text
USE_COLUMNS = ["text_func", "text_llm"]

rows = []

# Collect rows from all source CSVs
for src in SRC_FILES:
    with open(src) as f_in:
        reader = csv.DictReader(f_in)
        for row in reader:
            # randomly choose one of the available columns
            if row["text_func"] != "No prominent functional groups identified.":
                col = random.choice(USE_COLUMNS)
                rows.append({"name": fix_name(row["name"]), "text": row[col]})
            else:
                rows.append({"name": fix_name(row["name"]), "text": row["text_llm"]})

# Shuffle them
random.shuffle(rows)

# Write to destination
with open(DST_CSV, "w", newline="") as f_out:
    writer = csv.DictWriter(f_out, fieldnames=["name", "text"])
    writer.writeheader()
    writer.writerows(rows)

print(f"Wrote {len(rows)} rows to {DST_CSV}")


##### Train from scratch

In [None]:
import os, sys, yaml
from argparse import Namespace
from pathlib import Path

import torch
import pytorch_lightning as pl
import numpy as np

# Point Python to this repo so imports work in a notebook
REPO_ROOT = "/mnt/STORAGE3/sebastian2/DiffSBDD"
if REPO_ROOT not in sys.path:
    sys.path.append(REPO_ROOT)

from lightning_modules import LigandPocketDDPM

# ---- User inputs ----
CONFIG_YML = "/mnt/STORAGE3/sebastian2/DiffSBDD/configs/crossdock_fullatom_cond.yml"
DATADIR    = "/mnt/STORAGE3/sebastian2/DiffSBDD/data/crossdocked_pocket10_proc_rerun"  # folder with train.npz / val.npz / test.npz
LOGDIR     = "/mnt/STORAGE3/sebastian2/DiffSBDD/logs"
RUN_NAME   = "notebook_try_text_cond"
TEXT_CSV   = "/mnt/STORAGE3/sebastian2/DiffSBDD/data/training_descriptions_mixed.csv"  # from step 1
TEXT_MODEL = "GT4SD/multitask-text-and-chemistry-t5-base-standard"
GPUS       = 1  # set 0 for CPU
MAX_EPOCHS = 2  # quick test
# ---------------------

# Helper: convert dict -> Namespace recursively
def to_ns(d):
    if isinstance(d, dict):
        return Namespace(**{k: to_ns(v) for k, v in d.items()})
    return d

# Load config
with open(CONFIG_YML, "r") as f:
    cfg = yaml.safe_load(f)
args = to_ns(cfg)

# Minimal adjustments for notebook run
args.datadir = DATADIR
args.logdir = LOGDIR
args.run_name = RUN_NAME
args.wandb_params.mode = "disabled"  # disable wandb in notebook runs
args.enable_progress_bar = True
args.gpus = GPUS
args.n_epochs = MAX_EPOCHS

# Required histogram file from the processed dataset
histogram_file = Path(args.datadir, "size_distribution.npy")
histogram = np.load(histogram_file).tolist()

# Build LightningModule with text conditioning
pl_module = LigandPocketDDPM(
    outdir=Path(args.logdir, args.run_name),
    dataset=args.dataset,
    datadir=args.datadir,
    batch_size=args.batch_size,
    lr=args.lr,
    egnn_params=args.egnn_params,
    diffusion_params=args.diffusion_params,
    num_workers=args.num_workers,
    augment_noise=args.augment_noise,
    augment_rotation=args.augment_rotation,
    clip_grad=args.clip_grad,
    eval_epochs=args.eval_epochs,
    eval_params=args.eval_params,
    visualize_sample_epoch=args.visualize_sample_epoch,
    visualize_chain_epoch=args.visualize_chain_epoch,
    auxiliary_loss=args.auxiliary_loss,
    loss_params=args.loss_params,
    mode=args.mode,
    node_histogram=histogram,
    pocket_representation=args.pocket_representation,
    text_model_name=TEXT_MODEL,
    text_csv=TEXT_CSV,
)

# Trainer
accelerator = "gpu" if (GPUS and torch.cuda.is_available()) else "cpu"
trainer = pl.Trainer(
    max_epochs=args.n_epochs,
    enable_progress_bar=args.enable_progress_bar,
    num_sanity_val_steps=args.num_sanity_val_steps,
    accelerator=accelerator,
    devices=GPUS if accelerator == "gpu" else None,
    logger=False,
    strategy=('ddp_notebook' if args.gpus > 1 else None)
)

# Train
trainer.fit(model=pl_module)
print("Training finished.")

##### Finetune

In [None]:
import os, sys, yaml
from argparse import Namespace
from pathlib import Path

import torch
import pytorch_lightning as pl
import numpy as np

# Point Python to this repo so imports work in a notebook
REPO_ROOT = "/mnt/STORAGE3/sebastian2/DiffSBDD"
if REPO_ROOT not in sys.path:
    sys.path.append(REPO_ROOT)

from lightning_modules import LigandPocketDDPM

# ---- User inputs ----
CONFIG_YML = "/mnt/STORAGE3/sebastian2/DiffSBDD/configs/FT_crossdock_fullatom_cond.yml"
DATADIR    = "/mnt/STORAGE3/sebastian2/DiffSBDD/data/crossdocked_pocket10_proc_ca_only"  # folder with train.npz / val.npz / test.npz
LOGDIR     = "/mnt/STORAGE3/sebastian2/DiffSBDD/logs"
RUN_NAME   = "notebook_try_text_cond"
TEXT_CSV   = "/mnt/STORAGE3/sebastian2/DiffSBDD/data/training_descriptions_mixed.csv"  # from step 1
TEXT_MODEL = "GT4SD/multitask-text-and-chemistry-t5-base-standard"
GPUS       = 1  # set 0 for CPU
MAX_EPOCHS = 1001  # quick test
# ---------------------

# Helper: convert dict -> Namespace recursively
def to_ns(d):
    if isinstance(d, dict):
        return Namespace(**{k: to_ns(v) for k, v in d.items()})
    return d

# Load config
with open(CONFIG_YML, "r") as f:
    cfg = yaml.safe_load(f)
args = to_ns(cfg)

# Minimal adjustments for notebook run
args.datadir = DATADIR
args.logdir = LOGDIR
args.run_name = RUN_NAME
# args.wandb_params.mode = "disabled"  # disable wandb in notebook runs
args.enable_progress_bar = True
args.gpus = GPUS
args.n_epochs = MAX_EPOCHS

# Required histogram file from the processed dataset
histogram_file = Path(args.datadir, "size_distribution.npy")
histogram = np.load(histogram_file).tolist()

# Build LightningModule with text conditioning
pl_module = LigandPocketDDPM(
    outdir=Path(args.logdir, args.run_name),
    dataset=args.dataset,
    datadir=args.datadir,
    batch_size=args.batch_size,
    lr=args.lr,
    egnn_params=args.egnn_params,
    diffusion_params=args.diffusion_params,
    num_workers=args.num_workers,
    augment_noise=args.augment_noise,
    augment_rotation=args.augment_rotation,
    clip_grad=args.clip_grad,
    eval_epochs=args.eval_epochs,
    eval_params=args.eval_params,
    visualize_sample_epoch=args.visualize_sample_epoch,
    visualize_chain_epoch=args.visualize_chain_epoch,
    auxiliary_loss=args.auxiliary_loss,
    loss_params=args.loss_params,
    mode=args.mode,
    node_histogram=histogram,
    pocket_representation=args.pocket_representation,
    text_model_name=args.text_model_name,
    text_embeddings_path=args.text_embeddings_path,
    # text_csv=TEXT_CSV,
)

# Trainer
accelerator = "gpu" if (GPUS and torch.cuda.is_available()) else "cpu"
trainer = pl.Trainer(
    max_epochs=args.n_epochs,
    enable_progress_bar=args.enable_progress_bar,
    num_sanity_val_steps=args.num_sanity_val_steps,
    accelerator=accelerator,
    devices=GPUS if accelerator == "gpu" else None,
    logger=False,
    strategy=('ddp_notebook' if args.gpus > 1 else None)
)

CKPT_PATH = "/mnt/STORAGE3/sebastian2/DiffSBDD/checkpoints/crossdocked_ca_joint.ckpt" 
pl_module.lr = 1e-5 

# Load checkpoint weights, allowing missing keys for new layers
if CKPT_PATH is not None:
    checkpoint = torch.load(CKPT_PATH, map_location='cpu')
    missing, unexpected = pl_module.load_state_dict(checkpoint['state_dict'], strict=False)
    print(f"Loaded checkpoint with missing keys: {missing}, unexpected keys: {unexpected}")

# Train (fine-tune)
trainer.fit(model=pl_module)
print("Training finished.")
