In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import os
import transformers
from accelerate import Accelerator

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
accelerator = Accelerator()
device = accelerator.device

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import ViTImageProcessor, ViTForImageClassification, ViTConfig
from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

pretrained_name = 'google/vit-base-patch16-224'
config = ViTConfig.from_pretrained(pretrained_name)
processor = ViTImageProcessor.from_pretrained(pretrained_name)
pred_model = ViTForImageClassification.from_pretrained(pretrained_name)
pred_model.to(device)

inputs = processor(images=image, return_tensors="pt")
inputs.to(device)
outputs = pred_model(**inputs, output_hidden_states=True)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", pred_model.config.id2label[predicted_class_idx])




Predicted class: Egyptian cat


In [36]:
from torchvision import models
import numpy as np

model = models.resnet18(pretrained=True)
model.to(device)
mean = torch.tensor([0.485, 0.456, 0.406], device=device).reshape(1,-1,1,1)
std = torch.tensor([0.229, 0.224, 0.225], device=device).reshape(1,-1,1,1)



In [37]:
input_tensor = inputs['pixel_values'] * std + mean
logits = model(input_tensor)
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", pred_model.config.id2label[predicted_class_idx])

Predicted class: bucket, pail


In [35]:
predicted_class_idx

600

In [28]:
inputs['pixel_values'].shape

torch.Size([1, 3, 224, 224])