In [73]:
import pandas as pd
import os

import warnings
warnings.filterwarnings('ignore')

import utils
import datasets as cstm_dataset
import config
import trainvit
import transferlearningvit

import importlib
importlib.reload(config)
importlib.reload(utils)
importlib.reload(cstm_dataset)
importlib.reload(trainvit)
importlib.reload(transferlearningvit)

# download data
utils.download_and_extract_tar(url = config.STANFORD_DOG_DATASET_URL, extract_to='.', rename_folder_to='stanford-dog-dataset')

The folder ./stanford-dog-dataset already exists.


In [74]:
def count_subfolders(root_dir):
    count = 0
    for root, dirs, files in os.walk(root_dir):
        count += len(dirs)
        break  # Only consider the immediate subfolders of the root directory
    return count

# Example usage
root_directory = config.STANFORD_DOG_DATASET_LOCAL_PATH
STANFORD_DOG_DATASET_NUM_CLASSES = count_subfolders(root_directory)
DATASET = 'stanford_dog'
print(f"Number of subfolders in '{root_directory}': {STANFORD_DOG_DATASET_NUM_CLASSES}")

Number of subfolders in './stanford-dog-dataset/': 120


In [78]:
# Define model mapping
model_mapping = {
    "deit_tiny_patch16_224": "facebook/deit-tiny-patch16-224",
    "deit_small_patch16_224": "facebook/deit-small-patch16-224",
    "deit_base_patch16_224": "facebook/deit-base-patch16-224",
    "swin_tiny_patch4_window7_224": "microsoft/swin-tiny-patch4-window7-224",
    "swin_small_patch4_window7_224": "microsoft/swin-small-patch4-window7-224",
    "swin_base_patch4_window7_224": "microsoft/swin-base-patch4-window7-224",
    # "vit_tiny_patch16_224": "google/vit-tiny-patch16-224",
    # "vit_small_patch16_224": "google/vit-small-patch16-224",
    "vit_base_patch16_224": "google/vit-base-patch16-224",
}

In [76]:
def do_transfer_learning(model_name, num_epochs_ = 5, learning_rate_ = 4e-5, dataset_ = DATASET, num_labels_ = STANFORD_DOG_DATASET_NUM_CLASSES,
                         batch_size_ = 256, train_pct_ = 0.8, val_pct_ = 0.10, resize_img = 224):
    
    model_checkpoint = model_mapping[model_name]
    train_loader, val_loader, test_loader = cstm_dataset.create_dataset(model_checkpoint=model_checkpoint, batch_size_=batch_size_,  
                                                                        train_pct=train_pct_, val_pct=val_pct_, resize_img=resize_img)
    transferlearningvit.transfer_learning_pretrain_vit(model_checkpoint, train_loader, val_loader, test_loader, dataset = dataset_,
                                                       num_labels_ = num_labels_, learning_rate = learning_rate_, num_epochs = num_epochs_)

## Google's ViT

In [None]:
# User-specified model
model_name = "vit_base_patch16_224"
do_transfer_learning(model_name=model_name)


## Facebook DEIT

In [None]:
# User-specified model
model_name = "deit_base_patch16_224"  # Replace this with your desired model name
do_transfer_learning(model_name=model_name)

In [None]:
# User-specified model
model_name = "deit_tiny_patch16_224"  # Replace this with your desired model name
do_transfer_learning(model_name=model_name)


In [None]:
# User-specified model
model_name = "deit_small_patch16_224"  # Replace this with your desired model name
do_transfer_learning(model_name=model_name)


## Microsoft SWIN

In [None]:
# User-specified model
model_name = "swin_base_patch4_window7_224"  # Replace this with your desired model name
do_transfer_learning(model_name=model_name)


In [None]:
# User-specified model
model_name = "swin_small_patch4_window7_224"  # Replace this with your desired model name
do_transfer_learning(model_name=model_name)


In [None]:
# User-specified model
model_name = "swin_tiny_patch4_window7_224"  # Replace this with your desired model name
do_transfer_learning(model_name=model_name)
