In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.xlstm.utils import BlockDiagonal, CausalConv1D
class sLSTMblock(nn.Module):
    def __init__(self, params):
        super().__init__()
        
        self.n_embd = params.vocab_size
        self.ln = nn.LayerNorm(self.n_embd)
        
        self.conv = CausalConv1D(self.n_embd, self.n_embd, int(self.n_embd/8))
        self.drop = nn.Dropout(params.dropout)
        
        self.i_gate = BlockDiagonal(self.n_embd, self.n_embd, params.depth)
        self.f_gate = BlockDiagonal(self.n_embd, self.n_embd, params.depth)
        self.o_gate = BlockDiagonal(self.n_embd, self.n_embd, params.depth)
        self.z_gate = BlockDiagonal(self.n_embd, self.n_embd, params.depth)
        
        self.ri_gate = BlockDiagonal(self.n_embd, self.n_embd, params.depth, bias=False)
        self.rf_gate = BlockDiagonal(self.n_embd, self.n_embd, params.depth, bias=False)
        self.ro_gate = BlockDiagonal(self.n_embd, self.n_embd, params.depth, bias=False)
        self.rz_gate = BlockDiagonal(self.n_embd, self.n_embd, params.depth, bias=False)

        self.ln_i = nn.LayerNorm(self.n_embd)
        self.ln_f = nn.LayerNorm(self.n_embd)
        self.ln_o = nn.LayerNorm(self.n_embd)
        self.ln_z = nn.LayerNorm(self.n_embd)
        
        self.GN = nn.LayerNorm(self.n_embd)
        self.ln_c = nn.LayerNorm(self.n_embd)
        self.ln_n = nn.LayerNorm(self.n_embd)
        self.ln_h = nn.LayerNorm(self.n_embd)
        
        self.left_linear = nn.Linear(self.n_embd, int(self.n_embd*(4/3)))
        self.right_linear = nn.Linear(self.n_embd, int(self.n_embd*(4/3)))

        self.ln_out = nn.LayerNorm(int(self.n_embd*(4/3)))
        
        self.proj = nn.Linear(int(self.n_embd*(4/3)), self.n_embd)
        
        self.init_states(params)
        
    def init_states(self, params):
        self.nt_1 = torch.zeros(1, 1, self.n_embd, device=params.device)
        self.ct_1 = torch.zeros(1, 1, self.n_embd, device=params.device)
        self.ht_1 = torch.zeros(1, 1, self.n_embd, device=params.device)
        self.mt_1 = torch.zeros(1, 1, self.n_embd, device=params.device)
        
    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.ln(x)
        
        x_conv = F.silu( self.drop(self.conv( x.transpose(1, 2) ).transpose(1, 2) ) )
        
        batch_size = x.size(0)  # Get dynamic batch size from input
        
        if self.mt_1.size(0) != batch_size:
            # Adjust self.mt_1 to match current batch size.
            self.mt_1 = self.mt_1[:batch_size]

        # start sLSTM
        ht_1 = self.ht_1
        
        i = torch.exp(self.ln_i( self.i_gate(x_conv) + self.ri_gate(ht_1) ) )
        f = torch.exp( self.ln_f(self.f_gate(x_conv) + self.rf_gate(ht_1) ) )

        m = torch.max(torch.log(f)+self.mt_1[:, 0, :].unsqueeze(1), torch.log(i))
        i = torch.exp(torch.log(i) - m)
        f = torch.exp(torch.log(f) + self.mt_1[:, 0, :].unsqueeze(1)-m)
        self.mt_1 = m.detach()
        
        o = torch.sigmoid( self.ln_o(self.o_gate(x) + self.ro_gate(ht_1) ) )
        z = torch.tanh( self.ln_z(self.z_gate(x) + self.rz_gate(ht_1) ) )
        
        ct_1 = self.ct_1
        ct = f*ct_1 + i*z
        ct = torch.mean(self.ln_c(ct), [0, 1], keepdim=True)
        self.ct_1 = ct.detach()
        
        nt_1 = self.nt_1
        nt = f*nt_1 + i
        nt = torch.mean(self.ln_n(nt), [0, 1], keepdim=True)
        self.nt_1 = nt.detach()
        
        ht = o*(ct/nt) # torch.Size([4, 8, 16])
        ht = torch.mean(self.ln_h(ht), [0, 1], keepdim=True)
        self.ht_1 = ht.detach()
        # end sLSTM
        
        slstm_out = self.GN(ht)
        
        left = self.left_linear(slstm_out)
        right = F.gelu(self.right_linear(slstm_out))
        
        out = self.ln_out(left*right)
        out = self.proj(out)
        return out
  

