In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.utils as torch_utils

import numpy as np
import pandas as pd
import math
import matplotlib.pyplot as plt
from scipy.signal import spectrogram
from tqdm.auto import tqdm
device = torch.device('cuda:0')

# Uncomment the line below to download the dataset (when running on Google Colab)
# !gdown 1Y42nOM606No8IXlJB2Ezizu6H0QMprHT
# filepath = "/content/short_data.pkl"

filepath = "short_data.pkl"

In [None]:
class EMGsignals(Dataset):
  def __init__(self, filename, train, split=0.5):
      self.data = pd.read_pickle(filename)
      self.train = train
      self.raw = self.data["raw"]
      self.preprocess = self.data["preprocess"]
      self.n = self.data.shape[0]
      self.split = split

  def __getitem__(self, index):
      if self.train:
          raw_tensor = torch.tensor(self.raw[index], dtype=torch.float32).to(device)
          preprocess_tensor = torch.tensor(self.preprocess[index], dtype=torch.float32).to(device)
          return raw_tensor, preprocess_tensor
      else:
          constant = self.split * self.n
          raw_tensor = torch.tensor(self.raw[index + 100], dtype=torch.float32).to(device)
          preprocess_tensor = torch.tensor(self.preprocess[index + 100], dtype=torch.float32).to(device)
          return raw_tensor, preprocess_tensor

  def __len__(self):
      return int(self.n * self.split) if self.train else int(self.n * (1 - self.split))

In [None]:
train_data = EMGsignals(filepath, train=True, split=0.8)
batch_size = 32
train_dataloader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
val_data = EMGsignals(filepath, train=False, split=0.2)
batch_size = 32
val_dataloader = DataLoader(dataset=val_data, batch_size=batch_size, shuffle=True)

In [None]:
class AttentionHead(nn.Module):
    def __init__(self, dim: int, n_hidden: int):
        # dim: the dimension of the input
        # n_hidden: the dimension of the keys, queries, and values

        super().__init__()

        self.W_K = nn.Linear(dim, n_hidden)  # W_K weight matrix
        self.W_Q = nn.Linear(dim, n_hidden)  # W_Q weight matrix
        self.W_V = nn.Linear(dim, n_hidden)  # W_V weight matrix
        self.n_hidden = n_hidden

    def forward(self, x, attn_mask):
        # x                the inputs. shape: (B x T x dim)
        # attn_mask        an attention mask. If None, ignore. If not None, then mask[b, i, j]
        #                  contains 1 if (in batch b) token i should attend on token j and 0
        #                  otherwise. shape: (B x T x T)
        # Outputs:
        # attn_output      the output of performing self-attention on x. shape: (Batch x Num_tokens x n_hidden)

        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)
        num = torch.matmul(Q, K.transpose(-2, -1))

        if attn_mask is not None:
            num = num.masked_fill(attn_mask == 0, -1e6)
        
        alpha = nn.functional.softmax(
            num / torch.sqrt(torch.tensor(self.n_hidden, dtype=torch.float)), dim=-1
        )
        
        attn_output = torch.matmul(alpha, V)

        return attn_output


class MultiHeadedAttention(nn.Module):
    def __init__(self, dim: int, n_hidden: int, num_heads: int):
        # dim: the dimension of the input
        # n_hidden: the hidden dimenstion for the attention layer
        # num_heads: the number of attention heads
        super().__init__()

        self.dim = dim
        self.n_hidden = n_hidden
        self.num_heads = num_heads

        self.heads = []
        for _ in range(self.num_heads):
            self.heads.append(AttentionHead(self.dim, self.n_hidden))
        self.heads = nn.ModuleList(self.heads)

        self.project = nn.Linear(self.num_heads * self.n_hidden, self.dim)

    def forward(self, x, attn_mask):
        # x                the inputs. shape: (B x T x dim)
        # attn_mask        an attention mask. If None, ignore. If not None, then mask[b, i, j]
        #                  contains 1 if (in batch b) token i should attend on token j and 0
        #                  otherwise. shape: (B x T x T)
        #
        # Outputs:
        # attn_output      the output of performing multi-headed self-attention on x.
        #                  shape: (B x T x dim)


        attn_output = []

        for head in self.heads:
            temp1 = head.forward(x, attn_mask)
            attn_output.append(temp1)

        attn_output = self.project(torch.cat(attn_output, dim=-1))

        return attn_output


