In [5]:
from transformers import ConvNextImageProcessor, ConvNextModel
import numpy as np
import torch

In [3]:
convnext_processor = ConvNextImageProcessor.from_pretrained("facebook/convnext-base-224-22k")
convnext_model = ConvNextModel.from_pretrained("facebook/convnext-base-224-22k").to("cuda")

In [8]:
import os
files = sorted(os.listdir("datasets/scale-1_1/unsplash2000_ball"))

In [6]:
import torch
from tqdm.auto import tqdm
from PIL import Image

In [11]:
with torch.no_grad():
    for filename in tqdm(files):
        image = Image.open(f"datasets/scale-1_1/unsplash2000_ball/{filename}/image.png")
        convnext_inputs = convnext_processor(image, return_tensors="pt")
        convnext_inputs = {k: v.to("cuda") for k, v in convnext_inputs.items()}
        convnext_outputs = convnext_model(**convnext_inputs)
        convnext_last_hidden_states = convnext_outputs.last_hidden_state
        output_dir = f"datasets/scale-1_1/unsplash2000_convnext/{filename}"
        os.makedirs(output_dir, exist_ok=True)
        torch.save(convnext_last_hidden_states.cpu(), f"{output_dir}/last_hidden_states.pt")

  0%|          | 0/2000 [00:00<?, ?it/s]

In [12]:
import torchvision
a = torchvision.io.read_image("datasets/scale-1_1/unsplash2000_ball/-0.00019737/image.png")

# process small studio

In [15]:
with torch.no_grad():
    for idx, filename in enumerate(tqdm(np.linspace(-1,1,360))):
        filename = f"{filename:.8f}"
        image = Image.open(f"datasets/rotate_studio_ball/{idx:03d}.png")
        inputs = convnext_processor(images=image, return_tensors="pt", padding=True)
        inputs = {k: v.to("cuda") for k, v in inputs.items()}
        outputs = convnext_model(**inputs)
        convnext_last_hidden_states = outputs.last_hidden_state
        output_dir = f"datasets/scale-1_1/rotate_studio_convnext/{filename}"
        os.makedirs(output_dir, exist_ok=True)
        torch.save(convnext_last_hidden_states.cpu(), f"{output_dir}/last_hidden_states.pt")

  0%|          | 0/360 [00:00<?, ?it/s]

In [14]:
b = a / 255.0

In [16]:
b.shape

torch.Size([3, 256, 256])

In [11]:
a = torch.load("datasets/scale-1_1/unsplash2000_convnext/0.84577371/last_hidden_states.pt")

In [12]:
a.shape

torch.Size([1, 1024, 7, 7])

In [13]:
b = torch.load("datasets/scale-1_1/rotate_studio_convnext/-0.00278552/last_hidden_states.pt")

In [14]:
b.shape

torch.Size([1, 1024])