In [1]:
with open('/home/marko/data/ttt.txt', 'r') as f:
    selfies = f.read().split('\n')

In [2]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('/home/marko/gselfies/mofid_llm_tokenizer')

In [3]:
from datasets import Dataset, DatasetDict
datasets = Dataset.from_text('/home/marko/data/ttt.txt')
datasets.train_test_split(test_size=0.2).save_to_disk('test_datasets')

Found cached dataset text (/home/marko/.cache/huggingface/datasets/text/default-ec4e2832e7949f97/0.0.0)


Saving the dataset (0/1 shards):   0%|          | 0/8000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2000 [00:00<?, ? examples/s]

In [4]:
dataset = DatasetDict.load_from_disk('test_datasets')
trainset = dataset['train']

In [5]:
trainset[:3]

{'text': ['[C] [:0benzene] [Branch] [N] [Branch] [C] [C] [=N] [C] [C] [N] [Ring1] [Branch] [pop] [:0benzene] [Ring2] [O] [pop] [pop] ',
  '[C] [N] [Branch] [C] [pop] [C] [=Branch] [=O] [pop] [N] [:0benzene] [Ring2] [Cl] [pop] [Ring1] [Cl] [pop] ',
  '[C] [O] [:0benzene] [Branch] [S] [=Branch] [=O] [pop] [=Branch] [=O] [pop] [N] [C] [C] [N] [Branch] [C] [C] [Branch] [C] [=Branch] [=O] [pop] [N] [O] [pop] [Ring1] [=Branch] [pop] [C] [=Branch] [=O] [pop] [O] [C] [:0benzene] [pop] ']}

In [32]:
import torch
def my_collator(examples):
    output = tokenizer(
        [e['text'] for e in examples],
        truncation=True,
        max_length=40,
        return_tensors='pt',
        padding=True,
    )
    # output['labels'] = torch.tensor([e['label'] for e in examples])
    return output
    

In [33]:
from torch.utils.data import DataLoader
tloader = DataLoader(
    trainset,
    batch_size=4,
    num_workers=4,
    shuffle=True,
    collate_fn=my_collator
)

In [34]:
_iter = iter(tloader)


In [35]:
batch = next(_iter)

In [36]:
batch['input_ids'].shape

torch.Size([4, 40])

### Transformer

In [38]:
from transformers import GPT2Config, GPT2Model
from ml_collections import ConfigDict
cfg = ConfigDict()
cfg.gptcfg = gpt2cfg = ConfigDict()
gpt2cfg.n_embd = 32
gpt2cfg.n_layer = 4
gpt2cfg.n_head = 4
gpt2cfg.vocab_size = tokenizer.vocab_size
gpt2cfg.n_positions = 40
gpt2cfg = GPT2Config(**{k:v for k,v in cfg.gptcfg.items() if k in GPT2Config().to_dict()})
model = GPT2Model(gpt2cfg)

In [39]:
from torch import nn
emb2nrg = nn.Linear(32,1)

In [40]:
batch['input_ids'].shape

torch.Size([4, 40])

In [42]:
emb = model(**batch)[0][:,0,:]
nrg = emb2nrg(emb)

In [50]:
import torch
import pytorch_lightning as pl
from transformers import GPT2Model
from transformers import GPT2Config
from transformers import get_cosine_schedule_with_warmup

