In [1]:
import torch
from torch.utils.data import DataLoader
import os
from tqdm import tqdm

from networks.model import MainNet
from datasets.dataset import CompetitionDataset
import warnings
warnings.filterwarnings('ignore')
from config import proposalN, num_classes, channels
from utils.second_eval import eval


In [2]:
def load_model(checkpoint_path, model_class, device):
    model = MainNet(proposalN, num_classes ,channels)
    checkpoint = torch.load(checkpoint_path)
    state_dict = checkpoint['model_state_dict']
    # Remove the final layer's parameters
    model.load_state_dict(state_dict)

    model.to(device)
    model.eval()
    return model

def create_testloader(data_dir, input_size, batch_size):
    test_dataset = CompetitionDataset(input_size=input_size, root=data_dir, is_train=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return test_loader

def main():
    # Set paths and parameters
    data_dir = 'datasets/CompetitionData'
    checkpoint_path = 'models/epoch5.pth'
    input_size = 448  # Example input size
    batch_size = 32
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    epoch = 0  # Example epoch value, adjust as needed

    # Load the trained model
    model = load_model(checkpoint_path, MainNet, device)

    # Create test DataLoader
    testloader = create_testloader(data_dir, input_size, batch_size)

    # Evaluate the model on test data
    predictions = eval(model, testloader, epoch)

    # Print or save the predictions
    for i, pred in enumerate(predictions):
        print(f'Image {i}: Prediction {pred}')
    return predictions
    

if __name__ == '__main__':
    a = main()


  0%|          | 0/1 [00:00<?, ?it/s]

Evaluating


100%|██████████| 1/1 [00:00<00:00,  1.04it/s]

Image 0: Prediction [4]
Image 1: Prediction [4]
Image 2: Prediction [4]
Image 3: Prediction [4]
Image 4: Prediction [4]
Image 5: Prediction [1]
Image 6: Prediction [4]
Image 7: Prediction [4]
Image 8: Prediction [4]
Image 9: Prediction [2]
Image 10: Prediction [2]
Image 11: Prediction [4]
Image 12: Prediction [2]
Image 13: Prediction [2]
Image 14: Prediction [4]
Image 15: Prediction [4]



