# Petfinder Pawpularity Vision Transformer
This notebook referred to the following notebook.<br/>
https://www.kaggle.com/szuzhangzhi/vision-transformer-vit-cuda-as-usual

In [None]:
!pip install vision_transformer_pytorch

In [None]:
import sys
package_path = '../input/vision-transformer-pytorch/VisionTransformer-Pytorch'
sys.path.append(package_path)

In [None]:
import os
import pandas as pd
import time
import datetime
import copy
import matplotlib.pyplot as plt
import json
import seaborn as sns
import cv2
import albumentations as albu
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold, train_test_split


# ALBUMENTATIONS
import albumentations as albu

from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, 
    ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, 
    OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, 
    MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, 
    Flip, OneOf, Compose, Normalize, Cutout, 
    CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)
    
from albumentations.pytorch import ToTensorV2

# ADAMP
# from adamp import AdamP

In [None]:
BASE_DIR="../input/petfinder-pawpularity-score"
TRAIN_IMAGES_DIR = os.path.join(BASE_DIR,'train')

In [None]:
train_df = pd.read_csv(os.path.join(BASE_DIR, 'train.csv'))
train_df.head()

## Target: Pawpularity

In [None]:
class_name = train_df['Pawpularity'].value_counts().index  #target name
class_count = train_df['Pawpularity'].value_counts().values

In [None]:
train_df.Pawpularity.value_counts()    #target name

In [None]:
# Counting target values.

targ_cts=train_df.Pawpularity.value_counts()    #target name
fig = plt.figure(figsize=(15,6))
sns.barplot(x=targ_cts.sort_values(ascending=False).index, 
            y=targ_cts.sort_values(ascending=False).values, 
            palette='summer')
plt.title('Target Distribution')
plt.show()

from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()

le.fit(train_df.pollen_carrying)
train_df['pollen_carrying'] = le.transform(train_df.pollen_carrying)

In [None]:
def visualize_images(image_ids, labels):
    plt.figure(figsize=(16, 12))
    
    for idx, (image_id, label) in enumerate(zip(image_ids, labels)):
        plt.subplot(3, 3, idx+1)
        
        image = cv2.imread(os.path.join(TRAIN_IMAGES_DIR, image_id))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        plt.imshow(image)
        plt.title(f"Pawpularity: {label}", fontsize=12)    #target name
        plt.axis("off")
        
    plt.show()
    

def plot_augmentation(image_id, transform):
    plt.figure(figsize=(16, 4))
    
    img = cv2.imread(os.path.join(TRAIN_IAMGES_DIR, image_id))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    x = transform(image=img)['image']
    plt.imshow(x)
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    x = transform(image=img)['image']
    plt.imshow(x)
    
    
def visualize(images, transform):
    '''
    Plot images and their transformations
    '''
    fig = plt.figure(figsize=(32, 16))
    
    for i, im in enumerate(images):
        ax = fig.add_subplot(2, 5, i+1, xticks=[], yticks=[])
        plt.imshow(im)
        
    for i, im in enumerate(images):
        ax = fig.add_subplot(2, 5, i+6, xticks=[], yticks=[])
        plt.imshow(transform(image=im)['image'])

In [None]:
# CUSTOM DATASET CLASS
class PlantDataset(Dataset):
    def __init__(
        self, df:pd.DataFrame, imfolder:str, train:bool=True, transforms=None
    ):
        self.df = df
        self.imfolder = imfolder
        self.train = train
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        im_path = os.path.join(self.imfolder, self.df.iloc[index]['Id']+'.jpg')
        print(im_path)
        im = cv2.imread(im_path, cv2.IMREAD_COLOR)
        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
        
        if (self.transforms):
            '''
            When AlbumentationCompose, a dictionary with key 'image' is created
            '''
            im = self.transforms(image=im)['image']
            
        if (self.train):
            label = self.df.iloc[index]['Pawpularity']    #target name
            return im, label
        else:
            return im

In [None]:
# AUGMENTATIONS

