In [3]:
# Load model directly
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification, AutoImageProcessor

processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
model = AutoModelForZeroShotImageClassification.from_pretrained("openai/clip-vit-large-patch14")
model.eval()

feature_extractor = model.get_image_features



In [2]:
import os 
from torchvision import transforms
import os
from torchvision import transforms
import matplotlib.pyplot as plt

from PIL import Image
dir = '../../../PHOENIX-2014-T-release-v3/PHOENIX-2014-T/features/fullFrame-210x260px/train/01December_2011_Thursday_tagesschau-3479'
# PyTorch transformation to resize and convert to tensor
image_paths = os.listdir(dir)
images = [] 
for path in image_paths: 
    actual_path = os.path.join(dir, path)
    
    # Load the image

    img = Image.open(actual_path).convert("RGB")  # Ensures image is in RGB format
    images.append(img)

processed_images = processor(images=images , return_tensors="pt")
print(processed_images.pixel_values.shape)
# output = model.get_image_features(**processed_images)
# feature_extractor = model.get_image_features
# print(feature_extractor)

torch.Size([38, 3, 224, 224])


In [4]:
feature_extractor(processed_images.pixel_values).shape

torch.Size([38, 768])

In [None]:
total_params = sum(p.numel() for p in model.parameters())
print(f'Total number of parameters in the model: {total_params}')

Total number of parameters in the model: 427616513


: 

## Creating S^2 wrapper


In [3]:
#  ------------------------------------------------------------------------------------------
#  Copyright (c) 2024 Baifeng Shi.
#  All rights reserved.
#
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------

import math
import torch
import torch.nn.functional as F
from einops import rearrange

# Load model directly
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification

processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
model = AutoModelForZeroShotImageClassification.from_pretrained("openai/clip-vit-large-patch14")
model.eval()
feature_extractor = model.get_image_features

def split_chessboard(x, num_split):
    """
        x: b * c * h * w
        Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension
        E.g: If num_split =2, x will be divided into 4 sub-squares, and the output will be a tensor of shape 4b * c * h/2 * w/2
    """
    B, C, H, W = x.shape
    assert H % num_split == 0 and W % num_split == 0
    h, w = H // num_split, W // num_split
    x_split = torch.cat([x[:, :, i*h:(i+1)*h, j*w:(j+1)*w] for i in range(num_split) for j in range(num_split)], dim=0)
    return x_split

def merge_chessboard(x, num_split):
    """
        x: b * c * h * w
        Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square.
        (inverse of split_chessboard)
    """
    B, C, H, W = x.shape
    assert B % (num_split**2) == 0
    b = B // (num_split**2)
    x_merge = torch.cat([torch.cat([x[(i*num_split + j)*b:(i*num_split + j + 1)*b] for j in range(num_split)], dim=-1)
                         for i in range(num_split)], dim=-2)
    return x_merge


def forward(
    model, 
    input, 
    scales=None, 
    img_sizes=None, 
    max_split_size=None, 
    resize_output_to_idx=0, 
    num_prefix_token=0,
    output_shape='bnc',
):
    assert input.dim() == 4, "Input image must be in the shape of BxCxHxW."
    assert input.shape[2] == input.shape[3], "Currently only square images are supported."
    assert output_shape in ['bnc', 'bchw'], "Output shape should be either BxNxC (e.g., ViT) or BxCxHxW (e.g., ConvNet)."
    assert output_shape == 'bnc' or num_prefix_token == 0, "For ConvNet there shouldn't be any prefix token."

    b, c, input_size, _ = input.shape

    # image size for each scale
    assert scales is not None or img_sizes is not None, "Please assign either scales or img_sizes."
    
    img_sizes = img_sizes or [int(input_size * scale) for scale in scales]
    # img_sizes is a list of sizes to cut the image up into
    # img_sizes should be bigger than 1 (best to be multiples of 2) or scales of 224
    
    # prepare multiscale inputs
    max_split_size = max_split_size or input_size   # Default = 224; The maximum size of each split of image. Set as the input size by default
    num_splits = [math.ceil(size / max_split_size) for size in img_sizes]   # number of splits each scale
    input_multiscale = []
    for size, num_split in zip(img_sizes, num_splits):
        print(f"size: {size}, num_split: {num_split}")
        x = F.interpolate(input.to(torch.float32), size=size, mode='bicubic').to(input.dtype) # resize the input image to the larger target size
        x = split_chessboard(x, num_split=num_split)
        input_multiscale.append(x)
    print(len(input_multiscale))
    print(input_multiscale[0].shape)    
    # run feedforward on each scale
    outs_multiscale = [model(x) for x in input_multiscale]
    if num_prefix_token > 0:
        outs_prefix_multiscale = [out[:, :num_prefix_token] for out in outs_multiscale]
        outs_multiscale = [out[:, num_prefix_token:] for out in outs_multiscale]
    if output_shape == 'bnc':
        outs_multiscale = [rearrange(out, 'b (h w) c -> b c h w', h=int(out.shape[1] ** 0.5), w=int(out.shape[1] ** 0.5))
                           for out in outs_multiscale]
    
    # merge outputs of different splits for each scale separately
    outs_multiscale = [merge_chessboard(out, num_split=num_split) for num_split, out in zip(num_splits, outs_multiscale)]
    
    # interpolate outputs from different scales and concat together
    output_size = outs_multiscale[resize_output_to_idx].shape[-2]
    out = torch.cat([F.interpolate(outs_multiscale[i].to(torch.float32), size=output_size,
                                   mode='area').to(outs_multiscale[i].dtype)
                     for i in range(len(outs_multiscale))], dim=1)
    
    if output_shape == 'bnc':
        out = rearrange(out, 'b c h w -> b (h w) c')
    if num_prefix_token > 0:
        # take the mean of prefix tokens from different splits for each scale
        outs_prefix_multiscale = [torch.stack(out.split(b, dim=0), dim=0).mean(dim=0) for out in outs_prefix_multiscale]
        out_prefix_multiscale = torch.cat(outs_prefix_multiscale, dim=-1)
        out = torch.cat([out_prefix_multiscale, out], dim=1)
    
    return out




In [10]:
from s2wrapper import * 
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
img_sizes = 224

out = forward(feature_extractor, processed_images.pixel_values.to("mps"),img_sizes=[448])

size: 448, num_split: 2


NotImplementedError: The operator 'aten::upsample_bicubic2d.out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

In [None]:
out