In [1]:
import pandas as pd
from sklearn.metrics import mean_absolute_error
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem import MACCSkeys
from rdkit.Chem import AllChem
from sklearn.model_selection import train_test_split
import lightgbm as lgb
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, classification_report
import pickle

sg99 = sorted([14,2,19,4,61,15,33,29,9,5,1,60,7,18,56,43,88,13,145,92,144,76,96,169,78,170,20,86,41,45,114,146])
Z_list = sorted([4,2,8,1,16,6,3,9])

data = pd.read_csv('../crystal-info_CSD_filtered.csv', index_col=0)
print(data.shape, len(list(set(data['sg_number'].tolist()))))
data = data[data['sg_number'].isin(sg99)]
print(data.shape, len(list(set(data['sg_number'].tolist()))))
data = data[data['Z-value'].isin(Z_list)]
print(data.shape, len(list(set(data['sg_number'].tolist()))))

(170278, 15) 87
(169657, 15) 32
(169657, 15) 32


# Prepare data for ML

In [2]:
value_counts = data['sg_number'].value_counts()
value_counts

sg_number
14     66061
2      29667
19     28668
4      15625
61      8063
15      6990
33      3096
29      1554
5       1532
9       1532
1        872
60       760
7        694
18       660
56       421
43       361
88       340
13       263
145      248
92       241
144      227
76       222
96       219
169      208
78       199
170      187
20       153
86       126
41       125
45       117
114      114
146      112
Name: count, dtype: int64

In [3]:
value_counts = data['Z-value'].value_counts()
value_counts

Z-value
4.0     103783
2.0      45977
8.0      17340
1.0        874
16.0       701
3.0        482
6.0        395
9.0        105
Name: count, dtype: int64

In [4]:
def sg2system(sg):
    if 1 <= sg <= 2:
        system = 'triclinic'
    elif 3 <= sg <= 15:
        system = 'monoclinic'
    elif 16 <= sg <= 74:
        system = 'orthorhombic'
    elif 75 <= sg <= 142:
        system = 'tetragonal'
    elif 143 <= sg <= 167:
        system = 'trigonal'
    elif 168 <= sg <= 194:
        system = 'hexagonal'
    elif 195 <= sg <= 230:
        system = 'cubic'
    return system

data['crystal_system'] = data['sg_number'].apply(sg2system)
data['Mol'] = data['SMILES'].apply(Chem.MolFromSmiles)
data = data[data['Mol'].notnull()].reset_index(drop=True)
data.shape

[10:10:01] Can't kekulize mol.  Unkekulized atoms: 6 7 8 9 10 11 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 81 82 83 84 85 86 87


(169656, 17)

In [5]:
def calculate_MorganFP(mol, n_bits=2048):
    return pd.Series(list(AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=n_bits)))

# MACCS Keys
def calculate_MACCS_keys(mol):
    fp = MACCSkeys.GenMACCSKeys(mol)
    bits = list(fp)
    return pd.Series(bits[1:])

descriptor_data = data['Mol'].apply(calculate_MACCS_keys)
# descriptor_data = data['Mol'].apply(calculate_MorganFP)

In [6]:
train_ratio = 0.8
n_rows = len(data)
train_size = int(n_rows * train_ratio)
df_shuffled = data.sample(frac=1, random_state=0).reset_index(drop=True)
train_df = df_shuffled.iloc[:train_size]
test_df = df_shuffled.iloc[train_size:]
print(train_df.shape)
print(test_df.shape)

(135724, 17)
(33932, 17)


In [7]:
# train_df.to_csv('dataset_train.csv')
# test_df.to_csv('dataset_test.csv')

In [8]:
X = descriptor_data
X_shuffled = X.sample(frac=1, random_state=0).reset_index(drop=True)
X_train = X_shuffled.iloc[:train_size]
X_test = X_shuffled.iloc[train_size:]

# Space group prediction

In [10]:
y_train = train_df['sg_number']
y_test = test_df['sg_number']
model_sg = lgb.LGBMClassifier(class_weight='balanced', n_estimators=1000, verbosity=-1)
model_sg.fit(X_train, y_train)

In [11]:
# Prediction
y_pred_classes = model_sg.predict(X_test)

# Model evaluation
print('Accuracy:', accuracy_score(y_test, y_pred_classes))
print(classification_report(y_test, y_pred_classes))

file = 'Trained_LightGBM_sg_over100.pkl'
# pickle.dump(model_sg, open(file, 'wb'))

