In [7]:
import sys, os
import math, random, itertools, pickle
from collections import defaultdict, OrderedDict
import logging
from tqdm import tqdm
import yaml, psutil
from addict import Dict
import numpy as np, pandas as pd
import concurrent.futures as cf
sys.path.append('/workspace')
sys.path.append("../..")
from tools.logger import add_stream_handler
logger = logging.getLogger()
add_stream_handler(logger, logging.DEBUG)

import matplotlib.pyplot as plt
from tools.graph import tightargs, errorbarargs, COLORS2, setfont2, get_scatter_style
setfont2()

import rdkit
from rdkit import Chem
from rdkit.Chem.Draw import rdMolDraw2D as rm
from tools.rdkit import ignore_warning
import py3Dmol
ignore_warning()
print(rdkit.__version__)

train_dir = "../../training/results"


2024.03.1


## 1. 生成物の評価

### 2次元画像で保存

In [12]:
from rdkit.Chem import rdDepictor
from rdkit.Geometry.rdGeometry import Point3D
def draw_mols(sname, step, protocol='241019', n=10):
    rdir = f"./results/{protocol}/{sname}/{step}/mol"
    fdir = f"./graph/{protocol}/{sname}/mol/{step}"
    os.makedirs(fdir, exist_ok=True)
    rng = np.random.default_rng(0)
    with open(f"{rdir}/string.txt") as f:
        smiles = f.read().splitlines()

    idxs = np.arange(len(smiles))
    if n < len(idxs): idxs = rng.choice(idxs, n)
    for idx in idxs:
        smi = smiles[idx]
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            print(f"{idx}: SMILES is invalid")
            continue
        coord = pd.read_csv(f"{rdir}/coord/{idx}.csv", header=None).values
        if len(coord) < mol.GetNumAtoms():
            print(f"{idx}: Atom num mismatch: {mol.GetNumHeavyAtoms()} vs {len(coord)}")
            continue

        # set conformer
        rdDepictor.Compute2DCoords(mol)
        conf = mol.GetConformer()
        for i in range(mol.GetNumAtoms()):
            conf.SetAtomPosition(i, Point3D(*coord[i]))
        
        view = rm.MolDraw2DCairo(300,300,300,300)
        view.DrawMolecule(mol)
        view.FinishDrawing()
        view.WriteDrawingText(f"{fdir}/{idx}.png")

In [13]:
draw_mols('241101_all', 43000, n=25)

4: SMILES is invalid


### py3Dmolで描画

In [None]:

def draw3D(sname, step, protocol='241019'):
    rng = np.random.default_rng(0)
    rdir=f"./results/{protocol}/{sname}/{step}/mol"
    with open(f"{rdir}/string.txt") as f:
        smiles = f.read().splitlines()

    mols = []
    idxs = np.arange(len(smiles))
    rng.shuffle(idxs)
    for idx in idxs:
        smi = smiles[idx]
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            print("SMILES is invalid")
            continue
        coord = pd.read_csv(f"{rdir}/coord/{idx}.csv",
                    header=None).values
        if len(coord) < mol.GetNumAtoms():
            print(f"Atom num mismatch: {mol.GetNumHeavyAtoms()} vs {len(coord)}")
            continue
        coord = coord[:mol.GetNumAtoms()]
        if np.any(np.isnan(coord)):
            print(f"Nan in coord: {smi}")
            continue
        coord = coord - np.mean(coord, axis=0, keepdims=True)

        # set conformer
        rdDepictor.Compute2DCoords(mol)
        conf = mol.GetConformer()
        for i in range(mol.GetNumAtoms()):
            conf.SetAtomPosition(i, Point3D(*coord[i]))
        mols.append(mol)
        if len(mols) == 10:
            break
    view = py3Dmol.view(width=1500, height=600, viewergrid=(2,5), linked=True)
    for i, mol in enumerate(mols):
        view.addModel(Chem.MolToMolBlock(mol), 'sdf', viewer=divmod(i, 5))
    view.setStyle('stick')

    return view

In [9]:
view = draw3D("241101_all", 43000)
view.show()
view.png()


SMILES is invalid


## 7. デコードしたタンパク質を描画

