In [None]:
from typing import List


def predict_greedy_raw(dataset,
                       greedy_word_generator: GreedyGenerator,
                       max_n_steps = 19, # длина самого длинного слова в валидационной выборке
                      ) -> List[List[str]]:
    """
    Creates predictions using greedy generation.

    Supposed to be used with a dataset of a single grid
    
    Arguments:
    ----------
    dataset: NeuroSwipeDatasetv2
    grid_name_to_greedy_generator: dict
        Dict mapping grid names to GreedyGenerator objects.
    """
    preds = [None] * len(dataset)

    for data in tqdm(enumerate(dataset), total=len(dataset)):
        i, ((xyt, kb_tokens, _), _) = data

        pred = greedy_word_generator.generate_word_only(xyt, kb_tokens, max_n_steps)
        pred = pred.removeprefix("<sos>")
        preds[i] = pred

    return preds


def get_targets(dataset: CurveDataset) -> tp.List[str]:
    targets = []
    for _, target_tokens in dataset:
        # Last token is <eos>.
        target_str = word_char_tokenizer.decode(target_tokens[:-1])
        targets.append(target_str)
    return targets


def get_accuracy(preds, targets) -> float:
    return sum(pred == target for pred, target 
               in zip(preds, targets)) / len(targets)


def get_greedy_generator_accuracy(val_dataset, model, 
                                  word_char_tokenizer, device) -> float:
#     ! Лучше не гененрировать слово целиком, а продолжать побуквенно. 
#     Если буква не совпала сразу обрывать и говорить, 
#     что предсказание для этой кривой не совпало, а не гененировать все слово впустую
    val_targets = get_targets(val_dataset)
    greedy_generator = GreedyGenerator(model, word_char_tokenizer, device)
    greedy_preds = predict_greedy_raw(val_dataset, greedy_generator)
    return get_accuracy(greedy_preds, val_targets)

In [None]:
# ###################### протестируем predict_greedy_raw ######################


# # Главное теситровать не на случайных веах, потому что тогда будут генеироваться не короткие слова, а слова длиной max_seq_len


# MODEL_TO_TEST_GREEDY_GEN__PATH = "../data/trained_models_for_final_submit/m1_bigger/" \
#     "m1_bigger_v2__2023_11_11__14_29_37__0.13679_default_l2_0_ls0_switch_0.pt"

# # Leads to super slow inference.  I think it's due to 
# # high price of operations on small-amplitude floats.
# # MODEL_TO_TEST_GREEDY_GEN__PATH = None


# def test_greedy_generator(val_dataset, model_getter, model_weights, word_char_tokenizer, device) -> float:
    
#     model = model_getter(device, model_weights)

#     return get_greedy_generator_accuracy(val_dataset, model, word_char_tokenizer, device)



# test_greedy_generator(val_dataset, get_m1_bigger_model, MODEL_TO_TEST_GREEDY_GEN__PATH, word_char_tokenizer, device)

In [None]:
# ! Word-level accuracy of greedy search results and in-training-predictions are equal.
# ! Thus GreedyAccuracyCallback doesn't make sence and should be deleted.

# ! However if we would count char-level metrics,
# ! greedy search results and in-training-predictions would be different

class GreedyAccuracyCallback(Callback):
    def __init__(self, each_n_steps: int, val_dataset, word_char_tokenizer, logger):
        self.each_n_steps = each_n_steps
        self.val_dataset = val_dataset
        self.word_char_tokenizer = word_char_tokenizer
        self.logger = logger

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_index):
        device = next(pl_module.parameters()).device
        
        if (pl_module.global_step + 1) % self.each_n_steps == 0:
            greedy_accuracy = get_greedy_generator_accuracy(
                val_dataset, pl_module.model, word_char_tokenizer, device)
            self.logger.log_metrics({"greedy_val_accuracy": greedy_accuracy}, step = pl_module.global_step)

