In [1]:
# for protein structural modelling
from Bio.SVDSuperimposer import SVDSuperimposer
import numpy as np
import biovec
import glob

# from utils functions
from utils.encoder_decoder_no_proline import *
from utils.sequence import *
from utils.reward import *
from utils.environment_no_proline import *

# for deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# for envronment creation
import gymnasium as gym
from gymnasium import Env
from gymnasium.spaces import MultiDiscrete
from gymnasium.spaces import Discrete
from gymnasium.spaces import Box

#for reading PDB files and processing them
from biopandas.pdb import PandasPdb
import pandas as pd
from utils.sequence import *

# for generating structures through esm instead of modeller
import esm
import biotite.structure as struc
import biotite.structure.io as strucio

# for general utility
import random
import os
import subprocess
import time
import matplotlib.pyplot as plt
from datetime import datetime

In [2]:
class PolicyNetwork():
    def __init__(self, n_state, n_action, n_hidden=50,lr=0.001,entropy_weight=0.01):
        self.model = nn.Sequential(nn.Linear(n_state, n_hidden),
                                   nn.ReLU(),
                                   nn.Linear(n_hidden, n_action),
                                   nn.Softmax(dim=-1), )
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr)
        self.entropy_weight = entropy_weight
    def predict(self, s):
        return self.model(torch.Tensor(s))
    def update(self, returns, log_probs,entropies):
        policy_gradient = []
        for log_prob, Gt, entropy in zip(log_probs, returns, entropies):
            policy_gradient.append((-log_prob * Gt) + (self.entropy_weight * entropy))
        loss = torch.stack(policy_gradient).sum()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
    def get_action(self, s):

        probs = self.predict(s)
        action = torch.multinomial(probs, 1).item()
        log_prob = torch.log(probs[action])
        entropy = -torch.sum(probs * torch.log(probs + 1e-9))  # Calculate entropy
        return action, log_prob, entropy

In [8]:

def reinforce(env, estimator, n_episode, gamma=1.0):
    total_reward_episode = [0] * n_episode

    for episode in range(n_episode):
        log_probs = []
        rewards = []
        entropies = []
        state, info,dummy = env.reset()
        while True:
            action, log_prob, entropy = estimator.get_action(state)
            next_state, reward, terminated, truncated, info = env.step(action)
            print(info)
            total_reward_episode[episode] += reward
            log_probs.append(log_prob)
            rewards.append(reward)
            entropies.append(entropy)
            if terminated or truncated:
                returns = []
                Gt = 0
                pw = 0
                for reward in rewards[::-1]:
                    Gt += gamma ** pw * reward
                    pw += 1
                    returns.append(Gt)
                returns = returns[::-1]
                returns = torch.tensor(returns)
                if returns.std() != 0:
                    returns = (returns - returns.mean()) / (returns.std() + 1e-9)
                estimator.update(returns, log_probs,entropies)
                print('Episode: {}, total reward: {}'.format(episode, total_reward_episode[episode]))
                break
            state = next_state

    return total_reward_episode
                
                    



In [9]:
env = PeptideEvolutionNoProline(folder_containing_pdb_files='../DrugResistance/folder_for_machine_learning/30_test/',
                       structure_generator='esm_sse',
                       validation=False,
                       reward_cutoff=50,
                       unique_path_to_give_for_file = 'unique_1',
                       maximum_number_of_allowed_mutations_per_episode=15,
                       folder_to_save_validation_files='validation_structures')


In [10]:
n_state = env.observation_space.shape[0]
n_action = env.action_space.n
n_hidden = 128
lr = 0.0007
gamma = 0.95
entropic_factor = 0.0
n_episode = 8000


In [11]:
policy_net = PolicyNetwork(n_state, n_action, n_hidden, lr,entropy_weight=entropic_factor)

In [12]:
total_reward_episode = reinforce(env,estimator=policy_net,n_episode=n_episode,gamma=gamma)

[15, 'N']
[14, 'L']
[29, 'L']
[26, 'I']
[3, 'K']
[0, 'A']
[1, 'G']
[21, 'F']
[20, 'T']
[9, 'L']
[29, 'I']
[8, 'C']
[20, 'K']
[0, 'K']
[1, 'E']
Episode: 0, total reward: -25.14
[7, 'A']
[27, 'A']
[0, 'R']
[3, 'L']
[3, 'H']
[10, 'Q']
[18, 'N']
[29, 'I']
[19, 'C']
[5, 'W']
[16, 'D']
[15, 'A']
[8, 'S']
[16, 'R']
[17, 'F']
Episode: 1, total reward: -25.14
[13, 'E']
[18, 'E']
[28, 'A']
[28, 'A']
[5, 'A']
[14, 'A']
[0, 'A']
[4, 'V']
[29, 'D']
[2, 'W']
[12, 'W']
[18, 'T']
[12, 'M']
[0, 'V']
[28, 'F']
Episode: 2, total reward: -25.14
[19, 'K']
[20, 'D']
[24, 'D']
[29, 'T']
[2, 'K']
[15, 'V']
[25, 'V']
[29, 'Q']
[11, 'E']
[15, 'M']
[11, 'D']
[9, 'L']
[9, 'W']
[12, 'C']
[28, 'E']
Episode: 3, total reward: -25.14
[17, 'V']
[2, 'R']
[6, 'K']
[20, 'Y']
[20, 'N']
[28, 'Y']
[27, 'M']
[21, 'L']
[3, 'W']
[24, 'L']
[9, 'W']
[5, 'D']
[22, 'Y']
[21, 'N']
[19, 'L']
Episode: 4, total reward: -25.14
[9, 'G']
[14, 'T']
[22, 'M']
[12, 'L']
[27, 'E']
[17, 'A']
[11, 'Q']
[16, 'S']
[12, 'W']
[25, 'I']
[5, 'E']
[27

[13, 'N']
[20, 'R']
[2, 'T']
[28, 'T']
[15, 'K']


KeyboardInterrupt: 

In [9]:
torch.save(policy_net.model.state_dict(),f'saved_models/saved_rl_model_1_lr_{lr}_gamma_{gamma}_ep_{n_episode}_entropic_factor_{entropic_factor}_no_proline.pth')

In [10]:
np.savetxt(f'saved_models/saved_rl_model_1_lr_{lr}_gamma_{gamma}_ep_{n_episode}_entropic_factor_{entropic_factor}_no_proline.txt',total_reward_episode)