In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
# Added
from FovConvNeXt.models import make_model

In [None]:
dataset = datasets.ImageFolder(
    'dataset',
    transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
)

In [None]:
test, rest = torch.utils.data.random_split(dataset, [10, len(dataset) - 10])

In [None]:
loader = torch.utils.data.DataLoader(
    test,
    batch_size=10,
    shuffle=False,
    num_workers=0,
)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
    img = img.clone()
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img = img * std + mean
    npimg = img.numpy()
    npimg = np.transpose(npimg, (1, 2, 0))
    plt.imshow(npimg)
    plt.show()


In [None]:
thems = iter(loader)
images, labels = next(thems)
imshow(torchvision.utils.make_grid(images, nrow=5))
print(' '.join(f'{dataset.classes[labels[j]]:5s}' for j in range(10)))


In [None]:
# Unfoveated parameters
n_fixations = 1
radius = 0
block_sigma = 0.8
block_max_ord = 4
patch_sigma = 1.0
patch_max_ord = 4
ds_sigma = 0.6
ds_max_ord = 0

unfovea = make_model(
        n_fixations=n_fixations,
        n_classes=100,  # Use full 100 classes
        radius=radius,
        block_sigma=block_sigma,
        block_max_ord=block_max_ord,
        patch_sigma=patch_sigma,
        patch_max_ord=patch_max_ord,
        ds_sigma=ds_sigma,
        ds_max_ord=ds_max_ord
    )
unfovea.load_state_dict(torch.load('best_model_unfoveated.pth')['model_state_dict'])

In [None]:
# Foveated parameters
n_fixations = 1
radius = 0.4
block_sigma = 0.8
block_max_ord = 4
patch_sigma = 1.0
patch_max_ord = 4
ds_sigma = 0.6
ds_max_ord = 0

fovea = make_model(
        n_fixations=n_fixations,
        n_classes=100,  # Use full 100 classes
        radius=radius,
        block_sigma=block_sigma,
        block_max_ord=block_max_ord,
        patch_sigma=patch_sigma,
        patch_max_ord=patch_max_ord,
        ds_sigma=ds_sigma,
        ds_max_ord=ds_max_ord
    )
fovea.load_state_dict(torch.load('best_model_foveated.pth')['model_state_dict'])

In [None]:
import timeit


def fovea_run():
    fovea.eval()
    with torch.no_grad():
        start = timeit.default_timer()
        output = fovea(images)
        end = timeit.default_timer()
        print(f"Foveated inference time: {end - start:.4f} seconds")
        _, predicted = torch.max(output, 1)
        accuracy = (predicted == labels).sum().item() / len(labels)
        print(f"Foveated model accuracy: {accuracy:.2%}")
        print('Predicted:', ' '.join(f'{dataset.classes[predicted[j]]:5s}' for j in range(10)))

def unfovea_run():
    unfovea.eval()
    with torch.no_grad():
        start = timeit.default_timer()
        output = unfovea(images)
        end = timeit.default_timer()
        print(f"Unfoveated inference time: {end - start:.4f} seconds")
        _, predicted = torch.max(output, 1)
        accuracy = (predicted == labels).sum().item() / len(labels)
        print(f"Unfoveated model accuracy: {accuracy:.2%}")
        print('Predicted:', ' '.join(f'{dataset.classes[predicted[j]]:5s}' for j in range(10)))

In [None]:
fovea_run()

In [None]:
unfovea_run()