In [1]:
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import Parameter as TorchParam
from torch import Tensor
from typing import List, Tuple

In [2]:
device = torch.device("cpu")
target_cols = ["pm", "stator_yoke", "stator_tooth", "stator_winding"]
temperature_cols = ['pm', 'stator_yoke', 'stator_tooth', 'stator_winding', 'ambient', 'coolant']
data_columns = ['u_q', 'coolant', 'stator_winding', 'u_d', 'stator_tooth', 'motor_speed', 'i_d', 'i_q', 'pm', 'stator_yoke', 'ambient', 'torque', 'profile_id']
input_cols = ['u_q', 'coolant', 'u_d', 'motor_speed', 'i_d', 'i_q', 'ambient', 'torque', 'i_s', 'u_s']

In [3]:
class DiffEqLayer(nn.Module):
    """This class is a container for the computation logic in each step.
    This layer could be used for any 'cell', also RNNs, LSTMs or GRUs."""

    def __init__(self, cell, *cell_args):
        super().__init__()
        self.cell = cell(*cell_args)

    def forward(self, input: Tensor, state: Tensor) -> Tuple[Tensor, Tensor]:
        inputs = input.unbind(0)
        outputs = torch.jit.annotate(List[Tensor], [])
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i], state)
            outputs += [out]
        return torch.stack(outputs), state

In [4]:
n_temps = len(temperature_cols)
n_conds = int(0.5 * n_temps * (n_temps - 1))


In [5]:
output_size = len(target_cols)

In [6]:
conductance_net = nn.Sequential(
            nn.Linear(len(input_cols) + output_size, n_conds), nn.Sigmoid()
        )
conductance_net(torch.randn(10, len(input_cols) + output_size)).shape

torch.Size([10, 15])

In [7]:
ploss = nn.Sequential(
            nn.Linear(len(input_cols) + output_size, 16),
            nn.Tanh(),
            nn.Linear(16, output_size),
        )
ploss

Sequential(
  (0): Linear(in_features=14, out_features=16, bias=True)
  (1): Tanh()
  (2): Linear(in_features=16, out_features=4, bias=True)
)

In [8]:
x = torch.randn(10, len(input_cols) + output_size)
ploss(x)

tensor([[ 0.1954,  0.0485, -0.0908,  0.5545],
        [-0.0385, -0.2710, -0.1677,  0.1686],
        [-0.1077, -0.1104, -0.1087,  0.2152],
        [-0.3682, -0.1495,  0.2252, -0.0116],
        [-0.2821, -0.4265,  0.1511,  0.1539],
        [ 0.3642, -0.1110, -0.3849, -0.4641],
        [ 0.3801,  0.0870, -0.4116,  0.1856],
        [ 0.0149, -0.1003, -0.1166, -0.0596],
        [-0.4610, -0.2944,  0.3059,  0.2666],
        [-0.4504, -0.2748,  0.1927,  0.2906]], grad_fn=<AddmmBackward0>)

In [9]:
input_cols

['u_q',
 'coolant',
 'u_d',
 'motor_speed',
 'i_d',
 'i_q',
 'ambient',
 'torque',
 'i_s',
 'u_s']

In [10]:
temp_idcs = [i for i, x in enumerate(input_cols) if x in temperature_cols]
temp_idcs

[1, 6]

In [11]:
temp_idcs = [i for i, x in enumerate(input_cols) if x in temperature_cols]
temp_idcs

[1, 6]

In [12]:
nontemp_idcs = [
            i
            for i, x in enumerate(input_cols)
            if x not in temperature_cols + ["profile_id"]
        ]
nontemp_idcs

[0, 2, 3, 4, 5, 7, 8, 9]

