In [None]:
import os
import gc
import re
import time
import random
import numpy as np
import csv
import copy
import glob
import math
import joblib
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import AdamW
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from transformers import BertTokenizer, VisualBertForQuestionAnswering, VisualBertConfig
import scipy
from scipy.stats import mannwhitneyu
from scipy.signal import savgol_filter
from scipy.ndimage import gaussian_filter
from scipy.interpolate import CubicSpline, interp1d
from sklearn import svm
from sklearn.linear_model import SGDClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import KFold
from sklearn.multioutput import MultiOutputRegressor
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, RocCurveDisplay, roc_curve, classification_report, top_k_accuracy_score, coverage_error
import warnings
warnings.filterwarnings("ignore")

def scherrer_fwhm(crystal_size, theta, wavelength=1.5406, shape_factor=0.9):
    theta_rad = np.deg2rad(theta/2)
    fwhm = (shape_factor*wavelength)/(crystal_size*np.cos(theta_rad))
    return fwhm

def load_plt_setting():
    plt.style.use('seaborn-white')
    mpl.rcParams['font.sans-serif'] = "Arial"
    mpl.rcParams['font.family'] = "sans-serif"
    mpl.rcParams['axes.linewidth'] = 2
    font = {'size': 32}
    mpl.rc('font', **font)
    mpl.rcParams['xtick.major.pad']='8'
    mpl.rcParams['ytick.major.pad']='8'
    plt.rcParams["font.weight"] = "normal"
    plt.rcParams["axes.labelweight"] = "normal"
    plt.rcParams['svg.fonttype'] = 'none'
    mpl.rcParams['axes.linewidth'] = 2

tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
_ = tokenizer.add_tokens('pb')

# 1. Dataloader

