In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import numpy as np
import torch.optim as opt
import utils
import hyp
from rdkit import Chem
from rdkit.Chem import QED
from environment import Molecule
import replay_buffer
from torch.utils.tensorboard import SummaryWriter
from rdkit.Chem import rdFingerprintGenerator
from rdkit.Chem import Crippen
from rdkit.Chem import Descriptors, QED
from rdkit.Contrib.SA_Score import sascorer

In [2]:
import os
import subprocess
from rdkit import Chem
from rdkit.Chem import AllChem

def smiles_to_pdbqt(smiles: str, output_file: str):
    """Конвертирует SMILES в PDBQT через Open Babel."""
    # Создание молекулы из SMILES и добавление водородов
    mol = Chem.MolFromSmiles(smiles)
    mol = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol)
    
    # Сохранение во временный файл .mol
    temp_mol = "temp.mol"
    Chem.MolToMolFile(mol, temp_mol)
    
    # Конвертация в PDBQT через Open Babel
    subprocess.run(f"obabel {temp_mol} -O {output_file} --gen3d", shell=True)
    os.remove(temp_mol)

def run_vina_docking(protein_pdbqt: str, ligand_pdbqt: str, center: tuple, size: tuple = (20, 20, 20)) -> float:
    """Запускает докинг и возвращает энергию связывания."""
    # Создание конфигурационного файла для Vina
    config = f"""
    receptor = {protein_pdbqt}
    ligand = {ligand_pdbqt}
    out = result.pdbqt
    center_x = {center[0]}
    center_y = {center[1]}
    center_z = {center[2]}
    size_x = {size[0]}
    size_y = {size[1]}
    size_z = {size[2]}
    exhaustiveness = 8
    """
    with open("config.txt", "w") as f:
        f.write(config)
    
    # Запуск AutoDock Vina
    result = subprocess.run(
        "vina --config config.txt --log log.txt",
        shell=True,
        capture_output=True,
        text=True
    )
    
    # Извлечение энергии связывания из лога
    with open("log.txt", "r") as f:
        log = f.read()
    for line in log.split("\n"):
        if "Affinity" in line:
            return float(line.split()[1])
    return None

In [3]:
import warnings
warnings.filterwarnings('ignore')
warnings.filterwarnings('error', category=DeprecationWarning)

In [4]:
class MolDQN(nn.Module):
    def __init__(self, input_length, output_length):
        super(MolDQN, self).__init__()

        self.linear_1 = nn.Linear(input_length, 1024)
        self.linear_2 = nn.Linear(1024, 512)
        self.linear_3 = nn.Linear(512, 128)
        self.linear_4 = nn.Linear(128, 32)
        self.linear_5 = nn.Linear(32, output_length)

        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.activation(self.linear_1(x))
        x = self.activation(self.linear_2(x))
        x = self.activation(self.linear_3(x))
        x = self.activation(self.linear_4(x))
        x = self.linear_5(x)

        return x

In [5]:
from joblib import dump, load
REPLAY_BUFFER_CAPACITY = hyp.replay_buffer_size

irritation_model = load('model_irrit.joblib')
melanin_model = load('model_melanin.joblib')

def get_fingerprint(molecule):
    mfpgen = rdFingerprintGenerator.GetMorganGenerator(radius=hyp.fingerprint_radius,fpSize=hyp.fingerprint_length)
    if molecule is None:
        return np.zeros((hyp.fingerprint_length,))
    fingerprint = mfpgen.GetFingerprint(molecule)
    return np.array(fingerprint)

class QEDRewardMolecule(Molecule):
    
    def __init__(self, discount_factor, **kwargs):
        
        super(QEDRewardMolecule, self).__init__(**kwargs)
        self.discount_factor = discount_factor

    def _reward(self):
        
        molecule = Chem.MolFromSmiles(self._state)
        if molecule is None:
            return 0.0
        qed = QED.qed(molecule)
        irrit_proba = irritation_model.predict_proba(np.expand_dims(get_fingerprint(molecule), axis=0))[0, 1]
        melanin_proba = melanin_model.predict_proba(np.expand_dims(get_fingerprint(molecule), axis=0))[0, 1]
        sa_score = sascorer.calculateScore(molecule)
        #-irrit_proba + melanin_proba + permeability_score 
        return (qed - 0.7 * sa_score - irrit_proba + 0.5 * melanin_model) * self.discount_factor ** (self.num_steps_taken)


In [6]:
irritation_model.predict_proba(np.expand_dims(get_fingerprint(Chem.MolFromSmiles("CC#CC")), axis=0))[0, 1]

0.9959458154732174

