# Deep Learning for Business Applications course

## TOPIC 3: Computer Vision advanced. SotA model example

### 1. Libraries

In [None]:
!pip install transformers

In [None]:
import requests
import numpy as np
import matplotlib.pyplot as plt
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image

### 2. Looking ahead - ViT

Let's look at the [Vision Transformer (ViT)](https://huggingface.co/google/vit-base-patch16-224) that is a transformer encoder model (BERT-like) pretrained on a large collection of images in a supervised fashion, namely ImageNet-21k, at a resolution of 224x224 pixels. Next, the model was fine-tuned on ImageNet (also referred to as ILSVRC2012), a dataset comprising 1 million images and 1,000 classes, also at resolution 224x224.

In [None]:
MODEL_NAME = 'google/vit-base-patch16-224'
processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
model = ViTForImageClassification.from_pretrained(MODEL_NAME)

In [None]:
def classify_and_plot_coco_images(
    idxs, processor, model, 
    rows=None, cols=None, figsize=(12, 8)
):
    """
    Classify and plot multiple images in a grid layout with titles.
    
    Parameters:
    - idxs: Indexes of images from COCO dataset
    - processor: Image processor for the model
    - model: Model to classify omages
    - rows: Number of rows in the grid (optional - will be calculated if not provided)
    - cols: Number of columns in the grid (optional - will be calculated if not provided)
    - figsize: Tuple specifying the figure size (width, height)
    """
    
    n_images = len(idxs)
    
    if rows is None and cols is None:
        cols = int(np.ceil(np.sqrt(n_images)))
        rows = int(np.ceil(n_images / cols))
    elif rows is None:
        rows = int(np.ceil(n_images / cols))
    elif cols is None:
        cols = int(np.ceil(n_images / rows))

    # Use ViT model classifier to title an image
    images, titles = [], []
    for idx in idxs:
        url = f'http://images.cocodataset.org/{idx}.jpg'
        image = Image.open(requests.get(url, stream=True).raw)
        inputs = processor(images=image, return_tensors="pt")
        outputs = model(**inputs)
        logits = outputs.logits
        # model predicts one of the 1000 ImageNet classes
        predicted_class_idx = logits.argmax(-1).item()
        title = model.config.id2label[predicted_class_idx]
        images.append(image)
        titles.append(title)
        print('Predicted class:', title)
    
    # Create default titles if not provided
    if titles is None:
        titles = [f'Image {i+1}' for i in range(n_images)]
    elif len(titles) != n_images:
        # Extend or truncate titles to match number of images
        titles = titles[:n_images] + [f'Image {i+1}' for i in range(len(titles), n_images)]
    
    # Create the figure and subplots
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    
    # If there's only one row or column, axes might not be a 2D array
    if rows == 1 and cols == 1:
        axes = np.array([[axes]])
    elif rows == 1:
        axes = axes.reshape(1, -1)
    elif cols == 1:
        axes = axes.reshape(-1, 1)
    
    # Flatten the axes array for easier indexing
    axes_flat = axes.flatten()
    
    # Plot each image
    for i, (img, title) in enumerate(zip(images, titles)):
        ax = axes_flat[i]
        
        # Handle different image types
        if isinstance(img, str):  # File path
            img_data = plt.imread(img)
            ax.imshow(img_data)
        elif isinstance(img, Image.Image):  # PIL Image
            ax.imshow(np.array(img))
        else:  # Assume it's a numpy array or compatible
            ax.imshow(img)
        
        ax.set_title(title, fontsize=10)
        ax.axis('off')  # Hide axes
    
    # Hide any unused subplots
    for i in range(n_images, len(axes_flat)):
        axes_flat[i].axis('off')
    
    plt.tight_layout()
    plt.show()

### 3. Result

You will see that it may give errors with a hard cases with many objects at the picture. That is why will move to `Object detection` problem.

In [None]:
idxs = [
    'test-stuff2017/000000027294',
    'test-stuff2017/000000027303',
    'test-stuff2017/000000027155',
    'test-stuff2017/000000027049',
    'test-stuff2017/000000027057',
    'test-stuff2017/000000027833',
    'test-stuff2017/000000022859',
    'test-stuff2017/000000022895',
    'test-stuff2017/000000023209'
]

classify_and_plot_coco_images(
    idxs=idxs, 
    processor=processor, 
    model=model
)