In [1]:
import os
import torch
import pandas as pd
from polymerlearn.utils import get_IV_add, GraphDataset

# Load data from local path:
data = pd.read_csv(os.path.join('/Users/owenqueen/Desktop/eastman_project-confidential/Eastman_Project/CombinedData', 
            'pub_data.csv'))

add = get_IV_add(data)

dataset = GraphDataset(
    data = data,
    structure_dir = '../Structures/AG/xyz',
    Y_target=['IV'],
    test_size = 0.2,
    add_features=add
)

  result = getattr(ufunc, method)(*inputs, **kwargs)


In [2]:
from polymerlearn.models.gnn import PolymerGNN_IV
from polymerlearn.utils import train

model_kwargs = {
    'input_feat': 6,         # How many input features on each node; don't change this
    'hidden_channels': 32,   # How many intermediate dimensions to use in model
                            # Can change this ^^
    'num_additional': 4      # How many additional resin properties to include in the prediction
                            # Corresponds to the number in get_IV_add
}

model = PolymerGNN_IV(**model_kwargs)

optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)
criterion = torch.nn.MSELoss()

train(
    model,
    optimizer = optimizer,
    criterion = criterion,
    dataset = dataset,
    batch_size = 64,
    epochs = 800
)

  return self.test_data, torch.tensor(self.Ytest).float(), torch.tensor(self.add_test).float()


Epoch: 0, 	 Train r2: -0.8997 	 Train Loss: 4.9989 	 Test r2: -0.1446 	 Test Loss 0.0437
Epoch: 10, 	 Train r2: 0.1496 	 Train Loss: 2.2337 	 Test r2: 0.4057 	 Test Loss 0.0227
Epoch: 20, 	 Train r2: 0.3836 	 Train Loss: 2.1899 	 Test r2: 0.5019 	 Test Loss 0.0190
Epoch: 30, 	 Train r2: 0.3142 	 Train Loss: 1.9694 	 Test r2: 0.5087 	 Test Loss 0.0188
Epoch: 40, 	 Train r2: 0.3454 	 Train Loss: 2.0457 	 Test r2: 0.5328 	 Test Loss 0.0178
Epoch: 50, 	 Train r2: 0.4531 	 Train Loss: 2.4713 	 Test r2: 0.5583 	 Test Loss 0.0169
Epoch: 60, 	 Train r2: 0.3992 	 Train Loss: 1.3919 	 Test r2: 0.5811 	 Test Loss 0.0160
Epoch: 70, 	 Train r2: 0.4733 	 Train Loss: 1.4838 	 Test r2: 0.5910 	 Test Loss 0.0156
Epoch: 80, 	 Train r2: 0.5291 	 Train Loss: 1.4367 	 Test r2: 0.6107 	 Test Loss 0.0149
Epoch: 90, 	 Train r2: 0.1988 	 Train Loss: 1.6848 	 Test r2: 0.6455 	 Test Loss 0.0135
Epoch: 100, 	 Train r2: 0.4306 	 Train Loss: 1.2140 	 Test r2: 0.6425 	 Test Loss 0.0136
Epoch: 110, 	 Train r2: 0.4484

In [4]:
from polymerlearn.explain import PolymerGNN_IV_EXPLAIN, PolymerGNNExplainer

mexplain = PolymerGNN_IV_EXPLAIN(**model_kwargs)
mexplain.load_state_dict(model.state_dict()) # Load weights from trained model over to explaining one

explainer = PolymerGNNExplainer(mexplain)

test_batch, Ytest, add_test = dataset.get_test()
test_inds = dataset.test_mask

exp_summary = []

for i in range(Ytest.shape[0]):
    scores = explainer.get_explanation(test_batch[i], add_test[i])
    scores['A'] = torch.sum(scores['A'], dim = 1)
    scores['G'] = torch.sum(scores['G'], dim = 1)
    scores['table_ind'] = test_inds[i]

    exp_summary.append(scores)

  torch.tensor(add_test).float())


In [6]:
import numpy as np

data_mask = data.loc[data['IV'].notna(),:]

for i in range(len(test_inds)):
    j = test_inds[i]
    mw_table = data_mask['Mw (PS)'].iloc[j]
    # Test mw from add_test
    mw_add_test = add_test[i,0]

    print('Table: {} \t log(Table): {} \t Add test: {}'.format(
        mw_table, np.log(mw_table), mw_add_test.item()
    ))

