In [None]:
# Standard library imports
import math

# Data handling and numerical processing
import numpy as np
import pandas as pd

# Machine Learning and Deep Learning
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_percentage_error

# Keras and TensorFlow
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K

# Plotting and Visualization
import matplotlib.pyplot as plt
from pandas.plotting import register_matplotlib_converters
register_matplotlib_converters()


In [None]:
#set data

In [None]:
df_origin=pd.read_csv('Original file location')
df_main = df_origin[1000:37000]
scaler = MinMaxScaler(feature_range = (0,1))
df_scaled = scaler.fit_transform(df_main)
df_mm = pd.DataFrame(df_scaled, columns = ['time', 'slip rate', 'accumlated slip', 'shortening', 'normal stress', 'shear stree', 'friction', 'temperature', 'short av', 'normal av', 'shear av', 'friction av', 'temp av', 'slip av', 'energy'])

In [None]:
scaled_df = df_mm[['inp_parameter_label','friction av']]

In [None]:
df_inp_origin = scaled_df
df_inp_n = df_inp_origin*1000
df_inp = df_inp_n.astype(int)
df_inp

In [None]:
data = df_inp
dataset = data.values
training_data_len = math.ceil(len(dataset)*0.8)
training_data = dataset[0:training_data_len, :]

In [None]:
set_train_time = 10
set_pred_time = 5
x_train_or = []
y_train = []

for i in range((set_train_time + set_pred_time), training_data.shape[0]):
    x_train_or.append(training_data[i-(set_train_time + set_pred_time):i-(set_pred_time), 0:-1])
    y_train.append(training_data[i-(set_train_time-1):i+1, -1])

x_train_or = np.array(x_train_or)
y_train = np.array(y_train)

y_train_expanded = np.expand_dims(y_train, axis=-1)
train_data = np.concatenate((x_train_or, y_train_expanded), axis=-1)
train_data = train_data.astype(np.int64)

print(train_data.shape)  # Should print (5040, 50, 3)

In [None]:
def tensor_change(x, y):
    x = tf.convert_to_tensor(x, dtype=tf.int64)
    y = tf.convert_to_tensor(y, dtype=tf.int64)
    return x, y

BUFFER_SIZE = 1000
BATCH_SIZE = 256
train_time = 5
pred_time = 5

def make_batches_train(train_data):
    x_data = train_data[:, 4:(train_time+5), 0] 
    y_data = train_data[:, (pred_time-2):-1, -1]

    ds = tf.data.Dataset.from_tensor_slices((x_data, y_data))

    return (
        ds
        .cache()
        .shuffle(BUFFER_SIZE)
        .batch(BATCH_SIZE)
        .map(tensor_change, num_parallel_calls=tf.data.AUTOTUNE)
        .prefetch(tf.data.AUTOTUNE))

train_batches = make_batches_train(train_data)
x_train_batch, y_train_batch = next(iter(train_batches))
x_train_batch, y_train_batch

In [None]:
val_data = dataset[training_data_len - (train_time):, :]

x_val_or = []
y_val = []

for i in range((train_time + pred_time), val_data.shape[0]):
    x_val_or.append(val_data[i-(train_time + pred_time):i-(pred_time), 0:-1])
    y_val.append(val_data[i-(train_time-1):i+1, -1])

x_val_or = np.array(x_val_or)
y_val = np.array(y_val)

y_val_expanded = np.expand_dims(y_val, axis=-1)

val_data = np.concatenate((x_val_or, y_val_expanded), axis=-1)

val_data = val_data.astype(np.int64)

print(val_data.shape) 

In [None]:
def make_batches_val(val_data):
    x_data = train_data[:, 4:(train_time+5), 0]
    y_data = train_data[:, (pred_time-2):-1, -1]

    ds = tf.data.Dataset.from_tensor_slices((x_data, y_data))

    return (
        ds
        .cache()
        .shuffle(BUFFER_SIZE)
        .batch(BATCH_SIZE)
        .map(tensor_change, num_parallel_calls=tf.data.AUTOTUNE)
        .prefetch(tf.data.AUTOTUNE))

val_batches = make_batches_val(val_data)
x_val_batch, y_val_batch = next(iter(val_batches))
x_val_batch, y_val_batch

