In [None]:
from drivers.loaders.imdb import IMDB

from drivers.tokenizers.word_piece_vocab import WordPieceVocab
from drivers.tokenizers.word_level_vocab import WordLevelVocab
from drivers.tokenizers.unigram_vocab import UnigramVocab
from drivers.tokenizers.bpe_vocab import BPEVocab

VOCAB_SIZE = 1000
UNK_TOKEN = "[UNK]"

db = IMDB("data")

vocab = WordLevelVocab(db.get_train()["text"].values, UNK_TOKEN, VOCAB_SIZE)

In [None]:
import os

PATH_VOCABS = "vocabs/"

file_name_vocab = PATH_VOCABS + vocab.name + "_" + db.name + ".json"

print(file_name_vocab)

if os.path.isfile(file_name_vocab) == False:
    vocab.train()
    print("TRAINED:", db.name)
    
    vocab.save(file_name_vocab)
    print("SAVED:", db.name)
else:
    vocab.load(file_name_vocab)
    print("LOADED:", db.name)

In [None]:
from drivers.rl.util_tensorboard import TensorboardLoggerSimple, DummyLogger
from drivers.rl.vocab_search import VocabEnv, VocabSearch

X_train = db.get_train()["text"]
y_train = db.get_train()["label"]
X_test = db.get_train()["text"]
y_test = db.get_test()["label"]
n_classes = len(y_train.unique())

In [None]:
from drivers.models.simple import Simple

simple_name = "Simple_" + db.name + "_" + vocab.name

model = Simple(input_length=128, output_size=db.get_labels(),
                repeate=1,
                name=simple_name)

In [None]:
vocab_search = VocabSearch(X_train, y_train, {v: k for k, v in vocab.tokenizer.get_vocab().items()}, n_classes, model=model, input_length=128, logger=DummyLogger(log_dir="")) #logger=TensorboardLoggerSimple(log_dir="tb_logs")

In [None]:
from tensorflow.python.keras.utils.vis_utils import plot_model

env = VocabEnv(X_train, y_train, None, None, possible_words=vocab.tokenizer.get_vocab(), input_length=model.input_length, n_classes=n_classes, model=model)

env.reset()

model_built = env._build_model()

plot_model(model_built, show_shapes=True, show_layer_names=True)

In [None]:

#model._create_model(vocab_size=10)
#model.model.fit(env._preprocess_input(X_train), y_train)

In [None]:
vocab_search.search(n_envs=1, single_thread=True)