Testing google colab and github wooo

In [19]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import math

Just going to run a mock transformer model I built

In [5]:
dmodel = 512    # embedding, internal latent vector size
h = 8           # number of vkq vectors to spawn for attention
dk = int(dmodel / h) # size of vkq vectors for concatenation
inSeqLen = 250  # input sequence length
dff = 2048       # feed forward internal length

In [6]:
device = None
if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')

In [7]:
class SDAttention(torch.nn.Module):
  def __init__(self, masked=False):
    super(SDAttention, self).__init__()
    self.smax = torch.nn.Softmax(dim=-1)
    self.masked = masked

  def forward(self, v, k, q):
    self_attIn = torch.matmul(q, torch.transpose(k, -1, -2))
    if self.masked:
      self_attIn[self_attIn == 0] = -torch.inf
    self_attention = self.smax(self_attIn / (dk ** 0.5))
    return torch.matmul(self_attention, v)

In [8]:
class MultiHead(torch.nn.Module):
  def __init__(self, dmodel, dk, h, masked=False):
    super().__init__()
    self.sdattention = SDAttention(masked=masked)
    self.vmats = [torch.zeros([dmodel, dk], requires_grad=True, device=device) for _ in range(h)]
    self.kmats = [torch.zeros([dmodel, dk], requires_grad=True, device=device) for _ in range(h)]
    self.qmats = [torch.zeros([dmodel, dk], requires_grad=True, device=device) for _ in range(h)]
    self.outMat = torch.zeros([dmodel, dmodel], requires_grad=True, device=device)
    for vmat in self.vmats:
      torch.nn.init.xavier_normal_(vmat)
    for kmat in self.kmats:
      torch.nn.init.xavier_normal_(kmat)
    for qmat in self.qmats:
      torch.nn.init.xavier_normal_(qmat)    
    torch.nn.init.xavier_normal_(self.outMat)

    self.params = self.vmats + self.kmats + self.qmats
    self.params.append(self.outMat)

  def forward(self, v, k, q):
    outs = []
    for i in range(h):
      vmat = self.vmats[i]
      kmat = self.kmats[i]
      qmat = self.qmats[i]

      vIn = torch.matmul(v, vmat)
      kIn = torch.matmul(k, kmat)
      qIn = torch.matmul(q, qmat)

      attOut = self.sdattention(vIn, kIn, qIn)

      outs.append(attOut)
    out = torch.concat(outs, dim=-1)
    out = torch.matmul(out, self.outMat)
    return out

In [10]:
selfAtt = MultiHead(dmodel, dk, h, True)
v = torch.rand([inSeqLen, dmodel], device=device)
k = torch.rand([inSeqLen, dmodel], device=device)
q = torch.rand([inSeqLen, dmodel], device=device)
selfAtt(v,k,q).shape

torch.Size([250, 512])

In [12]:
class PositionEncode(torch.nn.Module):
  def __init__(self, seq_len, dmodel):
    super().__init__()
    encoded = torch.zeros([seq_len, dmodel])
    for pos in range(seq_len):
      for i in range(0, dmodel, 2):
        encoded[pos][i] = math.sin(pos / (10000 ** (i / dmodel)))
        encoded[pos][i+1] = math.cos(pos / (10000 ** (i / dmodel)))
    self.encoded = encoded.to(device)
  
  def forward(self, x):
    encoded = self.encoded
    #if len(x.shape) == 3:
    #  encoded = torch.stack([encoded for _ in range(x.shape[0])])
    
    return x + encoded

In [13]:
class FF(torch.nn.Module):
  def __init__(self, dmodel, dff):
    super().__init__()
    self.lin0 = torch.nn.Linear(dmodel, dff, device=device)
    #torch.nn.init.kaiming_normal_(self.lin0.weight)
    torch.nn.init.xavier_normal_(self.lin0.weight)
    self.lin1 = torch.nn.Linear(dff, dmodel, device=device)
    torch.nn.init.xavier_normal_(self.lin1.weight)
    self.relu = torch.nn.ReLU()
    self.dropout = torch.nn.Dropout(p=0.1)
  
  def forward(self, x):
    x = self.lin0(x)
    x = self.dropout(x)
    x = self.relu(x)
    x = self.lin1(x)
    return x

In [14]:
class TransformerEncoder(torch.nn.Module):
  def __init__(self, seq_len, dmodel, dk, h, dff):
    super().__init__()
    self.multiHead = MultiHead(dmodel, dk, h)
    self.ff = FF(dmodel, dff)
    self.layerNorm0 = torch.nn.LayerNorm(dmodel, device=device)
    self.layerNorm1 = torch.nn.LayerNorm(dmodel, device=device)
    self.dropout = torch.nn.Dropout(p=0.1)

    self.params = list(self.ff.parameters()) + self.multiHead.params + list(self.layerNorm0.parameters()) + list(self.layerNorm1.parameters())

  def forward(self, x):
    headOut = self.multiHead(x, x, x)
    headOut = self.dropout(headOut)
    x = self.layerNorm0(headOut +  x)

    feedOut = self.ff(x)
    feedOut = self.dropout(feedOut)
    x = self.layerNorm1(feedOut + x)
    
    return x

In [16]:
inVal = torch.rand([3, inSeqLen, dmodel], device=device)
encoder = TransformerEncoder(inSeqLen, dmodel, dk, h, dff).to(device)
encoder(inVal).shape

torch.Size([3, 250, 512])

In [17]:
class Senti(torch.nn.Module):
  def __init__(self, seq_len, dmodel, dk, h, dff):
    super().__init__()
    self.posEncode = PositionEncode(seq_len, dmodel)
    encoders = [TransformerEncoder(seq_len, dmodel, dk, h, dff) for _ in range(3)]
    self.encoders = torch.nn.Sequential(*encoders)
    self.flatten = torch.nn.Flatten()

    lin0 = torch.nn.Linear(seq_len * dmodel, 1)
    torch.nn.init.xavier_normal_(lin0.weight)
    #lin1 = torch.nn.Linear(1000, 100)
    #torch.nn.init.xavier_normal_(lin1.weight)
    #lin2 = torch.nn.Linear(100, 1)
    #torch.nn.init.xavier_normal_(lin2.weight)
    self.ff = torch.nn.Sequential(
                                  #torch.nn.BatchNorm1d(seq_len * dmodel),
                                  lin0,
                                  #torch.nn.ReLU(),
                                  #torch.nn.BatchNorm1d(1000),
                                  #lin1,
                                  #torch.nn.ReLU(),
                                  #torch.nn.BatchNorm1d(100),
                                  #lin2,
                                  torch.nn.Sigmoid(),
    )

    encoderParams = []
    for encoder in self.encoders:
      encoderParams += encoder.params
    self.params = encoderParams + list(self.ff.parameters())
  
  def forward(self, x):
    x = self.posEncode(x)
    x = self.encoders(x)
    x = self.flatten(x)
    x = self.ff(x)
    return x

In [20]:
inVal = torch.rand([3, inSeqLen, dmodel], device=device)
senti = Senti(inSeqLen, dmodel, dk, h, dff).to(device)
senti(inVal).shape

torch.Size([3, 1])