In [None]:
#build Transformer Architecture

In [None]:
def get_angles(pos, i, d_model):
    angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
    return pos * angle_rates

def positional_encoding(position, d_model):
    angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :], d_model)

    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    pos_encoding = angle_rads[np.newaxis, ...]
    return tf.cast(pos_encoding, dtype=tf.float32)

def tar_create_look_ahead_mask(size):
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask

def tar_create_padding_mask(seq):
    seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
    return seq[:, tf.newaxis, tf.newaxis, :]

def inp_create_look_ahead_mask(size):
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask

def inp_create_padding_mask(seq):
    seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
    return seq[:, tf.newaxis, tf.newaxis, :]

def scaled_dot_product_attention(q, k, v, mask):
    matmul_qk = tf.matmul(q, k, transpose_b=True)
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

    if mask is not None:
        scaled_attention_logits += (mask * -1e9)

    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
    output = tf.matmul(attention_weights, v)
    return output, attention_weights


In [None]:
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model

        assert d_model % self.num_heads == 0

        self.depth = d_model // self.num_heads

        # Linear Layer
        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.dense = tf.keras.layers.Dense(d_model)

    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])


    def call(self, v, k, q, mask):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)

        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)

        scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)

        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])

        concat_attention = tf.reshape(scaled_attention,
                          (batch_size, -1, self.d_model))

        output = self.dense(concat_attention)

        return output, attention_weights

In [None]:
def point_wise_feed_forward_network(d_model, dff):
    return tf.keras.Sequential([
        tf.keras.layers.Dense(dff, activation='relu'),
        tf.keras.layers.Dense(d_model)
    ])

In [None]:
class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(EncoderLayer, self).__init__()

        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = point_wise_feed_forward_network(d_model, dff)

        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)

    def call(self, x, training, mask):

        attn_output, _ = self.mha(x, x, x, mask)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)

        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(out1 + ffn_output)

        return out2

In [None]:
class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(DecoderLayer, self).__init__()

        self.mha1 = MultiHeadAttention(d_model, num_heads)
        self.mha2 = MultiHeadAttention(d_model, num_heads)

        self.ffn = point_wise_feed_forward_network(d_model, dff)

        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)
        self.dropout3 = tf.keras.layers.Dropout(rate)


    def call(self, x, enc_output, training, look_ahead_mask, padding_mask):

        attn1, attn_weights_block1 \
                = self.mha1(x, x, x, look_ahead_mask)
        attn1 = self.dropout1(attn1, training=training)
        out1 = self.layernorm1(attn1 + x)

        attn2, attn_weights_block2 \
                = self.mha2(enc_output, enc_output, out1, padding_mask)
        attn2 = self.dropout2(attn2, training=training)
        out2 = self.layernorm2(attn2 + out1)

        ffn_output = self.ffn(out2)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.layernorm3(ffn_output + out2)

        return out3, attn_weights_block1, attn_weights_block2


In [None]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, maximum_position_encoding, rate=0.1):
        super(Encoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)

        self.pos_encoding = positional_encoding(maximum_position_encoding, self.d_model)

        self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]

        self.dropout = tf.keras.layers.Dropout(rate)

    def call(self, x, training, mask):
        seq_len = tf.shape(x)[1]

        x = self.embedding(x)

        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))

        x += self.pos_encoding[:, :seq_len, :]

        x = self.dropout(x, training=training)

        for i in range(self.num_layers):
            x = self.enc_layers[i](x, training, mask)

        return x


In [None]:
class Decoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size,
               maximum_position_encoding, rate=0.1):
        super(Decoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
        self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)

        self.dec_layers \
             = [DecoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(rate)

    def call(self, x, enc_output, training, look_ahead_mask, padding_mask):

        seq_len = tf.shape(x)[1]
        attention_weights = {}

        x = self.embedding(x)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))

        x += self.pos_encoding[:, :seq_len, :]

        x = self.dropout(x, training=training)

        for i in range(self.num_layers):
          x, block1, block2 \
            = self.dec_layers[i](x, enc_output, training, look_ahead_mask, padding_mask)

          attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
          attention_weights['decoder_layer{}_block2'.format(i+1)] = block2

        return x, attention_weights

