# Vision Transformers (ViT) for Cell Phenotyping

In [2]:
!pip -q install transformers datasets

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.2.2[0m[39;49m -> [0m[32;49m23.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
from transformers import ViTFeatureExtractor, ViTForImageClassification
import torch
import sagemaker

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
role=sagemaker.get_execution_role()

In [9]:
from sagemaker.inputs import TrainingInput

train_input = TrainingInput(s3_data="s3://mlbucket-876f4491/cell-analysis-processed/train", input_mode="FastFile")
test_input = TrainingInput(s3_data="s3://mlbucket-876f4491/cell-analysis-processed/test", input_mode="FastFile")
inputs = {
    "training": train_input,
    "testing": test_input
}

In [None]:
from sagemaker.pytorch import PyTorch

icp_estimator = PyTorch(entry_point='timm_ViT.py',
                         source_dir = './source',
                            role=role,
                            framework_version='1.11.0',
                            py_version='py38',
                            train_instance_count=1,
                            train_instance_type='ml.p3.16xlarge',
                            volume_size = 500,
                            max_run = (24 * 60 * 60),
                            base_job_name='pytorch-vit-deep-phenotyping',
                            hyperparameters = {'epochs': 20, 'batch_size': 32, 'lr': 1e-4},
                            metric_definitions=[
                                {'Name': 'Train: Loss', "Regex": "Train Loss:(.*?);"},
                                {'Name': 'Train: Accuracy', "Regex": "Train Acc:(.*?);"},
                                {'Name': 'Validation: Loss', "Regex": "Valid Loss:(.*?);"},
                                {'Name': 'Validation: Accuracy', "Regex": "Valid Acc:(.*?);"}
                            ],
                            enable_sagemaker_metrics=True,
                        distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
                        debugger_hook_config=False
                        )
icp_estimator.fit(inputs)

train_instance_count has been renamed in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.
train_instance_type has been renamed in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.
train_instance_type has been renamed in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.


2023-03-21 16:29:24 Starting - Starting the training job...ProfilerReport-1679416163: InProgress
......
2023-03-21 16:30:54 Starting - Preparing the instances for training......
2023-03-21 16:31:54 Downloading - Downloading input data...
2023-03-21 16:32:14 Training - Downloading the training image..................
2023-03-21 16:35:24 Training - Training image download completed. Training in progress.....[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2023-03-21 12:35:49,921 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2023-03-21 12:35:49,985 sagemaker-training-toolkit INFO     No Neurons detected (normal if no neurons installed)[0m
[34m2023-03-21 12:35:49,998 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[34m2023-03-21 12:35:50,000 sagemaker_pytorch_container.training INFO     Invoking SMDataPar

In [5]:
import timm 
import torch

In [11]:
avail_pretrained_models = timm.list_models(pretrained=True)

In [7]:
model = timm.create_model('vit_large_patch32_224_in21k', pretrained=False, img_size=75, in_chans=4, num_classes=8, drop_rate=0.5)

In [8]:
model

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(4, 1024, kernel_size=(32, 32), stride=(32, 32))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.5, inplace=False)
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1024, out_features=3072, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1024, out_features=1024, bias=True)
        (proj_drop): Dropout(p=0.5, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (act): GELU(approximate=none)
        (drop1): Dropout(p=0.5, inplace=False)
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (drop2): Dropout(p=0.5, inp

In [12]:
avail_pretrained_models

['adv_inception_v3',
 'bat_resnext26ts',
 'beit_base_patch16_224',
 'beit_base_patch16_224_in22k',
 'beit_base_patch16_384',
 'beit_large_patch16_224',
 'beit_large_patch16_224_in22k',
 'beit_large_patch16_384',
 'beit_large_patch16_512',
 'beitv2_base_patch16_224',
 'beitv2_base_patch16_224_in22k',
 'beitv2_large_patch16_224',
 'beitv2_large_patch16_224_in22k',
 'botnet26t_256',
 'cait_m36_384',
 'cait_m48_448',
 'cait_s24_224',
 'cait_s24_384',
 'cait_s36_384',
 'cait_xs24_384',
 'cait_xxs24_224',
 'cait_xxs24_384',
 'cait_xxs36_224',
 'cait_xxs36_384',
 'coat_lite_mini',
 'coat_lite_small',
 'coat_lite_tiny',
 'coat_mini',
 'coat_tiny',
 'coatnet_0_rw_224',
 'coatnet_1_rw_224',
 'coatnet_bn_0_rw_224',
 'coatnet_nano_rw_224',
 'coatnet_rmlp_1_rw_224',
 'coatnet_rmlp_2_rw_224',
 'coatnet_rmlp_nano_rw_224',
 'coatnext_nano_rw_224',
 'convit_base',
 'convit_small',
 'convit_tiny',
 'convmixer_768_32',
 'convmixer_1024_20_ks9_p14',
 'convmixer_1536_20',
 'convnext_atto',
 'convnext_atto_

In [None]:
from sagemaker.huggingface import HuggingFace

huggingface_estimator = HuggingFace(
    role=role,
    # Fine-tuning script
    entry_point=entry_point,
    hyperparameters=hyperparameters,
    # Infrastructure
    transformers_version='4.12.3',
    pytorch_version='1.9.1',
    py_version='py38',
    instance_type='ml.g4dn.2xlarge',
    instance_count=1
)