cfg.lr = 1e-3
cfg.warmup_steps = 1000
class NRGPredictor(pl.LightningModule):
    def __init__(self, cfg, tokenizer):
        super().__init__()
        self.tokenizer = tokenizer
        self.cfg = cfg
        self.save_hyperparameters()
        gpt2cfg = GPT2Config(**{k:v for k,v in cfg.gptcfg.items() if k in GPT2Config().to_dict()})
        self.model = GPT2Model(gpt2cfg)
        self.emb2nrg = nn.Linear(32,1)
        self.loss_fn = nn.MSELoss()
        

    def forward(self, **batch):
        emb = self.model(**batch)[0][:,0,:]
        nrg = self.emb2nrg(emb)
        return nrg

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs.mean()
        self.log("train_loss", loss, on_step=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs.mean()
        self.log("val_loss", loss, on_step=True, sync_dist=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.cfg.lr)
        scheduler = get_cosine_schedule_with_warmup(optimizer, self.cfg.warmup_steps, self.cfg.warmup_steps*5)
        return [optimizer], {"scheduler": scheduler, "interval": "step"}



In [52]:
import os
os.environ["WANDB_API_KEY"] ="xxx"

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

wandb_logger = WandbLogger(
    project='test',
    name='asdadfasff',
    log_model=True,
)


trainer = Trainer(
    logger = wandb_logger,
    accelerator='auto',
    gradient_clip_algorithm='norm',
    gradient_clip_val=1.0,
    devices=1,
    max_epochs=10,
    check_val_every_n_epoch=1,
    log_every_n_steps=1,
    callbacks=[
        EarlyStopping(
            monitor="val_loss",
            mode="min",
            patience=10,
        ),
        LearningRateMonitor(logging_interval='step'),
        # ModelCheckpoint(
        #     dirpath=f'{setupparams.base_dir}/{setupparams.experiment_name}_checkpoints/{run_name}',
        #     monitor='val_loss',
        #     every_n_epochs=1,
        #     save_top_k=2,
        # ),
    ]
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [53]:
model = NRGPredictor(cfg, tokenizer)
trainer.fit(model, tloader)

  rank_zero_warn(
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name    | Type      | Params
--------------------------------------
0 | model   | GPT2Model | 71.7 K
1 | emb2nrg | Linear    | 33    
2 | loss_fn | MSELoss   | 0     
--------------------------------------
71.7 K    Trainable params
0         Non-trainable params
71.7 K    Total params
0.287     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [44]:
from constellatio.datasets.pubchemqc import random_cids, get_cids

In [54]:
cids = get_cids(1000, offset=0)

In [55]:
aaa = [c for c in cids]
aaaa = aaa[1:10]

In [63]:
import requests


def fetch_molecule_data(cids):
    base_url = "https://pcqc.matter.toronto.edu/pm6opt_chon300nosalt"
    molecule_data_list = []

    # Convert the list of CIDs into a comma-separated string
    cid_list = ",".join(map(str, cids))

    # Use the `any` operator in the query parameters
    params = {"select": "*", "and": f"(cid.any.{cid_list})"}

    # Debugging: print out the final URL and parameters
    print("Requesting URL:", base_url)
    print("With parameters:", params)

    response = requests.get(base_url, params=params)
    
    if response.status_code == 200:
        molecule_data = response.json()
        if molecule_data:  # Check if data is not empty
            molecule_data_list.extend(molecule_data)  # Add all fetched records to the list
    else:
        print(f"Failed to fetch data with status code {response.status_code}: {response.text}")

    return molecule_data_list

In [64]:
data = fetch_molecule_data(aaaa)

Requesting URL: https://pcqc.matter.toronto.edu/pm6opt_chon300nosalt
With parameters: {'select': '*', 'and': '(cid.eq(any).3,4,7,8,12,15,16,17,18)'}
Failed to fetch data with status code 400: {"code":"PGRST100","details":"unexpected \",\" expecting letter, digit, \"-\", \"->>\", \"->\" or delimiter (.)","hint":null,"message":"\"failed to parse logic tree ((cid.eq(any).3,4,7,8,12,15,16,17,18))\" (line 1, column 20)"}


In [58]:
aaa

[1,
 3,
 4,
 7,
 8,
 12,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 25,
 26,
 28,
 29,
 30,
 32,
 36,
 37,
 38,
 39,
 40,
 42,
 43,
 45,
 47,
 48,
 49,
 50,
 51,
 58,
 63,
 67,
 69,
 70,
 71,
 72,
 75,
 76,
 77,
 79,
 80,
 86,
 87,
 91,
 93,
 95,
 96,
 101,
 102,
 104,
 107,
 108,
 111,
 114,
 116,
 117,
 118,
 119,
 120,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 135,
 137,
 138,
 142,
 144,
 145,
 155,
 156,
 157,
 166,
 167,
 171,
 173,
 174,
 176,
 177,
 178,
 179,
 180,
 185,
 189,
 190,
 191,
 196,
 199,
 203,
 204,
 205,
 206,
 207,
 211,
 215,
 218,
 219,
 221,
 222,
 225,
 227,
 229,
 232,
 236,
 239,
 240,
 241,
 243,
 244,
 247,
 252,
 254,
 261,
 262,
 263,
 264,
 273,
 275,
 277,
 279,
 280,
 281,
 282,
 284,
 286,
 288,
 289,
 296,
 297,
 306,
 307,
 308,
 309,
 310,
 311,
 320,
 322,
 323,
 325,
 326,
 328,
 331,
 332,
 335,
 336,
 337,
 338,
 340,
 342,
 344,
 345,
 346,
 347,
 349,
 354,
 355,
 356,
 358,
 359,
 361,
 362,
 363,
 364,
 366,
 368,
 370,
 