In [22]:
import torch
import pickle

class RPLANDataset(torch.utils.data.Dataset):
    def __init__(self, path):
        print('Loading Data...')
        
        with open(path, 'rb') as f:
            dataset = pickle.load(f)
        self.param = dataset['param']
        self.seq_length = dataset['seq_mask']
        self.mask_length = dataset['ignore_mask']

    def __len__(self):
        return len(self.param)

    def __getitem__(self, idx):
        param_sample = torch.tensor(self.param[idx], dtype=torch.int)
        seq_length_sample = torch.tensor(self.seq_length[idx], dtype=torch.bool)
        mask_length_sample = torch.tensor(self.mask_length[idx], dtype=torch.bool)

        return param_sample, seq_length_sample, mask_length_sample

In [23]:
import os

# 查看当前的工作目录
current_directory = os.getcwd()
print("当前工作目录是：", current_directory)


当前工作目录是： /root/new/hnc-cad/codebook


In [24]:
dataset = RPLANDataset('test.pkl')

Loading Data...


In [25]:
dataloader = torch.utils.data.DataLoader(dataset, 
                                             shuffle=False, 
                                             batch_size=256,
                                             num_workers=0)

In [26]:
from model.encoder import RPLANEncoder
from model.decoder import RPLANDecoder
encoder = RPLANEncoder()
encoder.load_state_dict(torch.load(f'../proj_log/RPLAN/enc_epoch_500.pt'))
encoder = encoder.cuda().eval()
decoder = RPLANDecoder()
decoder.load_state_dict(torch.load(f'../proj_log/RPLAN/dec_epoch_500.pt'))
decoder = decoder.cuda().eval()

In [27]:
result = []

for param, seq_mask, ignore_mask in dataloader:
    param = param.cuda()
    seq_mask = seq_mask.cuda()
    ignore_mask = ignore_mask.cuda()
    latent_code, vq_loss, selection, _ = encoder(((ignore_mask == False) * param), seq_mask)
    param_logits = decoder(((ignore_mask == False) * param), seq_mask, ignore_mask, latent_code)
    max_values, max_indices = torch.max(param_logits, dim=3)
    
    for i in range(ignore_mask.shape[0]):
        matrix = max_indices[i] * (ignore_mask[i])
        non_zero_rows = matrix[~torch.all(matrix == 0, axis=1)]
        result.append(non_zero_rows.cpu().detach().numpy())

In [28]:
import pickle


# 将 result 对象保存到文件 'result.pkl'
with open('result.pkl', 'wb') as f:
    pickle.dump(result, f)
