In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, TensorDataset, DataLoader

In [None]:
class positionalEncoder(nn.Module):

  def __init__(self, frame_length, encoding_length):
    super().__init__()

    embedding = nn.Embedding(frame_length, encoding_length)

    self.pe = embedding(torch.tensor([i for i in range(frame_length)]).unsqueeze(1)).squeeze()

  def forward(self, x):

    return torch.cat((x, self.pe), 1)

class classifierNetwork(nn.Module):

  def __init__(self, inFeatCount, num_T_layers, num_frames, num_fc_layers = 5, pos_encode_size = 3, n_heads = 5, n_hidden = 500, dropout = 0.3, outFeatCount = 2):
    super().__init__()

    self.posEncoder = positionalEncoder(num_frames, pos_encode_size)

    n_heads = inFeatCount + pos_encode_size

    encoder_layer = nn.TransformerEncoderLayer(inFeatCount + pos_encode_size, n_heads, n_hidden, dropout)
    self.encoder = nn.TransformerEncoder(encoder_layer, num_T_layers)

    decoder_layer = nn.TransformerDecoderLayer(inFeatCount + pos_encode_size, n_heads)
    self.decoder = nn.TransformerDecoder(decoder_layer, num_T_layers)

    mid = ((inFeatCount + pos_encode_size) - outFeatCount) // 2 + outFeatCount

    self.fc1 = nn.Linear(inFeatCount + pos_encode_size, mid)
    self.fc2 = nn.Linear(mid, outFeatCount)

    self.init_weights()

  def init_weights(self):
      initrange = 0.1
      self.fc1.bias.data.zero_()
      self.fc1.weight.data.uniform_(-initrange, initrange)

      self.fc2.bias.data.zero_()
      self.fc2.weight.data.uniform_(-initrange, initrange)

  def forward(self, x):

    encoded = self.posEncoder(x)
    data = self.encoder(encoded)
    data = self.decoder(encoded, data)

    data = self.fc1(data)
    data = self.fc2(data)
    data = nn.functional.softmax(data, dim = 1)      

    return data

In [None]:
featCount = 10
num_frames = 20
encoder_layers = 2

net = classifierNetwork(featCount, encoder_layers, num_frames)

#do something