In [None]:
! rm -rf *

In [None]:
from google.colab import files
files.upload()

In [None]:
! pip install -q kaggle
! mkdir -p ~/.kaggle
! mv kaggle.json ~/.kaggle/ 
! chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d kwentar/blur-dataset
!unzip -qq blur-dataset.zip -d ./

<h2> Before moving to images folder, Augment ! </h2>

In [None]:
import os
import cv2

sharp_folder = "sharp"
rotated_folder = "rotated_sharp"

# Create rotated subdirectories
os.makedirs(os.path.join(rotated_folder, "No_rotation"), exist_ok=True)
os.makedirs(os.path.join(rotated_folder, "90_degrees_clockwise"), exist_ok=True)
os.makedirs(os.path.join(rotated_folder, "90_degrees_counterclockwise"), exist_ok=True)
os.makedirs(os.path.join(rotated_folder, "180_degrees"), exist_ok=True)

# Rotate images and save to subdirectories
for img_filename in os.listdir(sharp_folder):
    img = cv2.imread(os.path.join(sharp_folder, img_filename))
    if img is None:
        continue  # skip if unable to read image
    
    img_no_rot = img.copy()
    img_90cw = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
    img_90ccw = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
    img_180 = cv2.rotate(img, cv2.ROTATE_180)

    # Save images to appropriate subdirectory
    cv2.imwrite(os.path.join(rotated_folder, "No_rotation", img_filename), img_no_rot)
    cv2.imwrite(os.path.join(rotated_folder, "90_degrees_clockwise", img_filename), img_90cw)
    cv2.imwrite(os.path.join(rotated_folder, "90_degrees_counterclockwise", img_filename), img_90ccw)
    cv2.imwrite(os.path.join(rotated_folder, "180_degrees", img_filename), img_180)


In [None]:
import shutil
os.mkdir("images")
src = '/content/rotated_sharp'
dst = '/content/images'

for sub_dir in os.listdir(src):
    sub_dir_path = os.path.join(src, sub_dir)
    if os.path.isdir(sub_dir_path):
        shutil.move(sub_dir_path, dst)


In [None]:
%%capture

! pip install transformers pytorch-lightning --quiet
! sudo apt -qq install git-lfs
! git config --global credential.helper store

In [None]:
import requests
import math
import matplotlib.pyplot as plt
import shutil
from getpass import getpass
from PIL import Image, UnidentifiedImageError
from requests.exceptions import HTTPError
from io import BytesIO
from pathlib import Path
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from torchvision.datasets import ImageFolder
from transformers import ViTFeatureExtractor, ViTForImageClassification

## Init Dataset and Split into Training and Validation Sets


In [None]:
data_dir = Path('images')

In [None]:
ds = ImageFolder(data_dir)
indices = torch.randperm(len(ds)).tolist()
n_val = math.floor(len(indices) * .15)
train_ds = torch.utils.data.Subset(ds, indices[:-n_val])
val_ds = torch.utils.data.Subset(ds, indices[-n_val:])

In [None]:
plt.figure(figsize=(20,10))
num_examples_per_class = 5
i = 1
for class_idx, class_name in enumerate(ds.classes):
    folder = ds.root / class_name
    for image_idx, image_path in enumerate(sorted(folder.glob('*'))):
        if image_path.suffix in ds.extensions:
            image = Image.open(image_path)
            plt.subplot(len(ds.classes), num_examples_per_class, i)
            ax = plt.gca()
            ax.set_title(
                class_name,
                size='xx-large',
                pad=5,
                loc='left',
                y=0,
                backgroundcolor='white'
            )
            ax.axis('off')
            plt.imshow(image)
            i += 1

            if image_idx + 1 == num_examples_per_class:
                break

## Preparing Labels for Our Model's Config

By adding `label2id` + `id2label` to our model's config, we'll get friendlier labels in the inference API.

In [None]:
label2id = {}
id2label = {}

for i, class_name in enumerate(ds.classes):
    label2id[class_name] = str(i)
    id2label[str(i)] = class_name

## Image Classification Collator

To apply our transforms to images, we'll use a custom collator class. We'll initialize it using an instance of `ViTFeatureExtractor` and pass the collator instance to `torch.utils.data.DataLoader`'s `collate_fn` kwarg.

In [None]:
class ImageClassificationCollator:
    def __init__(self, feature_extractor):
        self.feature_extractor = feature_extractor
 
    def __call__(self, batch):
        encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
        encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.long)
        return encodings 

## Init Feature Extractor, Model, Data Loaders


In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=len(label2id),
    label2id=label2id,
    id2label=id2label
)
collator = ImageClassificationCollator(feature_extractor)
train_loader = DataLoader(train_ds, batch_size=8, collate_fn=collator, num_workers=2, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=8, collate_fn=collator, num_workers=2)

# Training

⚡ We'll use [PyTorch Lightning](https://pytorchlightning.ai/) to fine-tune our model.

In [None]:
class Classifier(pl.LightningModule):

    def __init__(self, model, lr: float = 2e-5, **kwargs):
        super().__init__()
        self.save_hyperparameters('lr', *list(kwargs))
        self.model = model
        self.forward = self.model.forward
        self.val_acc = Accuracy(
            task='multiclass' if model.config.num_labels > 2 else 'binary',
            num_classes=model.config.num_labels
        )

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        self.log(f"train_loss", outputs.loss)
        return outputs.loss

    def validation_step(self, batch, batch_idx):
        outputs = self(**batch)
        self.log(f"val_loss", outputs.loss)
        acc = self.val_acc(outputs.logits.argmax(1), batch['labels'])
        self.log(f"val_acc", acc, prog_bar=True)
        return outputs.loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

In [None]:
pl.seed_everything(42)
classifier = Classifier(model, lr=2e-5)
trainer = pl.Trainer(accelerator='gpu', devices=1, precision=32, max_epochs=12)
trainer.fit(classifier, train_loader, val_loader)

In [None]:
val_batch = next(iter(val_loader))
outputs = model(**val_batch)
print('Preds: ', outputs.logits.softmax(1).argmax(1))
print('Labels:', val_batch['labels'])

In [None]:
# save the model
model_dir = "./"
model_name = "vit_model"
model.save_pretrained(model_dir + model_name)


In [None]:
from PIL import Image

# Load the saved model
model_dir = "./"
model_name = "vit_model"
loaded_model = ViTForImageClassification.from_pretrained(model_dir + model_name)

# Load and preprocess the image
image_path = "/content/90anticlock.jpg"
image = Image.open(image_path)
inputs = feature_extractor(image, return_tensors="pt")

# Make the prediction
outputs = loaded_model(**inputs)
predicted_class = id2label[str(outputs.logits.argmax(1).item())]

print(f"Predicted class: {predicted_class}")


In [None]:
 !zip -r model.zip /content/vit_model

In [None]:
! du -sh /content/vit_model