In [None]:
%autoreload 2
import sys
sys.path.append('..')
import torch
import collections
from torch.utils.data import DataLoader
from torchvision import transforms

import tqdm

import pytorch_lightning as pl
import json
import pandas as pd

from ecgnet.utils.transforms import ToTensor, ApplyGain, Resample

from EcgCaptionGenerator.utils.dataset import collate_fn, CaptionDataset
from EcgCaptionGenerator.utils.pycocoevalcap.eval import COCOEvalCap


from EcgCaptionGenerator.systems.top_down_attention_lstm import TopDownLSTM

# from EcgCaptionGenerator.systems.topic_unchanged_decoder import TopicSimDecoder
from topic_transformer import TopicTransformer

# from EcgCaptionGenerator.systems.transformer import Transformer
from util import get_loaders, get_loaders_toy_data, FakeDataset
from utils_model import beam_search



In [None]:
basedir = './training/captioning/models/'

use_topic = True
checkpoint_loc, param_file ="./training/transformertopic/models/tansformertopic/TRAN1-2/checkpoints/epoch=16-step=70464.ckpt", 'config_transformer_muse.json' # 6.08
# checkpoint_loc, param_file ="./training/transformertopic/models/tansformertopic/TRAN1-9/checkpoints/epoch=14-step=11549.ckpt", 'config_transformer_consult_topic.json' # 6.08



In [None]:
pl.seed_everything(1234)
params = json.load(open(param_file, 'r'))

transform = transforms.Compose([Resample(500), ToTensor(), ApplyGain()])

model = TopicTransformer.load_from_checkpoint(checkpoint_path=checkpoint_loc).cuda()

threshold, is_train, vocab = 0, False, model.vocab

testset_df = pd.read_csv(params['test_labels_csv'], index_col=0)
testset = FakeDataset(100, use_topic, vocab, transform=transform)

gts = testset_df.apply(lambda x: {x['TestID']: [x['Label']]}, axis=1).to_list()
gts = {list(dict_item.keys())[0]: list(dict_item.values())[0][0] for dict_item in gts}
test_loader = DataLoader(testset, batch_size=64,
                            num_workers=4, collate_fn=collate_fn)
# max_length=50
# model.eval()

In [None]:
# For transformer topic
sample_method = {'temp':None, 'k':None, 'p':None, 'greedy':True, 'm':None}
max_length = 50

gts = {}
res = {}
for batch_idx, batch in enumerate(tqdm.tqdm(test_loader)):
    waveforms, _, _, ids, targets, _, topic = batch
    words = model.sample(waveforms, sample_method, max_length)

    generated = model.vocab.decode(words, skip_first=False)
    truth = model.vocab.decode(targets)
    for i in range(waveforms.shape[0]):
        res[ids[i]] = [generated[i]]
        gts[ids[i]] = [truth[i]]
#         print(res, gts)
    # save generations
#     pd.DataFrame(gts).to_csv(checkpoint_loc[:-5] + 'gts_.csv')
#     pd.DataFrame(res).to_csv(checkpoint_loc[:-5] + 'res_.csv')

COCOEval = COCOEvalCap()
COCOEval.evaluate(gts, res)
#     print(sample_method, COCOEval.eval)
print(sample_method, COCOEval.eval)
# print(gts, res)