Accuracy: 0.30375456795944833
              precision    recall  f1-score   support

           1       0.21      0.17      0.19       189
           2       0.31      0.38      0.34      5908
           4       0.25      0.32      0.28      3092
           5       0.09      0.10      0.10       329
           7       0.11      0.15      0.13       152
           9       0.07      0.10      0.08       324
          13       0.14      0.10      0.11        52
          14       0.51      0.30      0.38     13164
          15       0.09      0.15      0.11      1407
          18       0.16      0.10      0.12       137
          19       0.40      0.37      0.38      5716
          20       0.35      0.20      0.25        41
          29       0.07      0.15      0.09       329
          33       0.08      0.19      0.12       627
          41       0.06      0.04      0.05        23
          43       0.10      0.08      0.09        65
          45       0.18      0.12      0.15        

In [12]:
# Calculate weighted reference probability
import numpy as np
unique, counts = np.unique(y_test, return_counts=True)
total_samples = len(y_test)
class_probs = counts / total_samples
random_accuracy = np.sum(class_probs ** 2)

random_accuracy

0.22219624915279104

In [259]:
model_sg = pickle.load(open('Trained_LightGBM_sg_over100.pkl', 'rb'))
y_test = test_df['sg_number']
# Obtain predicted probabilities on test set
probs = model_sg.predict_proba(X_test)

# Define the probability threshold
threshold = 0.00000001  # You can adjust this value

# Initialize a list to store whether each prediction is considered correct
correct_predictions, num_valid_sg = [], []

# Iterate over each sample in the test set
for i in range(len(y_test)):
    true_class = y_test.tolist()[i]
    index = np.where(model_sg.classes_ == true_class)[0][0]
    predicted_prob = probs[i][index]
    num_valid_sg.append(len(probs[i][probs[i]>=threshold]))
    
    # Check if the predicted probability for the true class exceeds the threshold
    if predicted_prob >= threshold:
        correct_predictions.append(1)  # Correct
    else:
        correct_predictions.append(0)  # Incorrect

# Compute the custom accuracy
custom_accuracy = np.mean(correct_predictions)
print(f"Custom Accuracy with threshold {threshold}: {custom_accuracy:.2f}")

Custom Accuracy with threshold 1e-08: 0.99


In [13]:
smiles = 'O=C1OC(O)c2ccccc21' # NISNAE

def calc_proba_trained_model(smiles):
    trained_model = pickle.load(open('Trained_LightGBM_sg_over100.pkl', 'rb'))
    mol = Chem.MolFromSmiles(smiles)
    feature = calculate_MACCS_keys(mol)
    probabilities = trained_model.predict_proba(pd.DataFrame(feature).transpose())
    return probabilities[0]

def get_valid_classes(probs, classes, threshold=0.01):
    valid_classes = [classes[j] for j in range(len(probs)) if probs[j] >= threshold]
    return valid_classes

prob = calc_proba_trained_model(smiles)
for threshold in [1e-10, 1e-8, 1e-6, 1e-4, 1e-2]:
    list_valid = get_valid_classes(prob, sg99, threshold)
    
    print(threshold, len(list_valid), list_valid)

1e-10 31 [1, 2, 4, 5, 7, 9, 13, 14, 15, 18, 19, 20, 29, 33, 41, 43, 45, 56, 60, 61, 76, 78, 86, 88, 92, 96, 114, 144, 145, 169, 170]
1e-08 24 [1, 2, 4, 5, 7, 9, 13, 14, 15, 18, 19, 29, 33, 43, 56, 60, 61, 76, 86, 88, 92, 144, 145, 170]
1e-06 20 [1, 2, 4, 5, 7, 9, 14, 15, 18, 19, 29, 33, 43, 56, 60, 61, 86, 88, 92, 145]
0.0001 16 [1, 2, 4, 5, 7, 9, 14, 15, 18, 19, 29, 33, 43, 56, 60, 61]
0.01 8 [2, 4, 14, 15, 19, 29, 33, 61]


# Density prediction

In [14]:
y_train = train_df['density']
y_test = test_df['density']
model_density = lgb.LGBMRegressor(n_estimators=1000, verbosity=-1)
model_density.fit(X_train, y_train)

y_pred = model_density.predict(X_test)

# MAE
mae = mean_absolute_error(y_test, y_pred)
print(f'MAE: {mae:.3f} (g/cm^3)')

# Model save
file = 'Trained_LightGBM_density_over100.pkl'
# pickle.dump(model_density, open(file, 'wb'))

MAE: 0.049 (g/cm^3)


In [16]:
# Baseline model (mean model)
mae = mean_absolute_error(np.ones(len(y_test))*np.mean(y_train), y_test)
print(f'MAE: {mae:.3f} (g/cm^3)')

