In [1]:
from s4dd import S4forDenovoDesign
import torch
device = torch.device("cuda") if torch.cuda.is_available() else 'cpu'
print(device)


cuda


In [2]:

# Create an S4 model with (almost) the same parameters in the paper.
s4 = S4forDenovoDesign(
    n_max_epochs=1,  # This is for only demonstration purposes. Set this to a (much) higher value for actual training. Default: 400.
    batch_size=64,  # This is also for demonstration purposes. The value in the paper is 2048.
    device=device,  # Replace this with "cpu" if you didn't install pytorch with CUDA support.
)
# Pretrain the model on ChEMBL
s4.train(
    training_molecules_path="./datasets/chemblv31/mini_train.zip",  # This a 50K subsample of the ChEMBL training set for quick(er) testing.
    val_molecules_path="./datasets/chemblv31/valid.zip",
)


100%|██████████| 782/782 [02:01<00:00,  6.45it/s]


Epoch:0	Loss: 0.5900923684048835, Val Loss: 0.4607980233770262


{'train_loss': [0.5900923684048835], 'val_loss': [0.4607980233770262]}

In [3]:
# Fine-tune the model on bioactive molecules for PKM2
s4.train(
    training_molecules_path="./datasets/pkm2/train.zip",
    val_molecules_path="./datasets/pkm2/valid.zip",
)


100%|██████████| 7/7 [00:01<00:00,  6.42it/s]


Epoch:0	Loss: 0.45284183962004526, Val Loss: 0.3751584589481354


{'train_loss': [0.45284183962004526], 'val_loss': [0.3751584589481354]}

In [4]:
# Design new molecules
designs, lls = s4.design_molecules(n_designs=32, batch_size=16, temperature=1.0)

In [5]:
designs

['O=C(c1cccc(CNC(=O)Nc2ccccc2)o1',
 'COc1ccc(CSc1ccc2c(CNC(=O)CC(C)(C(=O)N2CCN(CCCN3CCCC2=O)c2=O)cc1',
 'CN1CCN(S(=O)(=O)Nc2ccco2)cc1',
 'Cc1ccc(S(=O)(=O)N)cc1)c1cc(C(=O)c2ccncc12',
 'N#Cc1nn2ncnn1C)c1nc2ccccc2no1',
 'CSc1nnc(C2(C)c1NS(=O)(=O)c1cccc(F)cc2)nc1',
 'COc1ccc(C[N+](=O)[O-])c1CCN(CC(=O)c1ccccc1C',
 'CCc1c(Cl)c2c([N+](C)(C)c1C(=O)NCc1ccccc1',
 'Oc1ccc(C=NNS(=O)(=O)c2ccc(S(=O)(=O)N(Cc2ncccc2)c1',
 'COC(=O)c1ccc2c(c1)Oc1ccc(C=Cc1ccncc1)c1ccc(F)cc1',
 'COc1cccc(S(=O)(=O)N2CCOc2ccccc3c(=O)c2C)o1',
 'c1nonn2SCC3(c1nc(SCC)c3ccc(OCCCC)oc21',
 'COc1ccnc2c1c(=O)n2c(C)ncc3c(n1',
 'COc1ccc(S(=O)(=O)N2CCN(C(=O)c2ccc(NC(=O)N(CC(=O)NC)ccc12',
 'CCSc1nn(Cc2ccc(S(=O)(=O)F)cc1',
 'Cc1ccccc1NS(=O)(=O)c1ccc(Cl)cc1',
 'Cc1ccc(S(=O)(=O)c2ccccc2)o1',
 'Cc1ccc(S(=O)(=O)N1CCC23CCCC2CC2)n1',
 'CCOc1ccc2C(C(=O)Nc1ccc([N+](=O)[O-])c1',
 'Cc1ccc(Cc2cc(C2=C(NC(=O)c1ccc(NC)c1)C(O)(c1)c(=O)o1',
 'COc1ccc(C(=O)NS(=O)(=O)c2ccc3c(C(=O)Nc2ccccc2Cl)cc1',
 'Cc1ccc(CSc1nnn2CCOC1CS(=O)(=O)c2ccccc2)c1',
 'COc1cc2c(

In [7]:
len(designs)

32

In [6]:
lls

[-0.5611103674548148,
 -0.8342084684368012,
 -0.509624955108583,
 -0.8085676003507126,
 -1.2420427195273775,
 -0.828296768052326,
 -0.672862124495985,
 -0.8720166819088213,
 -0.6371822301915416,
 -0.6862884598765752,
 -0.6628748380436041,
 -1.7075688785958867,
 -1.1411158117279843,
 -0.6137186634436909,
 -0.8762567389095735,
 -0.3722511489711929,
 -0.3624927165616213,
 -0.7695154276973861,
 -0.7061847419979134,
 -1.1477409358258166,
 -0.4657707691095364,
 -0.7432080386922458,
 -0.7453101329787115,
 -0.5484992251830975,
 -0.8210023049486206,
 -0.5624696326567653,
 -0.536329639943703,
 -0.704730388319334,
 -1.3122901610064628,
 -1.187093605413841,
 -0.8356709592856877,
 -0.6879055349116143]