In [None]:
import torch
import torchvision.models as models
import torch.nn as nn

NUM_CLASSES = 10 

print(f"Instantiating ResNet18 model structure for {NUM_CLASSES} classes...")

model = models.resnet18(weights=None)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, NUM_CLASSES)

print("Model structure is ready.")

In [4]:
from typing import Any, Optional

from torchvision.models import ResNet
from torchvision.models._utils import _ovewrite_named_param
from torchvision.models.resnet import BasicBlock, ResNet18_Weights


class Resnet18(ResNet):
    def __init__(self, num_classes, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any):
        self.num_classes = num_classes

        weights = ResNet18_Weights.verify(weights)

        if weights is not None:
            _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

        super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, **kwargs)

        if weights is not None:
            super().load_state_dict(weights.get_state_dict(progress=progress))                                                           

In [21]:
import torch
import torch.nn as nn
from torchvision import transforms


MODEL_PATH = '/home/k3s-server-07/federated_learning/workspace/example_project/prod_00/admin@nvidia.com/transfer/1752bb34-fe21-4b3f-b7ff-94e10e3a7cd0/workspace/app_server/FL_global_model.pt'
NUM_CLASSES = 10
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {DEVICE}")

model = Resnet18(num_classes=NUM_CLASSES)
model.to(DEVICE)

print(f"Instantiated custom Resnet18 for {NUM_CLASSES} classes.")

try:
    checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
    print(f'Checkpoint Keys: {checkpoint.keys()}')

    if 'model' in checkpoint:
        weights = checkpoint['model']
    elif 'model_state_dict' in checkpoint:
        weights = checkpoint['model_state_dict']
    else:
        # Assume the file is the state_dict itself
        weights = checkpoint

    # Load the weights into the model
    model.load_state_dict(weights)
    
    # Set to evaluation mode
    model.eval()

    print("\n✅ Successfully loaded federated model weights!")

except Exception as e:
    print(f"\n❌ An error occurred: {e}")
    print("Ensure 'resnet_18.py' is present and matches the training model.")


Using device: cuda
Instantiated custom Resnet18 for 10 classes.
Checkpoint Keys: odict_keys(['model', 'meta_props', 'train_conf'])

✅ Successfully loaded federated model weights!


In [22]:
print("Running sample inference...")

# preprocess will be used for pictures not in the cifar10
# dataset to resize to 32x32
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dummy_input = torch.randn(1, 3, 32, 32).to(DEVICE)

with torch.no_grad():
    # output = model(processed_img) 
    output = model(dummy_input)
    
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    predicted_class_index = torch.argmax(probabilities).item()

print(f"Model output shape: {output.shape}")
print(f"Predicted class index: {predicted_class_index}")

Running sample inference...
Model output shape: torch.Size([1, 10])
Predicted class index: 9


In [23]:
import torchvision
import torchvision.transforms as transforms

# Define transformations for the images
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# Load the training and test datasets
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

# Create DataLoaders for batching
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

In [24]:
correct = 0
total = 0

model.to(DEVICE)
model.eval()

print("Starting evaluation on test set...")

with torch.no_grad():
    for data in testloader:
        images, labels = data
        
        images, labels = images.to(DEVICE), labels.to(DEVICE)        
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'\nAccuracy of the network on the 10000 test images: {100 * correct / total} %')

Starting evaluation on test set...

Accuracy of the network on the 10000 test images: 78.3 %