In [None]:
class Transformer(tf.keras.Model):
    def __init__(self, num_layers, d_model,  num_heads, dff, input_vocab_size, target_vocab_size, pe_input, pe_target, rate=0.1):
        super().__init__()
        self.encoder = Encoder(num_layers, d_model, num_heads, dff,input_vocab_size, pe_input, rate)
        self.decoder = Decoder(num_layers, d_model, num_heads, dff, target_vocab_size, pe_target, rate)
        self.final_layer = tf.keras.layers.Dense(target_vocab_size) 

    def call(self, inputs, training):
        inp, tar = inputs
        enc_padding_mask, look_ahead_mask, dec_padding_mask = self.create_masks(inp, tar)
        enc_output = self.encoder(inp, training, enc_padding_mask) 
        dec_output, attention_weights = self.decoder(tar, enc_output, training, look_ahead_mask, dec_padding_mask)
        final_output = self.final_layer(dec_output)
        return final_output, attention_weights

    def create_masks(self, inp, tar):
        enc_padding_mask = inp_create_padding_mask(inp)
        dec_padding_mask = inp_create_padding_mask(inp)
        look_ahead_mask = tar_create_look_ahead_mask(tf.shape(tar)[1])
        dec_target_padding_mask = tar_create_padding_mask(tar)
        look_ahead_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

        return enc_padding_mask, look_ahead_mask, dec_padding_mask

In [None]:
#set hyperparameters

In [None]:
num_layers =   #ref = 6
d_model =      #ref = 256 or 512
dff =          #ref = 1024 or 2048
num_heads =    #ref = 6 to 8
dropout_rate = #ref = 0.1 to 0.2
input_vocab_size = #ref = 1001
target_vocab_size = #ref = 1001
pe_input =     #ref = 1 to 6
pe_target =    #ref = 1to 6

In [None]:
#set training conditions

In [None]:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=6000):
        super().__init__()
        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)
        self.warmup_steps = warmup_steps

    def __call__(self, step):
        step = tf.cast(step, dtype=tf.float32)
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

In [None]:
learning_rate = CustomSchedule(d_model)

optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9, clipvalue=0.5)

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

In [None]:
def loss_function(real, pred):
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  loss_ = loss_object(real, pred)

  mask = tf.cast(mask, dtype=loss_.dtype)
  loss_ *= mask

  return tf.reduce_sum(loss_)/tf.reduce_sum(mask)

def accuracy_function(real, pred):
  accuracies = tf.equal(real, tf.argmax(pred, axis=2))

  mask = tf.math.logical_not(tf.math.equal(real, 0))
  accuracies = tf.math.logical_and(mask, accuracies)

  accuracies = tf.cast(accuracies, dtype=tf.float32)
  mask = tf.cast(mask, dtype=tf.float32)
  return tf.reduce_sum(accuracies)/tf.reduce_sum(mask)

In [None]:
def loss_function_val(real, pred):
    mse = tf.keras.losses.MeanSquaredError()
    loss = mse(real, pred)
    return  tf.reduce_mean(loss)

In [None]:
def accuracy_function_val(real, pred):
    mape = mean_absolute_percentage_error
    loss = mape(real, pred)
    return  tf.reduce_mean(loss)

In [None]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.Mean(name='train_accuracy')

In [None]:
transformer = Transformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    dff=dff,
    input_vocab_size=input_vocab_size, 
    target_vocab_size=input_vocab_size,
    pe_input=pe_input, 
    pe_target=pe_target,
    rate=dropout_rate)

In [None]:
checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(transformer=transformer,
                           optimizer=optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=400)

In [None]:
train_step_signature = [
    tf.TensorSpec(shape=(None, None), dtype=tf.int64), ##################################num_feature지정
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]


@tf.function(input_signature=train_step_signature)

def train_step(inp, tar):
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]

    with tf.GradientTape() as tape:
        predictions, _ = transformer([inp, tar_inp], training = True)
        loss = loss_function(tar_real, predictions)

    gradients = tape.gradient(loss, transformer.trainable_variables)
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

    train_loss(loss)
    train_accuracy(accuracy_function(tar_real, predictions))

