This is a combination of a time-compressing VAE and a SDAPM module. Training is split into three sets:

1) text to text VAE training (reconstruction)

2) vec 2 vec SDAPM training (prediction)

3) text to text SDAPM training. (prediction)

set 1 only involves the VAE, set 2 only involves the encoder & SDAPM, set 3 involves all modules, but the VAE is frozen.

In [None]:
''' Psuedo code

import SDPAM
import VAE

import dataset
dataset.load("gutenberg conversations")
import tokenizers
tokenizers.load('unigram_25600.json')


sdapm = SDAPM(data_width)
vae = VAE(data_width)


iters=0
while true:
  if(iters%30 < 10):
    data = tokenize(load_conv())
    reconstruction, t_loss, kl_loss = vae(data)
    loss = cross_entropy(reconstruction, data) + t_loss + kl_loss
    loss.backward()
    opt.step()
  elif(iters%30>20):
    data = tokenize(load_conv())
    vae.decoder.freeze()
    prediction, splits = vae.decode(sdapm(vae.encode(data)[:-1].detach()))
    loss = cross_entropy(prediction, data[splits[0]:])
    loss.backward()
    opt.step()
    vae.decoder.unfreeze()

  else:
    data = tokenize(load_conv())
    enc_data = vae.encode(data).detach()
    prediction = sdapm(enc_data[:-1], enc_data[1:])
    loss = mse_loss(prediction, enc_data[1:])
    loss.backward()
    opt.step()

  print(loss)


'''

' Psuedo code\n\nimport SDPAM\nimport VAE\n\nimport dataset\ndataset.load("gutenberg conversations")\nimport tokenizers\ntokenizers.load(\'unigram_25600.json\')\n\n\nsdapm = SDAPM(data_width)\nvae = VAE(data_width)\n\n\niters=0\nwhile true:\n  if(iters%30 < 10):\n    data = tokenize(load_conv())\n    reconstruction, t_loss, kl_loss = vae(data)\n    loss = cross_entropy(reconstruction, data) + t_loss + kl_loss\n    loss.backward()\n    opt.step()\n  elif(iters%30>20):\n    data = tokenize(load_conv())\n    vae.decoder.freeze()\n    prediction, splits = vae.decode(sdapm(vae.encode(data)[:-1].detach()))\n    loss = cross_entropy(prediction, data[splits[0]:])\n    loss.backward()\n    opt.step()\n    vae.decoder.unfreeze()\n\n  else:\n    data = tokenize(load_conv())\n    enc_data = vae.encode(data).detach()\n    prediction = sdapm(enc_data[:-1], enc_data[1:])\n    loss = mse_loss(prediction, enc_data[1:])\n    loss.backward()\n    opt.step()\n\n  print(loss)\n\n\n'

In [None]:
!pip install -U datasets
!pip install torch-optimizer
!pip install tensordict



In [None]:
import torch as T
import torch.nn.functional as F
import os


if not os.path.exists("gutenberg_conv.json"):
  from datasets import load_dataset
  ds = load_dataset("willwade/Gutenberg-dialog-en")

  ds.set_format("torch")

  import json
  print(ds['train']['text'][:10])
  data = ("`GO`"+("`STOP`\n`GO`".join(ds['train']['text']))+"`STOP`").split("\n`GO``STOP`\n")
  with open("gutenberg_conv.json","w") as f:
    json.dump(data, f)
  print(data[0])
else:
  import json
  with open("gutenberg_conv.json",'r') as f:
    data  = json.load(f)
  print(data[0])

from tokenizers import Tokenizer
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
import tokenizers

vocab_size = 6400

if not os.path.exists("unigram-"+str(vocab_size)+".json"):


  import tokenizers.normalizers as nz

  tokenizer = Tokenizer(Unigram())
  tokenizer.normalizer=nz.Lowercase()

  import tokenizers.pre_tokenizers as pretok
  tokenizer.pre_tokenizer = pretok.Sequence([pretok.Digits(), pretok.Split(tokenizers.Regex("[\s,.:;'\"-=+_~{}\[\]<>?!@#$%^&*()|]"), 'isolated')])

  trainer = UnigramTrainer(vocab_size=vocab_size, unk_token="`UNK`", special_tokens=["`UNK`","`GO`", "`PAD`", "`STOP`"])
  print("training... (tokenizer)")
  import random
  tokenizer.train_from_iterator(random.sample(data,10_000), trainer)

  tokenizer.save("unigram-"+str(vocab_size)+".json")
else:
  tokenizer = Tokenizer.from_file("unigram-"+str(vocab_size)+".json")