In [None]:
greedy_acc_callback = GreedyAccuracyCallback(
    each_n_steps = 9000, val_dataset=val_dataset, 
    word_char_tokenizer=word_char_tokenizer, logger = tb_logger)

In [None]:
# # Протестируем корректность collate_fn (вызывается неявно в DataLoader)

# batch_size = 6


# PAD_CHAR_TOKEN = word_char_tokenizer.char_to_idx["<pad>"]


# train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False,
#                               num_workers=0, collate_fn=collate_fn)


# dataset_els = [train_dataset[i] for i in range(batch_size)]
# unproc_batch_x, unproc_batch_y = zip(*dataset_els)

# batch_x, batch_y = next(iter(train_dataloader))


# ############### Проверка корректности batch_y ###################
# max_out_seq_len = max([len(y) for y in unproc_batch_y])

# assert batch_y.shape == (max_out_seq_len, batch_size)


# for i in range(batch_size):
#     assert (batch_y[:len(unproc_batch_y[i]), i] == unproc_batch_y[i]).all()
#     assert (batch_y[len(unproc_batch_y[i]):, i] == PAD_CHAR_TOKEN).all()

# print("batch_y is correct")



# ############### Проверка корректности batch_x ###################
# unproc_batch_traj_feats, unproc_batch_kb_tokens, unproc_batch_dec_in_char_seq = zip(*unproc_batch_x)

# (traj_feats, kb_tokens, dec_in_char_seq, traj_pad_mask, word_pad_mask) = batch_x


# # каждая сущность, полученная выше из unpoc_batch_x - это tuple длины batch_size.
# # Например, unproc_batch_traj_feats[i] = train_dataset[i][0][0]

# N_TRAJ_FEATS = 6
# max_curve_len = max([el.shape[0] for el in unproc_batch_traj_feats]) 

# assert max_curve_len == max([el.shape[0] for el in unproc_batch_kb_tokens])

# assert traj_feats.shape == (max_curve_len, batch_size, N_TRAJ_FEATS)
# assert kb_tokens.shape == (max_curve_len, batch_size)
# assert dec_in_char_seq.shape == (max_out_seq_len, batch_size)
# assert traj_pad_mask.shape == (batch_size, max_curve_len)
# assert word_pad_mask.shape == (batch_size, max_out_seq_len)


# for i in range(batch_size):
#     assert (traj_feats[:len(unproc_batch_traj_feats[i]), i] == unproc_batch_traj_feats[i]).all()
#     assert (kb_tokens[:len(unproc_batch_kb_tokens[i]), i] == unproc_batch_kb_tokens[i]).all()

#     assert (dec_in_char_seq[:len(unproc_batch_dec_in_char_seq[i]), i] == unproc_batch_dec_in_char_seq[i]).all()
#     assert (dec_in_char_seq[len(unproc_batch_dec_in_char_seq[i]):, i] == PAD_CHAR_TOKEN).all()

#     assert (traj_pad_mask[i, :len(unproc_batch_traj_feats[i])] == False).all()
#     assert (traj_pad_mask[i, len(unproc_batch_traj_feats[i]):] == True).all()
    
#     assert (word_pad_mask[i, :len(unproc_batch_dec_in_char_seq[i])] == False).all()
#     assert (word_pad_mask[i, len(unproc_batch_dec_in_char_seq[i]):] == True).all()

# print("batch_x is correct")

In [None]:
# def move_all_to_device(x, device):
#     if torch.is_tensor(x):
#         return x.to(device)
#     elif not isinstance(x, (list, tuple)):
#         raise ValueError(f'Unexpected data type {type(x)}')
#     new_x = []
#     for el in x:
#         if not torch.is_tensor(el):
#             raise ValueError(f'Unexpected data type {type(el)}')
#         new_x.append(el.to(device))
#     return new_x

In [None]:
# TENSORBOARD_LOG_PATH = f"/kaggle/working/tensorboard_log/{EXPERIMENT_NAME}"

# tb = SummaryWriter(TENSORBOARD_LOG_PATH)
