# Example script of loading ViscoNet models

This demo shows how to load the ViscoNet model artifacts and use them for inference.

In [1]:
import yaml
import torch
from model import SimpleANN
from utils import VEDatasetV2
from torch.utils.data import DataLoader

In [2]:
experiment = "thermoplastic2thermoplastic" # "thermoplastic2thermoset"
ve_response = "storage_modulus" # "tan_delta"
model_code = "VE256WDNN3" # "VE128np", "VE128", "VE256", "VE256WDNN5"
base_path = f"artifacts/{experiment}/{ve_response}/{model_code}"

In [3]:
# load config
with open(f"{base_path}/config.yml", "r") as f:
    config = yaml.safe_load(f)
# init model
model = SimpleANN(**config).to(torch.device('cpu'))
# load weights
model_wt = torch.load(f"{base_path}/model.pt", map_location='cpu')
model.load_state_dict(model_wt)
model.eval()

SimpleANN(
  (activation): GELU()
  (linear_relu_stack_with_dropout): Sequential(
    (0): Linear(in_features=37, out_features=256, bias=True)
    (1): GELU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=256, bias=True)
    (4): GELU()
    (5): Dropout(p=0.2, inplace=False)
  )
  (last_hidden): Linear(in_features=259, out_features=30, bias=True)
)

In [4]:
BATCH_SIZE = 10
NUM_WORKERS = 0
dataset = VEDatasetV2(["example_data/example_ep.json"],**config) # "example_tand.json"
dataloader = DataLoader(dataset, batch_size = BATCH_SIZE, num_workers = NUM_WORKERS)

In [5]:
results = []
with torch.no_grad():
    # take one batch as example
    for batch in dataloader:
        results.append(model(batch["input"]))
# model generates raw prediction that needs to be scaled back
results = torch.cat(results)
print("Raw prediction:")
print(results)
# scale back with the dataset routine
scaled_back = dataset.scale_back(results, ve_id=0)
print("Scaled-back prediction:")
print(scaled_back)

Raw prediction:
tensor([[0.4003, 0.4125, 0.4267,  ..., 0.9740, 0.9914, 1.0063],
        [0.2918, 0.3109, 0.3323,  ..., 0.9252, 0.9379, 0.9505],
        [0.2371, 0.2515, 0.2661,  ..., 0.9846, 1.0037, 1.0205],
        ...,
        [0.2419, 0.2574, 0.2735,  ..., 0.9848, 0.9980, 1.0090],
        [0.1950, 0.2166, 0.2377,  ..., 0.9940, 1.0092, 1.0216],
        [0.0493, 0.0713, 0.0924,  ..., 1.0010, 1.0162, 1.0281]])
Scaled-back prediction:
tensor([[6.8151, 6.8620, 6.9167,  ..., 9.0179, 9.0848, 9.1421],
        [6.3986, 6.4721, 6.5541,  ..., 8.8305, 8.8794, 8.9278],
        [6.1886, 6.2439, 6.2999,  ..., 9.0586, 9.1322, 9.1964],
        ...,
        [6.2071, 6.2666, 6.3284,  ..., 9.0594, 9.1101, 9.1526],
        [6.0271, 6.1097, 6.1907,  ..., 9.0946, 9.1533, 9.2008],
        [5.4677, 5.5518, 5.6331,  ..., 9.1218, 9.1800, 9.2258]],
       dtype=torch.float64)