In [1]:
import processing
train_dataloader, test_dataloader = processing.get_train_test_dataloaders('..\\dataset\\np_dataset')

In [2]:
import models.xlstm
import train 
xlstm_dict = train.get_xlstm_dict()
model = models.xlstm.xLSTM(xlstm_dict)
for batch_idx, (src, trg, metadata) in enumerate(train_dataloader):
    output = model(src.float())

In [None]:
x = src.float()
x = x.unsqueeze(1)
x = model.ln(x)

x_conv = F.silu( model.drop(model.conv( x.transpose(1, 2) ).transpose(1, 2) ) )

batch_size = x.size(0)  # Get dynamic batch size from input

if model.mt_1.size(0) != batch_size:
    # Adjust self.mt_1 to match current batch size.
    model.mt_1 = model.mt_1[:batch_size]

# start sLSTM
ht_1 = model.ht_1

i = torch.exp(model.ln_i( model.i_gate(x_conv) + model.ri_gate(ht_1) ) )
f = torch.exp( model.ln_f(model.f_gate(x_conv) + model.rf_gate(ht_1) ) )

m = torch.max(torch.log(f)+model.mt_1[:, 0, :].unsqueeze(1), torch.log(i))
i = torch.exp(torch.log(i) - m)
f = torch.exp(torch.log(f) + model.mt_1[:, 0, :].unsqueeze(1)-m)
model.mt_1 = m.detach()

o = torch.sigmoid( model.ln_o(model.o_gate(x) + model.ro_gate(ht_1) ) )
z = torch.tanh( model.ln_z(model.z_gate(x) + model.rz_gate(ht_1) ) )

ct_1 = model.ct_1
ct = f*ct_1 + i*z
ct = torch.mean(model.ln_c(ct), [0, 1], keepdim=True)
model.ct_1 = ct.detach()

nt_1 = model.nt_1
nt = f*nt_1 + i
nt = torch.mean(model.ln_n(nt), [0, 1], keepdim=True)
model.nt_1 = nt.detach()

ht = o*(ct/nt) # torch.Size([4, 8, 16])
ht = torch.mean(model.ln_h(ht), [0, 1], keepdim=True)
model.ht_1 = ht.detach()
# end sLSTM

slstm_out = model.GN(ht)

left = model.left_linear(slstm_out)
right = F.gelu(model.right_linear(slstm_out))

out = model.ln_out(left*right)
out = model.proj(out)

In [23]:
print(x.shape,'\n',
      x_conv.shape,'\n',
      ht_1.shape,'\n',
      i.shape,'\n',
      f.shape,'\n',
      o.shape,'\n',
      z.shape,'\n',
      ct_1.shape,'\n',
      ct.shape,'\n',
      nt_1.shape,'\n',
      ht.shape,'\n',
      slstm_out.shape,'\n',
      left.shape,'\n',
      right.shape,'\n',
      out.shape)

torch.Size([3, 1, 128]) 
 torch.Size([3, 1, 128]) 
 torch.Size([1, 1, 128]) 
 torch.Size([3, 1, 128]) 
 torch.Size([3, 1, 128]) 
 torch.Size([3, 1, 128]) 
 torch.Size([3, 1, 128]) 
 torch.Size([1, 1, 128]) 
 torch.Size([1, 1, 128]) 
 torch.Size([1, 1, 128]) 
 torch.Size([1, 1, 128]) 
 torch.Size([1, 1, 128]) 
 torch.Size([1, 1, 170]) 
 torch.Size([1, 1, 170]) 
 torch.Size([1, 1, 128])


In [5]:
xlstm_dict

namespace(layers=['s', 'm', 's', 'm'],
          n_embd=128,
          depth=4,
          factor=2,
          vocab_size=835,
          block_len=128,
          device='cpu',
          metadata_dims=namespace(composer=8),
          dropout=0.01,
          epochs=200,
          eval_interval=100,
          save_interval=500,
          learning_rate=0.1,
          eval_iters=200,
          test_ratio=0.2,
          batch_size=8)