<a href="https://colab.research.google.com/github/ylchen-QsNb/Brett/blob/main/Brett1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers
!pip install datasets

In [None]:
import torch
import torch.nn as nn
import random
from transformers import AutoTokenizer

emb_size = 768
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('which device we are using: ' + device)

class SSM(nn.Module):
  def __init__(self, d_state, d_out, d_a, hidden):
    super().__init__()

    self.state_ini = nn.Linear(d_out, d_state, bias=False)
    self.get_ini = torch.eye(d_out, d_out, requires_grad=False).to(device)

    self.alpha = nn.Sequential(
        nn.Linear(d_state, hidden),
        nn.GELU(),
        nn.Linear(hidden, hidden),
        nn.GELU(),
        nn.Linear(hidden, d_a),
        nn.Softmax(dim=-1)
    )
    self.beta = nn.Sequential(
        nn.Linear(d_state, hidden),
        nn.GELU(),
        nn.Linear(hidden, hidden),
        nn.GELU(),
        nn.Linear(hidden, d_a),
        nn.Softmax(dim=-2)
    )
    self.gamma = nn.Sequential(
        nn.Linear(d_state, hidden),
        nn.GELU(),
        nn.Linear(hidden, hidden),
        nn.GELU(),
        nn.Linear(hidden, d_state),
        nn.Tanh()
    )
    self.delta = nn.Linear(d_state, d_state, bias=False)
    self.A = nn.Parameter(torch.randn(d_out, d_state))
    self.B = nn.Parameter(torch.randn(d_out, d_state))
    self.tau = nn.Parameter(torch.rand(d_out, d_state))


  def initialize(self):
    return self.state_ini(self.get_ini)

  def forward(self, d_t, state):
    beta = self.beta(state)
    gamma = self.gamma(state)
    bg = torch.matmul(torch.transpose(beta, -1, -2), gamma)
    bgd = self.delta(bg)
    alpha = self.alpha(state)
    abgd = torch.matmul(alpha, bgd)

    out = (state + d_t * (abgd*self.A + self.B)) / (1 + d_t * (self.tau + abgd))

    return out