In [None]:
def diff(tensor, axis=-1):
    return tensor[..., 1:] - tensor[..., :-1]

In [None]:
class Translator(tf.Module):
    def __init__(self, transformer):
        self.transformer = transformer

    def __call__(self, inp_single, tar_single, max_length=None):
        encoder_input = inp_single[:, :] 
        start = tf.expand_dims(tar_single[0,0], axis=0)

        output_array = tf.TensorArray(dtype=tf.int64, size=0, dynamic_size=True)
        output_array = output_array.write(0, start)

        for i in tf.range(max_length):
            output = tf.transpose(output_array.stack())
            predictions, _ = self.transformer([encoder_input, output], training=False)
            predictions = predictions[:, -1:, :]
            predicted_id = tf.argmax(predictions, axis=-1)
            output_array = output_array.write(i + 1, predicted_id[0])

        output = tf.transpose(output_array.stack())
        output = output[:, 1:] 

        _, attention_weights = self.transformer([encoder_input, output[:, :-1]], training=False)

        return output, attention_weights



def evaluate_validation_batches(translator, val_batches, num_val_batches=None, max_length=None):

    val_loss_value = tf.keras.metrics.Mean(name='val_loss_value')
    val_loss_delta = tf.keras.metrics.Mean(name='val_loss_delta')
    val_accuracy = tf.keras.metrics.Mean(name='val_accuracy')
    
    for (batch, (inp, tar)) in enumerate(val_batches.take(num_val_batches)):
        batch_size = inp.shape[0]
        predictions_list = []
        tar_total_list = []

        for i in range(batch_size):
            inp_single = inp[i:i+1]  
            tar_single = tar[i:i+1]
            outputs, _ = translator(inp_single, tar_single, max_length)
            predictions_list.append(outputs)

            tar_single = tar[i:i+1]  
            tar_total_list.append(tar_single)

        predictions = tf.concat(predictions_list, axis=0)
        print(f'prediction value {predictions.numpy()}') 
        tar_total = tf.concat(tar_total_list, axis = 0)
        print(f'true value {tar_total.numpy()}')

        delta_predictions = diff(predictions, axis=1)
        delta_tar_total = diff(tar_total, axis=1)

        loss_val = loss_function_val(tar_total, predictions)
        loss_del = loss_function_val(delta_tar_total, delta_predictions)

        validation1 = val_loss_value(loss_val)
        validation2 = val_loss_delta(loss_del)

        print(f'val_loss_value {validation1}')
        print(f'val_loss_delta {validation2}')
    return val_loss_value.result(), val_loss_delta.result()


In [None]:
#training

In [None]:
EPOCHS = 200
num_val_batches = # ref = number of validation dataset/batch size
max_length = #ref = 6
translator = Translator(transformer)
import time
for epoch in range(EPOCHS):
    start = time.time()
    train_loss.reset_states()
    train_accuracy.reset_states()

    for (batch, (inp, tar)) in enumerate(train_batches):
        train_step(inp, tar)

        if batch % 50 == 0:
            print(f'Epoch {epoch + 1} Batch {batch} Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f}')

    val_loss_value, val_loss_delta = evaluate_validation_batches(translator, val_batches, num_val_batches, max_length)
    print(f'Validation Loss Value: {val_loss_value:.4f}')
    print(f'Validation Loss Delta: {val_loss_delta:.4f}')

    ckpt_save_path = ckpt_manager.save()
    print(f'Saving checkpoint for epoch {epoch+1} at {ckpt_save_path}')

    print(f'Epoch {epoch + 1} Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f}')

    print(f'Time taken for 1 epoch: {time.time() - start:.2f} secs\n')


In [None]:
#load checkpoint model

In [None]:
checkpoint_path_specific = "./checkpoints/train/ckpt-26"  

ckpt = tf.train.Checkpoint(transformer=transformer,
                           optimizer=optimizer)


ckpt.restore(checkpoint_path_specific)
print('Specific checkpoint restored!!')

In [None]:
#test loaded model