train_augs = albu.Compose([
    albu.RandomResizedCrop(height=384, width=384, p=1.0),
    albu.HorizontalFlip(p=0.5),
    albu.VerticalFlip(p=0.5),
    albu.RandomBrightnessContrast(p=0.5),
    albu.ShiftScaleRotate(p=0.5),
    albu.Normalize(    
        mean=[0.3, 0.3, 0.3],
        std=[0.1, 0.1, 0.1],),
    CoarseDropout(p=0.5),
    Cutout(p=0.5),
    ToTensorV2(),
])

valid_augs = albu.Compose([
    albu.Resize(height=384, width=384, p=1.0),
    albu.Normalize(
        mean=[0.3, 0.3, 0.3],
        std=[0.1, 0.1, 0.1],),
    ToTensorV2(),
])


In [None]:
# DATA SPLIT
train, valid = train_test_split(
    train_df,
    test_size=0.1,
    random_state=42,
    stratify=train_df.Pawpularity.values    #target name
)

# reset index on both dataframes
train = train.reset_index(drop=True)
valid = valid.reset_index(drop=True)

# targets in train,valid datasets
train_targets = train.Pawpularity.values    #target name
valid_targets = valid.Pawpularity.values

In [None]:
# DEFINE PYTORCH CUSTOM DATASET
train_dataset = PlantDataset(
    df = train,
    imfolder = TRAIN_IMAGES_DIR,
    train = True,
    transforms = train_augs
)

valid_dataset = PlantDataset(
    df = valid,
    imfolder = TRAIN_IMAGES_DIR,
    train=True,
    transforms = valid_augs
)

In [None]:
def plot_image(img_dict):
    image_tensor = img_dict[0]
    target = img_dict[1]
    print(target)
    image = image_tensor.permute(1, 2, 0)
    plt.imshow(image)

In [None]:
plot_image(train_dataset[7])

In [None]:
plot_image(train_dataset[10])

In [None]:
plot_image(train_dataset[13])

In [None]:
# MAKE PYTORCH DATALOADER
train_loader = DataLoader(
    train_dataset,
    batch_size = 4,
    num_workers = 4,
    shuffle = True
)

valid_loader = DataLoader(
    valid_dataset,
    batch_size = 4,
    num_workers = 4,
    shuffle = False
)

In [None]:
# TRAIN
def train_model(datasets, dataloaders, model, criterion, optimizer, scheduler, num_epochs, device):
    since = time.time()
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs-1))
        print('-' * 10)
        
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            running_loss = 0.0
            running_corrects = 0.0
            
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
               
                # Zero out the grads
                optimizer.zero_grad()
                
                # Forward
                # Track history in train mode
                with torch.set_grad_enabled(phase == 'train'):
                    model = model.to(device)
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1) 
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                # Statistics
                running_loss += loss.item()*inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            
            if phase == 'train':
                scheduler.step()
                
            epoch_loss = running_loss / len(datasets[phase])
            epoch_acc = running_corrects.double() / len(datasets[phase])
            
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))
            
            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        
        print()
    
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:.4f}'.format(best_acc))
    
    model.load_state_dict(best_model_wts)
    
    return model

In [None]:
from vision_transformer_pytorch import VisionTransformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

datasets = {'train': train_dataset,
            'valid': valid_dataset}

dataloaders = {'train': train_loader,
               'valid': valid_loader}

# LOAD PRETRAINED ViT MODEL
model = VisionTransformer.from_pretrained('ViT-B_16', num_classes=101)     ####      

# OPTIMIZER
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.001)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.001)
# optimizer = AdamP(model.parameters(), lr=1e-4, weight_decay=0.001)

# LEARNING RATE SCHEDULER
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)

criterion = nn.CrossEntropyLoss()
num_epochs = 8

In [None]:
# MODEL TRAIN
trained_model = train_model(datasets, dataloaders, 
                            model, criterion, optimizer, 
                            scheduler, num_epochs, device)

In [None]:
# Save the mode after training
torch.save(model.state_dict(), 'vit_b-16_2epoch.pt')