In [None]:
sname = '241019_protein_only'
step = 54400
def draw_protein(sname, step, protocol='241019'):
    gdir = 
    os.makedirs(f"/workspace/cplm/training/results/{sname}/generate/241019_generate/{step}/prot_graph", exist_ok=True)
    with open(f"/workspace/cplm/training/results/{sname}/generate/241019_generate/{step}/prot.txt") as f:
        atomss = f.read().splitlines()
    for i in range(25):
        atoms = np.array(atomss[i].split(','))
        atoms = atoms[atoms == 'CA']
        coords = pd.read_csv(f"/workspace/cplm/training/results/{sname}/generate/241019_generate/{step}/prot_coord/{i}.csv").values
        coords= coords[:len(atoms)]
        if len(coords) < len(atoms): continue
        
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
        ax.scatter(coords[:,0], coords[:,1], coords[:,2])
        ax.legend()
        fig.savefig(f"/workspace/cplm/training/results/{sname}/generate/241019_generate/{step}/prot_graph/{i}.png", **tightargs)
        plt.close(fig)



## -1. experiments

### 241019 なぜ241018_protein_onlyで生成できないのか?

In [39]:
# ロスのスケールを調べる。
sname = '241018_protein_only'
train_subset = 'valid'
WORKDIR = "/workspace"
batch_first = False
rdir = f"{train_dir}/{sname}"

args = Dict(yaml.load(open(f"{rdir}/config.yaml"), yaml.Loader))
tokenizer = MoleculeProteinTokenizer(coord_min=-args.coord_range, coord_sup=args.coord_range)
coord_transform = CoordTransform(args.seed, args.normalize_coord, args.random_rotate, args.coord_noise_std)


train_prot_data = ProteinDataset(f"{WORKDIR}/cheminfodata/unimol/pockets/{train_subset}.lmdb", tokenizer, coord_transform)
train_data = train_prot_data
train_loader = DataLoader(train_data, shuffle=True, num_workers=28, pin_memory=args.pin_memory)
train_iter = train_loader.__iter__()
next_item = None

model = Model(8, 768, 12, 4, 0.1, 'gelu', True, 
        tokenizer.voc_size, tokenizer.pad_token)

print(model.load_state_dict(remove_module(torch.load(f"{rdir}/models/20000.pth"))))
model.to(torch.bfloat16)
model.to(device)
# model = DistributedDataParallel(model)
criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=tokenizer.pad_token)

  model.load_state_dict(remove_module(torch.load(f"{rdir}/models/20000.pth")))


In [40]:
import logging
from time import time
from torch.nn.utils.rnn import pad_sequence
logger = logging.getLogger()


# get batch
data_start = time()
batch = []
max_length = 0
n_accum_token = 0
while True:
    if next_item is None:
        try:
            next_item = train_iter.__next__().squeeze(0)
        except StopIteration:
            break
    if ((len(batch)+1) * max(max_length, len(next_item)) <= args.token_per_batch):
        batch.append(next_item)
        max_length = max(max_length, len(next_item))
        n_accum_token += len(next_item)
        next_item = None
    else:
        break
batch = pad_sequence(batch, batch_first=batch_first,
        padding_value=tokenizer.pad_token).to(torch.long)
batch = batch.to(device)

with torch.autocast('cuda', dtype=torch.bfloat16):
    pred = model(batch[:-1])
    loss = criterion(pred.reshape(-1, tokenizer.voc_size), batch[1:].ravel())

In [44]:
df = pd.read_csv("/workspace/cplm/training/results/241018_protein_only/step_data/0.csv")
loss_per_token = (df['loss'] / df['n_token']).values

In [31]:
# 初期値での重みは同じくらい。
print(loss.item() / n_accum_token)
print(loss_per_token[:10])

7.400406483640164
[7.4856306  7.4989764  7.48326189 7.48724936 7.48952367 7.47931202
 7.47376649 7.48790399 7.48443156 7.49268806]


In [47]:
print(loss.item() / n_accum_token)
print(loss_per_token[19690:])

12.778897545919904
[2.32423619 2.33471101 2.33297505 2.3349361  2.3288591  2.33813169
 2.32481359 2.32874027 2.34877777 2.33083808]
