# Testing Models

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### Loading Pretrained Models

In [None]:
from algo_reasoning.src.models.network import EncodeProcessDecode
from algo_reasoning.src.lightning.AlgorithmicReasoningTask import AlgorithmicReasoningTask 
from algo_reasoning.src.specs import CLRS_30_ALGS
from algo_reasoning.src.losses.AlgorithmicReasoningLoss import AlgorithmicReasoningLoss

In [None]:
ckpt_path = "../checkpoints/insertion_sort/insertion_sort-epoch=88-val_loss=0.10.ckpt"

model = EncodeProcessDecode(["insertion_sort"])
loss_fn = AlgorithmicReasoningLoss()

model_hidden = AlgorithmicReasoningTask.load_from_checkpoint(ckpt_path, model=model, loss_fn=loss_fn).model

In [None]:
ckpt_path = "../checkpoints/insertion_sort/insertion_sort-epoch=96-val_loss=0.07.ckpt"

model = EncodeProcessDecode(["insertion_sort"])
loss_fn = AlgorithmicReasoningLoss()

model_nohidden = AlgorithmicReasoningTask.load_from_checkpoint(ckpt_path, model=model, loss_fn=loss_fn).model

In [None]:
model = EncodeProcessDecode(["insertion_sort"])

### Load Dataset

In [None]:
from algo_reasoning.src.data import OriginalCLRSDataset, CLRSSampler, collate
from torch.utils.data import DataLoader
algorithms = ["insertion_sort"]

test_dataset = OriginalCLRSDataset(algorithms, "val", "../tmp/CLRS30")
test_sampler = CLRSSampler(test_dataset, algorithms=algorithms, batch_size=32)
test_dataloader = DataLoader(test_dataset, batch_sampler=test_sampler, num_workers=0, collate_fn=collate)

obj = next(iter(test_dataloader))


In [None]:
output_hidden = model_hidden(obj)

In [None]:
output_nohidden = model_nohidden(obj)

In [None]:
output_random = model(obj)

In [None]:
obj.outputs.pred

In [None]:
torch.argmax(output_hidden.output.outputs.pred, dim=-1)

In [None]:
torch.argmax(output_nohidden.output.outputs.pred, dim=-1)

In [None]:
torch.argmax(output_random.output.outputs.pred, dim=-1)

In [None]:
from algo_reasoning.src.eval import eval_function

eval_function(output_hidden.output, obj, average="micro")

In [None]:
eval_function(output_nohidden.output, obj, average="micro")

In [None]:
eval_function(output_random.output, obj)