# Example script of loading ViscoNet models

This demo shows how to load the ViscoNet model artifacts and use them for inference.

In [1]:
import yaml
import torch
from model import SimpleANN
from utils import VEDatasetV2
from torch.utils.data import DataLoader

In [2]:
experiment = "thermoplastic2thermoplastic" # "thermoplastic2thermoset"
ve_response = "storage_modulus" # "tan_delta"
model_code = "VE256WDNN3" # "VE128np", "VE128", "VE256", "VE256WDNN5"
base_path = f"artifacts/{experiment}/{ve_response}/{model_code}"

In [3]:
# load config
with open(f"{base_path}/config.yml", "r") as f:
    config = yaml.safe_load(f)
# init model
model = SimpleANN(**config).to(torch.device('cpu'))
# load weights
model_wt = torch.load(f"{base_path}/model.pt", map_location='cpu')
model.load_state_dict(model_wt)
model.eval()

SimpleANN(
  (activation): GELU(approximate='none')
  (linear_relu_stack_with_dropout): Sequential(
    (0): Linear(in_features=37, out_features=256, bias=True)
    (1): GELU(approximate='none')
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=256, bias=True)
    (4): GELU(approximate='none')
    (5): Dropout(p=0.2, inplace=False)
  )
  (last_hidden): Linear(in_features=259, out_features=30, bias=True)
)

In [4]:
BATCH_SIZE = 10
NUM_WORKERS = 0
dataset = VEDatasetV2(["example_data/example_ep.json"],**config) # "example_tand.json"
dataloader = DataLoader(dataset, batch_size = BATCH_SIZE, num_workers = NUM_WORKERS)

In [5]:
results = []
with torch.no_grad():
    # take one batch as example
    for batch in dataloader:
        display(batch)
        results.append(model(batch["input"]))

# model generates raw prediction that needs to be scaled back
results = torch.cat(results)
print("Raw prediction:")
print(results)
# scale back with the dataset routine
scaled_back = dataset.scale_back(results, ve_id=0)
print("Scaled-back prediction:")
print(scaled_back)