if not os.path.exists("tokenized_set.pkl"):
  print("tokenizing...")
  import random
  dataset = tokenizer.encode_batch(random.sample(data, len(data)//2))
  with open("tokenized_set.pkl", 'wb') as f:
    import pickle
    pickle.dump(dataset, f)
else:
  import pickle
  dataset = pickle.load(open("tokenized_set.pkl",'rb') )
print(dataset[0].tokens)



`GO`the shelling from the enemys mortars was severe . . . and having but little mortar powder we were unable to reply effectually . . . . i regret that our ordnance supplies are so scanty . . . . no powder for the mortars no suitable fuses for the fire on charleston no shells for the 30pounder parrotts a most useful gun for silencing the enemys fire no material for making cartridge bags or grease for lubricating the projectiles . . . . more ammunition for the 300pounder the most useful guns in these works is also very much needed . . . .`STOP`
`GO`within the last 2 days the work . . . has been greatly interfered with by a corps of sharpshooters . . . stationed on fort sumter . the bullets came in very thick when i was at the front this morning . . . .`STOP`
['`GO`', 'my', ' ', 'dear', ' ', 'simon', ' ', 'you', ' ', 'have', ' ', 'behaved', ' ', 'i', 'r', 'reproach', 'ably', ' ', '.', ' ', 'eleanor', ' ', 'will', ' ', 'feel', ' ', 'it', ' ', 'for', ' ', 'some', ' ', 'time', ' ', 'no', ' 

In [None]:
# modified from https://avandekleut.github.io/vae/
import math

class PositionalEncoding(T.nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, device='cpu'):
        super().__init__()
        self.device=device
        self.dropout = T.nn.Dropout(p=dropout)

        position = T.arange(max_len, device=self.device).unsqueeze(1)
        div_term = T.exp(T.arange(0, d_model, 2, device=self.device) * (-math.log(10000.0) / d_model))
        pe = T.zeros((max_len, 1, d_model), device=self.device)
        pe[:, 0, 0::2] = T.sin(position * div_term)
        pe[:, 0, 1::2] = T.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        x= x.transpose(0,1)
        x = x + self.pe[:x.size(0)]
        return self.dropout(x).transpose(0,1)


def drop_tokens(x, probability):
    return x * T.where(T.rand(x.size()[:-1],device=x.device)>probability,0,1).unsqueeze(-1)

def direction(x):
    return T.round(T.clamp(x, min=-1, max=1))

def deNaNed(x):
    return T.where(T.isnan(x),0,x)

def max_norm(x, max):
  return T.where(T.sum(x)>max, F.normalize(x, p=1, dim=-1)*max, x)

class BernoulliMix(T.nn.Module):
  def __init__(self, mix_ratio=0.5, device='cpu'):
    super().__init__()
    self.device=device
    self.B = T.distributions.Bernoulli(probs=mix_ratio)
    if(isinstance(device, str) and "cuda" in device):
      self.B.probs = self.B.probs.cuda()

  def set_probs(self, probs):
    self.B = T.distributions.Bernoulli(probs=probs)
    if(isinstance(self.device, str) and "cuda" in self.device):
      self.B.probs = self.B.probs.cuda()

  def forward(self, a, b):

    if(self.training):
      return T.where(self.B.sample(a.size()).bool(), a, b)
    else:
      return a

def softabs(x,k):
  return (x**2)/(k+T.abs(x))



class TVET(T.nn.Module): # transformer variational encoder across time dimension.
    def __init__(self, data_width, emb_size=None,  temporal_division=8, layer_num=3, nheads = 6, device='cpu'):
        super().__init__()
        self.device=device
        self.emb_size = emb_size if emb_size!=None else data_width
        self.layers = [T.nn.TransformerEncoderLayer(data_width, nheads, dim_feedforward=data_width, batch_first=True, device=self.device )]*(layer_num-1)
        self.expansion = T.nn.Linear(data_width, self.emb_size*2, device=self.device)
        print(emb_size*2)
        self.var_layer = T.nn.TransformerEncoderLayer(self.emb_size*2, nheads, dim_feedforward=data_width*2, batch_first=True, device=self.device)
        self.poscoder = PositionalEncoding(data_width, max_len=10_000, device=self.device)

        self.expansion_offset=0

        self.temporal_division = temporal_division


        self.N = T.distributions.Normal(0, 1)
        if(isinstance(device, str) and "cuda" in device):
            self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU
            self.N.scale = self.N.scale.cuda()
        self.kl = 0

    def forward(self, x):
        x = self.poscoder(x)

        for layer in self.layers:
            x = layer(x, src_mask = T.nn.Transformer.generate_square_subsequent_mask(x.size(1)), is_causal=True) + x
        x = self.expansion(x)
        x = self.var_layer(x, src_mask = T.nn.Transformer.generate_square_subsequent_mask(x.size(1)), is_causal=True)



        #x = T.sum(T.stack(T.split(x,temporal_division,dim=-1),dim=-1),dim=-1) ## chunk along the time dimension(adding a new dimension) then sum along the new dimension.
        mu =  T.chunk(x,2,dim=-1)[0]
        sigma = T.exp(T.chunk(x,2,dim=-1)[1])
        z = mu + (sigma*self.N.sample(mu.shape) if self.training else 0)
        self.kl = (sigma**2 + mu**2 - T.log(sigma) - 1/2).sum()

        ## This is a craZy way to implement variable compression lengths.
        ## I would much prefer a scan operator, but we aren't allowed nice things.

        reshaped = z[:,:,-1].unsqueeze(-1).expand(-1,-1,z.size(-2))

        reshaped = deNaNed(F.sigmoid(self.expansion_offset+z[:,:,-1].unsqueeze(-2).expand(-1, z.size(-2), -1)))+0.01 ## the +100 is a hack to ecnourage little compression at the beginning.
        sections = T.clamp(T.sum(T.tril(T.clamp(reshaped,max=1)),dim=-1)-1e-6, min=0)
        #sections = F.relu(T.where(sections[:,0].unsqueeze(-1)>=1, sections-1, sections))


        len_loss = softabs(T.amax(sections,dim=-1)-(x.size(1)/self.temporal_division), 1/3) ##punish lengths above len/temporal_div


        densified_z = T.zeros(  ( z.size(0), int(T.amax(T.ceil(sections)).item()), z.size(-1) ) , device=self.device)
        for i in range(x.size(0)):
            #print("??")
            ## Writen in pytorch 2.5 (apparently, this API may change in the future. )

            densified_z[i,T.floor(sections).int()] = T.index_reduce(z, dim=1, index=T.floor(sections[i]).int(), source=z, reduce='mean')
        ##
        return densified_z, len_loss, sections


class TVDT(T.nn.Module): # transformer variational encoder across time dimension.
    def __init__(self, data_width, emb_size=None,  temporal_division=8, layer_num=3, nheads = 6, device='cpu'):
        super().__init__()


        self.device=device
        self.nheads = nheads
        self.emb_size= emb_size if emb_size != None else data_width
        self.layers = [T.nn.TransformerEncoderLayer(data_width, nheads, dim_feedforward=data_width, batch_first=True, device=self.device )]*layer_num
        self.poscoder = PositionalEncoding(data_width, max_len=10_000, device=self.device)
        self.b_mix = BernoulliMix(0.5, device=self.device)


    def forward(self, z, skip_vec=None):

        z = self.poscoder(z)

        for i in range(len(self.layers)):
            layer=self.layers[i]

            #print("??2")
            if(i==len(self.layers)//2 and skip_vec!=None and self.training): ## mix in 'correct' answer at the middle layer
              z = layer(self.b_mix(z, skip_vec), src_mask = T.nn.Transformer.generate_square_subsequent_mask(z.size(1), device=self.device)) + z
            else:
              z = layer(z, src_mask = T.nn.Transformer.generate_square_subsequent_mask(z.size(1), device=self.device)) + z
        return z

class TVAT(T.nn.Module):
    def __init__(self, data_width, emb_size=None, temporal_division=8, layer_num=12, nheads=12, device='cpu'):
        super().__init__()
        self.device=device
        self.emb_size = emb_size if emb_size!=None else data_width
        self.encoder = TVET(data_width, emb_size, temporal_division, math.ceil(layer_num/2), nheads, device=self.device)
        self.decoder = TVDT(data_width, emb_size, temporal_division, math.floor(layer_num/2)+1, nheads, device=self.device)
        self.dropout = T.nn.Dropout()
        self.proj_up = T.nn.Linear(self.emb_size, data_width, device=self.device)
        self.kl=0



    def generate_mask(self, z, length, return_sizes=False):
      idxs=[0]*z.size(0)
      for i in range(z.size(0)):
          idxs[i] = T.repeat_interleave(T.clamp((1/(F.sigmoid(z[i,:,-1])+0.01)),max=length).int())[:length]

      #idxs = T.stack(idxs, dim=0)
      target_mask = T.full((z.size(0), length,length), -float('inf'), device=self.device)
      sizes=[None]*z.size(0)
      for i in range(z.size(0)): ###AAAAHHHH I HATE THISSS
        sizes[i] = T.histc(idxs[i].float(), T.clamp(T.amax(idxs[i]),min=1)   ).int().tolist()

        xpos = 0
        ypos = sizes[i][0]

        for j in range(len(sizes[i])-1):
          target_mask[i,xpos:xpos+sizes[i][j],ypos:]=0
          ypos+=sizes[i][j+1]
          xpos+=sizes[i][j]
      target_mask = target_mask.mT

      if(return_sizes):
        return target_mask, sizes
      else:
        return target_mask

    def encode(self, x, return_mask=False):
        result, len_loss, sections = self.encoder(x)
        self.kl = self.encoder.kl

        return result, len_loss, sections


    def decode(self, x, length, z=None):

        expanded_z = T.zeros((x.size(0), length, self.emb_size), device=self.device)
        idxs=[0]*x.size(0)
        for i in range(x.size(0)):
            sizes = max_norm(T.clamp((1/(F.sigmoid(x[i,:,-1])+0.01)),max=x.size(1) ), x.size(1)*1.2)
            idxs[i] = T.repeat_interleave(sizes.int())[:length]
            expanded_z[i,:min(length,idxs[i].size(-1))] = x[i,idxs[i]][:length]
        '''
        #idxs = T.stack(idxs, dim=0)
        target_mask = T.full((x.size(0), x.size(-2),x.size(-2)), -float('inf'), device=self.device)

        for i in range(x.size(0)): ###AAAAHHHH I HATE THISSS
          #print(T.amax(idxs[i]))i
          sizes = T.histc(idxs[i].float(), T.clamp(T.amax(idxs[i]),min=1)    ).int().tolist()

          xpos = 0
          ypos = sizes[0]

          for j in range(len(sizes)-1):
            target_mask[i,xpos:xpos+sizes[j],ypos:]=0
            ypos+=sizes[j+1]
            xpos+=sizes[j]
        target_mask = target_mask.mT

        '''

        result = self.decoder(self.proj_up(expanded_z), z)
        mask = T.zeros((x.size(0), length), device=self.device).unsqueeze(-1)
        for i in range(x.size(0)):
            mask[i,:idxs[i].size(-1)]=1
        return mask * result




    def forward(self, x):
        z, len_loss, sections = self.encoder(x)
        self.kl = self.encoder.kl
        #print(T.mean(T.amax(sections,dim=-1)).item())

        expanded_z = T.zeros((x.size(0), x.size(1), self.emb_size), device=self.device)
        idxs=[0]*x.size(0)
        for i in range(x.size(0)):
            sizes = max_norm(T.clamp((1/(F.sigmoid(z[i,:,-1])+0.01)),max=x.size(1) ), x.size(1)*1.2)

            idxs[i] = T.repeat_interleave(sizes.int())[:x.size(1)]
            expanded_z[i,:min(x.size(1),idxs[i].size(-1))] = z[i,idxs[i]][:x.size(1)]
        '''
        #idxs = T.stack(idxs, dim=0)
        target_mask = T.full((x.size(0), x.size(-2),x.size(-2)), -float('inf'), device=self.device)

        for i in range(x.size(0)): ###AAAAHHHH I HATE THISSS
          sizes = T.histc(idxs[i].float(), T.clamp(T.amax(idxs[i]),min=1)   ).int().tolist()

          xpos = 0
          ypos = sizes[0]

          for j in range(len(sizes)-1):
            target_mask[i,xpos:xpos+sizes[j],ypos:]=0
            ypos+=sizes[j+1]
            xpos+=sizes[j]
        target_mask = target_mask.mT

        '''

        result = self.decoder(self.proj_up(expanded_z), self.decoder.poscoder(x))
        mask = T.zeros(x.size()[:-1], device=self.device).unsqueeze(-1)
        for i in range(x.size(0)):
            mask[i,:idxs[i].size(-1)]=1
        return mask * result, len_loss


In [None]:
from torch.nn.attention import SDPBackend, sdpa_kernel

class LanguageModel(T.nn.Module):

    def __init__(self, data_width, emb_size, vocab_size=12800, temporal_division=8,  device='cpu'):
        super().__init__()

        self.device=device

        self.embedding = T.nn.Embedding(vocab_size, data_width, device=self.device)
        self.tvat = TVAT(data_width, emb_size = emb_size, temporal_division = temporal_division, device=self.device)
        self.exbedding = T.nn.Linear(data_width, vocab_size, device=self.device)
        self.kl=0


        opt = T.optim.Adam(list([*self.embedding.parameters(), *self.exbedding.parameters()]))
        for i in range(0):
            F.cross_entropy(self.exbedding(self.embedding(T.arange(vocab_size, device=self.device))), T.diag(T.ones(vocab_size,device=self.device))).backward()
            opt.step()


    def forward(self, x):
        with sdpa_kernel(SDPBackend.MATH):
            #print(":D")
            decoded, loss = self.tvat(self.embedding(x))
            self.kl = self.tvat.kl
            mask = T.where(T.sum(T.abs(decoded), dim=-1,keepdim=True)==0,0,1)


            # technique to lessen repetitive outputs
            #hist_mask = T.zeros((x.size(1),x.size(1)),device=self.device)
            #for i in range(10):
            #    hist_mask = T.diagonal_scatter(hist_mask, T.ones(x.size(1)-i-1,device=self.device)/(2**i), offset=-1-i)
            #decoded -= T.sum(decoded.unsqueeze(-2).expand(-1,-1,x.size(1),-1) * (hist_mask.flip((1,)).reshape(1,x.size(1),x.size(1),1)), dim=-2)*0.5


            return mask*self.exbedding(decoded), loss

    def embed(self, x):
        with sdpa_kernel(SDPBackend.MATH):
            result, len_loss, sections =  self.tvat.encode(self.embedding(x))
            self.kl = self.tvat.kl
            return result, len_loss, sections

    def exbed(self, x, len, correct=None):
        with sdpa_kernel(SDPBackend.MATH):
            result = self.tvat.decode(x, len, z=correct)
            return self.exbedding(result)



In [None]:



# automatic procedural memory
import torch as T
import torch.nn.functional as F
#import tensordict
import math, random

## altered from mitm4 to use a transformer as worth estimator.

@T.no_grad()
def pos_encoding(data_width, index):

    mult = 2**T.arange(data_width)
    sins = T.sign(T.sin(index* 2**T.arange(data_width//2)))
    coss = T.sign(T.cos(index * 2**T.arange(data_width//2)))

    return T.cat((sins,coss),dim=-1)



def softswish(x): ## an activation function similar to swish.
    b = 0.2974953
    return T.where(x>20,x-0.944031344, (1/b)* T.log(1+T.exp(b*x))*(x/(1+T.abs(x))))

def tapered_swish(x):
    softsign =(lambda x_: x_/(1+T.abs(x_)))
    return softsign(x)*(softsign(x+1)+1)



class SDAPM(T.nn.Module):

    def __init__(self, data_width, max_steps=10, mem_len=100, layer_gain=None, var_lr=False, param_mem_mats=False, autoscale=True, sparse_router='top 3', router_search = False, device = 'cpu'):

        super().__init__()

        self.device = device
        self.data_width = data_width
        self.max_steps = max_steps
        self.mem_len = mem_len

        if layer_gain==None:
            layer_gain=1/data_width
        self.layer_gain=layer_gain

        self.memories=None  ## I'm using both tensordict and nested tensors for this. Tensordict allows me to cleanly modify memories using a single index, and nested tensor allows for the addition of more memories as time goes on, without overwriting existing tensors.
        self.autoscale=autoscale
        atten_dat_size= data_width*9+1 + 2*int(math.ceil(math.sqrt(mem_len)))

        self.num_attn_heads=6
        self.attention_model = T.nn.TransformerEncoder(T.nn.TransformerEncoderLayer(math.ceil(atten_dat_size/self.num_attn_heads)*self.num_attn_heads, self.num_attn_heads, dim_feedforward=atten_dat_size, batch_first=True, device=self.device), 4)

        #self.worth_model = T.nn.LSTM(data_width*2, data_width*2*4, num_layers=3,  batch_first=True, device=self.device)
        self.worth_model = T.nn.LSTM(data_width*2, data_width*2*4, num_layers=3, batch_first=True, device=self.device)
        self.hx_cx = None

        self.initial_mem_mats = layer_gain*T.eye(self.data_width, device=self.device).expand(mem_len,-1,-1)+0.001*(2*T.rand((mem_len, data_width,data_width), device=self.device)-1)

        for i in range(self.mem_len):
            permutations = T.randperm(self.data_width)
            self.initial_mem_mats[i] = self.initial_mem_mats[i,permutations]
            self.initial_mem_mats[i]*=T.sign(T.rand((self.data_width,self.data_width),device=self.device)-0.5)

        self.initial_bv = T.zeros((mem_len,data_width), device=self.device)
        self.initial_bw = T.zeros((mem_len,data_width), device=self.device)
        self.initial_worth = T.zeros((mem_len,1), device=self.device)
        self.initial_write_time = T.zeros((mem_len, int(math.ceil(math.sqrt(mem_len)))), device=self.device)

        if(param_mem_mats):
            self.initial_mem_mats = T.nn.Parameter(self.initial_mem_mats)
            self.initial_bv = T.nn.Parameter(self.initial_bv)
            self.initial_bw = T.nn.Parameter(self.initial_bw)
            self.initial_worth = T.nn.Parameter(self.initial_worth)
            self.initial_write_time=T.nn.Parameter(self.initial_write_time)
        #self.initial_mem_mats[::2]=T.round(T.rand((mem_len//2,data_width,data_width),device=self.device)-0.5)





        #self.initial_mem_mats[0]=T.eye(self.data_width,device=self.device)
        #self.initial_mem_mats[1] = T.flip(T.eye(data_width,device=self.device),dims=(-1,))
        #self.initial_bv[0] = T.zeros(data_width,device=self.device)
        #self.initial_bw[0] = T.zeros(data_width,device=self.device)


        self.worth_mult=1

        self.attn_activation = (lambda x: x)#self.deNaNed(F.normalize(x**3,dim=-2)))# T.where(x.abs()-x.abs().mean(dim=-2,keepdim=True)>0, x, 0))


        #self.activation = (lambda x: x)#x/(1+T.abs(x)))
        #self.inv_activation = (lambda x: x)#T.clamp(x,min=-0.999,max=0.999)/(1-T.abs(T.clamp(x,min=-0.999,max=0.999))))

        #self.activation = (lambda x: x/(1+T.abs(x)))
        #self.inv_activation = (lambda x: T.clamp(x,min=-0.999,max=0.999)/(1-T.abs(T.clamp(x,min=-0.999,max=0.999))))

        self.lstm_activation = (lambda x: T.clamp(x,min=-0.99999,max=0.99999)/(1-T.abs(T.clamp(x,min=-0.99999,max=0.99999))))

        #self.activation = (lambda x: T.sign(x)*T.log(T.abs(x)+1))
        #self.inv_activation = (lambda x: T.sign(x)*(T.exp(T.abs(x))-1))

        #self.activation = softswish # self-normalizing (ish)

        self.activation = tapered_swish

        self.just_reset=True
        self.steps_since_reset=0




    def engram_info(self):
        bv = self.memories[self.steps_since_reset]['bv']
        bw = self.memories[self.steps_since_reset]['bw']
        worth = self.memories[self.steps_since_reset]['w']
        diag = T.diagonal(self.memories[self.steps_since_reset]['m'],dim1=-1,dim2=-2)
        antidiag = T.diagonal(T.flip(self.memories[self.steps_since_reset]['m'],dims=(-1,)),dim1=-1,dim2=-2)
        top = self.memories[self.steps_since_reset]['m'][:,:,:,0]
        bottom = self.memories[self.steps_since_reset]['m'][:,:,:,-1]
        left = self.memories[self.steps_since_reset]['m'][:,:,0,:]
        right = self.memories[self.steps_since_reset]['m'][:,:,-1,:]
        write_time = self.memories[self.steps_since_reset]['t'].to(self.device)
        current_time = self.pos_encoding(self.steps_since_reset).unsqueeze(0).unsqueeze(-2).expand(bv.size(0), bv.size(1),-1).to(self.device)

        atten_dat_size= self.data_width*9+1 + 2*int(math.ceil(math.sqrt(self.mem_len)))
        zero_padding = T.zeros((self.memories[self.steps_since_reset]['m'].size(0), self.mem_len, math.ceil(atten_dat_size/self.num_attn_heads)*self.num_attn_heads - atten_dat_size), device=self.device)


        return T.cat((bv, bw, worth, diag, antidiag, top, bottom, left, right, write_time, current_time, zero_padding), dim=-1)


    def reset(self, x):
        self.memories=[{
                'm':self.initial_mem_mats.unsqueeze(0).expand(x.size(0),-1,-1,-1).clone(),
                'bv':self.initial_bv.unsqueeze(0).expand(x.size(0),-1,-1).clone(),
                'bw':self.initial_bw.unsqueeze(0).expand(x.size(0),-1,-1).clone(),
                'w':self.initial_worth.unsqueeze(0).expand(x.size(0),-1,-1).clone(),
                't':self.initial_write_time.unsqueeze(0).expand(x.size(0),-1,-1).clone()
                }]
        temp = {
            'm':T.empty((x.size(0), self.mem_len, self.data_width, self.data_width), device=self.device),
            'bv':T.empty((x.size(0), self.mem_len, self.data_width), device=self.device),
            'bw':T.empty((x.size(0), self.mem_len, self.data_width), device=self.device),
            'w':T.empty((x.size(0), self.mem_len, 1), device=self.device),
            't':T.empty((x.size(0), self.mem_len, int(math.ceil(math.sqrt(self.mem_len)))), device=self.device)
            }
        self.memories= self.memories*2 + [temp]*(x.size(-2)+1)
        self.steps_since_reset=0
        self.hx_cx = (T.zeros((3,x.size(0),self.data_width*2*4), device=self.device), T.zeros((3,x.size(0),self.data_width*2*4), device=self.device))

    def deNaNed(self, x):
        if(isinstance(x, dict)):
            for key in x.keys():

                x[key] = T.where(x[key].isnan(), T.rand(x[key].size(),device=self.device)*0.02-0.01, x[key])
                x[key] = T.where(x[key].isposinf(), 1_000_000, x[key])
                x[key] = T.where(x[key].isneginf(), -1_000_000, x[key])
            return x
        else:
            x = T.where(x.isnan(), T.rand(x.size(),device=self.device)*0.02-0.01, x)
            x = T.where(x.isposinf(), 1_000_000, x)
            x = T.where(x.isneginf(), -1_000_000, x)
            return x

    # adapted from https://en.wikipedia.org/wiki/Gray_code#Converting_to_and_from_Gray_code
    @T.no_grad()
    def pos_encoding(self, index):
        data_width = int(math.ceil(math.sqrt(self.mem_len)))
        mult = 2**(T.arange(data_width//2, device=self.device)-2)
        sins = T.sign(T.sin(index * mult))
        coss = T.sign(T.cos(index * mult))

        return T.cat((sins,coss),dim=-1)



    def sparsity(self, x):  ## uses the Hoyer measure. Gini would be invariant under cloning, but Gini requires
                            ## that the data is sorted.
        return ((self.data_width**0.5) - (T.sum(x,dim=-1)/T.sum(x**2,dim=-1 )))*(((self.data_width**0.5)-1)**-1)


    def forward(self,x, correct=None, reset=True):
        if(reset):
            self.reset(x)

        if correct!=None:
                loops = x.size(-2)
                results = T.empty(x.size(), device=self.device)
                routing_losses = [T.empty(1, device=self.device)]*loops

                for i in range(loops):
                    self.steps_since_reset+=1
                    ## changes to main memory are returned as a tuple of (engrams, memory_mask).
                    ## Engrams & memory mask have the same length as memory, and memory_mask
                    ## has masks for each key/attribute of the engram. Once all changes are decided,
                    ## the memory masks are multiplied by the engram, summed together,  and then
                    ## LERP'ed into the next memory step.

                    result, routing_loss, read_mem = self.read(x[:,i], correct[:,i])


                    if(self.autoscale): ## rescale to size of correct output
                      shift = T.mean(correct[:,i], dim=-1, keepdim=True) - T.mean(result, dim=-1, keepdim=True)
                      scale = T.amax(correct[:,i],dim=-1, keepdim=True) - T.amin(correct[:,i],dim=-1, keepdim=True)
                      result = F.normalize(result-shift,dim=-1, p=float('inf'))*scale #result * (T.linalg.vector_norm(correct[:,i], ord=1)/T.linalg.vector_norm(result, ord=1))

                    results[:,i]=result
                    routing_losses[i]+=routing_loss



                    consolidate_mem = self.consolidate(5,2) ## merge some memories together.

                    inaccuracy = T.mean(F.mse_loss(result, correct[:,i], reduction='none'), dim=-1,keepdim=True) ## TODO: add hook

                    memorize_mem = self.memorize(x[:,i],  correct[:,i], log_inaccuracy=T.log(inaccuracy+1e-6))


                    new_mem = {} ## Note: might need to make tensordict.

                    for key in self.memories[self.steps_since_reset].keys():
                        mask_sum = T.clamp(read_mem[1][key] + consolidate_mem[1][key] + memorize_mem[1][key],min=1e-6)
                        #if(key=='m'):
                            #print(mask_sum[0,:,0,0])


                        new_mem[key] = self.deNaNed((read_mem[0][key] * (read_mem[1][key]/mask_sum)) + \
                                                    (consolidate_mem[0][key] * (consolidate_mem[1][key]/mask_sum)) + \
                                                    (memorize_mem[0][key] * (memorize_mem[1][key]/mask_sum)) + \
                                                    (self.memories[self.steps_since_reset][key] * T.clamp(1-mask_sum, min=0))
                                                    )
                    #new_mem = tensordict.TensorDict(new_mem, batch_size=self.memories[self.steps_since_reset].batch_size)
                    #print((new_mem['m'] - self.memories[0]['m'])[0,:,0,0])
                    #print(new_mem['m'][0,:,0,0])
                    self.memories[self.steps_since_reset+1]=new_mem

                inaccuracies = T.sum(F.mse_loss(results, correct, reduction='none'), dim=-1) ## TODO: add hook so the inaccuracies can accurately represent the error of any supermodels.


                est_worth, self.hx_cx = self.worth_model(T.cat((x,correct), dim=-1), self.hx_cx)
                estimated_worth = T.sum(self.lstm_activation(est_worth), dim=-1)
                loss_estimated_worth = F.mse_loss(T.log(inaccuracies+1e-6), estimated_worth)


                routing_losses = T.mean(T.stack(routing_losses, dim=0))

                return self.deNaNed(results), routing_losses, loss_estimated_worth

        else:
            self.input=x
            self.step_index=0
            loops = x.size(-2)
            results = T.empty(x.size(), device=self.device)
            lru_hidden=[T.empty((x.size(0), 1, self.data_width*2), device=self.device)]*x.size(-2)
            for i in range(loops):




                self.steps_since_reset+=1

                result,log_inaccuracy, read_mem = self.read(x[:,i])
                if(self.autoscale): ## rescale to size of correct output
                      shift = T.mean(x[:,i], dim=-1, keepdim=True) - T.mean(result, dim=-1, keepdim=True)
                      scale = T.amax(x[:,i],dim=-1, keepdim=True) - T.amin(x[:,i],dim=-1, keepdim=True)
                      result = F.normalize(result-shift,dim=-1, p=float('inf'))*scale #result * (T.linalg.vector_norm(correct[:,i], ord=1)/T.linalg.vector_norm(result, ord=1))

                results[:,i]=result
                consolidate_mem = self.consolidate(5,2)

                memorize_mem = self.memorize(x[:,i], result, log_inaccuracy = log_inaccuracy)

                new_mem={}

                for key in self.memories[self.steps_since_reset].keys():
                    mask_sum = T.clamp(read_mem[1][key] + consolidate_mem[1][key] + memorize_mem[1][key], min=1e-6)

                    new_mem[key] = self.deNaNed((read_mem[0][key] * (read_mem[1][key]/mask_sum)) + \
                                                (consolidate_mem[0][key] * (consolidate_mem[1][key]/mask_sum)) + \
                                                (memorize_mem[0][key] * (memorize_mem[1][key]/mask_sum)) + \
                                                (self.memories[self.steps_since_reset][key] * T.clamp(1-mask_sum, min=0))
                                                )
                #new_mem = tensordict.TensorDict(new_mem, batch_size=self.memories[self.steps_since_reset].batch_size)
                self.memories[self.steps_since_reset+1]=new_mem

            return results






    @T.inference_mode()
    def mitm_route(self, x, output):

        # y_{i} = g((x - bv_{i}) @ m_{i}) + bw_{i}
        # y = sum( y_{i} * a_{i} )

        forward_vec = x.detach()
        backward_vec = output.detach().unsqueeze(-2)
        routing = [T.zeros((x.size(0), self.mem_len,1), device=self.device) for i in range(self.max_steps)]

        worth = T.zeros((x.size(0), self.mem_len,1), device=self.device)


        ## forward half is neural router. backward half is perfect routing.
        ## backwards goal = lerp(backward_vector, forward_vector, 1/(max_steps-step)) <-- this means backwards routing will move the same distance each step.
        for i in range(self.max_steps):
            if(i%2==1):
                #backward_vec_candidates =  T.linalg.solve(self.memories[-1]['m']+T.eye(self.data_width,device=self.device) * 1e-6, self.inv_activation(backward_vec-self.memories[-1]['bw']).unsqueeze(-2), left=False).squeeze(-2)

                lerp_forward = T.lerp(backward_vec.squeeze(-2), forward_vec, 1/(self.max_steps-i))


                matrix_results = self.activation(T.matmul((lerp_forward.unsqueeze(-2).expand(-1,self.mem_len,-1)-self.memories[self.steps_since_reset]['bv']).unsqueeze(-2), self.memories[self.steps_since_reset]['m'])).squeeze(-2) + self.memories[self.steps_since_reset]['bw'] + lerp_forward.unsqueeze(-2)
                backwards_route =T.linalg.lstsq(matrix_results.mT, backward_vec.mT).solution


                ## Theory: A high sparsity in the backwards route indicates a clear solution to the problem.
                ## If there is a clear solution to the problem, the number of steps needed to complete the solution
                ## should be low, so therefore be stopped early to prevent data degradation.

                #print(self.sparsity(backwards_route[:,:,0])[0])

                ## This isn't strictly neccessary, but it pushes the router to activate when forward vec is aprox. equal to bv

                softsign =(lambda x_: x_/(1+T.abs(x_)))
                w = ((F.cosine_similarity(forward_vec.unsqueeze(-2), self.memories[self.steps_since_reset]['bv'], dim=-1).unsqueeze(-1) + \
                      F.cosine_similarity(self.pos_encoding(self.steps_since_reset).unsqueeze(0).unsqueeze(0), self.memories[self.steps_since_reset]['t'], dim=-1).unsqueeze(-1)) + \
                      softsign(self.memories[self.steps_since_reset]['w']))
                #w=T.where(backwards_route.abs()>0.5, 1,-backwards_route)
                offset = (T.eye(self.mem_len,device=self.device).unsqueeze(0)-T.linalg.lstsq(matrix_results.mT, matrix_results.mT).solution)@w
                backwards_route +=offset

                #print("offset+route",self.sparsity(backwards_route[:,:,0])[0])
                #backwards_route *=0
                #backwards_route[:,1]=1
                routing[-i//2]=backwards_route

                '''
                matrix_sum = T.sum(backwards_route.unsqueeze(-1)*self.memories[-1]['m'], dim=-3)
                bv_sum = T.sum(backwards_route*self.memories[-1]['bv'],dim=-2,keepdim=True)
                bw_sum = T.sum(backwards_route*self.memories[-1]['bw'],dim=-2,keepdim=True)
                '''
                #matrix_results = self.activation(T.matmul((lerp_forward.unsqueeze(-2).expand(-1,self.mem_len,-1)-self.memories[self.steps_since_reset]['bv']).unsqueeze(-2), self.memories[self.steps_since_reset]['m'])).squeeze(-2) + self.memories[self.steps_since_reset]['bw']
                #print( (T.sum(backwards_route*matrix_results,dim=-2) - backward_vec[:,0])[0])

                backward_vec = lerp_forward.unsqueeze(-2)#F.normalize(T.sum(backward_vec_candidates*backwards_route, dim=-2, keepdim=True), dim=(-1,), p=float('inf'))
            else:
                attention_dat = T.cat((self.engram_info(),forward_vec.unsqueeze(-2).expand(-1,self.mem_len,-1)),dim=-1)
                routing[i//2]= self.attn_activation(T.sum(self.attention_model(attention_dat), dim=-1, keepdim=True))
                matrix_results = self.activation(T.matmul((forward_vec.unsqueeze(-2).expand(-1,self.mem_len,-1)-self.memories[self.steps_since_reset]['bv']).unsqueeze(-2), self.memories[self.steps_since_reset]['m'])).squeeze(-2) + self.memories[self.steps_since_reset]['bw']
                #matrix_sum = T.sum(routing[i//2].unsqueeze(-1)*self.memories[-1]['m'], dim=-3)
                #bv_sum = T.sum(routing[i//2]*self.memories[-1]['bv'],dim=-2,keepdim=True)
                #bw_sum = T.sum(routing[i//2]*self.memories[-1]['bw'],dim=-2,keepdim=True)
                ## I'd like to use residual connections, but they don't work with the backward routing.
                forward_vec = T.sum(matrix_results*routing[i//2], dim=-2) + forward_vec
                #forward_vec = F.normalize(self.activation(T.sum(T.bmm((forward_vec.unsqueeze(-2) -bv_sum), matrix_sum)) ) + bw_sum, dim=(-1,), p=float('inf')).squeeze(-2)




            #print("Backward", backward_vec[0])
        #self.debug_vectors[:,-1]= output
        return T.stack(routing, dim=-3)








    def read(self, x, correct=None):


        #self.memories[-1]['bv'] = self.memories[-1]['bv'] + T.rand(self.memories[-1]['bv'].size(), device=self.device)
        #print(self.memories[-1]['w'][0])

        result = [x]*(self.max_steps+1)

        # y_{i} = g((x - bv_{i}) @ m_{i}) + bw_{i}
        # y = sum( y_{i} * a_{i} )



        if(correct!=None):
            routing = self.mitm_route(x, correct).clone()
            routing_loss=T.zeros(1,device=self.device)
            for i in range(self.max_steps):

                matrix_results = self.activation(T.matmul((result[i].unsqueeze(-2).expand(-1,self.mem_len,-1)-self.memories[self.steps_since_reset]['bv']).unsqueeze(-2), self.memories[self.steps_since_reset]['m'])).squeeze(-2) + self.memories[self.steps_since_reset]['bw']

                result[i+1] = T.sum(routing[:,i]*matrix_results,dim=-2) + result[i]


                if self.training and i>=self.max_steps//2:
                    attention_dat = T.cat((self.engram_info(),result[i].unsqueeze(-2).expand(-1,self.mem_len,-1)),dim=-1)
                    routing_loss += F.mse_loss(routing[:,i],self.attn_activation(T.sum(self.attention_model(attention_dat),dim=-1)).unsqueeze(-1))


            inaccuracy = T.sum(F.mse_loss(result[-1],correct, reduction='none'), dim=-1, keepdim=True)
            d_worth = T.sum(routing, dim=1)*T.log(inaccuracy+1e-6).unsqueeze(-2).expand(-1,self.mem_len,-1)/self.max_steps

            read_mem_vals ={
                'm':0,
                'bv':0,
                'bw':0,
                'w':self.memories[self.steps_since_reset]['w'] - self.worth_mult*d_worth,
                't':0
                }

            read_mem_mask = {'m':0,'bv':0,'bw':0,'w':1,'t':0}



            return result[-1], routing_loss, (read_mem_vals, read_mem_mask)



        else:
            routing = [T.empty((x.size(0), self.mem_len,1))]*self.max_steps
            for i in range(self.max_steps):
                attention_dat = T.cat((self.engram_info(),result[i].unsqueeze(-2).expand(-1,self.mem_len,-1)),dim=-1)
                routing[i]=self.attn_activation(T.sum(self.attention_model(attention_dat),dim=-1)).unsqueeze(-1)

                matrix_results = self.activation(T.matmul((result[i].unsqueeze(-2).expand(-1,self.mem_len,-1)-self.memories[self.steps_since_reset]['bv']).unsqueeze(-2), self.memories[self.steps_since_reset]['m'])).squeeze(-2) + self.memories[self.steps_since_reset]['bw']
                result[i+1] = T.sum(routing[i]*matrix_results,dim=-2) + result[i]


            ## Unfortunately, there's no pretty way to incrementally predict the worth of each output using a transformer.
            ## Ideally, I would cache the internal state resulting from each step, but pytorch doesn't allow that.
            ## torchtune does, but that's for inference, and I'm not sure it's applicable to training.




            worth_est, self.hx_cx = self.worth_model(T.cat((x,result[-1]), dim=-1).unsqueeze(-2).detach(), self.hx_cx)
            worth_est = T.sum(self.lstm_activation(worth_est),dim=-1)
            d_worth =T.sum(T.stack(routing,dim=1), dim=1)*worth_est.unsqueeze(-2).expand(-1,self.mem_len,-1)/self.max_steps

            read_mem_vals={
                    'm':0,
                    'bv':0,
                    'bw':0,
                    'w': self.memories[self.steps_since_reset]['w']-self.worth_mult*d_worth,
                    't':0
                    }

            read_mem_mask = {'m':0,'bv':0,'bw':0,'w':1,'t':0}

        return result[-1], worth_est, (read_mem_vals, read_mem_mask)



    def memorize(self, x, y, log_inaccuracy=None): ## overly simplistic. Memorry consolidation is done elsewhere.

        x = F.normalize(x,dim=-1)#T.where(T.sum(T.abs(x),dim=-1, keepdim=True)>self.data_width, F.normalize(x,dim=-1), x)
        y = F.normalize(y,dim=-1)#T.where(T.sum(T.abs(y),dim=-1, keepdim=True)>self.data_width, F.normalize(y,dim=-1), y)


        new_bv = x
        new_bw = y



        #min_worth_idx = (T.rand(x.size(0),device=self.device)*self.mem_len).int()
        min_worth_idx = T.min(self.memories[self.steps_since_reset]['w'].squeeze(-1),dim=-1)[1]

        idx = T.arange(self.memories[self.steps_since_reset]['m'].size(0), device=self.device)

        memorize_mem_vals ={
                'm':T.eye(self.data_width, device=self.device),
                'bv':new_bv.unsqueeze(-2),
                'bw':new_bw.unsqueeze(-2),
                'w':-log_inaccuracy.unsqueeze(-2) if log_inaccuracy!=None else T.median(self.memories[self.steps_since_reset]['w'].squeeze(-1),dim=-1)[0].unsqueeze(-1).unsqueeze(-2),
                't':self.pos_encoding( self.steps_since_reset).to(self.device).unsqueeze(-2).expand(x.size(0),-1,-1)
                }


        mask = T.zeros((x.size(0), self.mem_len,1), device=self.device)
        mask[idx, min_worth_idx] = T.ones([1], device=self.device)

        memorize_mem_mask={'m':mask.unsqueeze(-1),'bv':mask,'bw':mask,'w':mask,'t':mask}

        return (memorize_mem_vals, memorize_mem_mask)

    def consolidate(self, num_k=5, loops=1):

        # finds the index of the engram with the most similarities.
        delta_b = (self.memories[self.steps_since_reset]['bw']-self.memories[self.steps_since_reset]['bv'])* T.sign(T.sum(T.abs(self.memories[self.steps_since_reset]['m']), dim=(-1,-2))).unsqueeze(-1)
        cos_sim = F.cosine_similarity(delta_b.unsqueeze(-2), delta_b.unsqueeze(-2).transpose(-2,-3),dim=-1)
        cos_sim += T.rand(cos_sim.size(),device=self.device)*0.002-0.001 ## inject a little bit of randomness.

        cos_sim = T.where(cos_sim.bool(), cos_sim,-1)-2*T.eye(self.mem_len, device=self.device).unsqueeze(0)

       #consolidate_mem_vals = map( lambda x: x.clone().detach(), self.memories[self.steps_since_reset])
        consolidate_mem_vals={
            'm':self.memories[self.steps_since_reset]['m'].clone().detach(),
            'bv':self.memories[self.steps_since_reset]['bv'].clone().detach(),
            'bw':self.memories[self.steps_since_reset]['bw'].clone().detach(),
            'w':0,#self.memories[self.steps_since_reset]['w'].clone().detach(),
            't':self.memories[self.steps_since_reset]['t'].clone().detach()
        }
        consolidate_mem_mask = {
            'm':T.zeros((self.memories[self.steps_since_reset]['m'].size(0), self.mem_len,1,1), device=self.device),
            'bv':T.zeros((self.memories[self.steps_since_reset]['bv'].size(0), self.mem_len,1), device=self.device),
            'bw':T.zeros((self.memories[self.steps_since_reset]['bw'].size(0), self.mem_len,1), device=self.device),
            'w':0,
            't':T.zeros((self.memories[self.steps_since_reset]['t'].size(0), self.mem_len,1), device=self.device)
            }


        for i in range(loops):
            most_similarities = T.max(T.max(F.relu(cos_sim),dim=-1)[0]/(1-T.mean(F.relu(cos_sim)+0.0001,dim=-1)), dim=-1)[1]
            idx = T.arange(self.memories[self.steps_since_reset]['m'].size(0), device=self.device)
            cos_sim += T.eye(self.mem_len, device=self.device).unsqueeze(0)*2 ## We want the index of the most similarities to be in the top k.
            topk = T.topk(cos_sim[idx,most_similarities,:],num_k,dim=-1).indices

            # generate new engram
            bv_diff = self.deNaNed(self.memories[self.steps_since_reset]['bv'][:,topk][:,0]-T.mean(self.memories[self.steps_since_reset]['bv'][:,topk],dim=-2,keepdim=True)[:,0])
            bw_diff = self.deNaNed(self.memories[self.steps_since_reset]['bw'][:,topk][:,0]-T.mean(self.memories[self.steps_since_reset]['bw'][:,topk],dim=-2,keepdim=True)[:,0])

            # lazy way to prevent non-full rank issues
            bv_diff += T.rand(bv_diff.size(), device=self.device) * 0.0001
            bw_diff += T.rand(bw_diff.size(), device=self.device) * 0.0001

            try:
                generated_matrix = T.linalg.lstsq(bv_diff, bw_diff).solution
            except Exception as e:
                print(e)
                print(bv_diff, bw_diff)
                exit()
            new_bv = T.mean(self.memories[self.steps_since_reset]['bv'][:,topk][:,0],dim=-2)
            new_bw = T.mean(self.memories[self.steps_since_reset]['bw'][:,topk][:,0],dim=-2)



            # remove nans
            generated_matrix = T.where(generated_matrix.isnan(), 0.0001, generated_matrix)
            new_bv = T.where(new_bv.isnan(), 0, new_bv)
            new_bw = T.where(new_bw.isnan(), 0, new_bw)


            # replace old engram

            consolidate_mem_vals['m'][idx,most_similarities] = generated_matrix
            consolidate_mem_vals['bv'][idx,most_similarities] = new_bv
            consolidate_mem_vals['bw'][idx,most_similarities] = new_bw
            consolidate_mem_vals['t'][idx,most_similarities] = self.pos_encoding( self.steps_since_reset).to(self.device)

            consolidate_mem_mask['m'][idx, most_similarities]+=1
            consolidate_mem_mask['bv'][idx, most_similarities]+=1
            consolidate_mem_mask['bw'][idx, most_similarities]+=1
            consolidate_mem_mask['t'][idx, most_similarities]+=1

            # set the similarity of previous consolidated engram to -1.
            cos_sim[idx,most_similarities]*=0
            cos_sim[idx,most_similarities]-=1
            cos_sim = T.clamp(cos_sim, min=-1, max=1)

        return (consolidate_mem_vals, consolidate_mem_mask)


In [None]:


import torch_optimizer as optim

sdapm=None
vae=None
loss=None
vae_opt=None
sdapm_opt=None
import gc
gc.collect()
T.cuda.empty_cache()
gc.collect()


batch_size=4
emb_size = 768
sdapm_emb_size = 60
epoch=0
iter=0
count=0
device='cuda'

checkpoints = list(filter(lambda x: x.startswith('TVAT'), os.listdir()))
if checkpoints==None or len(checkpoints)==0:
    vae = LanguageModel(emb_size, sdapm_emb_size, vocab_size, temporal_division=4, device=device)
    vae.tvat.encoder.expansion_offset=0
else:
    print("loading", sorted(checkpoints)[-1])
    vae = T.load(sorted(checkpoints)[-1], weights_only=False)
    vae.tvat.encoder.temporal_division=4

checkpoints = list(filter(lambda x: x.startswith('sdapm'), os.listdir()))
sdapm = SDAPM(sdapm_emb_size, mem_len=128, device=device, param_mem_mats=True)
if not ( checkpoints==None or len(checkpoints)==0):
    print("loading", sorted(checkpoints)[-1])
    sdapm.load_state_dict(T.load(sorted(checkpoints)[-1]) )


vae_opt = optim.DiffGrad(vae.parameters())
sdapm_opt = optim.DiffGrad(sdapm.parameters())

def softclamp(x, max):
  return max*((x/max)/(1+T.abs(x/max)))


#sdapm=T.compile(sdapm, backend='cudagraphs', dynamic=True)
#T._dynamo.config.capture_scalar_outputs = True
#vae=T.compile(vae, backend='cudagraphs')


vae_sdapm_opt = optim.DiffGrad((*vae.parameters(), *sdapm.parameters()), lr=5e-3)

max_grad_norm=100
def grad_norm(grad):
    global max_grad_norm
    if(T.any(grad.abs()>max_grad_norm)):
        return F.normalize(grad, dim=-1)*max_grad_norm
    return grad

padding_idx = tokenizer.token_to_id("`PAD`")


def max_norm(x, max):
  return T.where(T.sum(x)>max, F.normalize(x, p=1, dim=-1)*max, x)



for p in sdapm.parameters():
    p.register_hook(grad_norm)
for p in vae.parameters():
    p.register_hook(grad_norm)

import time

print(sdapm.initial_mem_mats[0,7:,7:])
sdapm_loss=T.zeros(1)
vae_loss = T.zeros(1)

while True:

  count = count+batch_size % (len(dataset)-batch_size)
  data_batch = T.nn.utils.rnn.pad_sequence([T.tensor(data.ids) for data in dataset[count:count+batch_size]], batch_first=True, padding_value=tokenizer.token_to_id("`PAD`"), padding_side='right').int().to(device)[:,:128]


  if(( (iter<00))):
    print("vae_only")

    reconstruction, len_loss = vae(data_batch)
    loss = F.cross_entropy(reconstruction.transpose(-1,-2), data_batch.long(), ignore_index = padding_idx) + vae.kl/emb_size/batch_size + T.mean(len_loss)
    loss.backward()
    vae_opt.step()
    if(iter%100==0):
      print("---===+++ EXAMPLE RECONSTRUCTION +++===---\n"+tokenizer.decode(T.argmax(reconstruction[0], dim=-1).tolist(), skip_special_tokens=False)+"\n---===+++#############+++===---")
  elif(True or iter%100>=66):
    print("sdapm_only+vae")

    embedded, len_loss, _ = vae.embed(data_batch[:,:-1])


    #print(T.histc(T.floor(sections)[:,-1])sizes
    z = sdapm(embedded)
    with T.no_grad():
      _, sizes = vae.tvat.generate_mask(embedded.detach(), data_batch.size(-2), return_sizes=True)
      correct_results=[None]*batch_size
      for i in range(batch_size):
        correct_results[i] = data_batch[i,sizes[i][0]:]
      correct_result= T.nn.utils.rnn.pad_sequence(correct_results, batch_first=True, padding_value=tokenizer.token_to_id("`PAD`"), padding_side='right').int().to(device).detach()
      correct_result = T.cat((correct_result, T.full((batch_size, data_batch.size(-1)-correct_result.size(-1)),tokenizer.token_to_id("`PAD`"), device=device)), dim=-1)
    #ae.tvat.generate_mask(z, data_batch.size(-2))

    prediction= vae.exbed(z, data_batch.size(-1), correct=vae.embedding(correct_result))
    loss = F.cross_entropy(prediction.transpose(-1,-2), correct_result.long(), ignore_index = padding_idx) +T.mean(len_loss) + vae.kl/emb_size/batch_size



    loss.backward()
    #sdapm_opt.step()
    vae_sdapm_opt.step()
    #vae_opt.step()
    if(iter%100==99):
      print("---===+++ EXAMPLE RESULT +++===---\n"+tokenizer.decode(T.argmax(prediction[0], dim=-1).tolist(), skip_special_tokens=False)+"\n---===+++################+++===---")

  else:
    if(iter>100):
      #print("sdapm_only")
      enc_data = vae.embed(data_batch)[0].detach()
      #print(T.max(T.abs(enc_data)))
      prediction, r_loss, w_loss = sdapm(enc_data[:,:-1], enc_data[:,1:])

      #print(T.max(T.abs(prediction)))
      print(r_loss.item(), w_loss.item())
      sdapm_loss = F.mse_loss(prediction, enc_data[:,1:]) + softclamp(r_loss,1000) + softclamp(w_loss, 100)
      sdapm_loss.backward()
      sdapm_opt.step()

    reconstruction, len_loss = vae(data_batch)
    vae_loss = F.cross_entropy(reconstruction.transpose(-1,-2), data_batch.long(), ignore_index = padding_idx) + vae.kl/emb_size/batch_size + T.mean(len_loss)
    vae_loss.backward()
    vae_opt.step()
    if(iter%100==0):
      print("---===+++ EXAMPLE RECONSTRUCTION +++===---\n"+tokenizer.decode(T.argmax(reconstruction[0], dim=-1).tolist(), skip_special_tokens=False)+"\n---===+++#############+++===---")


    #print(tokenizer.decode(T.argmax(prediction[0], dim=-1).tolist()))

  iter+=1
  #vae_opt.zero_grad(set_to_none=True)
  #sdapm_opt.zero_grad(set_to_none=True)
  vae_sdapm_opt.zero_grad(set_to_none=True)

  if iter%30==29:
    T.cuda.empty_cache()
  gc.collect()
  print("EPOCH "+str(math.floor(iter/len(dataset))) +"\tLOSS_VAE: "+ str(vae_loss.item())+"\tLOSS_SDAPM: "+ str(loss.item()))
  vae.tvat.decoder.b_mix.set_probs(0.1**(1/((iter+1500000)*0.003))) ## slowly converges to 1, decreasing the influence of the input skip layer on the decoder.

  if(T.any(T.isnan(sdapm.initial_mem_mats)) or T.any(T.isnan(sdapm.initial_bv)) or T.any(T.isnan(sdapm.initial_bw))):
    checkpoints = list(filter(lambda x: x.startswith('sdapm'), os.listdir()))
    print("loading", sorted(checkpoints)[-1])
    sdapm.load_state_dict(T.load(sorted(checkpoints)[-1]) )

  if(iter%500==0):
    print("saving...")
    current_time = time.time()
    T.save(sdapm.state_dict(), "sdapm_"+str(int(current_time))+".pt")
    T.save(vae, "TVAT_"+str(int(current_time))+".pt")





loading TVAT_1752865178.pt
loading sdapm_1752865178.pt
tensor([[ 0.0286,  0.1285, -0.1284,  ...,  0.0129, -0.0653, -0.0211],
        [-0.0389,  0.1149,  0.0314,  ...,  0.0205, -0.0044, -0.0640],
        [-0.0198,  0.0047, -0.1036,  ..., -0.0749, -0.1251, -0.0753],
        ...,
        [-0.0977,  0.0415, -0.0689,  ..., -0.0153, -0.1933, -0.0215],
        [-0.0306, -0.0017, -0.0076,  ...,  0.0641, -0.0439, -0.0259],
        [ 0.2209,  0.2566,  0.2286,  ...,  0.1570,  0.0913,  0.3188]],
       device='cuda:0', grad_fn=<SliceBackward0>)
sdapm_only+vae


  densified_z[i,T.floor(sections).int()] = T.index_reduce(z, dim=1, index=T.floor(sections[i]).int(), source=z, reduce='mean')


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
EPOCH 0	LOSS_VAE: 0.0	LOSS_SDAPM: 10.549012184143066
sdapm_only+vae
EPOCH 0	LOSS_VAE: 0.0	LOSS_SDAPM: 9.587272644042969
sdapm_only+vae
EPOCH 0	LOSS_VAE: 0.0	LOSS_SDAPM: 9.50847053527832
sdapm_only+vae
EPOCH 0	LOSS_VAE: 0.0	LOSS_SDAPM: 10.364312171936035
sdapm_only+vae
EPOCH 0	LOSS_VAE: 0.0	LOSS_SDAPM: 11.192941665649414
sdapm_only+vae
EPOCH 0	LOSS_VAE: 0.0	LOSS_SDAPM: 10.776065826416016
sdapm_only+vae
EPOCH 0	LOSS_VAE: 0.0	LOSS_SDAPM: 10.602923393249512
sdapm_only+vae
EPOCH 0	LOSS_VAE: 0.0	LOSS_SDAPM: 10.161172866821289
sdapm_only+vae
EPOCH 0	LOSS_VAE: 0.0	LOSS_SDAPM: 11.056880950927734
sdapm_only+vae
EPOCH 0	LOSS_VAE: 0.0	LOSS_SDAPM: 11.377476692199707
sdapm_only+vae
EPOCH 0	LOSS_VAE: 0.0	LOSS_SDAPM: 10.124530792236328
sdapm_only+vae
EPOCH 0	LOSS_VAE: 0.0	LOSS_SDAPM: 10.317265510559082
sdapm_only+vae
EPOCH 0	LOSS_VAE: 0.0	LOSS_SDAPM: 10.117131233215332
sdapm_only+vae
EPOCH 0	LOSS_VAE: 0.0	LOSS_SDAPM: 10.665565490722656
s

RuntimeError: amax(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.