In [1]:
import os
import torch
import pandas as pd
from polymerlearn.utils import get_IV_add, GraphDataset

# Load data from local path:
data = pd.read_csv(os.path.join('/Users/owenqueen/Desktop/eastman_project-confidential/Eastman_Project/CombinedData', 
            'pub_data.csv'))

add = get_IV_add(data)

dataset = GraphDataset(
    data = data,
    structure_dir = '../Structures/AG/xyz',
    Y_target=['IV'],
    test_size = 0.2,
    add_features=add
)

  result = getattr(ufunc, method)(*inputs, **kwargs)


In [2]:
from polymerlearn.models.gnn import PolymerGNN_IV
from polymerlearn.utils import train

model_kwargs = {
    'input_feat': 6,         # How many input features on each node; don't change this
    'hidden_channels': 32,   # How many intermediate dimensions to use in model
                            # Can change this ^^
    'num_additional': 4      # How many additional resin properties to include in the prediction
                            # Corresponds to the number in get_IV_add
}

model = PolymerGNN_IV(**model_kwargs)

optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)
criterion = torch.nn.MSELoss()

train(
    model,
    optimizer = optimizer,
    criterion = criterion,
    dataset = dataset,
    batch_size = 64,
    epochs = 800
)

  return self.test_data, torch.tensor(self.Ytest).float(), torch.tensor(self.add_test).float()


Epoch: 0, 	 Train r2: -0.8997 	 Train Loss: 4.9989 	 Test r2: -0.1446 	 Test Loss 0.0437
Epoch: 10, 	 Train r2: 0.1496 	 Train Loss: 2.2337 	 Test r2: 0.4057 	 Test Loss 0.0227
Epoch: 20, 	 Train r2: 0.3836 	 Train Loss: 2.1899 	 Test r2: 0.5019 	 Test Loss 0.0190
Epoch: 30, 	 Train r2: 0.3142 	 Train Loss: 1.9694 	 Test r2: 0.5087 	 Test Loss 0.0188
Epoch: 40, 	 Train r2: 0.3454 	 Train Loss: 2.0457 	 Test r2: 0.5328 	 Test Loss 0.0178
Epoch: 50, 	 Train r2: 0.4531 	 Train Loss: 2.4713 	 Test r2: 0.5583 	 Test Loss 0.0169
Epoch: 60, 	 Train r2: 0.3992 	 Train Loss: 1.3919 	 Test r2: 0.5811 	 Test Loss 0.0160
Epoch: 70, 	 Train r2: 0.4733 	 Train Loss: 1.4838 	 Test r2: 0.5910 	 Test Loss 0.0156
Epoch: 80, 	 Train r2: 0.5291 	 Train Loss: 1.4367 	 Test r2: 0.6107 	 Test Loss 0.0149
Epoch: 90, 	 Train r2: 0.1988 	 Train Loss: 1.6848 	 Test r2: 0.6455 	 Test Loss 0.0135
Epoch: 100, 	 Train r2: 0.4306 	 Train Loss: 1.2140 	 Test r2: 0.6425 	 Test Loss 0.0136
Epoch: 110, 	 Train r2: 0.4484

In [None]:
from polymerlearn.explain import PolymerGNN_EXPLAIN, PolymerGNNExplainer

mexplain = PolymerGNN_EXPLAIN(**model_kwargs)
mexplain.load_state_dict(model.state_dict()) # Load weights from trained model over to explaining one

explainer = PolymerGNNExplainer(mexplain)

test_batch, Ytest, add_test = dataset.get_test()
test_inds = dataset.test_maks

exp_summary = []

for i in range(Ytest.shape[0]):
    scores = explainer.get_explanation(test_batch[i], add_test[i])
    scores['A'] = torch.sum(scores['A'], dim = 1)
    scores['G'] = torch.sum(scores['G'], dim = 1)
    scores['table_ind'] = test_inds[i]

    exp_summary.append(scores)

  torch.tensor(add_test).float())


In [9]:
import numpy as np

data_mask = data.loc[data['IV'].notna(),:]

for i in range(len(test_inds)):
    j = test_inds[i]
    mw_table = data_mask['Mw (PS)'].iloc[j]
    # Test mw from add_test
    mw_add_test = add_test[i,0]

    print('Table: {} \t log(Table): {} \t Add test: {}'.format(
        mw_table, np.log(mw_table), mw_add_test.item()
    ))