{'input': tensor([[2.0000e+01, 5.0000e+00, 2.0000e+00, 3.0000e+00, 1.0000e+00, 4.0009e-01,
          5.9830e-01, 0.0000e+00, 2.0355e-02, 3.9034e-02, 5.6774e-02, 7.3916e-02,
          9.0619e-02, 1.0370e-01, 1.1760e-01, 1.3244e-01, 1.4792e-01, 1.6376e-01,
          1.8100e-01, 1.9977e-01, 2.2143e-01, 2.4665e-01, 2.7929e-01, 3.1569e-01,
          3.6741e-01, 4.3149e-01, 5.0499e-01, 5.8314e-01, 6.6704e-01, 7.5120e-01,
          8.2261e-01, 8.7056e-01, 9.0827e-01, 9.3930e-01, 9.6410e-01, 9.8411e-01,
          1.0000e+00],
         [2.0000e+01, 1.0000e+01, 3.0000e+00, 3.0000e+00, 1.0000e+00, 5.1408e-02,
          7.4026e-01, 0.0000e+00, 2.0355e-02, 3.9034e-02, 5.6774e-02, 7.3916e-02,
          9.0619e-02, 1.0370e-01, 1.1760e-01, 1.3244e-01, 1.4792e-01, 1.6376e-01,
          1.8100e-01, 1.9977e-01, 2.2143e-01, 2.4665e-01, 2.7929e-01, 3.1569e-01,
          3.6741e-01, 4.3149e-01, 5.0499e-01, 5.8314e-01, 6.6704e-01, 7.5120e-01,
          8.2261e-01, 8.7056e-01, 9.0827e-01, 9.3930e-01, 9.6410e-

{'input': tensor([[3.5000e+01, 2.0000e+01, 2.0000e+00, 1.0000e+00, 1.0000e+00, 3.0385e-01,
          5.9703e-01, 0.0000e+00, 2.0355e-02, 3.9034e-02, 5.6774e-02, 7.3916e-02,
          9.0619e-02, 1.0370e-01, 1.1760e-01, 1.3244e-01, 1.4792e-01, 1.6376e-01,
          1.8100e-01, 1.9977e-01, 2.2143e-01, 2.4665e-01, 2.7929e-01, 3.1569e-01,
          3.6741e-01, 4.3149e-01, 5.0499e-01, 5.8314e-01, 6.6704e-01, 7.5120e-01,
          8.2261e-01, 8.7056e-01, 9.0827e-01, 9.3930e-01, 9.6410e-01, 9.8411e-01,
          1.0000e+00],
         [5.0000e+01, 1.0000e+01, 1.0000e+00, 2.0000e+00, 1.0000e+00, 3.5809e-01,
          6.3763e-01, 0.0000e+00, 2.0355e-02, 3.9034e-02, 5.6774e-02, 7.3916e-02,
          9.0619e-02, 1.0370e-01, 1.1760e-01, 1.3244e-01, 1.4792e-01, 1.6376e-01,
          1.8100e-01, 1.9977e-01, 2.2143e-01, 2.4665e-01, 2.7929e-01, 3.1569e-01,
          3.6741e-01, 4.3149e-01, 5.0499e-01, 5.8314e-01, 6.6704e-01, 7.5120e-01,
          8.2261e-01, 8.7056e-01, 9.0827e-01, 9.3930e-01, 9.6410e-

{'input': tensor([[3.5000e+01, 2.0000e+01, 3.0000e+00, 1.0000e+00, 0.0000e+00, 1.5122e-01,
          4.4094e-01, 0.0000e+00, 2.0355e-02, 3.9034e-02, 5.6774e-02, 7.3916e-02,
          9.0619e-02, 1.0370e-01, 1.1760e-01, 1.3244e-01, 1.4792e-01, 1.6376e-01,
          1.8100e-01, 1.9977e-01, 2.2143e-01, 2.4665e-01, 2.7929e-01, 3.1569e-01,
          3.6741e-01, 4.3149e-01, 5.0499e-01, 5.8314e-01, 6.6704e-01, 7.5120e-01,
          8.2261e-01, 8.7056e-01, 9.0827e-01, 9.3930e-01, 9.6410e-01, 9.8411e-01,
          1.0000e+00],
         [5.0000e+01, 1.0000e+01, 2.0000e+00, 1.0000e+00, 1.0000e+00, 2.1250e-01,
          4.4996e-01, 0.0000e+00, 2.0355e-02, 3.9034e-02, 5.6774e-02, 7.3916e-02,
          9.0619e-02, 1.0370e-01, 1.1760e-01, 1.3244e-01, 1.4792e-01, 1.6376e-01,
          1.8100e-01, 1.9977e-01, 2.2143e-01, 2.4665e-01, 2.7929e-01, 3.1569e-01,
          3.6741e-01, 4.3149e-01, 5.0499e-01, 5.8314e-01, 6.6704e-01, 7.5120e-01,
          8.2261e-01, 8.7056e-01, 9.0827e-01, 9.3930e-01, 9.6410e-

{'input': tensor([[3.5000e+01, 2.0000e+01, 3.0000e+00, 1.0000e+00, 1.0000e+00, 1.1291e-01,
          5.8506e-01, 0.0000e+00, 2.0355e-02, 3.9034e-02, 5.6774e-02, 7.3916e-02,
          9.0619e-02, 1.0370e-01, 1.1760e-01, 1.3244e-01, 1.4792e-01, 1.6376e-01,
          1.8100e-01, 1.9977e-01, 2.2143e-01, 2.4665e-01, 2.7929e-01, 3.1569e-01,
          3.6741e-01, 4.3149e-01, 5.0499e-01, 5.8314e-01, 6.6704e-01, 7.5120e-01,
          8.2261e-01, 8.7056e-01, 9.0827e-01, 9.3930e-01, 9.6410e-01, 9.8411e-01,
          1.0000e+00],
         [5.0000e+00, 1.0000e+00, 3.0000e+00, 1.0000e+00, 0.0000e+00, 1.0001e-01,
          4.5345e-01, 0.0000e+00, 2.0355e-02, 3.9034e-02, 5.6774e-02, 7.3916e-02,
          9.0619e-02, 1.0370e-01, 1.1760e-01, 1.3244e-01, 1.4792e-01, 1.6376e-01,
          1.8100e-01, 1.9977e-01, 2.2143e-01, 2.4665e-01, 2.7929e-01, 3.1569e-01,
          3.6741e-01, 4.3149e-01, 5.0499e-01, 5.8314e-01, 6.6704e-01, 7.5120e-01,
          8.2261e-01, 8.7056e-01, 9.0827e-01, 9.3930e-01, 9.6410e-

{'input': tensor([[1.0000e+01, 3.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0012e-01,
          8.8220e-01, 0.0000e+00, 2.0355e-02, 3.9034e-02, 5.6774e-02, 7.3916e-02,
          9.0619e-02, 1.0370e-01, 1.1760e-01, 1.3244e-01, 1.4792e-01, 1.6376e-01,
          1.8100e-01, 1.9977e-01, 2.2143e-01, 2.4665e-01, 2.7929e-01, 3.1569e-01,
          3.6741e-01, 4.3149e-01, 5.0499e-01, 5.8314e-01, 6.6704e-01, 7.5120e-01,
          8.2261e-01, 8.7056e-01, 9.0827e-01, 9.3930e-01, 9.6410e-01, 9.8411e-01,
          1.0000e+00],
         [5.0000e+01, 3.0000e+01, 3.0000e+00, 3.0000e+00, 0.0000e+00, 5.0504e-02,
          3.0828e-02, 0.0000e+00, 2.0355e-02, 3.9034e-02, 5.6774e-02, 7.3916e-02,
          9.0619e-02, 1.0370e-01, 1.1760e-01, 1.3244e-01, 1.4792e-01, 1.6376e-01,
          1.8100e-01, 1.9977e-01, 2.2143e-01, 2.4665e-01, 2.7929e-01, 3.1569e-01,
          3.6741e-01, 4.3149e-01, 5.0499e-01, 5.8314e-01, 6.6704e-01, 7.5120e-01,
          8.2261e-01, 8.7056e-01, 9.0827e-01, 9.3930e-01, 9.6410e-

{'input': tensor([[5.0000e+00, 3.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 3.6078e-01,
          6.3909e-01, 0.0000e+00, 2.0355e-02, 3.9034e-02, 5.6774e-02, 7.3916e-02,
          9.0619e-02, 1.0370e-01, 1.1760e-01, 1.3244e-01, 1.4792e-01, 1.6376e-01,
          1.8100e-01, 1.9977e-01, 2.2143e-01, 2.4665e-01, 2.7929e-01, 3.1569e-01,
          3.6741e-01, 4.3149e-01, 5.0499e-01, 5.8314e-01, 6.6704e-01, 7.5120e-01,
          8.2261e-01, 8.7056e-01, 9.0827e-01, 9.3930e-01, 9.6410e-01, 9.8411e-01,
          1.0000e+00],
         [1.0000e+01, 5.0000e+00, 2.0000e+00, 3.0000e+00, 1.0000e+00, 2.0029e-01,
          7.9941e-01, 0.0000e+00, 2.0355e-02, 3.9034e-02, 5.6774e-02, 7.3916e-02,
          9.0619e-02, 1.0370e-01, 1.1760e-01, 1.3244e-01, 1.4792e-01, 1.6376e-01,
          1.8100e-01, 1.9977e-01, 2.2143e-01, 2.4665e-01, 2.7929e-01, 3.1569e-01,
          3.6741e-01, 4.3149e-01, 5.0499e-01, 5.8314e-01, 6.6704e-01, 7.5120e-01,
          8.2261e-01, 8.7056e-01, 9.0827e-01, 9.3930e-01, 9.6410e-

{'input': tensor([[1.0000e+01, 5.0000e+00, 3.0000e+00, 2.0000e+00, 0.0000e+00, 2.0029e-01,
          3.6084e-01, 0.0000e+00, 2.0355e-02, 3.9034e-02, 5.6774e-02, 7.3916e-02,
          9.0619e-02, 1.0370e-01, 1.1760e-01, 1.3244e-01, 1.4792e-01, 1.6376e-01,
          1.8100e-01, 1.9977e-01, 2.2143e-01, 2.4665e-01, 2.7929e-01, 3.1569e-01,
          3.6741e-01, 4.3149e-01, 5.0499e-01, 5.8314e-01, 6.6704e-01, 7.5120e-01,
          8.2261e-01, 8.7056e-01, 9.0827e-01, 9.3930e-01, 9.6410e-01, 9.8411e-01,
          1.0000e+00],
         [5.0000e+00, 1.0000e+00, 1.0000e+00, 3.0000e+00, 1.0000e+00, 1.0008e-01,
          8.9944e-01, 0.0000e+00, 2.0355e-02, 3.9034e-02, 5.6774e-02, 7.3916e-02,
          9.0619e-02, 1.0370e-01, 1.1760e-01, 1.3244e-01, 1.4792e-01, 1.6376e-01,
          1.8100e-01, 1.9977e-01, 2.2143e-01, 2.4665e-01, 2.7929e-01, 3.1569e-01,
          3.6741e-01, 4.3149e-01, 5.0499e-01, 5.8314e-01, 6.6704e-01, 7.5120e-01,
          8.2261e-01, 8.7056e-01, 9.0827e-01, 9.3930e-01, 9.6410e-

{'input': tensor([[5.0000e+01, 3.0000e+01, 2.0000e+00, 1.0000e+00, 0.0000e+00, 7.5280e-02,
          2.4512e-01, 0.0000e+00, 2.0355e-02, 3.9034e-02, 5.6774e-02, 7.3916e-02,
          9.0619e-02, 1.0370e-01, 1.1760e-01, 1.3244e-01, 1.4792e-01, 1.6376e-01,
          1.8100e-01, 1.9977e-01, 2.2143e-01, 2.4665e-01, 2.7929e-01, 3.1569e-01,
          3.6741e-01, 4.3149e-01, 5.0499e-01, 5.8314e-01, 6.6704e-01, 7.5120e-01,
          8.2261e-01, 8.7056e-01, 9.0827e-01, 9.3930e-01, 9.6410e-01, 9.8411e-01,
          1.0000e+00],
         [2.0000e+01, 5.0000e+00, 1.0000e+00, 2.0000e+00, 1.0000e+00, 1.0019e-01,
          8.9066e-01, 0.0000e+00, 2.0355e-02, 3.9034e-02, 5.6774e-02, 7.3916e-02,
          9.0619e-02, 1.0370e-01, 1.1760e-01, 1.3244e-01, 1.4792e-01, 1.6376e-01,
          1.8100e-01, 1.9977e-01, 2.2143e-01, 2.4665e-01, 2.7929e-01, 3.1569e-01,
          3.6741e-01, 4.3149e-01, 5.0499e-01, 5.8314e-01, 6.6704e-01, 7.5120e-01,
          8.2261e-01, 8.7056e-01, 9.0827e-01, 9.3930e-01, 9.6410e-

{'input': tensor([[5.0000e+01, 2.0000e+01, 1.0000e+00, 3.0000e+00, 1.0000e+00, 3.5332e-01,
          5.8339e-01, 0.0000e+00, 2.0355e-02, 3.9034e-02, 5.6774e-02, 7.3916e-02,
          9.0619e-02, 1.0370e-01, 1.1760e-01, 1.3244e-01, 1.4792e-01, 1.6376e-01,
          1.8100e-01, 1.9977e-01, 2.2143e-01, 2.4665e-01, 2.7929e-01, 3.1569e-01,
          3.6741e-01, 4.3149e-01, 5.0499e-01, 5.8314e-01, 6.6704e-01, 7.5120e-01,
          8.2261e-01, 8.7056e-01, 9.0827e-01, 9.3930e-01, 9.6410e-01, 9.8411e-01,
          1.0000e+00],
         [2.0000e+01, 1.0000e+01, 2.0000e+00, 1.0000e+00, 1.0000e+00, 1.5124e-01,
          6.6940e-01, 0.0000e+00, 2.0355e-02, 3.9034e-02, 5.6774e-02, 7.3916e-02,
          9.0619e-02, 1.0370e-01, 1.1760e-01, 1.3244e-01, 1.4792e-01, 1.6376e-01,
          1.8100e-01, 1.9977e-01, 2.2143e-01, 2.4665e-01, 2.7929e-01, 3.1569e-01,
          3.6741e-01, 4.3149e-01, 5.0499e-01, 5.8314e-01, 6.6704e-01, 7.5120e-01,
          8.2261e-01, 8.7056e-01, 9.0827e-01, 9.3930e-01, 9.6410e-

{'input': tensor([[2.0000e+01, 1.0000e+01, 1.0000e+00, 2.0000e+00, 1.0000e+00, 3.5033e-01,
          5.0128e-01, 0.0000e+00, 2.0355e-02, 3.9034e-02, 5.6774e-02, 7.3916e-02,
          9.0619e-02, 1.0370e-01, 1.1760e-01, 1.3244e-01, 1.4792e-01, 1.6376e-01,
          1.8100e-01, 1.9977e-01, 2.2143e-01, 2.4665e-01, 2.7929e-01, 3.1569e-01,
          3.6741e-01, 4.3149e-01, 5.0499e-01, 5.8314e-01, 6.6704e-01, 7.5120e-01,
          8.2261e-01, 8.7056e-01, 9.0827e-01, 9.3930e-01, 9.6410e-01, 9.8411e-01,
          1.0000e+00],
         [1.0000e+01, 3.0000e+00, 3.0000e+00, 1.0000e+00, 1.0000e+00, 2.0017e-01,
          7.9981e-01, 0.0000e+00, 2.0355e-02, 3.9034e-02, 5.6774e-02, 7.3916e-02,
          9.0619e-02, 1.0370e-01, 1.1760e-01, 1.3244e-01, 1.4792e-01, 1.6376e-01,
          1.8100e-01, 1.9977e-01, 2.2143e-01, 2.4665e-01, 2.7929e-01, 3.1569e-01,
          3.6741e-01, 4.3149e-01, 5.0499e-01, 5.8314e-01, 6.6704e-01, 7.5120e-01,
          8.2261e-01, 8.7056e-01, 9.0827e-01, 9.3930e-01, 9.6410e-

Raw prediction:
tensor([[0.4005, 0.4107, 0.4226,  ..., 0.9592, 0.9769, 0.9923],
        [0.3088, 0.3273, 0.3483,  ..., 0.9209, 0.9345, 0.9478],
        [0.2433, 0.2579, 0.2724,  ..., 0.9777, 0.9955, 1.0104],
        ...,
        [0.2473, 0.2629, 0.2797,  ..., 0.9910, 1.0040, 1.0151],
        [0.2011, 0.2237, 0.2454,  ..., 0.9930, 1.0088, 1.0214],
        [0.0472, 0.0694, 0.0906,  ..., 0.9958, 1.0113, 1.0232]])
Scaled-back prediction:
tensor([[6.8160, 6.8552, 6.9007,  ..., 8.9613, 9.0292, 9.0881],
        [6.4638, 6.5349, 6.6156,  ..., 8.8139, 8.8663, 8.9173],
        [6.2125, 6.2684, 6.3242,  ..., 9.0322, 9.1006, 9.1578],
        ...,
        [6.2277, 6.2877, 6.3520,  ..., 9.0834, 9.1334, 9.1756],
        [6.0504, 6.1372, 6.2203,  ..., 9.0909, 9.1517, 9.2001],
        [5.4594, 5.5447, 5.6261,  ..., 9.1017, 9.1611, 9.2068]],
       dtype=torch.float64)