In [None]:
class DataLoader():
    '''
    [Input]
    dataset_path:      Path to dataset generated from preprocess.py
                       dataset: 'formula', 'element_list', 'space_group', 'xrd_list'
                                'xrd_list': list of (X,Y), len(xrd_list) depends on the strain setting
    
    batch_size:        Batch size n.
    max_n_mix:         Maximum number of componds in XRD mix
    
    [Output]
    Xs:                Array of intensity in (n,Y,1), n = batch_size
    Ys:                Array of classification labels in (n, len(dataset))
    element_list:      List of elements for each sample, n = len(element_list) 
    formula_list:      Human labels, n = len(formula_list)
    '''
    def __init__(self, dataset_path):
        
        try:
            with open(dataset_path, 'rb') as handle:
                self.dataset = joblib.load(handle)
                print('Loading dataset successful.')
        except:
            print("Missing dataset.")
        
        self.sample_list, self.sample_formula_list, self.combination_list = [], [], []
        self.multiphase = {}
        self.len = len(self.dataset)
        
        for sample, self.data in self.dataset.items():
            self.sample_list.append(sample)
            formula = self.data['formula']
            if '-' in formula:
                formula = formula.split('-')
                formula = formula[-1] + '-' + formula[0]
            self.sample_formula_list.append(re.findall(r'\D+', formula) + re.findall(r'\d+', formula))
            
        self.sample_list = [x for _, x in sorted(zip(self.sample_formula_list, self.sample_list))]
        for i,sample in enumerate(self.sample_list):
            self.elements = self.dataset[sample]['element_list']
            self.multiphase[''.join(set(self.elements))] = {'elements':self.elements,'samples':[]}
            print(i,self.dataset[sample]['formula'],self.elements)
            
        for sample in self.sample_list:
            self.elements = self.dataset[sample]['element_list']
            for combination in self.multiphase.keys():
                if set(self.elements).issubset(set(self.multiphase[combination]['elements'])):
                    self.multiphase[combination]['samples'].append(sample)    
        
        self.multiphase = {k: v for k, v in self.multiphase.items() if len(v.get('samples', [])) >= 2}
        print(self.multiphase)
        
    def load_data(self, batch_size=10, twotheta=np.arange(5.00, 60.01, 0.01), 
                  n_mix=[1,2,3], resonable_mixing=False, min_mixing_ratio=0.05, 
                  high_orientation_probability=0.2, crystal_size_range=(5, 20), intensity_variation_range=(0.2, 1), 
                  noise_sigma_list=np.logspace(-4,-2,num=101)):
        
        self.n_mix_list = np.random.choice(n_mix, batch_size, replace=True)
        
        self.Xs = np.zeros((batch_size,len(twotheta),1))
        self.Ys = np.zeros((batch_size,self.len))
        self.element_list, self.formula_list = [], []
        
        i = 0
        while i < batch_size:
            if resonable_mixing and n_mix!=[1]:
                self.sample_idxs = []
                samples = self.multiphase[np.random.choice(list(self.multiphase.keys()))]['samples']
                self.n_mix_list[i] = min(self.n_mix_list[i],len(samples))
                samples = np.random.choice(samples, self.n_mix_list[i], replace=False)
                for sample in samples:
                    self.sample_idxs.append(self.sample_list.index(sample))
            else:
                self.sample_idxs = np.random.choice(self.len, size=self.n_mix_list[i], replace=False)
            self.formulas, self.elements = [], []
            
            self.mixing_ratio = np.random.uniform(min_mixing_ratio,1,len(self.sample_idxs))
            self.mixing_ratio = self.mixing_ratio/np.sum(self.mixing_ratio)
            
            for j, sample_idx in enumerate(self.sample_idxs):
                self.data = self.dataset[self.sample_list[sample_idx]]
                self.formulas.append(self.data['formula'])
                for element in self.data['element_list']:
                    if element not in self.elements:
                        self.elements.append(element)
                self.Ys[i, sample_idx] += 1
                
                self.twotheta_short, self.X_short = self.data['xrd_list'][np.random.randint(0,len(self.data['xrd_list']))]
                if np.random.binomial(1, high_orientation_probability) == 1:
                    high_orientation_peak_index = np.random.choice(np.argpartition(dataloader.X_short,-3)[-3:],1)[0]
                    self.X_short = np.array([self.X_short[high_orientation_peak_index]])
                    self.twotheta_short = np.array([self.twotheta_short[high_orientation_peak_index]])
                
                self.X = np.zeros(twotheta.shape)
                self.X[np.searchsorted(twotheta,self.twotheta_short)] = self.X_short
                
                # Intensity variation
                self.X = self.X * np.random.uniform(*intensity_variation_range, self.X.shape[0])
                # Crystal size broadening
                fwhm = scherrer_fwhm(np.random.uniform(*crystal_size_range), twotheta)
                sigma = np.mean(fwhm)/(2*np.sqrt(2*np.log(2)))
                self.X = gaussian_filter(self.X, sigma=sigma*100)
                
                self.Xs[i,:,0] += self.X/np.max(self.X)*self.mixing_ratio[j]
            
            if np.max(self.Xs[i,:,0]) == 0:
                self.Ys[i, :] = np.zeros(self.Ys[i, :].shape)
                pass
            else:
                self.Xs[i,:,0] = self.Xs[i,:,0]/np.max(self.Xs[i,:,0]) + np.random.normal(0, np.random.choice(noise_sigma_list), len(twotheta))
                self.Xs[i,:,0] = (self.Xs[i,:,0]-np.min(self.Xs[i,:,0]))/(np.max(self.Xs[i,:,0])-np.min(self.Xs[i,:,0]))
                self.formula_list.append(self.formulas)
                self.element_list.append(list(set(np.array(self.elements).flatten())))
                i += 1
            
        return self.Xs, self.Ys, self.element_list, self.formula_list, self.n_mix_list
    
    def load_ref(self, sample_idx=0, twotheta=np.arange(5.00, 60.01, 0.01),):
        
        self.data = self.dataset[self.sample_list[sample_idx]]
        self.twotheta_short, self.X_short = self.data['xrd_list'][int(len(dataloader.data['xrd_list'])/2+0.5)]
        self.X = np.zeros(twotheta.shape)
        self.X[np.searchsorted(twotheta,self.twotheta_short)] = self.X_short
        self.X = self.X/np.max(self.X)
        return self.X