MAE: 0.125 (g/cm^3)


In [15]:
# Save inference results
y_train_pred = model_density.predict(X_train)
y_test_pred = model_density.predict(X_test)

dens_train_pred = pd.DataFrame({
    'refcode': train_df['refcode'].tolist(),
    'y_exp': train_df['density'].tolist(),
    'y_pred': list(y_train_pred)
})
dens_test_pred = pd.DataFrame({
    'refcode': test_df['refcode'].tolist(),
    'y_exp': test_df['density'].tolist(),
    'y_pred': list(y_test_pred)
})
# dens_train_pred.to_csv('density_train.csv')
# dens_test_pred.to_csv('density_test.csv')

In [165]:
from sklearn.neural_network import MLPRegressor

model = MLPRegressor(
    hidden_layer_sizes=(100, 100, 20, 20),
    random_state=42
)

model.fit(X_train, y_train)

In [1]:
# from sklearn.gaussian_process import GaussianProcessRegressor
# from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C

# # ガウス過程回帰モデルの定義
# kernel = C(1.0, (1e-3, 1e3)) * RBF([1.0] * X.shape[1], (1e-2, 1e2))
# model = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=10, random_state=42)

# # モデルの学習
# model.fit(X_train, y_train)

# Z-value prediction

In [18]:
y_train = train_df['Z-value']
y_test = test_df['Z-value']

model_Z = lgb.LGBMClassifier(class_weight='balanced', n_estimators=1000, verbosity=-1)
model_Z.fit(X_train, y_train)

# 予測
y_pred_classes = model_Z.predict(X_test)
print('Accuracy:', accuracy_score(y_test, y_pred_classes))
print(classification_report(y_test, y_pred_classes))

file = 'Trained_LightGBM_Z_over100.pkl'
# pickle.dump(model_Z, open(file, 'wb'))

Accuracy: 0.4275610043616645
              precision    recall  f1-score   support

         1.0       0.14      0.19      0.16       189
         2.0       0.36      0.45      0.40      9151
         3.0       0.18      0.16      0.17        86
         4.0       0.66      0.44      0.53     20754
         6.0       0.21      0.14      0.17        87
         8.0       0.15      0.35      0.21      3516
         9.0       0.09      0.05      0.06        22
        16.0       0.07      0.11      0.09       127

    accuracy                           0.43     33932
   macro avg       0.23      0.24      0.22     33932
weighted avg       0.52      0.43      0.45     33932



In [24]:
# Calculate weighted reference probability
unique, counts = np.unique(y_test, return_counts=True)
total_samples = len(y_test)
class_probs = counts / total_samples
random_accuracy = np.sum(class_probs ** 2)

random_accuracy

0.4576235356206585

In [20]:
# Obtain predicted probabilities on test set
probs = model_Z.predict_proba(X_test)

# Define the probability threshold
threshold = 0.01  # You can adjust this value

# Initialize a list to store whether each prediction is considered correct
correct_predictions, num_valid_Z = [], []

# Iterate over each sample in the test set
for i in range(len(y_test)):
    true_class = y_test.tolist()[i]
    index = np.where(model_Z.classes_ == true_class)[0][0]
    predicted_prob = probs[i][index]
    num_valid_Z.append(len(probs[i][probs[i]>=threshold]))
    
    # Check if the predicted probability for the true class exceeds the threshold
    if predicted_prob >= threshold:
        correct_predictions.append(1)  # Correct
    else:
        correct_predictions.append(0)  # Incorrect

# Compute the custom accuracy
custom_accuracy = np.mean(correct_predictions)
print(f"Custom Accuracy with threshold {threshold}: {custom_accuracy:.2f}")

Custom Accuracy with threshold 0.01: 0.99


In [22]:
smiles = 'O=C1OC(O)c2ccccc21' # NISNAE

def calc_proba_trained_model(smiles):
    trained_model = pickle.load(open('Trained_LightGBM_Z_over100.pkl', 'rb'))
    mol = Chem.MolFromSmiles(smiles)
    feature = calculate_MACCS_keys(mol)
    probabilities = trained_model.predict_proba(pd.DataFrame(feature).transpose())
    return probabilities[0]

def get_valid_classes(probs, classes, threshold=0.01):
    valid_classes = [classes[j] for j in range(len(probs)) if probs[j] >= threshold]
    return valid_classes

prob = calc_proba_trained_model(smiles)
list_valid = get_valid_classes(prob, Z_list, 0.000001)

print(len(list_valid), list_valid)

7 [1, 2, 3, 4, 6, 8, 16]
