In [None]:
### 这里是如何调用模型
import torch
import timm
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F

def load_model(model_path, num_classes, device='cuda'):
    """
    Load the pre-trained Swin model from the saved weights.
    
    :param model_path: Path to the saved .pth file containing the model weights.
    :param num_classes: Number of output classes for the model.
    :param device: Device to run the model on ('cuda' or 'cpu').
    :return: Loaded model
    """
    # Initialize the model
    model_swin = timm.create_model(
        'swin_tiny_patch4_window7_224',
        pretrained=True,
        num_classes=num_classes,
        img_size=100
    ).to(device)
    
    # Load the model weights
    model_swin.load_state_dict(torch.load(model_path, map_location=device))
    
    # Set the model to evaluation mode
    model_swin.eval()
    
    return model_swin

In [None]:
def predict_image_with_softmax(model, image_path, device='cuda'):
    """
    Run a prediction on a given image using the loaded model and return softmax probabilities.
    
    :param model: The model to make predictions with.
    :param image_path: Path to the image you want to classify.
    :param device: Device to run the model on ('cuda' or 'cpu').
    :return: Softmax probabilities for each class.the shape of output is (77,)
    """
    # Define the transformation to be applied to the image
    transform = transforms.Compose([
        transforms.Resize(100),  # Resize image to 224x224
        transforms.ToTensor(),   # Convert image to tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
    ])
    
    # Open the image and apply the transformations
    img = Image.open(image_path).convert('RGB')
    img = transform(img).unsqueeze(0).to(device)  # Add batch dimension and move to device
    
    # Run inference
    with torch.no_grad():
        output = model(img)
        
        # Apply softmax to get probabilities
        softmax_output = F.softmax(output, dim=1)
    
    return softmax_output.squeeze().cpu().numpy()  #### the shape of output is (77,)