In [None]:
folder = os.getcwd()
cif_folder = os.path.join(folder, 'cif')
dataloader = DataLoader(os.path.join(cif_folder, 'dataset.npy'))

# 2 SVM

In [None]:
for sample_size in [1e2,3e2,1e3,3e3,1e4,3e4,1e5,3e5,1e6,3e6,1e7]:
    batch_size = int(sample_size)
    if sample_size > 1000:
        num_of_epochs = int(batch_size/1000)
        batch_size = 1000
    else:
        num_of_epochs = 1
        
    clf = SGDClassifier(learning_rate='adaptive', eta0=0.01)
    
    time_start = time.time()
    for i in range(num_of_epochs):
        Xs, Ys, element_list, formula_list, _ = dataloader.load_data(batch_size=batch_size, n_mix=[1],
                                                              high_orientation_probability=0.2, crystal_size_range=(5, 20), 
                                                              intensity_variation_range=(0.01, 1))
        Ys_int = [np.where(x==1)[0][0] for x in Ys]

        clf.partial_fit(Xs[:,:-1,0], Ys_int, classes=np.linspace(0,65,66).astype(int))

        if (i+1)%10 == 0:
            print('{}\t{}'.format(i+1, time.time()-time_start))

    with open(os.path.join(folder, 'baselines', 'SVM', 'SVM_{}.pkl'.format(int(sample_size))),'wb') as f:
        pickle.dump(clf,f)

In [None]:
with open(os.path.join(cif_folder, 'test_dataset_1phase_orientation.npy'), 'rb') as handle:
    test_dataset = joblib.load(handle)
    
Xs, Ys, element_list, formula_list = test_dataset['Xs'], test_dataset['Ys'], test_dataset['element_list'], test_dataset['formula_list']
Ys_int = [np.where(x==1)[0][0] for x in Ys]