Table: 18022.0 	 log(Table): 9.799348512794982 	 Add test: 9.799348831176758
Table: 41426.0 	 log(Table): 10.631663982015468 	 Add test: 10.631664276123047
Table: nan 	 log(Table): nan 	 Add test: 10.427446365356445
Table: 45044.0 	 log(Table): 10.715395068816916 	 Add test: 10.715394973754883
Table: 97437.0 	 log(Table): 11.486961294292419 	 Add test: 11.486961364746094
Table: 47245.0 	 log(Table): 10.763102107216726 	 Add test: 10.763102531433105
Table: nan 	 log(Table): nan 	 Add test: 10.427446365356445
Table: 17656.0 	 log(Table): 9.778830947936573 	 Add test: 9.778830528259277
Table: 33774.0 	 log(Table): 10.427446554692082 	 Add test: 10.427446365356445
Table: 60549.0 	 log(Table): 11.01120823356823 	 Add test: 11.011208534240723
Table: 21043.0 	 log(Table): 9.954323242238624 	 Add test: 9.954322814941406
Table: 65814.0 	 log(Table): 11.09458786063939 	 Add test: 11.094588279724121
Table: 29554.0 	 log(Table): 10.293974377463579 	 Add test: 10.293973922729492
Table: 20882.0 	 lo

In [23]:
# Summarize importance scores:

# Mw summary:
acid_scores = []
glycol_scores = []

mw_scores = []
an_scores = []
ohn_scores = []
tmp_scores = []

def get_AG_info(data, ac = (20,33), gc = (34,46)):

    # Decompose the data into included names
    acid_names = pd.Series([c[1:] for c in data.columns[ac[0]:ac[1]].tolist()])
    glycol_names = pd.Series([c[1:] for c in data.columns[gc[0]:gc[1]].tolist()])

    # Holds all names of acids and glycols
    acid_included = []
    glycol_included = []

    # Keep track of percents in each acid, glycol
    acid_pcts = []
    glycol_pcts = []

    # Get relevant names and percentages of acid/glycols
    for i in range(data.shape[0]):

        acid_hit = (data.iloc[i,ac[0]:ac[1]].to_numpy() > 0)
        glycol_hit = (data.iloc[i,gc[0]:gc[1]].to_numpy() > 0)

        # Add to percentage lists:
        acid_pcts.append(data.iloc[i,ac[0]:ac[1]][acid_hit].tolist())
        glycol_pcts.append(data.iloc[i,gc[0]:gc[1]][glycol_hit].tolist()) 

        acid_pos = acid_names[np.argwhere(acid_hit).flatten()].tolist()
        glycol_pos = glycol_names[np.argwhere(glycol_hit).flatten()].tolist()

        acid_included.append(acid_pos)
        glycol_included.append(glycol_pos)

    return acid_included, glycol_included, acid_pcts, glycol_pcts

acid_names = pd.Series([c[1:] for c in data_mask.columns[20:33].tolist()])
glycol_names = pd.Series([c[1:] for c in data_mask.columns[34:46].tolist()])
acids, glycols, _, _ = get_AG_info(data_mask)

acid_key = {a:[] for a in acid_names}
glycol_key = {g:[] for g in glycol_names}

for i in range(len(exp_summary)):

    df_ind = exp_summary[i]['table_ind']

    for a in range(len(acids[df_ind])):
        acid_key[acids[df_ind][a]].append(exp_summary[i]['A'][a].item()) 
    
    for g in range(len(glycols[df_ind])):
        glycol_key[glycols[df_ind][g]].append(exp_summary[i]['G'][g].item()) 

    acid_scores.append(torch.sum(exp_summary[i]['A']).item())
    glycol_scores.append(torch.sum(exp_summary[i]['G']).item())

    mw_scores.append(exp_summary[i]['add'][0].item())
    an_scores.append(exp_summary[i]['add'][1].item())
    ohn_scores.append(exp_summary[i]['add'][2].item())
    tmp_scores.append(exp_summary[i]['add'][3].item())

print(acid_scores)
print(glycol_scores)

print(mw_scores)
print(an_scores)
print(ohn_scores)
print(tmp_scores)

print(acid_key)
print(glycol_key)


[-0.8878821134567261, -1.0357295274734497, -0.4294568598270416, -2.0798838138580322, -1.716792345046997, -1.7568775415420532, -0.47344210743904114, -0.6670295000076294, -0.8664342761039734, -1.5350826978683472, -0.7446939945220947, -2.2649717330932617, -0.899152398109436, -0.632922887802124, -1.637803554534912, -0.6790949106216431, -1.4522514343261719, -1.697495937347412, -2.3722074031829834, -1.516371488571167, -0.6643494367599487, -0.5550310611724854, -0.7443832755088806, -0.6352322697639465, -0.4272572100162506, -1.1757748126983643, -0.371645987033844, -2.1001482009887695, -0.5719149708747864, -0.5653843283653259, -1.7718409299850464, -0.5770406723022461, -0.5178261399269104, -0.9155601263046265, -2.273362398147583, -0.9710496664047241, -0.49641841650009155, -1.1031956672668457, -0.7165290117263794, -1.0446813106536865, -0.653941810131073, -1.13113534450531, -2.244901180267334, -1.4169310331344604, -2.206786870956421, -0.695833146572113, -0.4037015736103058, -0.8395153284072876, -1.