In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import math
import pandas as pd
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
import pandas as pd
import numpy as np
import pickle
from gensim.models import Word2Vec
import json
max_token = 200
embedding_dims = 300

tokenizer = Tokenizer.from_file(r'../file_vocab.json')
with open(r'..\BM250_retrival_Gomoe.pkl' , 'rb' ) as f:
  retrival_gomoe = pickle.load(f)
with open(r'../BM25O_Retrival_clinc.pkl' , 'rb') as f:
  retrival_clinc = pickle.load(f)
word2vec_clinc = Word2Vec.load(r'../Word2vec_cls_clinc.model')
word2vec_gomoe = Word2Vec.load(r'../Word2vec_cls_gomoe.model')
with open(r'../copus_clinc.json' , 'r') as f:
  copus_clinc = json.load(f)
with open(r'../copus_gomoe.json' , 'r') as f:
  copus_gomoe = json.load(f)
#load model


def read_file_retrival_clinc(text):
  clinc_re = retrival_clinc.get_scores(tokenizer.encode(text).tokens)
  top_indices = sorted(range(len(clinc_re)), key=lambda i: clinc_re[i], reverse=True)[: 9]
  dict_cls = {}
  for inx in top_indices:
    g = copus_clinc[str(inx)]
    dict_cls[g] = dict_cls.get(g , 0) + 1
  clas = max(dict_cls, key=dict_cls.get)
  # print('Max class clinc', clas)
  return word2vec_clinc.wv[clas]

# load gomoe -> vector
def read_file_retrival_gomoe(text):
  gomoe_re = retrival_gomoe.get_scores(tokenizer.encode(text).tokens)
  top_indices = sorted(range(len(gomoe_re)), key=lambda i: gomoe_re[i], reverse=True)[: 9]
  dict_cls = {}
  for inx in top_indices:
    g = copus_gomoe[str(inx)]
    dict_cls[g] = dict_cls.get(g , 0) + 1
  clas = max(dict_cls, key=dict_cls.get)
  # print('Max class gomoe', clas)
  return word2vec_gomoe.wv[clas]



pad = tokenizer.token_to_id("[PAD]")
def trun_pad_in(sequen):
  # print('trun_pad_in' , sequen)
  sequen = tokenizer.encode(str(sequen))
  text_ids = sequen.ids[:max_token-1]
  leng = max_token - len(text_ids)
  text_ids += [pad] * leng
  mask = np.where(np.array(text_ids) == 1 , 0 , 1)

  return torch.tensor(text_ids, dtype = torch.long), torch.tensor(mask , dtype = torch.long)



checkpoint = torch.load(r'../best_model.pth', weights_only=False)
model = checkpoint['Model']
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

def autoregressive_predict(
    model,
    text,
    max_token = 200,
    embedding_dims = embedding_dims, max_len = 200):

  model.eval()

  token_ids , mask_in = trun_pad_in(text)
  mask_in = mask_in.unsqueeze(0).type(torch.bool)
  token_ids = token_ids.unsqueeze(1)
  x_retrival_gome = torch.tensor(read_file_retrival_clinc(text)).unsqueeze(0).unsqueeze(0).repeat( 1 , max_token , 1 )
  x_retrival_clinc = torch.tensor(read_file_retrival_gomoe(text)).unsqueeze(0).unsqueeze(0).repeat( 1 , max_token , 1 )

  print('Shape x_retrival_gome', x_retrival_gome.shape)
  print('Shape x_retrival_clinc', x_retrival_clinc.shape)

  x_emb_retrival = torch.cat((x_retrival_gome ,x_retrival_clinc), dim = -1)


  x_emb = model.embedding_pos(token_ids)
  print('Shape x_emb_retrival', x_emb_retrival.shape)
  print('Shape x_emb', x_emb.shape)

  print('shape mask' , mask_in.shape)

  sos_token_id = tokenizer.token_to_id("[SOS]")
  eos_token_id = tokenizer.token_to_id("[END]")
  tgt_out = torch.tensor([[eos_token_id]])
  tgt_out = model.embedding_pos(tgt_out)
  print('tgt_out' , tgt_out.shape)

  x_emb = x_emb.permute(1, 0, 2)
  tgt_out = tgt_out.permute(1, 0, 2)
  x_emb_retrival = x_emb_retrival.permute(1, 0, 2)
  with torch.no_grad():
    for _ in range(max_len):
      output = model(
                      x_emb_retrival = x_emb_retrival,
                      x_embedding = x_emb,
                      tgt_out = tgt_out,
                      mask_ids = mask_in
                ) # -> (S, N , E )
      print('out_put' , output.shape)
      return output

      next_token_logits = output[-1, 0, :]
      next_token_id = next_token_logits.argmax(dim = -1).item()
      tgt_out = torch.cat([tgt_out, torch.tensor([[next_token_id]])], dim=1)
      if next_token_id == eos_token_id:
        break
    result_ids = tgt_out.squeeze(0).tolist()[1:]
    result_text = tokenizer.decode(result_ids)
    return result_text

input_text = "nhập câu test của bạn ở đây"
linh = autoregressive_predict(model, input_text)