Goal: Load a MEGNet model from .pth, cut off layers after the first output MLP layer, then run inference on the dataset. Save the vector at the output of the MLP layer, with filename associated with the MP ID.

In [None]:
import os
import pandas as pd
import json
import yaml
import matplotlib.pyplot as plt
import glob

##torch imports
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.data import DataLoader, Dataset, Data, InMemoryDataset
from torch_geometric.utils import dense_to_sparse, degree, add_self_loops
from torch_geometric.nn.models import meta

import ase
from ase import io
import pymatgen as pmg
from matdeeplearn.models.megnet import MEGNet
from matdeeplearn.training.training import evaluate

from tqdm import tqdm

MODEL_PATH = 'matdeeplearn/MEGNet_allmats.pth'
MDL_CONFIG_PATH = 'matdeeplearn/config.yml'
OUT_DIR = 'mdl_data/representations'

data_path = 'mdl_data/BGML_data/BGML_train_noe/'
processed_path = 'processed'

target_index = 0

In [None]:
class StructureDataset(InMemoryDataset):
    def __init__(self, data_path, processed_path="processed", transform=None, pre_transform=None):
        self.data_path = data_path
        self.processed_path = processed_path
        super(StructureDataset, self).__init__(data_path, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_dir(self):
        return os.path.join(self.data_path, self.processed_path)

    @property
    def processed_file_names(self):
        file_names = ["data.pt"]
        return file_names

class GetY(object):
    def __init__(self, index=0):
        self.index = index

    def __call__(self, data):
        # Specify target.
        if self.index != -1:
            data.y = data.y[0][self.index]
        return data

In [None]:
transforms = GetY(index=target_index)
if os.path.exists(os.path.join(data_path, processed_path, "data.pt")) == True:
    dataset = StructureDataset(
        data_path,
        processed_path,
        transforms,
    )

loader = DataLoader(dataset, batch_size=64, shuffle=False)

data_structure_ids = [x.structure_id[0][0] for x in dataset]
df_data_ids = pd.DataFrame(data_structure_ids)



In [None]:
# Read in MEGNet config (MatDeepLearn/config.yml)
with open(MDL_CONFIG_PATH, 'r') as file:
    config = yaml.load(file, Loader=yaml.FullLoader)
c = config['Models']['MEGNet_demo']

# Make the MEGNet
device = 'cuda'
model = MEGNet(dataset, c['dim1'], c['dim2'], c['dim3'], c['pre_fc_count'],
               c['gc_count'], c['gc_fc_count'], c['post_fc_count'],
               c['pool'], c['pool_order'], c['batch_norm'], c['batch_track_stats'],
               c['act'], c['dropout_rate']
              ).to(device)
# Reload parameters
d = torch.load(MODEL_PATH)
model.load_state_dict(d['model_state_dict'])

# Trim off the linear layers (may need to change this if the MEGNet was made differently)
model.post_lin_list[1] = nn.Identity()
model.post_lin_list[2] = nn.Identity()
model.lin_out = nn.Identity()

#model

In [None]:
# Evaluate on some data
model.eval()

for data in tqdm(loader):
    data = data.to(device)
    with torch.no_grad():
        out = model(data)
    # save the representation vectors
    for i, name in enumerate(data.structure_id):
        fn = os.path.join(OUT_DIR, name[0][0] + '_repr.pt') 
        torch.save(out[i], fn)

100%|████████████████████████████████████████████████████████████████████████████████| 290/290 [00:49<00:00,  5.92it/s]
