## Imports

In [1]:
!git clone https://github.com/rmnigm/qber-forecasting.git
!git clone https://github.com/thuml/Autoformer.git

Cloning into 'qber-forecasting'...
remote: Enumerating objects: 456, done.[K
remote: Counting objects: 100% (280/280), done.[K
remote: Compressing objects: 100% (211/211), done.[K
remote: Total 456 (delta 141), reused 177 (delta 58), pack-reused 176[K
Receiving objects: 100% (456/456), 47.13 MiB | 15.12 MiB/s, done.
Resolving deltas: 100% (220/220), done.
Updating files: 100% (29/29), done.
Cloning into 'Autoformer'...
remote: Enumerating objects: 371, done.[K
remote: Counting objects: 100% (204/204), done.[K
remote: Compressing objects: 100% (60/60), done.[K
remote: Total 371 (delta 157), reused 151 (delta 144), pack-reused 167[K
Receiving objects: 100% (371/371), 2.20 MiB | 5.94 MiB/s, done.
Resolving deltas: 100% (220/220), done.


In [2]:
import collections
import math
import os
import pathlib
import random


import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import seaborn as sns
import scipy.stats as sps
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_percentage_error

import torch
from torch import nn, Tensor
from torch.utils.data import Dataset, DataLoader
from tempfile import TemporaryDirectory

from tqdm.notebook import tqdm

In [3]:
def seed_everything(seed: int) -> None:
    """Fix all the random seeds we can for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Datasets

In [6]:
qber_path = pathlib.Path('qber-forecasting')

In [7]:
dataframe = pl.read_csv(qber_path / 'datasets' / 'data.csv')
dataframe.head()

id,e_mu_current,e_mu_estimated,e_nu_1,e_nu_2,q_mu,q_nu1,q_nu2
i64,f64,f64,f64,f64,f64,f64,f64
1506053531,0.01298,0.01164,0.01904,0.17794,0.550377,0.164911,0.008094
1506053531,0.01283,0.00961,0.01672,0.20868,0.564295,0.167629,0.006639
1506053531,0.01268,0.0059,0.01337,0.20442,0.564179,0.16411,0.007052
1506053531,0.01129,0.00988,0.01637,0.18453,0.573555,0.167174,0.006663
1506053531,0.01169,0.01338,0.01783,0.11478,0.569296,0.169658,0.006823


In [8]:
class PolarsDataset(Dataset):
    def __init__(self,
                 data_path: str | pathlib.Path,
                 window_size: int,
                 data_format: str = 'csv',
                 dtype: torch.dtype = torch.float32,
                 columns: list[str] | None = None,
                 device: torch.device = None,
                 offset: int | float = None,
                 limit: int | float = None
                 ):
        assert window_size is not None, data_path is not None
        self.data_format = data_format
        self.window_size = window_size
        self.build_dataset(data_path, offset, limit, columns)
        self.device = device or torch.device('cpu')
        self.dtype = dtype

    def build_dataset(self, data_path, offset, limit, columns) -> None:
        assert self.data_format in ('csv', 'parquet')
        if self.data_format == 'csv':
          dataframe = pl.scan_csv(data_path)
        elif self.data_format == 'parquet':
          dataframe = pl.scan_parquet(data_path)
        length = dataframe.select(pl.count()).collect().item()
        offset, limit = self.calculate_offset_limit(offset, limit, length)
        columns = columns or dataframe.columns
        dataframe = (
            dataframe
            .select(columns)
            .slice(offset, limit)
        )
        self.data_array = dataframe.collect().to_numpy()
        self.dataset = np.lib.stride_tricks.sliding_window_view(
            self.data_array,
            self.window_size + 1,
            axis=0
            )
        self.shape = self.dataset.shape[1:]

    @staticmethod
    def calculate_offset_limit(offset, limit, length) -> tuple[int, int]:
        if offset is None:
          offset = 0
        else:
          offset = offset if offset >= 1 else int(offset * length)
        if limit is None:
          limit = length
        else:
          limit = limit if limit >= 1 else int(limit * length)
        return offset, limit


    def __len__(self) -> int:
        return len(self.dataset)

    def __repr__(self) -> str:
        return f'PolarsDataset(len={self.__len__()})'

    def __getitem__(self, idx) -> tuple[Tensor, Tensor]:
        x = torch.tensor(self.dataset[idx][:, :-1].T, dtype=self.dtype).to(self.device)
        y = torch.tensor(self.dataset[idx][:, -1], dtype=self.dtype).to(self.device)
        return x, y

In [9]:
columns = [
    'e_mu_current',
    'e_nu_1',
    'e_nu_2',
    'q_mu',
    'q_nu1',
    'q_nu2'
]

In [55]:
train_dataset = PolarsDataset(qber_path / 'datasets' / 'data.csv',
                              window_size=30,
                              columns=columns,
                              device=device,
                              limit=0.75)
test_dataset = PolarsDataset(qber_path / 'datasets' / 'data.csv',
                             window_size=30,
                             columns=columns,
                             device=device,
                             offset=0.75)


batch_size = 256
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=batch_size
    )
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    shuffle=True,
    batch_size=batch_size
    )

## Autoformer Ideas

In [None]:
import torch
import torch.nn as nn
import math


class AutoCorrelation(nn.Module):
    """
    AutoCorrelation Mechanism with the following two phases:
    (1) period-based dependencies discovery
    (2) time delay aggregation
    This block can replace the self-attention family mechanism seamlessly.
    """
    def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False):
        super(AutoCorrelation, self).__init__()
        self.factor = factor
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

    def time_delay_agg_training(self, values, corr):
        """
        SpeedUp version of Autocorrelation (a batch-normalization style design)
        This is for the training phase.
        """
        head = values.shape[1]
        channel = values.shape[2]
        length = values.shape[3]
        # find top k
        top_k = int(self.factor * math.log(length))
        mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
        index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
        weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
        # update corr
        tmp_corr = torch.softmax(weights, dim=-1)
        # aggregation
        tmp_values = values
        delays_agg = torch.zeros_like(values).float()
        for i in range(top_k):
            pattern = torch.roll(tmp_values, -int(index[i]), -1)
            delays_agg = delays_agg + pattern * \
                         (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
        return delays_agg

    def time_delay_agg_inference(self, values, corr):
        """
        SpeedUp version of Autocorrelation (a batch-normalization style design)
        This is for the inference phase.
        """
        batch = values.shape[0]
        head = values.shape[1]
        channel = values.shape[2]
        length = values.shape[3]
        # index init
        init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0)\
            .repeat(batch, head, channel, 1).to(values.device)
        # find top k
        top_k = int(self.factor * math.log(length))
        mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
        weights, delay = torch.topk(mean_value, top_k, dim=-1)
        # update corr
        tmp_corr = torch.softmax(weights, dim=-1)
        # aggregation
        tmp_values = values.repeat(1, 1, 1, 2)
        delays_agg = torch.zeros_like(values).float()
        for i in range(top_k):
            tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)
            pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
            delays_agg = delays_agg + pattern * \
                         (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
        return delays_agg

    def time_delay_agg_full(self, values, corr):
        """
        Standard version of Autocorrelation
        """
        batch = values.shape[0]
        head = values.shape[1]
        channel = values.shape[2]
        length = values.shape[3]
        # index init
        init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0)\
            .repeat(batch, head, channel, 1).to(values.device)
        # find top k
        top_k = int(self.factor * math.log(length))
        weights, delay = torch.topk(corr, top_k, dim=-1)
        # update corr
        tmp_corr = torch.softmax(weights, dim=-1)
        # aggregation
        tmp_values = values.repeat(1, 1, 1, 2)
        delays_agg = torch.zeros_like(values).float()
        for i in range(top_k):
            tmp_delay = init_index + delay[..., i].unsqueeze(-1)
            pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
            delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
        return delays_agg

    def forward(self, queries, keys, values, attn_mask):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        if L > S:
            zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
            values = torch.cat([values, zeros], dim=1)
            keys = torch.cat([keys, zeros], dim=1)
        else:
            values = values[:, :L, :, :]
            keys = keys[:, :L, :, :]

        # period-based dependencies
        q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
        k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
        res = q_fft * torch.conj(k_fft)
        corr = torch.fft.irfft(res, n=L, dim=-1)

        # time delay agg
        if self.training:
            V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
        else:
            V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)

        if self.output_attention:
            return (V.contiguous(), corr.permute(0, 3, 1, 2))
        else:
            return (V.contiguous(), None)


class AutoCorrelationLayer(nn.Module):
    def __init__(self, correlation, d_model, n_heads, d_keys=None,
                 d_values=None):
        super(AutoCorrelationLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)

        self.inner_correlation = correlation
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out, attn = self.inner_correlation(
            queries,
            keys,
            values,
            attn_mask
        )
        out = out.view(B, L, -1)

        return self.out_projection(out), attn

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class my_Layernorm(nn.Module):
    """
    Special designed layernorm for the seasonal part
    """
    def __init__(self, channels):
        super(my_Layernorm, self).__init__()
        self.layernorm = nn.LayerNorm(channels)

    def forward(self, x):
        x_hat = self.layernorm(x)
        bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
        return x_hat - bias


class moving_avg(nn.Module):
    """
    Moving average block to highlight the trend of time series
    """
    def __init__(self, kernel_size, stride):
        super(moving_avg, self).__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        # padding on the both ends of time series
        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        x = torch.cat([front, x, end], dim=1)
        x = self.avg(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        return x


class series_decomp(nn.Module):
    """
    Series decomposition block
    """
    def __init__(self, kernel_size):
        super(series_decomp, self).__init__()
        self.moving_avg = moving_avg(kernel_size, stride=1)

    def forward(self, x):
        moving_mean = self.moving_avg(x)
        res = x - moving_mean
        return res, moving_mean


class EncoderLayer(nn.Module):
    """
    Autoformer encoder layer with the progressive decomposition architecture
    """
    def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
        self.decomp1 = series_decomp(moving_avg)
        self.decomp2 = series_decomp(moving_avg)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, attn_mask=None):
        new_x, attn = self.attention(
            x, x, x,
            attn_mask=attn_mask
        )
        x = x + self.dropout(new_x)
        x, _ = self.decomp1(x)
        y = x
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))
        res, _ = self.decomp2(x + y)
        return res, attn


class Encoder(nn.Module):
    """
    Autoformer encoder
    """
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
        self.norm = norm_layer

    def forward(self, x, attn_mask=None):
        attns = []
        if self.conv_layers is not None:
            for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
                x, attn = attn_layer(x, attn_mask=attn_mask)
                x = conv_layer(x)
                attns.append(attn)
            x, attn = self.attn_layers[-1](x)
            attns.append(attn)
        else:
            for attn_layer in self.attn_layers:
                x, attn = attn_layer(x, attn_mask=attn_mask)
                attns.append(attn)

        if self.norm is not None:
            x = self.norm(x)

        return x, attns

In [None]:
import dataclasses

In [None]:
@dataclasses.dataclass
class Config:
    factor: int
    dropout: float
    output_attention: bool
    n_heads: int
    d_model: int
    d_hid: int
    kernel_size_mavg: int
    e_layers: int
    activation: str
    d_window: int

In [None]:
config = Config(
    factor=1,
    dropout=0.1,
    output_attention=False,
    n_heads=2,
    d_model=6,
    d_hid=128,
    kernel_size_mavg=9,
    e_layers=4,
    activation='gelu',
    d_window=30,
)

In [None]:
encoder = Encoder(
    [
        EncoderLayer(
            AutoCorrelationLayer(
                    AutoCorrelation(False, config.factor,
                                    attention_dropout=config.dropout,
                                    output_attention=config.output_attention
                                    ),
                    config.d_model,
                    config.n_heads),
                config.d_model,
                config.d_hid,
                moving_avg=config.kernel_size_mavg,
                dropout=config.dropout,
                activation=config.activation
            ) for l in range(config.e_layers)
        ],
        norm_layer=my_Layernorm(config.d_model)
    ).to(device)

In [None]:
class AutoCorrelationEncoder(nn.Module):
    def __init__(self,
                 config: Config,
                 device: torch.device) -> None:
      super().__init__()
      self.d_model = config.d_model
      self.d_window = config.d_window
      self.device = device
      self.encoder = Encoder(
          [
              EncoderLayer(
              AutoCorrelationLayer(
                  AutoCorrelation(False,
                                  config.factor,
                                  attention_dropout=config.dropout,
                                  output_attention=config.output_attention
                                  ),
                  config.d_model,
                  config.n_heads),
              config.d_model,
              config.d_hid,
              moving_avg=config.kernel_size_mavg,
              dropout=config.dropout,
              activation=config.activation
              ) for l in range(config.e_layers)
          ],
          norm_layer=my_Layernorm(config.d_model)
          ).to(self.device)
      self.mlp = nn.Linear(config.d_model * config.d_window, 1).to(self.device)

      self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.mlp.bias.data.zero_()
        self.mlp.weight.data.uniform_(-initrange, initrange)

    def make_mask(self, length):
        nn.Transformer.generate_square_subsequent_mask(length).to(self.device)


    def forward(self, x: Tensor) -> Tensor:
        mask = self.make_mask(length=x.shape[1])
        outp = self.encoder(x, mask)[0]
        outp = outp.view(-1, self.d_window * self.d_model)
        outp = self.mlp(outp)
        return outp.flatten()

## Transformer

In [125]:
class SimpleTransformer(nn.Module):
    def __init__(self,
                 d_model: int,
                 d_window: int,
                 nhead: int,
                 d_hid: int,
                 nlayers: int,
                 dropout: float = 0.1) -> None:
      super().__init__()
      encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, d_hid, dropout, batch_first=True)
      self.transformer_encoder = nn.TransformerEncoder(encoder_layers, nlayers)
      self.d_input = d_model
      self.mlp = nn.Linear(d_model, d_model)
      self.linear = nn.Linear(d_window, d_model)
      self.act = nn.ReLU()
      self.final = nn.Linear(2 * d_model, 1)

      self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        for layer in (self.mlp, self.linear, self.final):
            layer.bias.data.zero_()
            layer.weight.data.uniform_(-initrange, initrange)


    def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
        src_mask = src_mask or nn.Transformer.generate_square_subsequent_mask(
            src.shape[1]
            ).to(device)
        linear_outp = self.act(self.linear(src[:, :, 0]))
        attn_outp = self.transformer_encoder(src, src_mask)
        mlp_outp = self.act(self.mlp(attn_outp)[:, -1])
        features = torch.cat((mlp_outp, linear_outp), axis=1)
        output = self.final(features)
        return output

## Training

In [126]:
import time

def train(model: nn.Module, train_loader, epoch) -> None:
    model.train()  # turn on train mode
    total_loss = 0.
    log_interval = 100
    start_time = time.time()

    for i, batch in enumerate(train_loader):
        x, y = batch
        output = model(x).flatten()
        targets = y[:, 0]
        loss = criterion(output, targets)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if i % log_interval == 0 and i > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            cur_loss *= 10e6
            print(f'epoch {epoch:3d} | {i:5d}/{len(train_loader):5d} batches | '
                  f'lr {lr:2.5f} | ms/batch {ms_per_batch:5.2f} | '
                  f'loss {cur_loss:0.4f} * 10^(-6)')
            total_loss = 0
            start_time = time.time()

def evaluate(model: nn.Module, test_loader) -> float:
    model.eval()  # turn on evaluation mode
    total_loss = 0.
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            x, y = batch
            output = model(x).flatten()
            targets = y[:, 0]
            total_loss += criterion(output, targets).item()
    return total_loss / (len(test_loader))

In [133]:
seed_everything(123456)

model = SimpleTransformer(
    d_model=train_dataset.shape[0],
    d_window=30,
    nhead=2,
    d_hid=512,
    nlayers=2
    ).to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.9)

In [134]:
best_val_loss = float('inf')
epochs = 5


with TemporaryDirectory() as tempdir:
    best_model_params_path = os.path.join(tempdir, "best_model_params.pt")

    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        train(model, train_loader, epoch)
        val_loss = evaluate(model, test_loader)
        elapsed = time.time() - epoch_start_time
        print('-' * 89)
        print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
            f'valid loss {10e6 * val_loss:0.4f} * 10^(-6)')
        print('-' * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), best_model_params_path)

        scheduler.step()
    model.load_state_dict(torch.load(best_model_params_path)) # load best model states

epoch   1 |   100/  542 batches | lr 0.00100 | ms/batch 24.42 | loss 146.7478 * 10^(-6)
epoch   1 |   200/  542 batches | lr 0.00100 | ms/batch 22.61 | loss 152.0607 * 10^(-6)
epoch   1 |   300/  542 batches | lr 0.00100 | ms/batch 26.62 | loss 151.5342 * 10^(-6)
epoch   1 |   400/  542 batches | lr 0.00100 | ms/batch 23.49 | loss 174.3120 * 10^(-6)
epoch   1 |   500/  542 batches | lr 0.00100 | ms/batch 21.79 | loss 92.2760 * 10^(-6)
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 16.31s | valid loss 463.7904 * 10^(-6)
-----------------------------------------------------------------------------------------
epoch   2 |   100/  542 batches | lr 0.00090 | ms/batch 25.62 | loss 145.7522 * 10^(-6)
epoch   2 |   200/  542 batches | lr 0.00090 | ms/batch 26.16 | loss 68.2979 * 10^(-6)
epoch   2 |   300/  542 batches | lr 0.00090 | ms/batch 22.53 | loss 57.6073 * 10^(-6)
epoch   2 |   400/  542 batches | lr 0.00090 | ms/bat

In [135]:
targets, predictions = [], []

model.eval()
for x, y in test_loader:
    predictions += list(model(x).flatten().cpu().detach().numpy())
    targets += list(y[:, 0].cpu().numpy())

In [137]:
mse = mean_squared_error(targets, predictions)
mape = mean_absolute_percentage_error(targets, predictions)
r2 = r2_score(targets, predictions)

print(f'MAPE = {mape:5.5f} | MSE = {mse * 1e6:.4f} * 10^(-6) | MSE improved: {mse < 0.00001518} | R^2 = {r2:5.3f}')

MAPE = 0.10723 | MSE = 23.8087 * 10^(-6) | MSE improved: False | R^2 = 0.654


## Interpretation

In [None]:
!pip install captum

In [138]:
from captum.attr import (
    GradientShap,
    DeepLift,
    DeepLiftShap,
    IntegratedGradients,
    LayerConductance,
    NeuronConductance,
    NoiseTunnel,
)

from pprint import pprint

In [139]:
x, _ = next(iter(test_loader))
input = x
baseline = torch.zeros_like(input)

In [150]:
model = model.to(torch.device('cpu'))
input = input.to(torch.device('cpu'))
baseline = baseline.to(torch.device('cpu'))

In [None]:
ig = IntegratedGradients(model)
nt = NoiseTunnel(ig)
attributions, delta = nt.attribute(input,
                                   nt_type='smoothgrad',
                                   stdevs=0.02,
                                   nt_samples=4,
                                   baselines=baseline,
                                   target=0,
                                   return_convergence_delta=True
                                   )

In [None]:
print(f'Maximum convergence delta: {torch.max(torch.abs(delta))}')

In [None]:
attributions.mean(axis=0).T