In [5]:
#| default_exp pil_to_tensor

In [6]:
#| export
import os
from torchvision import transforms

In [7]:
#| export
def pil_to_tensor(image, resize_size=(224, 224), normalize=True):
    """
    Transform a PIL image into a tensor with optional resizing and normalization.

    Args:
        image (PIL.Image.Image): The input PIL image to be transformed.
        resize_size (tuple): The target size for resizing the image. Default is (224, 224).
        normalize (bool): Whether to normalize the tensor. Default is True.

    Returns:
        torch.Tensor: The transformed image as a tensor.

    Usage:
    - Call this function to convert a PIL image into a tensor, optionally resizing and normalizing it.
    - By default, the function resizes the image to (224, 224) and applies mean and standard deviation normalization.
    """
    # Define the transformation pipeline
    transform = transforms.Compose([
        transforms.Resize(resize_size),
        transforms.ToTensor(),
    ])
    
    if normalize:
        transform = transforms.Compose([
            transform,
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    # Apply the transformation to the input image
    tensor_image = transform(image)
    
    return tensor_image


In [8]:
#| hide
# os.chdir("/project/validating_attribution_techniques/shardul/api_notebooks/")
from nbdev.export import nb_export
nb_export('pil_to_tensor.ipynb', '/project/validating_attribution_techniques/commons/api/')