In [7]:
class Agent(object):
    def __init__(self, input_length, output_length, device):
        self.device = device
        self.dqn, self.target_dqn = (
            MolDQN(input_length, output_length).to(self.device),
            MolDQN(input_length, output_length).to(self.device),
        )
        for p in self.target_dqn.parameters():
            p.requires_grad = False
        self.replay_buffer = replay_buffer.ReplayBuffer(REPLAY_BUFFER_CAPACITY)
        self.optimizer = getattr(opt, hyp.optimizer)(
            self.dqn.parameters(), lr=hyp.learning_rate
        )
        self.times_of_update = 0

    def get_action(self, observations, epsilon_threshold):

        if np.random.uniform() < epsilon_threshold:
            action = np.random.randint(0, observations.shape[0])
        else:
            q_value = self.dqn.forward(observations.to(self.device)).cpu()
            action = torch.argmax(q_value).numpy()

        return action

    def update_params(self, batch_size, gamma, polyak):
        # update target network

        # sample batch of transitions
        states, _, rewards, next_states, dones = self.replay_buffer.sample(batch_size)
        q_t = torch.zeros(batch_size, 1, requires_grad=False)
        v_tp1 = torch.zeros(batch_size, 1, requires_grad=False)
        for i in range(batch_size):
            state = (
                torch.FloatTensor(states[i])
                .reshape(-1, hyp.fingerprint_length + 1)
                .to(self.device)
            )
            q_t[i] = self.dqn(state)

            next_state = (
                torch.FloatTensor(next_states[i])
                .reshape(-1, hyp.fingerprint_length + 1)
                .to(self.device)
            )
            v_tp1[i] = torch.max(self.target_dqn(next_state))

        rewards = torch.FloatTensor(rewards).reshape(q_t.shape).to(self.device)
        q_t = q_t.to(self.device)
        v_tp1 = v_tp1.to(self.device)
        dones = torch.FloatTensor(dones).reshape(q_t.shape).to(self.device)

        # # get q values
        q_tp1_masked = (1 - dones) * v_tp1
        q_t_target = rewards + gamma * q_tp1_masked
        td_error = q_t - q_t_target

        q_loss = torch.where(
            torch.abs(td_error) < 1.0,
            0.5 * td_error * td_error,
            1.0 * (torch.abs(td_error) - 0.5),
        )
        q_loss = q_loss.mean()

        # backpropagate
        self.optimizer.zero_grad()
        q_loss.backward()
        self.optimizer.step()

        if self.times_of_update % 10 == 0:
            with torch.no_grad():
                for p, p_targ in zip(self.dqn.parameters(), self.target_dqn.parameters()):
                    p_targ.data.mul_(polyak)
                    p_targ.data.add_((1 - polyak) * p.data)
            self.times_of_update += 1
        return q_loss
        

In [None]:
TENSORBOARD_LOG = True
TB_LOG_PATH = "./runs/dqn/run2"
episodes = 0
iterations = 4000
update_interval = 2
batch_size = 512
num_updates_per_it = 1

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

environment = QEDRewardMolecule(
    discount_factor=hyp.discount_factor,
    atom_types=set(hyp.atom_types),
    init_mol=hyp.start_molecule,
    allow_removal=hyp.allow_removal,
    allow_no_modification=hyp.allow_no_modification,
    allow_bonds_between_rings=hyp.allow_bonds_between_rings,
    allowed_ring_sizes=set(hyp.allowed_ring_sizes),
    max_steps=hyp.max_steps_per_episode,
)

# DQN Inputs and Outputs:
# input: appended action (fingerprint_length + 1) .
# Output size is (1).

agent = Agent(hyp.fingerprint_length + 1, 1, device)

if TENSORBOARD_LOG:
    writer = SummaryWriter(TB_LOG_PATH)

environment.initialize()

eps_threshold = 1.0
batch_losses = []