In [13]:
class TNNCell(nn.Module):
    
    """The main TNN logic. Here, the sub-NNs are initialized as well as the constant learnable
    thermal capacitances. The forward function houses the LPTN ODE discretized with the explicit Euler method
    """

    def __init__(self):
        super().__init__()
        self.sample_time = 0.5  # in s
        self.output_size = len(target_cols)
        self.caps = TorchParam(torch.Tensor(self.output_size).to(device))
        nn.init.normal_(
            self.caps, mean=-9.2, std=0.5
        )  # hand-picked init mean, might be application-dependent
        n_temps = len(temperature_cols)  # number of temperatures (targets and input)
        n_conds = int(0.5 * n_temps * (n_temps - 1))  # number of thermal conductances
        # conductance net sub-NN
        self.conductance_net = nn.Sequential(
            nn.Linear(len(input_cols) + self.output_size, n_conds), nn.Sigmoid()
        )
        # populate adjacency matrix. It is used for indexing the conductance sub-NN output
        self.adj_mat = np.zeros((n_temps, n_temps), dtype=int)
        adj_idx_arr = np.ones_like(self.adj_mat)
        triu_idx = np.triu_indices(n_temps, 1)
        adj_idx_arr = adj_idx_arr[triu_idx].ravel()
        self.adj_mat[triu_idx] = np.cumsum(adj_idx_arr) - 1
        self.adj_mat += self.adj_mat.T
        self.adj_mat = torch.from_numpy(self.adj_mat[: self.output_size, :]).type(
            torch.int64
        )  # crop
        self.n_temps = n_temps

        # power loss sub-NN
        self.ploss = nn.Sequential(
            nn.Linear(len(input_cols) + self.output_size, 16),
            nn.Tanh(),
            nn.Linear(16, self.output_size),
        )

        self.temp_idcs = [i for i, x in enumerate(input_cols) if x in temperature_cols]
        self.nontemp_idcs = [
            i
            for i, x in enumerate(input_cols)
            if x not in temperature_cols + ["profile_id"]
        ]

    def forward(self, inp: Tensor, hidden: Tensor) -> Tuple[Tensor, Tensor]:
        prev_out = hidden
        temps = torch.cat([prev_out, inp[:, self.temp_idcs]], dim=1)
        sub_nn_inp = torch.cat([inp, prev_out], dim=1)
        conducts = torch.abs(self.conductance_net(sub_nn_inp))
        power_loss = torch.abs(self.ploss(sub_nn_inp))
        temp_diffs = torch.sum(
            (temps.unsqueeze(1) - prev_out.unsqueeze(-1)) * conducts[:, self.adj_mat],
            dim=-1,
        )
        out = prev_out + self.sample_time * torch.exp(self.caps) * (
            temp_diffs + power_loss
        )
        return prev_out, torch.clip(out, -1, 5)


In [14]:
model = torch.jit.script(DiffEqLayer(TNNCell).to(device))
model

RecursiveScriptModule(
  original_name=DiffEqLayer
  (cell): RecursiveScriptModule(
    original_name=TNNCell
    (conductance_net): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(original_name=Linear)
      (1): RecursiveScriptModule(original_name=Sigmoid)
    )
    (ploss): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(original_name=Linear)
      (1): RecursiveScriptModule(original_name=Tanh)
      (2): RecursiveScriptModule(original_name=Linear)
    )
  )
)

In [25]:
train_tensor = torch.load('train_tensor.pt')
train_sample_weights = torch.load('train_sample_weights.pt')
test_tensor = torch.load('test_tensor.pt')
test_sample_weights = torch.load('test_sample_weights.pt')
train_tensor.shape, test_tensor.shape

(torch.Size([43971, 66, 14]), torch.Size([25600, 3, 14]))

In [26]:
loss_func = nn.MSELoss(reduction="none")
opt = optim.Adam(model.parameters(), lr=1e-3)
n_epochs = 100
tbptt_size = 512
n_batches = np.ceil(train_tensor.shape[0] / tbptt_size).astype(int)

In [28]:
target_cols

['pm', 'stator_yoke', 'stator_tooth', 'stator_winding']

In [27]:
hidden = train_tensor[0, :, -len(target_cols) :]

