In [5]:
from PIL import Image
from torchvision.transforms import v2 , transforms
import torch
img_path = 'styles/style (20).jpg'
image_transforms = transforms.Compose([
        v2.Resize(size = (720,720)),
        v2.CenterCrop(((720,720))),
        transforms.ToTensor(),
    ])

In [6]:
def load_image(img_path: str, 
               image_transforms: transforms.Compose) -> torch.Tensor:
    """
    Loads an image from the specified file path, applies the given transformations,
    and returns it as a tensor suitable for input to a PyTorch model.

    Args:
        img_path (str): The file path to the image to be loaded.
        image_transforms (transforms.Compose): A composition of transformations to be applied to the image.

    Returns:
        torch.Tensor: The transformed image as a tensor, with an added batch dimension, and moved to the specified device.

    Example:
        image_transforms = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        image_tensor = load_image('path/to/image.jpg', image_transforms)
    """
    image = Image.open(img_path)
    return image_transforms(image).unsqueeze(0)

In [10]:
img = load_image(img_path, image_transforms)
img.shape

torch.Size([1, 3, 720, 720])