class FFN(nn.Module):
    def __init__(self, dim: int, n_hidden: int):
        # dim       the dimension of the input
        # n_hidden  the width of the linear layer

        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, n_hidden),
            nn.GELU(),
            nn.Linear(n_hidden, dim),
        )

    def forward(self, x: torch.Tensor)-> torch.Tensor:
        # x         the input. shape: (B x T x dim)

        # Outputs:
        # out       the output of the feed-forward network: (B x T x dim)
        return self.net(x)

class AttentionResidual(nn.Module):
    def __init__(self, dim: int, attn_dim: int, mlp_dim: int, num_heads: int):
        # dim       the dimension of the input
        # attn_dim  the hidden dimension of the attention layer
        # mlp_dim   the hidden layer of the FFN
        # num_heads the number of heads in the attention layer
        super().__init__()
        self.attn = MultiHeadedAttention(dim, attn_dim, num_heads)
        self.ffn = FFN(dim, mlp_dim)

    def forward(self, x, attn_mask):
        # x                the inputs. shape: (B x T x dim)
        # attn_mask        an attention mask. If None, ignore. If not None, then mask[b, i, j]
        #                  contains 1 if (in batch b) token i should attend on token j and 0
        #                  otherwise. shape: (B x T x T)
        #
        # Outputs:
        # attn_out         shape: (B x T x dim)

        attn_out = self.attn(x=x, attn_mask=attn_mask)
        x = attn_out + x
        x = self.ffn(x) + x
        return x

class TransformerLayer(nn.Module):
    def __init__(self, dim, num_heads, num_layers, attn_dim=16, mlp_dim=16):
        # dim       the dimension of the input
        # attn_dim  the hidden dimension of the attention layer
        # mlp_dim   the hidden layer of the FFN
        # num_heads the number of heads in the attention layer
        # num_layers the number of attention layers.
        super().__init__()

        self.residuals = []
        for _ in range(num_layers):
            self.residuals.append(AttentionResidual(dim, attn_dim, mlp_dim, num_heads))
        self.residuals = nn.ModuleList(self.residuals)

        self.dim = dim
        self.attn_dim = attn_dim
        self.mlp_dim = mlp_dim
        self.num_heads = num_heads
        self.num_layers = num_layers

    def forward(self, x, attn_mask):
        # x                the inputs. shape: (B x T x dim)
        # attn_mask        an attention mask. Pass this to each of the AttentionResidual layers!
        #                  shape: (B x T x T)
        #
        # Outputs:
        # attn_output      shape: (B x T x dim)

        output = x

        for layer in self.residuals:
            output = layer(output, attn_mask)

        return output

class Encoder(nn.Module):
    def __init__(self, input_dim, num_heads, num_attn_layers, num_linear, lin_dim, latent_dim):
        super(Encoder, self).__init__()

        self.transformer = TransformerLayer(input_dim, num_heads, num_attn_layers)
        self.linear = nn.Sequential()

        if num_linear < 2:
            raise AssertionError("Number of linear layers must be at least 2")

        for i in range(num_linear):
            if i == 0:
                self.linear.add_module(f"linear_{i}", nn.Linear(input_dim, lin_dim))
                self.linear.add_module(f"relu_{i}", nn.ReLU())
            elif i == num_linear - 1:
                self.linear.add_module(f"linear_{i}", nn.Linear(lin_dim, latent_dim))
            else:
                self.linear.add_module(f"linear_{i}", nn.Linear(lin_dim, lin_dim))
                self.linear.add_module(f"relu_{i}", nn.ReLU())

    def forward(self, x):
        attended = self.transformer(x, attn_mask=None)
        output = self.linear(x)
        return output