for it in range(iterations):

    steps_left = hyp.max_steps_per_episode - environment.num_steps_taken

    valid_actions = list(environment.get_valid_actions())

    observations = np.vstack(
        [
            np.append(
                utils.get_fingerprint(
                    act, hyp.fingerprint_length, hyp.fingerprint_radius
                ),
                steps_left,
            )
            for act in valid_actions
        ]
    )

    observations_tensor = torch.Tensor(observations)

    a = agent.get_action(observations_tensor, max(0.1, eps_threshold))

   
    action = valid_actions[a]
    result = environment.step(action)

    action_fingerprint = np.append(
        utils.get_fingerprint(action, hyp.fingerprint_length, hyp.fingerprint_radius),
        steps_left,
    )

    next_state, reward, done = result

    steps_left = hyp.max_steps_per_episode - environment.num_steps_taken

    next_state = utils.get_fingerprint(
        next_state, hyp.fingerprint_length, hyp.fingerprint_radius
    ) 

    action_fingerprints = np.vstack(
        [
            np.append(
                utils.get_fingerprint(
                    act, hyp.fingerprint_length, hyp.fingerprint_radius
                ),
                steps_left,
            )
            for act in environment.get_valid_actions()
        ]
    )  


    agent.replay_buffer.add(
        obs_t=action_fingerprint,  # (fingerprint_length + 1)
        action=0,  # No use
        reward=reward,
        obs_tp1=action_fingerprints,  # (num_actions, fingerprint_length + 1)
        done=float(result.terminated),
    )

    if done:
        final_reward = reward
        if episodes != 0 and TENSORBOARD_LOG and len(batch_losses) != 0:
            writer.add_scalar("episode_reward", final_reward, episodes)
            writer.add_scalar("episode_loss", np.array(batch_losses).mean(), episodes)
        if episodes != 0 and episodes % 2 == 0 and len(batch_losses) != 0:
            print(
                "reward of final molecule at episode {} is {}, qed is {},sa is {}, molecule is {}".format(
                    episodes, final_reward, QED.qed(Chem.MolFromSmiles(environment._state)), sascorer.calculateScore(Chem.MolFromSmiles(environment._state)),  environment._state
                )
            )
            print(
                "mean loss in episode {} is {}".format(
                    episodes, np.array(batch_losses).mean()
                )
            )
        episodes += 1
        eps_threshold -= 0.01
        batch_losses = []
        environment.initialize()

    if it % update_interval == 0 and agent.replay_buffer.__len__() >= batch_size:
        for update in range(num_updates_per_it):
            loss = agent.update_params(batch_size, hyp.gamma, hyp.polyak)
            loss = loss.item()
            batch_losses.append(loss)

reward of final molecule at episode 12 is -2.135732820598522, qed is 0.09107870243907015,sa is 6.103670245120666, molecule is CC(N=N)C(O)=C=NNN(C)ON(O)N(N=O)OC=N
mean loss in episode 12 is 1.7132644653320312
reward of final molecule at episode 14 is -2.126356398780658, qed is 0.2585405335274589,sa is 6.286700660187834, molecule is CNN1NOC2=C(N)C(O)(ONN2)O1
mean loss in episode 14 is 0.11283107250928878
reward of final molecule at episode 16 is -2.3554398777265635, qed is 0.04217036102538414,sa is 6.658719634376916, molecule is C#CC1C(N=C(O)N=NC(N=N)NN)N1NN(C)NN(N)N
mean loss in episode 16 is 0.03935618922114372
reward of final molecule at episode 18 is -1.8765669687002655, qed is 0.32837061363055725,sa is 5.673435187372101, molecule is C=C(CC)NN1C(CO)NC(C=O)(N=O)C2=NC21CCO
mean loss in episode 18 is 0.00892090166453272
reward of final molecule at episode 20 is -1.8829772350842686, qed is 0.09590969158174215,sa is 5.400837104628774, molecule is C=NOC(=O)C(OC(O)CN)N(N)NN
mean loss in epi

In [None]:
generated_molecules = []
num_molecules_to_generate = 10
agent.dqn.eval()
eps_threshold = 0.03

for it in range(num_molecules_to_generate):
    done = False
    environment.initialize()
    while not done:
        steps_left = hyp.max_steps_per_episode - environment.num_steps_taken
        valid_actions = list(environment.get_valid_actions())
    
        observations = np.vstack(
            [
                np.append(
                    utils.get_fingerprint(
                        act, hyp.fingerprint_length, hyp.fingerprint_radius
                    ),
                    steps_left,
                )
                for act in valid_actions
            ]
        ) 
    
        observations_tensor = torch.Tensor(observations)
        a = agent.get_action(observations_tensor, eps_threshold)
        action = valid_actions[a]
        result = environment.step(action)
    
        action_fingerprint = np.append(
            utils.get_fingerprint(action, hyp.fingerprint_length, hyp.fingerprint_radius),
            steps_left,
        )
    
        next_state, reward, done = result
        steps_left = hyp.max_steps_per_episode - environment.num_steps_taken
    
        next_state = utils.get_fingerprint(
            next_state, hyp.fingerprint_length, hyp.fingerprint_radius
        )  
    
        action_fingerprints = np.vstack(
            [
                np.append(
                    utils.get_fingerprint(
                        act, hyp.fingerprint_length, hyp.fingerprint_radius
                    ),
                    steps_left,
                )
                for act in environment.get_valid_actions()
            ]
        )
        #print(environment._state)
    
    
    generated_molecules.append(environment._state)
    print(generated_molecules[-1])

In [None]:
generated_molecules

In [None]:
from rdkit import Chem

valid_smiles = []
for smi in generated_molecules:
    mol = Chem.MolFromSmiles(smi)
    if mol is not None:
        valid_smiles.append(smi)

print(f"Сгенерировано валидных молекул: {len(valid_smiles)}/{len(generated_molecules)}")

In [None]:
print("Пример сгенерированных молекул:")
for smi in valid_smiles[:5]:
    print(smi)