In [None]:
from transformers import DeiTForImageClassification, DeiTImageProcessor
import torch
from thop import profile
from datasets import load_dataset
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)


_, test_ds = load_dataset("imagenet-1k", split=['train', 'validation'])
key_to_get_image = 'image'

processor = DeiTImageProcessor.from_pretrained("/home/dhruv/pruning_vit/pruning/trained-models/deit-small-baseline-imagenet-1k-20240110-014420/best-baseline-loaded-at-end")

image_mean, image_std = processor.image_mean, processor.image_std                 
size = processor.crop_size["height"]                                             
normalize = Normalize(mean=image_mean, std=image_std)  

_val_transforms = Compose(
[
    Resize(size),
    CenterCrop(size),                                             
    ToTensor(),
    normalize,
]
)

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples[key_to_get_image]]
    return examples

test_ds.set_transform(val_transforms)


def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}


inputs = collate_fn([test_ds[i] for i in range(1)])
print(inputs['pixel_values'].shape)

from inference_patchers_blockprune import optimize_model_deit

model = optimize_model_deit(DeiTForImageClassification.from_pretrained(pretrained_model_name_or_path="trained-models/imagenet-1k_deit_drop_tokens_variant_i_small/epochs_11_blockPruningInfo_finalThreshold_0.5_blockSize_32_method_topK_tokenDropInfo_layerCount_3_layerType_default_keepRate_0.9_fused_True_20231230-020027/fine-pruned-MASKED"), mode="dense")
macs, params = profile(model, inputs=(inputs['pixel_values'],))

print(macs/10**9, params/10**6)