In [None]:
class Decoder(nn.Module):
    def __init__(self, input_size, output_size, inner_layers=1):
        super(Decoder, self).__init__()

        self.lin1 = nn.Linear(input_size, input_size)

        self.inner = nn.ModuleList()

        if inner_layers < 1:
            raise AssertionError("Number of inner linear layers must be at least 1")

        for i in range(inner_layers):
            self.inner.append(nn.Linear(input_size, input_size))

        self.linear = nn.Linear(input_size, output_size)
        self.nonlinear = nn.ReLU()

    def forward(self, x):
        x = self.lin1(x)
        x = self.nonlinear(x)

        for inner_layer in self.inner:
            x = inner_layer(x)
            x = self.nonlinear(x)

        x = self.linear(x)
        x = self.nonlinear(x)

        return x

In [None]:
class CustomLoss(nn.Module):
    def __init__(self, overshoot_penalty):
        super(CustomLoss, self).__init__()
        self.overshoot_penalty = overshoot_penalty

    def forward(self, guess, truth):
        assert guess.shape == truth.shape, "Input shapes do not match."

        error = torch.abs(guess - truth)
        overshoot_penalty = torch.where(truth < guess, self.overshoot_penalty * error, error)
        mean_error = overshoot_penalty.sum() / torch.numel(overshoot_penalty)

        return mean_error

In [None]:
def train_autoencoder(train_dataloader, val_dataloader, enc, dec, lr=2e-5):
    r"""
    Train encoder and decoder networks with `latent_dim` latent dimensions according
    to the autoencoder objective (i.e., MSE reconstruction).

    Returns the trained encoder and decoder.
    """

    optim = torch.optim.Adam(list(enc.parameters()) + list(dec.parameters()), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=5, gamma=0.5)  # Learning rate scheduler

    best_val_loss = float('inf')
    patience = 5
    counter = 0
    clip_norm = 0.5

    train_losses = []
    val_losses = []

    for epoch in tqdm(range(num_epochs), desc=f"{num_epochs} epochs total"):
        enc.train()
        dec.train()
        for raw_batch, preprocess_batch in tqdm(train_dataloader, desc="Training Batches", leave=False):
            optim.zero_grad()

            loss_func = CustomLoss(2.25)

            embed = enc(raw_batch)
            output = dec(embed)

            loss = loss_func(output, preprocess_batch)
            loss.backward()

            torch_utils.clip_grad_norm_(list(enc.parameters()) + list(dec.parameters()), clip_norm)
            optim.step()

        scheduler.step()

        enc.eval()
        dec.eval()
        val_loss = 0.0
        with torch.no_grad():
            for raw_batch_val, preprocess_batch_val in val_dataloader:
                embed_val = enc(raw_batch_val)
                output_val = dec(embed_val)
                val_loss += loss_func(output_val, preprocess_batch_val).item()
        val_loss /= len(val_dataloader)

        train_losses.append(loss.item())
        val_losses.append(val_loss)

        print(f"[Autoencoder] epoch {epoch + 1: 4d}   Training Loss = {loss.item():.4g}   Validation Loss = {val_loss:.4g}")

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print(f"Early stopping triggered after {counter} epochs without improvement.")
                break

    return train_losses, val_losses

In [None]:
input_size = 56

num_heads = 4
num_attn_layers = 2

num_linear = 2
lin_dim = 42
latent_dim = 35

output_size = 56

In [None]:
num_epochs = 25
enc = Encoder(input_size, num_heads, num_attn_layers, num_linear, lin_dim, latent_dim).to(device)
dec = Decoder(latent_dim, output_size, inner_layers=1).to(device)
train_losses, val_losses = train_autoencoder(train_dataloader, val_dataloader, enc, dec)
print("Training finished.")

In [None]:
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()