In [None]:
def make_batches_sequential(data):
    x_data = data[:, 4:(train_time+5), 0]
    y_data = data[:, (pred_time-2):-1, -1]
    ds = tf.data.Dataset.from_tensor_slices((x_data, y_data))

    return (
        ds
        .cache()
        .batch(BATCH_SIZE)
        .map(tensor_change, num_parallel_calls=tf.data.AUTOTUNE)
        .prefetch(tf.data.AUTOTUNE))

In [None]:
val_batches_seq = make_batches_sequential(val_data)

In [None]:
translator = Translator(transformer)
def predict_batches(translator, val_batches, max_length=None):
    predictions_list = []

    for (batch, (inp, tar)) in enumerate(val_batches_seq):
        batch_size = inp.shape[0]

        for i in range(batch_size):
            inp_single = inp[i:i+1]
            tar_single = tar[i:i+1]
            outputs, _ = translator(inp_single, tar_single, max_length)
            predictions_list.append(outputs)

    predictions = tf.concat(predictions_list, axis=0)

    return predictions

predictions = predict_batches(translator, val_batches_seq, max_length=6)
print(predictions)

In [None]:
#estimate attention weight

In [None]:
class Trans(tf.Module):
    def __init__(self, transformer):
        self.transformer = transformer

    def __call__(self, inp_single, max_length=None):
        encoder_input = inp_single[:,:]

        start = tf.expand_dims(inp_single[0, -1], axis=0)

        output_array = tf.TensorArray(dtype=tf.int64, size=0, dynamic_size=True)
        output_array = output_array.write(0, start)

        for i in tf.range(max_length):
            output = tf.transpose(output_array.stack())
            predictions, _ = self.transformer([encoder_input, output], training=False)

            predictions = predictions[:, -1:, :]

            predicted_id = tf.argmax(predictions, axis=-1)


            output_array = output_array.write(i + 1, predicted_id[0])

        output = tf.transpose(output_array.stack())
        output = output[:, 1:] 

        _, attention_weights = self.transformer([encoder_input, output[:, :-1]], training=False)

        return output, attention_weights

def translator1(trans, val_batches, max_length=None):
    inp_single = val_batches
    outputs, attention_weight = trans(inp_single, max_length)
    return outputs, attention_weight


In [None]:
loc1= #ref = data index - 1
loc2= #ref = data index
trans = Trans(transformer)
im_input=x_val_seq_batch[loc1:loc2] 
val_data1=y_val_seq_batch[loc1:loc2]
print(val_data1)
print(im_input)

In [None]:
output, attention_weights = translator1(trans, im_input, max_length = 6)
print(output)
print(attention_weights)

In [None]:
head = 0
attention_heads = tf.squeeze(attention_weights['decoder_layer6_block2'], 0)
attention = attention_heads[head]
attention.shape
print(attention)

In [None]:
def plot_attention_head(im_input, output, attention, filename="plot.png"):
    # convert tensor to numpy
    im_input = im_input.numpy()
    output = output.numpy()

    # Normalize attention values to range [0, 10]
    attention = 10 * (attention - np.min(attention)) / (np.max(attention) - np.min(attention))

    fig, ax = plt.subplots(figsize=(8,8))
    cax = ax.matshow(attention, cmap='viridis')

    # Define the labels for the ticks
    x_labels = [1, 2, 3, 4, 5, 6]
    y_labels = [1, 2, 3, 4, 5]

    ax.set_xticks(range(len(x_labels)))
    ax.set_yticks(range(len(y_labels)))

    ax.set_xticklabels(x_labels)
    ax.set_yticklabels(y_labels)

    # Add the attention values as text annotations
    for i in range(attention.shape[0]):
        for j in range(attention.shape[1]):
            text = ax.text(j, i, f'{attention[i,j]:.2f}',
                       ha="center", va="center", color="w")

    # Add a colorbar legend
    cbar_ax = fig.add_axes([0.92, 0.15, 0.03, 0.7])  # Adjust these parameters to change the position and size of the colorbar
    plt.colorbar(cax, cax=cbar_ax)

    # Save the figure as a high-resolution PNG
    plt.savefig(filename, dpi=300)

    plt.show()


In [None]:
plot_attention_head(im_input, output, attention, filename="imp921.png")