In [None]:
import os
import numpy as np
import pandas as pd
import json
import yaml
from easydict import EasyDict
from glob import glob
from moftransformer.utils import prepare_data
import torch
import torch.nn as nn
import functools
from typing import Optional
from torch.utils.data import DataLoader
from pytorch_lightning import LightningDataModule
from moftransformer.datamodules.dataset import Dataset
from tqdm import tqdm
import sys
import copy
import warnings
from pathlib import Path
import shutil
import pytorch_lightning as pl
from moftransformer.config import ex
from moftransformer.config import config as _config
from moftransformer.datamodules.datamodule import Datamodule
from moftransformer.modules.module import Module
from moftransformer.utils.validation import (
    get_valid_config,
    get_num_devices,
    ConfigurationError,
)
from simpletransformers.classification import ClassificationModel, ClassificationArgs
from protonmof.model.model import MOFProtonModel
from protonmof.data.dataset_finetune import MOFProtonDataset
from protonmof.utils import split_by_cif, Normalizer, calculate_con
warnings.filterwarnings(
    "ignore", ".*Trying to infer the `batch_size` from an ambiguous collection.*"
)

_IS_INTERACTIVE = hasattr(sys, "ps1")
os.environ['CUDA_VISIBLE_DEVICES']='1'
print(torch.cuda.is_available())
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
target_cif_dir = 'cif/'
cif_names = ['RIMROX']
target_guest_smiles = ['CN(C)CO']
target_guest_names = ['DMF']
T, RH = [298], [98]

ckpt_dir = '../ckpt/best.ckpt'


In [None]:
exp_data = pd.DataFrame({'Name': cif_names, 
              'Temperature': T, 
              'RH': RH, 
              'Guest': target_guest_names,
             'proton conductivity': [0 for _ in range(len(cif_names))]})
#exp_data.to_csv('exp_data.csv', index=None)

### MOFTransformer

In [None]:
train_fraction = 0
test_fraction = 1 
target_data_dir = 'target'
try:
    prepare_data(target_cif_dir, target_data_dir,  downstream='example',
                 train_fraction=train_fraction, test_fraction=test_fraction)
except Exception as e:
    print(e)

In [None]:
cif_tmp_json = {cif[:-5]: 1 for cif in os.listdir(f'{target_data_dir}/test/') if cif.endswith('grid')}
with open(f'{target_data_dir}/test.json', 'w') as f:
    json.dump(cif_tmp_json, f)

In [None]:
from moftransformer.config import config as _config
downstream=None
log_dir="logs/"
test_only=True

config = copy.deepcopy(_config())
# for key in kwargs.keys():
#     if key not in config:
#         raise ConfigurationError(f"{key} is not in configuration.")

# config.update(kwargs)
config["root_dataset"] = f'{target_data_dir}/'
config["downstream"] = ''
config["log_dir"] = log_dir
config["test_only"] = test_only
config['load_path']='/home/seunghh/anaconda3/envs/protonmof/lib/python3.8/site-packages/moftransformer/database/moftransformer.ckpt'
_config = config
_config = copy.deepcopy(_config)
pl.seed_everything(_config["seed"])

_config = get_valid_config(_config)

dm = Datamodule(_config)
model = Module(_config)
exp_name = f"{_config['exp_name']}"
dm.setup(stage='test')
dataloader = dm.test_dataloader()

In [None]:
model.eval()
all_features= []
cif_ids = []
with torch.no_grad():
    for i, batch in enumerate(tqdm(dataloader)):
        output = model(batch)
        features = output['cls_feats']
        all_features.append(features.detach().numpy())
        cif_ids += (batch['cif_id'])

    all_features = np.concatenate(all_features)

mof_desc_dict={ids : feat for ids, feat in zip(cif_ids, all_features)}
mof_desc_dict = {key: value.tolist() for key, value in mof_desc_dict.items()}
with open('./mof_features_eval.json', 'w') as f:
    json.dump(mof_desc_dict,f)

### ChemBERT

In [None]:
chembert = ClassificationModel('roberta', 'seyonec/PubChem10M_SMILES_BPE_396_250', 
                            num_labels=1,
                            args={'evaluate_each_epoch': True, 
                                  'evaluate_during_training_verbose': True,
                                  'no_save': False, 'num_train_epochs': 10, 
                                  'regression' : True,
                                  'auto_weights': True}) # You can set class weights by using the optional weight argument
model = chembert.model
tokens = chembert.tokenizer(target_guest_smiles, add_special_tokens=True, truncation=True, 
                                 max_length=256, padding="max_length", 
                              return_tensors='pt',
                              return_offsets_mapping=False)
for k, v in tokens.items():
    tokens[k] = torch.tensor(v, dtype=torch.long,).to(model.device)  
    
model.eval()
with torch.no_grad():
    outputs=model.roberta(tokens['input_ids'], tokens['attention_mask'])[0][:,0,:]
    outputs = outputs.detach().numpy()

smiles_feat = {name: feat.tolist() for name, feat in zip(target_guest_names, outputs)}

with open('guest_features_eval.json', 'w') as json_file:
    json.dump(smiles_feat, json_file)

In [None]:
with open('config.yml', 'r') as f:
    config = EasyDict(yaml.safe_load(f))

In [None]:
prop_scaler = Normalizer(mean = config.dataset.mean, std = config.dataset.std)
t_scaler = Normalizer(mean = config.dataset.t_mean, std = config.dataset.t_std)
rh_scaler = Normalizer(mean = config.dataset.rh_mean, std = config.dataset.rh_std)
proton_model = MOFProtonModel.load_from_checkpoint(ckpt_dir,  config=config, scaler = prop_scaler,  strict=False)

In [None]:


test_data = MOFProtonDataset(proton_data = exp_data,
                         config=config,
                             scaler = prop_scaler,
                         t_scaler = t_scaler,
                         rh_scaler = rh_scaler,
                             
                          )

test_loader =DataLoader(test_data, 1 , num_workers = config.train.num_workers, 
                         shuffle=False)   


In [None]:
all_con_pred = []
proton_model.eval()
for batch in test_loader:
    output  = proton_model(batch)
    con_pred = calculate_con(output, batch, proton_model.arr)
    pred = con_pred.cpu().tolist()
    all_con_pred += list(np.array(pred).reshape(-1))
    
all_con_pred = proton_model.scaler.decode(np.array(all_con_pred))

In [None]:
all_con_pred

In [None]:
print(f'Proton Conductivity (predicted): {all_con_pred}')