### Prepare data

In [None]:
import os
import numpy as np
import json
from glob import glob

In [None]:
from moftransformer.examples import example_path
from moftransformer.utils import prepare_data

# Get example path
root_cifs = example_path['root_cif']
root_dataset = example_path['root_dataset']
downstream = example_path['downstream']

train_fraction = 0  
test_fraction = 1   

In [None]:
prepare_data('../database/structures/cif/','../database/features/dataset', downstream=downstream, 
             train_fraction=train_fraction, test_fraction=test_fraction)


In [None]:
cif_tmp_json = {cif[:-5]: 1 for cif in os.listdir('dataset/total/') if cif.endswith('grid')}
with open('../database/features/dataset/test.json', 'w') as f:
    json.dump(cif_tmp_json, f)

### MOFTransformer

In [None]:
import os
import torch

os.environ['CUDA_VISIBLE_DEVICES']='1'

print(torch.cuda.is_available())
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
import torch
import torch.nn as nn
import functools
from typing import Optional

from torch.utils.data import DataLoader

from pytorch_lightning import LightningDataModule
from moftransformer.datamodules.dataset import Dataset
from tqdm import tqdm

In [None]:
import sys
import os
import copy
import warnings
from pathlib import Path
import shutil

import pytorch_lightning as pl

from moftransformer.config import ex
from moftransformer.config import config as _config
from moftransformer.datamodules.datamodule import Datamodule
from moftransformer.modules.module import Module
from moftransformer.utils.validation import (
    get_valid_config,
    get_num_devices,
    ConfigurationError,
)

warnings.filterwarnings(
    "ignore", ".*Trying to infer the `batch_size` from an ambiguous collection.*"
)


_IS_INTERACTIVE = hasattr(sys, "ps1")

In [None]:
from moftransformer.config import config as _config
downstream=None
log_dir="logs/"
test_only=True

config = copy.deepcopy(_config())

config["root_dataset"] ='../database/features/dataset/'
config["downstream"] = ''
config["log_dir"] = log_dir
config["test_only"] = test_only
config['load_path']='{moftransformer_path}/moftransformer.ckpt'
_config = config
_config = copy.deepcopy(_config)
pl.seed_everything(_config["seed"])

_config = get_valid_config(_config)

dm = Datamodule(_config)
model = Module(_config)
exp_name = f"{_config['exp_name']}"
dm.setup(stage='test')
dataloader = dm.test_dataloader()

In [None]:
model.eval()
all_features= []
cif_ids = []
with torch.no_grad():
    for i, batch in enumerate(tqdm(dataloader)):
        output = model(batch)
        features = output['cls_feats']
        all_features.append(features.detach().numpy())
        cif_ids += (batch['cif_id'])
        break
    all_features = np.concatenate(all_features)

In [None]:
mof_desc_dict={ids.split('_')[0] : feat for ids, feat in zip(cif_ids, all_features)}

In [None]:
mof_desc_dict = {key: value.tolist() for key, value in mof_desc_dict.items()}

### Save MOF Features

In [None]:
with open('../database/features/moftransformer.json', 'w') as f:
    json.dump(mof_desc_dict,f)