# Sample Notebook to test trained transformer (Chignolin) and generated components

In [2]:
import os
import torch
import torch.nn as nn

import numpy as np
from utils import Transformer
torch.set_printoptions(sci_mode=False)

import matplotlib.pyplot as plt

In [3]:
device = torch.device("cuda:0")

In [5]:
root_network_folder_name = "./ChignolinGMMTransformerTraining/"
root_dataset_folder_name = "./ChignolinGMMTransformerDataset/"

all_num_components_by_res_num = np.load(f"{root_dataset_folder_name}all_num_components.npy", allow_pickle=True).item()


In [7]:
dim_model = 256
dropout_p = 0.1
kt_cutoff = -50
beta_target = 1.0
epoch_num = 14

network_folder_name = f"{root_network_folder_name}dropout_p_{dropout_p}_dim_model_{dim_model}_kt_cutoff_{kt_cutoff}/"

t = Transformer(num_tokens_src=7, num_tokens_tgt=67,
                dim_model=dim_model, num_heads=8, num_encoder_layers=6,
                num_decoder_layers=6, dropout_p=dropout_p).to(device)
t.load_state_dict(torch.load(f"{network_folder_name}t_{epoch_num}.pt"))
t.eval()
print("")


dataset_tag = f"prop_temp_300.0_dt_0.001_num_steps_5_cutoff_to_use_kt_{kt_cutoff}"
dataset_folder_name = f"{root_dataset_folder_name}{dataset_tag}/"
all_backbone_indices = os.listdir(f"{dataset_folder_name}/test")
all_backbone_indices = np.unique([f.split("_")[0] 
                                  for f in all_backbone_indices])


save_folder_name = f"{dataset_tag}/test/"
os.makedirs(save_folder_name, exist_ok=True)

test_all_src_cat = torch.tensor(np.load(f"{dataset_folder_name}test_all_src_cat.npy"), 
                                 device=device)[0:100000]
all_pred_sizes = []
for backbone_index in all_backbone_indices:
    all_c_info = np.load(f"{dataset_folder_name}test/{backbone_index}_c_info.npy")
    all_src_cont = np.stack((np.sin(all_c_info[:, :, 0]),
                             np.cos(all_c_info[:, :, 0]),
                             np.sin(all_c_info[:, :, 1]),
                             np.cos(all_c_info[:, :, 1])), axis=-1)
    all_src_cont = torch.tensor(all_src_cont, device=device).float()
    all_src_cont = all_src_cont.repeat((100, 1, 1))
    all_src_cat = test_all_src_cat[0:all_src_cont.shape[0]]
    
    with torch.no_grad():
        pred_components = torch.tensor([[65]],
                                       device=device).repeat((all_src_cat.shape[0], 1)).long()
        for _ in range(13):
            tgt_mask = t.get_tgt_mask(pred_components.shape[1]).to(device)
            pred = t(src_cont=all_src_cont,
                    src_cat = all_src_cat,
                    tgt=pred_components, tgt_mask=tgt_mask)[:, -1, :]
            pred = nn.Softmax(dim=-1)(pred)
            next_item = torch.multinomial(pred, 1, replacement=True)
            pred_components = torch.cat((pred_components, 
                                        next_item), dim=-1)
        correct_sequences = ((~torch.any(pred_components[:, :13] == 66, dim=-1))
                            & (pred_components[:, -1] == 66))

        pred_components = pred_components[correct_sequences][:, 1:-1]
        
        all_allowed = torch.ones_like(pred_components[:, 0]).bool()
        for (res_num, num_components) in all_num_components_by_res_num.items():
            all_allowed = all_allowed * (pred_components[:, res_num] < num_components)

        pred_components_filtered = pred_components[all_allowed]
        for (res_num, num_components) in all_num_components_by_res_num.items():
            assert torch.max(pred_components_filtered[:, res_num]) < num_components
    
    torch.save(pred_components_filtered, f"{save_folder_name}{backbone_index}_pred_components.pt")
    all_pred_sizes.append(pred_components_filtered.shape[0])





In [10]:
np.save("./all_pred_sizes.npy", np.array(all_pred_sizes))