for sample_size in [1e2,3e2,1e3,3e3,1e4,3e4,1e5,3e5,1e6,3e6,1e7]:

    clf = pickle.load(open(os.path.join(folder, 'baselines', 'SVM', 'SVM_{}.pkl'.format(int(sample_size))), 'rb'))

    Ys_pred_int = clf.predict(Xs[:,:-1,0])

    load_plt_setting()

    report = classification_report(Ys_int, Ys_pred_int, target_names=dataloader.sample_list, output_dict=True)
    matrix = confusion_matrix(Ys_int, Ys_pred_int)

    fig, ax = plt.subplots(figsize=(12, 10))

    mat = ax.matshow(matrix, cmap='hot')
    ax.xaxis.set_ticks_position('bottom')
    ax.xaxis.set_major_locator(mpl.ticker.MaxNLocator(nbins=10, steps=[1, 2, 5, 10]))
    ax.xaxis.set_minor_locator(mpl.ticker.AutoMinorLocator(2))
    ax.yaxis.set_major_locator(mpl.ticker.MaxNLocator(nbins=10, steps=[1, 2, 5, 10]))
    ax.yaxis.set_minor_locator(mpl.ticker.AutoMinorLocator(2))
    ax.tick_params(axis='both',direction='out',length=8,width=2,pad=10,color='black',labelsize=28)
    ax.tick_params(axis='both',which='minor',direction='out',length=4,width=2,pad=10,color='black',labelsize=28)
    # ax.axes.set_xlim([450,950])
    # ax.axes.set_ylim([0,5])
    # # ax.set_xscale('log')

    cbar = plt.colorbar(mat)
    cbar.ax.tick_params(axis='y', direction='out',length=8,width=3,pad=5,labelsize=28)
    cbar.ax.set_ylabel('Number of samples',labelpad=40, rotation=-90)
    cbar.ax.yaxis.set_major_locator(mpl.ticker.MaxNLocator(nbins=4, integer=True, steps=[1, 2, 5, 10]))
    cbar.ax.yaxis.set_minor_locator(mpl.ticker.AutoMinorLocator(2))
    cbar.ax.tick_params(axis='both',which='minor',direction='out',length=4,width=3,pad=5)

    ax.set_xlabel(r'Predicted class', labelpad=20, fontsize=34)
    ax.set_ylabel(r'True class', labelpad=20, fontsize=34)

    plt.text(0.95, 0.91, 'Accuracy: {:.1f}%'.format(report['accuracy']*100), fontsize=32, 
             transform=ax.transAxes, color='#FFFFFF', horizontalalignment='right')
    if sample_size%1000000 == 0:
        data_size = str(int(sample_size//1000000))+'M'
    elif sample_size%1000 == 0:
        data_size = str(int(sample_size//1000))+'k'
    else:
        data_size = int(sample_size)
    plt.text(0.05, 0.06, 'Data size: {}'.format(data_size), fontsize=32, transform=ax.transAxes, color='#FFFFFF', horizontalalignment='left')

    plt.savefig(os.path.join(folder, 'baselines', 'SVM', 'SVM_{}.png'.format(int(sample_size))), format='png', dpi=300, transparent=True, bbox_inches='tight')

# 3. CNN

In [None]:
class Model(nn.Module):
    def __init__(self, input_dim, n_class):
        super(Model, self).__init__()
        self.con1 = nn.LazyConv1d(64, 50, stride=2)
        self.poo1 = nn.MaxPool1d(3,stride=2)
        self.con2 = nn.LazyConv1d(64, 25, stride=3)
        self.poo2 = nn.MaxPool1d(2,stride=3)
        self.fc1 = nn.LazyLinear(2000)
        self.fc2 = nn.LazyLinear(500)
        self.fc3 = nn.LazyLinear(n_class)

    def forward(self, x):
        x = self.con1(x)
        x = self.poo1(x)
        x = self.con2(x)
        x = self.poo2(x)
        x = self.fc1(torch.flatten(x,start_dim=1))
        x = self.fc2(x)
        x = self.fc3(x)
        return x    

In [None]:
train_losses,train_y_true,train_y_pred = [],[],[]

for sample_size in [1e2,3e2,1e3,3e3,1e4,3e4,1e5,3e5,1e6,3e6,1e7]:
# for sample_size in [1e4]:

    Xs, Ys, element_list, formula_list, _ = dataloader.load_data(batch_size=1, n_mix=[1], high_orientation_probability=0.2, 
                                                                 crystal_size_range=(5, 20),intensity_variation_range=(0.01, 1))
    Ys_int = [np.where(x==1)[0][0] for x in Ys]

    model = Model(Xs.shape[1]-1,Ys.shape[1]).to('cuda')

    optimizer = AdamW(model.parameters(), lr=1e-5)
    batch_size = 20
    num_of_epochs = int(sample_size/batch_size)

    time_start = time.time()
    for i in range(num_of_epochs):
        Xs, Ys, element_list, formula_list, _ = dataloader.load_data(batch_size=batch_size, n_mix=[1],high_orientation_probability=0.2, 
                                                                     crystal_size_range=(5, 20),intensity_variation_range=(0.01, 1))

        Y_predict = model(torch.swapaxes(torch.tensor(Xs[:,:-1,:], dtype=torch.float32),1,2).to('cuda'))
        loss = nn.CrossEntropyLoss()(Y_predict, torch.tensor(Ys,dtype=torch.float32).to('cuda'))

        labels = torch.from_numpy(np.array(Ys)).to(dtype=torch.float32)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        y_true = Ys
        train_y_true.append(Ys)
        y_pred = np.array(Y_predict.detach().cpu())
        train_y_pred.append(y_pred)
        train_losses.append(np.array(loss.detach().cpu()))
        
        if (i+1)%10 == 0:
            print('{}\t{:.5f}\t{}'.format(i+1, np.array(loss.detach().cpu()), y_true.argmax(-1)[:15]-y_pred.argmax(-1)[:15]))

    torch.save(model, os.path.join(folder, 'baselines', 'CNN', '{}.pt'.format(int(sample_size))))
    np.savetxt(os.path.join(folder, 'baselines', 'CNN', 'train_loss_{}.csv'.format(int(sample_size))), train_losses, delimiter=',')
    
    del model

In [None]:
with open(os.path.join(cif_folder, 'test_dataset_1phase_orientation.npy'), 'rb') as handle:
    test_dataset = joblib.load(handle)
    
Xs, Ys, element_list, formula_list = test_dataset['Xs'], test_dataset['Ys'], test_dataset['element_list'], test_dataset['formula_list']
Ys_int = [np.where(x==1)[0][0] for x in Ys]


for sample_size in [1e2,3e2,1e3,3e3,1e4,3e4,1e5,3e5,1e6,3e6,1e7]:

    model = torch.load(os.path.join(folder, 'baselines', 'CNN', '{}.pt'.format(int(sample_size))))

    Ys_pred = np.array(model(torch.swapaxes(torch.tensor(Xs[:,:-1,:], dtype=torch.float32),1,2).to('cuda')).detach().cpu())
    Ys_pred_int = np.argmax(Ys_pred, axis=1)
    
    load_plt_setting()

    report = classification_report(Ys_int, Ys_pred_int, target_names=dataloader.sample_list, output_dict=True)
    matrix = confusion_matrix(Ys_int, Ys_pred_int)

    fig, ax = plt.subplots(figsize=(12, 10))

    mat = ax.matshow(matrix, cmap='hot')
    ax.xaxis.set_ticks_position('bottom')
    ax.xaxis.set_major_locator(mpl.ticker.MaxNLocator(nbins=10, steps=[1, 2, 5, 10]))
    ax.xaxis.set_minor_locator(mpl.ticker.AutoMinorLocator(2))
    ax.yaxis.set_major_locator(mpl.ticker.MaxNLocator(nbins=10, steps=[1, 2, 5, 10]))
    ax.yaxis.set_minor_locator(mpl.ticker.AutoMinorLocator(2))
    ax.tick_params(axis='both',direction='out',length=8,width=2,pad=10,color='black',labelsize=28)
    ax.tick_params(axis='both',which='minor',direction='out',length=4,width=2,pad=10,color='black',labelsize=28)
    # ax.axes.set_xlim([450,950])
    # ax.axes.set_ylim([0,5])
    # # ax.set_xscale('log')

    cbar = plt.colorbar(mat)
    cbar.ax.tick_params(axis='y', direction='out',length=8,width=3,pad=5,labelsize=28)
    cbar.ax.set_ylabel('Number of samples',labelpad=40, rotation=-90)
    cbar.ax.yaxis.set_major_locator(mpl.ticker.MaxNLocator(nbins=4, integer=True, steps=[1, 2, 5, 10]))
    cbar.ax.yaxis.set_minor_locator(mpl.ticker.AutoMinorLocator(2))
    cbar.ax.tick_params(axis='both',which='minor',direction='out',length=4,width=3,pad=5)

    ax.set_xlabel(r'Predicted class', labelpad=20, fontsize=34)
    ax.set_ylabel(r'True class', labelpad=20, fontsize=34)

    plt.text(0.95, 0.91, 'Accuracy: {:.1f}%'.format(report['accuracy']*100), fontsize=32, 
             transform=ax.transAxes, color='#FFFFFF', horizontalalignment='right')
    if sample_size%1000000 == 0:
        data_size = str(int(sample_size//1000000))+'M'
    elif sample_size%1000 == 0:
        data_size = str(int(sample_size//1000))+'k'
    else:
        data_size = int(sample_size)
    plt.text(0.05, 0.06, 'Data size: {}'.format(data_size), fontsize=32, transform=ax.transAxes, color='#FFFFFF', horizontalalignment='left')

    plt.savefig(os.path.join(folder, 'baselines', 'CNN', 'CNN_{}.png'.format(int(sample_size))), format='png', dpi=300, transparent=True, bbox_inches='tight')