In [29]:
i = 0
train_tensor[
    i * tbptt_size : (i + 1) * tbptt_size, :, : len(input_cols)
].shape

torch.Size([512, 66, 10])

In [None]:
with tqdm(desc="Training", total=n_epochs) as pbar:
    for epoch in range(n_epochs):
        # first state is ground truth temperature data
        hidden = train_tensor[0, :, -len(target_cols) :]

        # propagate batch-wise through data set
        for i in range(n_batches):
            model.zero_grad()
            output, hidden = model(
                train_tensor[
                    i * tbptt_size : (i + 1) * tbptt_size, :, : len(input_cols)
                ],
                hidden.detach(),
            )
            loss = loss_func(
                output,
                train_tensor[
                    i * tbptt_size : (i + 1) * tbptt_size, :, -len(target_cols) :
                ],
            )
            # sample_weighting
            loss = (
                (
                    loss
                    * train_sample_weights[
                        i * tbptt_size : (i + 1) * tbptt_size, :, None
                    ]
                    / train_sample_weights[
                        i * tbptt_size : (i + 1) * tbptt_size, :
                    ].sum()
                )
                .sum()
                .mean()
            )
            loss.backward()
            opt.step()

        # reduce learning rate
        if epoch == 75:
            for group in opt.param_groups:
                group["lr"] *= 0.5
        pbar.update()
        pbar.set_postfix_str(f"loss: {loss.item():.2e}")

Training: 100%|██████████| 100/100 [35:22<00:00, 21.22s/it, loss: 4.16e-04]


In [None]:
mdl_path = Path.cwd() / 'data' / 'models'
mdl_path.mkdir(exist_ok=True, parents=True)
mdl_file_path = mdl_path / 'tnn_jit_torch.pt'

model.save(mdl_file_path)  # save
model = torch.jit.load(mdl_file_path)  # load
model.eval()

RecursiveScriptModule(
  original_name=DiffEqLayer
  (cell): RecursiveScriptModule(
    original_name=TNNCell
    (conductance_net): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(original_name=Linear)
      (1): RecursiveScriptModule(original_name=Sigmoid)
    )
    (ploss): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(original_name=Linear)
      (1): RecursiveScriptModule(original_name=Tanh)
      (2): RecursiveScriptModule(original_name=Linear)
    )
  )
)

In [34]:
# evaluate against test set
with torch.no_grad():
    pred, hidden = model(
        test_tensor[:, :, : len(input_cols)], test_tensor[0, :, -len(target_cols) :]
    )
    pred = pred.cpu().numpy() * 200  # denormalize
pred

array([[[35.52539  , 38.01852  , 37.8663   , 38.895035 ],
        [23.14468  , 22.525713 , 22.396318 , 23.351286 ],
        [68.55495  , 71.66812  , 72.12173  , 75.44503  ]],

       [[35.527523 , 38.025578 , 37.871094 , 38.89906  ],
        [23.147926 , 22.53506  , 22.40196  , 23.35643  ],
        [68.54954  , 71.660484 , 72.114655 , 75.42669  ]],

       [[35.530346 , 38.032356 , 37.878452 , 38.910778 ],
        [23.15106  , 22.544796 , 22.408064 , 23.363768 ],
        [68.54413  , 71.652885 , 72.107574 , 75.40843  ]],

       ...,

       [[ 7.4406514,  8.139296 ,  7.9792724,  8.642113 ],
        [31.240984 , 28.761036 , 28.543165 , 29.268423 ],
        [52.113033 , 10.481602 , 10.515229 , 13.537882 ]],

       [[ 7.439987 ,  8.139255 ,  7.9792285,  8.642031 ],
        [31.240957 , 28.761036 , 28.54316  , 29.268412 ],
        [52.100952 , 10.480215 , 10.513332 , 13.5350685]],

       [[ 7.4393244,  8.139214 ,  7.979185 ,  8.641949 ],
        [31.24094  , 28.761091 , 28.543186 , 29.2