Table: nan 	 log(Table): nan 	 Add test: 10.427446365356445
Table: 23976.0 	 log(Table): 10.084808608996498 	 Add test: 10.084808349609375
Table: 94705.0 	 log(Table): 11.458522076090842 	 Add test: 11.458521842956543
Table: 24619.0 	 log(Table): 10.111273781529295 	 Add test: 10.111273765563965
Table: 47381.0 	 log(Table): 10.765976583441928 	 Add test: 10.765976905822754
Table: 24860.0 	 log(Table): 10.121015365064702 	 Add test: 10.121015548706055
Table: nan 	 log(Table): nan 	 Add test: 10.427446365356445
Table: 14098.0 	 log(Table): 9.553788222333822 	 Add test: 9.553788185119629
Table: 68358.0 	 log(Table): 11.13251387992617 	 Add test: 11.132513999938965
Table: 97437.0 	 log(Table): 11.486961294292419 	 Add test: 11.486961364746094
Table: 56224.0 	 log(Table): 10.937098990986824 	 Add test: 10.93709945678711
Table: 52715.0 	 log(Table): 10.87265532401105 	 Add test: 10.872654914855957
Table: 38227.0 	 log(Table): 10.551297351207467 	 Add test: 10.551297187805176
Table: 27730.0 	

In [7]:
# Summarize importance scores:

# Mw summary:
acid_scores = []
glycol_scores = []

mw_scores = []
an_scores = []
ohn_scores = []
tmp_scores = []

def get_AG_info(data, ac = (20,33), gc = (34,46)):

    # Decompose the data into included names
    acid_names = pd.Series([c[1:] for c in data.columns[ac[0]:ac[1]].tolist()])
    glycol_names = pd.Series([c[1:] for c in data.columns[gc[0]:gc[1]].tolist()])

    # Holds all names of acids and glycols
    acid_included = []
    glycol_included = []

    # Keep track of percents in each acid, glycol
    acid_pcts = []
    glycol_pcts = []

    # Get relevant names and percentages of acid/glycols
    for i in range(data.shape[0]):

        acid_hit = (data.iloc[i,ac[0]:ac[1]].to_numpy() > 0)
        glycol_hit = (data.iloc[i,gc[0]:gc[1]].to_numpy() > 0)

        # Add to percentage lists:
        acid_pcts.append(data.iloc[i,ac[0]:ac[1]][acid_hit].tolist())
        glycol_pcts.append(data.iloc[i,gc[0]:gc[1]][glycol_hit].tolist()) 

        acid_pos = acid_names[np.argwhere(acid_hit).flatten()].tolist()
        glycol_pos = glycol_names[np.argwhere(glycol_hit).flatten()].tolist()

        acid_included.append(acid_pos)
        glycol_included.append(glycol_pos)

    return acid_included, glycol_included, acid_pcts, glycol_pcts

acid_names = pd.Series([c[1:] for c in data_mask.columns[20:33].tolist()])
glycol_names = pd.Series([c[1:] for c in data_mask.columns[34:46].tolist()])
acids, glycols, _, _ = get_AG_info(data_mask)

acid_key = {a:[] for a in acid_names}
glycol_key = {g:[] for g in glycol_names}

for i in range(len(exp_summary)):

    df_ind = exp_summary[i]['table_ind']

    for a in range(len(acids[df_ind])):
        acid_key[acids[df_ind][a]].append(exp_summary[i]['A'][a].item()) 
    
    for g in range(len(glycols[df_ind])):
        glycol_key[glycols[df_ind][g]].append(exp_summary[i]['G'][g].item()) 

    acid_scores.append(torch.sum(exp_summary[i]['A']).item())
    glycol_scores.append(torch.sum(exp_summary[i]['G']).item())

    mw_scores.append(exp_summary[i]['add'][0].item())
    an_scores.append(exp_summary[i]['add'][1].item())
    ohn_scores.append(exp_summary[i]['add'][2].item())
    tmp_scores.append(exp_summary[i]['add'][3].item())

print(acid_scores)
print(glycol_scores)

print(mw_scores)
print(an_scores)
print(ohn_scores)
print(tmp_scores)

print(acid_key)
print(glycol_key)


[-0.5220816731452942, -0.871830940246582, -1.372873067855835, -0.7051889300346375, -1.479537844657898, -0.8235787749290466, -0.49047982692718506, -0.4644251763820648, -1.7646515369415283, -1.5553362369537354, -1.144914150238037, -2.20985746383667, -1.4301812648773193, -0.5704778432846069, -0.9277279376983643, -0.6290164589881897, -0.811713457107544, -0.4857237935066223, -1.2340717315673828, -0.7832549214363098, -2.0962419509887695, -0.3163483738899231, -0.6890739798545837, -1.1152818202972412, -0.6931972503662109, -0.519429087638855, -0.70504230260849, -0.5535339117050171, -0.7099688053131104, -0.5150319337844849, -1.429985761642456, -0.777635931968689, -0.45038893818855286, -0.8560777902603149, -0.5053524971008301, -0.9263220429420471, -0.41691508889198303, -1.2905853986740112, -1.724608302116394, -1.264465093612671, -1.064122200012207, -0.8285700678825378, -0.9279764890670776, -0.7006344199180603, -2.206166982650757, -0.6776403188705444, -2.2430825233459473, -0.9685702323913574, -2.1