# Brett: Bidirectional Recurssive Encoder from Transformer Time-dependent
class BrettCore(nn.Module):
  def __init__(self, d_model, head, d_k, d_v, d_state, d_a, hidden, hidden_feed):
    super().__init__()
    self.head = head
    self.d_model = d_model
    self.d_k = d_k
    self.d_v = d_v
    self.d_state = d_state
    
    self.ssm_q = nn.ModuleList(
        [SSM(d_state, d_k, d_a, hidden) for _ in range(head)]
    )
    self.ssm_k = nn.ModuleList(
        [SSM(d_state, d_k, d_a, hidden) for _ in range(head)]
    )
    self.ssm_v = nn.ModuleList(
        [SSM(d_state, d_v, d_a, hidden) for _ in range(head)]
    )

    self.sq = nn.ModuleList(
        [nn.Linear(d_state, d_model) for _ in range(head)]
    )
    self.sk = nn.ModuleList(
        [nn.Linear(d_state, d_model) for _ in range(head)]
    )
    self.sv = nn.ModuleList(
        [nn.Linear(d_state, d_model) for _ in range(head)]
    )

    self.feed_a_state = SSM(2*hidden_feed, hidden_feed, hidden_feed//2, hidden_feed)
    self.feed_a_trans = nn.Linear(2*hidden_feed, head*d_v)
    self.feed_a_bias = nn.Linear(2*hidden_feed, 1)

    self.gelu = nn.GELU()

    self.feed_b_state = SSM(2*d_model, d_model, d_model//2, d_model)
    self.feed_b_trans = nn.Linear(2*d_model, hidden_feed)
    self.feed_b_bias = nn.Linear(2*d_model, 1)

    self.tau = nn.Parameter(torch.rand(1, 1, d_model))

  def initialize(self):
    state_q = torch.zeros(0, self.d_k, self.d_state).to(device)
    state_k = torch.zeros(0, self.d_k, self.d_state).to(device)
    state_v = torch.zeros(0, self.d_v, self.d_state).to(device)
    for idx in range(self.head):
      state_q = torch.cat((state_q, self.ssm_q[idx].initialize().unsqueeze(0)), 0)
      state_k = torch.cat((state_k, self.ssm_k[idx].initialize().unsqueeze(0)), 0)
      state_v = torch.cat((state_v, self.ssm_v[idx].initialize().unsqueeze(0)), 0)
    state_a = self.feed_a_state.initialize()
    state_b = self.feed_b_state.initialize()
    out_tuple = (state_q, state_k, state_v, state_a, state_b)
    return out_tuple

  def forward(self, dt, x, in_tuple, mask=None):
    batch = x.size(0)
    length = x.size(1)

    state_q, state_k, state_v, state_a, state_b = in_tuple

    if mask is not None:
      pad = mask.unsqueeze(-1)
      pad = pad.repeat(1, 1, self.d_k)

    attn = torch.zeros(batch, length, 0).to(device)

    out_q = torch.zeros(0, self.d_k, self.d_state).to(device)
    out_k = torch.zeros(0, self.d_k, self.d_state).to(device)
    out_v = torch.zeros(0, self.d_v, self.d_state).to(device)

    for idx in range(self.head):
      wq_trans = self.sq[idx](state_q[idx, :, :])
      wk_trans = self.sq[idx](state_k[idx, :, :])
      wv_trans = self.sq[idx](state_v[idx, :, :])

      k_trans = torch.matmul(wk_trans, torch.transpose(x, -1, -2))

      if mask is not None:
        padd = torch.transpose(pad, -1, -2)
        k_trans = k_trans.masked_fill(padd==0, -1e9)

      k_trans = nn.functional.softmax(k_trans, dim=-1)
      v = torch.matmul(x, torch.transpose(wv_trans, -1, -2))
      ktv = torch.bmm(k_trans, v)
      q = torch.matmul(x, torch.transpose(wq_trans, -1, -2))

      if mask is not None:
        q = q.masked_fill(pad==0, -1e9)

      q = nn.functional.softmax(q, dim=-1)
      qktv = torch.matmul(q, ktv)

      attn = torch.cat((attn, qktv), -1)

      out_q = torch.cat((out_q, self.ssm_q[idx].forward(dt, state_q[idx, :, :]).unsqueeze(0)), 0)
      out_k = torch.cat((out_k, self.ssm_k[idx].forward(dt, state_k[idx, :, :]).unsqueeze(0)), 0)
      out_v = torch.cat((out_v, self.ssm_v[idx].forward(dt, state_v[idx, :, :]).unsqueeze(0)), 0)
    
    out_a = self.feed_a_state.forward(dt, state_a)
    out_b = self.feed_b_state.forward(dt, state_b)
    
    # f = self.feed(attn)

    f = torch.matmul(attn, torch.transpose(self.feed_a_trans(state_a), -1, -2))
    f = f + torch.transpose(self.feed_a_bias(state_a), -1, -2)
    f = self.gelu(f)
    f = torch.matmul(f, torch.transpose(self.feed_b_trans(state_b), -1, -2))
    f = f + torch.transpose(self.feed_b_bias(state_b), -1, -2)
    denom = torch.sum(f*x, -1).unsqueeze(-1) / self.d_model + self.tau
    y = (x + dt*f) / (1 + dt*denom)
    out_tuple = (out_q, out_k, out_v, out_a, out_b)
    return y, out_tuple

class Embed(nn.Module):
  def __init__(self, vec_size=emb_size, dict_size=30522, max_len=512, max_sentence=32):
    super().__init__()
    self.word_embed = nn.Embedding(dict_size, vec_size, padding_idx=0)
    self.position_embed = nn.Embedding(max_len, vec_size, padding_idx=0)
    self.sentence_embed = nn.Embedding(max_sentence, vec_size, padding_idx=0)
    
  def forward(self, token):
    # token = self.tokenizer(text1, text2, padding=True, truncation=True, add_special_tokens=True, return_tensors="pt")
    attn_mask = token['attention_mask'].to(device)
    word_ids = token['input_ids'].to(device)
    type_ids = token['token_type_ids'].to(device)
    length = word_ids.size(-1)
    batch = word_ids.size(0)
    position_ids = torch.arange(length).repeat(batch, 1).to(device)
    vector = self.word_embed(word_ids) + self.position_embed(position_ids) + self.sentence_embed(type_ids)

    return vector, attn_mask

class Brett(nn.Module):
  def __init__(self, d_model, head, d_k, d_v, d_state, d_a, hidden, hidden_feed):
    super().__init__()
    self.embed = Embed()
    self.brett = BrettCore(d_model, head, d_k, d_v, d_state, d_a, hidden, hidden_feed)

  def forward(self, token, num=11):
    x, pad = self.embed(token)
    print(x.shape)
    samples = torch.rand(num)
    samples, _ = torch.sort(samples)
    samples = torch.cat((samples, torch.tensor([1.0])), -1)
    t0 = torch.tensor(0.0)
    state = self.brett.initialize()
    for t1 in samples:
      dt = t1 - t0
      print('t1:', t1, 'dt:', dt)
      x, state = self.brett.forward(dt, x, state, mask=pad)
      t0 = t1
    return x

class Pretrain_Brett(nn.Module):
  def __init__(self, d_model, head, d_k, d_v, d_state, d_a, hidden, hidden_feed):
    super().__init__()
    self.brett = Brett(d_model, head, d_k, d_v, d_state, d_a, hidden, hidden_feed)
    self.mlm_out = nn.Sequential(
        nn.Linear(d_model, 30522),
    )
    self.nsp_out = nn.Sequential(
        nn.Linear(d_model, d_model),
        nn.GELU(),
        nn.Linear(d_model, 1),
        nn.Sigmoid()
    )
    self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    # self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
    self.loss = nn.BCELoss()
    self.loss_mlm = nn.CrossEntropyLoss(ignore_index=0)
    
  def nsp(self, token, target, num=7):
    out = self.brett(token, num=num)
    classifier_units = out[:, 0, :]
    nsp_res = self.nsp_out(classifier_units).squeeze()
    # target = torch.tensor(target)
    # target = target.to(device)
    # print(device, target.device)
    loss = self.loss(nsp_res, target)
    return loss

  def mlm(self, masked_token, cheat_sheet, num=7):
    out = self.brett(masked_token, num=num)
    logits = self.mlm_out(out)
    # loss = 0
    # for idx in range(cheat_sheet.size(0)):
    #   testant = out[idx, :, :]
    #   mask = cheat_sheet[idx, :]==1
    #   testant_ = torch.masked_select(testant, mask[:, None]).reshape(-1, testant.shape[1])
    #   mlm_res = self.mlm_out(testant_)
    #   test_token = origin_token[idx, :][mask]
    #   res = torch.diag(torch.index_select(mlm_res, dim=1, index=test_token))
    #   loss = loss + torch.mean(torch.log(res))
    logits = logits.view(-1, 30522)
    labels = cheat_sheet.view(-1).to(device)

    return self.loss_mlm(logits, labels)

def pre_nsp(training_set):
  tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
  length = len(training_set)
  set1 = []
  set2 = []
  target = []
  for idx in range(length-1):
    set1.append(training_set[idx])
    if idx % 2 == 0:
      set2.append(training_set[idx+1])
      target.append(1.0)
    else:
      id = random.randint(0, length-1)
      if id == idx+1:
        while id == idx+1:
          id = random.randint(0, length-1)
      set2.append(training_set[id])
      target.append(0.0)
  
  token = tokenizer(set1, set2, padding=True, truncation=True, add_special_tokens=True, return_tensors="pt")
  return token, torch.tensor(target)

def pre_mlm(training_set):
  tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
  length = len(training_set)

  token = tokenizer(training_set, padding=True, truncation=True, add_special_tokens=True, return_tensors="pt")
  origin_token = token['input_ids']
  cheat_sheet = torch.zeros(origin_token.size(0), origin_token.size(1), dtype=torch.long)

  for j in range(length):
    con_len = torch.count_nonzero(origin_token[j, :])
    rand = random.sample(range(con_len), con_len*3//20)
    for idx in range(con_len*3//20):
      enter = rand[idx]
      u = random.random()
      if u < 0.8:
        token['input_ids'][j, enter] = tokenizer.mask_token_id
        cheat_sheet[j, enter] = tokenizer.mask_token_id
      elif (u > 0.8) and (u < 0.9):
        replace = random.randint(tokenizer.mask_token_id+1, 30522)
        token['input_ids'][j, enter] = replace
        cheat_sheet[j, enter] = replace

  return token, cheat_sheet

# Main()

In [None]:
def train_nsp(pretrain, training_set, epochs, lr):
  token, target = pre_nsp(training_set)
  token = token.to(device)
  target = target.to(device)
  optimizer = torch.optim.Adam(pretrain.parameters(), lr=lr)
  for epoch in range(epochs):
    optimizer.zero_grad()
    loss = pretrain.nsp(token, target)
    loss.backward()
    optimizer.step()
    print('epoch: {}/{}, loss: {}'.format(epoch+1, epochs, loss.item()))

def train_mlm(pretrain, training_set, epochs, lr):
  token, cheat_sheet = pre_mlm(training_set)
  token = token.to(device)
  print(cheat_sheet)
  cheat_sheet = cheat_sheet.to(device)
  optimizer = torch.optim.Adam(pretrain.parameters(), lr=lr)
  for epoch in range(epochs):
    optimizer.zero_grad()
    loss = pretrain.mlm(token, cheat_sheet)
    loss.backward()
    optimizer.step()
    print('epoch: {}/{}, loss: {}'.format(epoch+1, epochs, loss.item()))

if __name__ == '__main__':
  d_model, head, d_k, d_v, d_state, d_a, hidden, hidden_feed = 768, 64, 64, 128, 128, 32, 128, 128
  pretrain = Pretrain_Brett(d_model, head, d_k, d_v, d_state, d_a, hidden, hidden_feed).to(device)
  batch_size = 64

  training_set = [
      "I'm good to go for the party.",
      "I have no idea where is the location?",
      "I can guide you guys to the rigth place.",
      "Thank you so much!",
      "No worries, we are friends.",
      "Have a fun time!"
  ]

  epochs = 8
  lr = 5e-5
  train_nsp(pretrain, training_set, epochs, lr)

# Trash Bin