In [1]:
import torch

In [2]:
import numpy as np
import pandas as pd
import os
import torchvision
from torchvision import datasets
from torchvision import transforms as T
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, sampler, random_split
from torchvision import models

In [3]:
import timm
#import torchvision image models
from timm.loss import LabelSmoothingCrossEntropy

In [4]:
import matplotlib.pyplot as plt
%matplotlib inline

In [5]:
import sys
from tqdm import tqdm
import time
import copy

In [6]:
def get_classes(data_dir):
    all_data = datasets.ImageFolder(data_dir)
    return all_data.classes

In [36]:
def get_data_loaders(data_dir, batch_size, train = False):
    if train:
        transform = T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomVerticalFlip(),
            T.RandomApply(torch.nn.ModuleList([T.ColorJitter()]), p = 0.25),
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(timm.data.IMAGENET_DEFAULT_MEAN, timm.data.IMAGENET_DEFAULT_STD),
            #imagenet means
            T.RandomErasing(p = 0.1, value = 'random')
        ])
        
        train_data = datasets.ImageFolder(os.path.join(data_dir, 'train\\'), transform=transform)
        train_loader = DataLoader(train_data, batch_size = batch_size, shuffle = True, num_workers= 4)
        return train_loader, len(train_data)
    else:
        transform = T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(timm.data.IMAGENET_DEFAULT_MEAN, timm.data.IMAGENET_DEFAULT_STD)
            
        ])
        val_data = datasets.ImageFolder(os.path.join(data_dir, "valid\\"), transform = transform)
        test_data = datasets.ImageFolder(os.path.join(data_dir, "test\\"), transform = transform)
        val_loader = DataLoader(val_data, batch_size = batch_size , shuffle=True, num_workers=4)
        test_loader = DataLoader(test_data, batch_size= batch_size, shuffle =True, num_workers = 4)
        return val_loader, test_loader, len(val_data), len(test_data)

In [18]:
!pip install opendatasets

Defaulting to user installation because normal site-packages is not writeable


In [32]:
dataset_path = r"C:\F\f_drive\dataset\archive"

In [33]:
print(dataset_path)

C:\F\f_drive\dataset\archive


In [34]:
print(os.path.join(dataset_path, 'valid\\'))

C:\F\f_drive\dataset\archive\valid\


In [84]:
(train_loader, train_data_len) = get_data_loaders(dataset_path, batch_size=128, train = True)
(val_loader, test_loader, val_data_len, test_data_len) = get_data_loaders(dataset_path, batch_size =32 , train = False)

In [47]:
classes = get_classes(r'C:\F\f_drive\dataset\archive\test')

In [48]:
print(len(classes))

100


In [45]:
dataloaders = {
    'train': train_loader,
    'val': val_loader
}
dataset_sizes = {
    'train': train_data_len,
    'val' : val_data_len
}

In [49]:
print(len(train_loader), len(test_loader), len(val_loader))

99 4 4


In [50]:
print(train_data_len, test_data_len, val_data_len)

12639 500 500


In [59]:
HUB_URL = "SharanSMenon/swin-transformer-hub:main"
MODEL_NAME = "swin_tiny_patch4_window7_224"
#load from hub 
model = torch.hub.load(HUB_URL, MODEL_NAME, pretrained = True)

Using cache found in C:\Users\shwet/.cache\torch\hub\SharanSMenon_swin-transformer-hub_main


In [60]:
#freez model
for param in model.parameters():
    param.requires_grad = False

In [62]:
print(model)

SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0): BasicLayer(
      dim=96, input_resolution=(56, 56), depth=2
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          dim=96, input_resolution=(56, 56), num_heads=3, window_size=7, shift_size=0, mlp_ratio=4.0
          (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=96, window_size=(7, 7), num_heads=3
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=96, out_features=96, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): Identity()
          (norm2): LayerNo

In [63]:
n_input = model.head.in_features
model.head = nn.Sequential(
        nn.Linear(n_input, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, len(classes))
)
print(n_input, model.head)

768 Sequential(
  (0): Linear(in_features=768, out_features=512, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.3, inplace=False)
  (3): Linear(in_features=512, out_features=100, bias=True)
)


In [67]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [68]:
model = model.to(device)
print(model.head)

Sequential(
  (0): Linear(in_features=768, out_features=512, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.3, inplace=False)
  (3): Linear(in_features=512, out_features=100, bias=True)
)


In [69]:
criterion = LabelSmoothingCrossEntropy()
criterion = criterion.to(device)
optimizer = optim.AdamW(model.head.parameters(), lr = 0.001)

In [70]:
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size= 3, gamma= 0.97)
#lr scheduler


In [78]:
def train_model(model, criterion, optimizer, scheduler, num_epochs = 10):
    since = time.time()
    best_model_wt = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/ {num_epochs - 1}')
        print("-" * 10)
        
        for phase in ['train', 'val']:    # Training and validation phase per epoch
            if phase == 'train':
                model.train()   #model to training phase
            else:
                model.eval()  #model to evaluation phase
                
            running_loss = 0.0
            running_corrects = 0.0
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):  #no autograd makes validation step faster
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
            if phase == 'train':
                scheduler.step() #step at end of epoch
                
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            print("{} Loss : {:.4f} Acc : {:.4f}".format(phase, epoch_loss, epoch_acc))
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wt = copy.deepcopy(model.state_dict())   #keep the best vlidation accuracy model
                
        print()
        
    time_elapsed  = time.time() - since  
    print("training complete in {:.0f}m {:.0f}s".format(time_elapsed //60, time_elapsed % 60))
    print("Best val acc : {:.4f}".format(best_acc))
        
        
    model.load_state_dict(best_model_wt)
    return model
            
            
            
        
        
        

In [79]:
model_ft = train_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs = 8)

Epoch 0/ 7
----------


100%|██████████████████████████████████████████████████████████████████████████████████| 99/99 [01:14<00:00,  1.32it/s]


train Loss : 1.6534 Acc : 0.7865


100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00,  1.40s/it]


val Loss : 1.4415 Acc : 0.8760

Epoch 1/ 7
----------


100%|██████████████████████████████████████████████████████████████████████████████████| 99/99 [01:16<00:00,  1.30it/s]


train Loss : 1.4427 Acc : 0.8528


100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00,  1.38s/it]


val Loss : 1.3111 Acc : 0.9100

Epoch 2/ 7
----------


100%|██████████████████████████████████████████████████████████████████████████████████| 99/99 [01:16<00:00,  1.29it/s]


train Loss : 1.3356 Acc : 0.8859


100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00,  1.36s/it]


val Loss : 1.2505 Acc : 0.9020

Epoch 3/ 7
----------


100%|██████████████████████████████████████████████████████████████████████████████████| 99/99 [01:16<00:00,  1.30it/s]


train Loss : 1.2816 Acc : 0.8996


100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00,  1.36s/it]


val Loss : 1.2071 Acc : 0.9180

Epoch 4/ 7
----------


100%|██████████████████████████████████████████████████████████████████████████████████| 99/99 [01:16<00:00,  1.30it/s]


train Loss : 1.2377 Acc : 0.9115


100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00,  1.37s/it]


val Loss : 1.1801 Acc : 0.9240

Epoch 5/ 7
----------


100%|██████████████████████████████████████████████████████████████████████████████████| 99/99 [01:16<00:00,  1.29it/s]


train Loss : 1.2056 Acc : 0.9236


100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00,  1.39s/it]


val Loss : 1.1601 Acc : 0.9300

Epoch 6/ 7
----------


100%|██████████████████████████████████████████████████████████████████████████████████| 99/99 [01:16<00:00,  1.29it/s]


train Loss : 1.1797 Acc : 0.9259


100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00,  1.38s/it]


val Loss : 1.1549 Acc : 0.9420

Epoch 7/ 7
----------


100%|██████████████████████████████████████████████████████████████████████████████████| 99/99 [01:17<00:00,  1.28it/s]


train Loss : 1.1603 Acc : 0.9325


100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00,  1.42s/it]

val Loss : 1.1344 Acc : 0.9300

training complete in 10m 57s
Best val acc : 0.9420





In [88]:
#Lets run the dataset on the test loader and calculate accuracy
test_loss = 0.0
class_correct = list(0 for i in range(len(classes)))
class_total = list(0 for i in range(len(classes)))
model_ft.eval()

for data, traget in tqdm(test_loader):
    data, traget = data.to(device), traget.to(device)
    with torch.no_grad():
        output = model_ft(data)
        loss = criterion(output, traget)
        test_loss = loss.item() * data.size(0)
        _, pred = torch.max(output, 1)
        correct_tensor = pred.eq(traget.data.view_as(pred))
        correct = np.squeeze(correct_tensor.cpu().numpy())
        if len(traget) == 32:
            for i in range(32):
                label = traget.data[i]
                class_correct[label] += correct[i].item()
                class_total[label] += 1
                
test_loss = test_loss / test_data_len
print('Test Loss: {:.4f}'.format(test_loss))
for i in range(len(classes)):
    if class_total[i] > 0:
        print("Test Accuracy of %5s: %2d %% (%2d/%2d)"%(classes[i], 100 * class_correct[i] / class_total[i], np.sum(class_correct[i]), np.sum(class_total[i])))
        
    else:
        print("Test accuracy of %5s : NA" % (classes[i]))
        
print("Test accuracy of %2d%% (%2d%2d)" % (100*np.sum(class_correct)/np.sum(class_total), np.sum(class_correct), np.sum(class_total)))





100%|██████████████████████████████████████████████████████████████████████████████████| 16/16 [00:05<00:00,  2.79it/s]

Test Loss: 0.0463
Test Accuracy of ADONIS: 100 % ( 4/ 4)
Test Accuracy of AFRICAN GIANT SWALLOWTAIL: 100 % ( 5/ 5)
Test Accuracy of AMERICAN SNOOT: 80 % ( 4/ 5)
Test Accuracy of AN 88: 100 % ( 5/ 5)
Test Accuracy of APPOLLO: 100 % ( 5/ 5)
Test Accuracy of ARCIGERA FLOWER MOTH: 100 % ( 5/ 5)
Test Accuracy of ATALA: 100 % ( 5/ 5)
Test Accuracy of ATLAS MOTH: 80 % ( 4/ 5)
Test Accuracy of BANDED ORANGE HELICONIAN: 100 % ( 4/ 4)
Test Accuracy of BANDED PEACOCK: 100 % ( 5/ 5)
Test Accuracy of BANDED TIGER MOTH: 100 % ( 5/ 5)
Test Accuracy of BECKERS WHITE: 100 % ( 5/ 5)
Test Accuracy of BIRD CHERRY ERMINE MOTH: 75 % ( 3/ 4)
Test Accuracy of BLACK HAIRSTREAK: 100 % ( 5/ 5)
Test Accuracy of BLUE MORPHO: 80 % ( 4/ 5)
Test Accuracy of BLUE SPOTTED CROW: 100 % ( 5/ 5)
Test Accuracy of BROOKES BIRDWING: 100 % ( 5/ 5)
Test Accuracy of BROWN ARGUS: 100 % ( 4/ 4)
Test Accuracy of BROWN SIPROETA: 100 % ( 4/ 4)
Test Accuracy of CABBAGE WHITE: 100 % ( 5/ 5)
Test Accuracy of CAIRNS BIRDWING: 100 % ( 5/ 




In [86]:
for data, traget in tqdm(test_loader):
    data, traget = data.to(device), traget.to(device)
    print(len(traget))

 56%|██████████████████████████████████████████████▋                                    | 9/16 [00:02<00:01,  4.63it/s]

32
32
32
32
32
32
32
32
32
32
32
32
32


 94%|████████████████████████████████████████████████████████████████████████████▉     | 15/16 [00:02<00:00,  8.47it/s]

32
32
20


100%|██████████████████████████████████████████████████████████████████████████████████| 16/16 [00:03<00:00,  5.23it/s]


In [90]:
example = torch.rand(1,3, 224,224)
traced_script_module = torch.jit.trace(model.cpu(), example)
traced_script_module.save("butterfly_swin_transformer.pt")

  assert L == H * W, "input feature has wrong size"
  B = int(windows.shape[0] / (H * W / window_size / window_size))
  assert L == H * W, "input feature has wrong size"


In [91]:
ls

 Volume in drive C is OS
 Volume Serial Number is 4839-38DC

 Directory of C:\Users\shwet

03-03-2023  00:19    <DIR>          .
10-02-2023  02:39    <DIR>          ..
02-03-2023  03:34    <DIR>          .cache
14-02-2023  03:38    <DIR>          .conda
27-01-2023  17:23               188 .gitconfig
01-03-2023  18:44    <DIR>          .ipynb_checkpoints
14-02-2023  02:03    <DIR>          .ipython
14-02-2023  03:15    <DIR>          .jupyter
28-02-2023  12:30    <DIR>          .keras
16-02-2023  03:53    <DIR>          .matplotlib
20-01-2023  23:42    <DIR>          .ms-ad
24-01-2023  18:00    <DIR>          .RapidMiner
18-01-2023  17:27    <DIR>          .virtualenvs
18-01-2023  17:22    <DIR>          .vscode
25-01-2023  15:20    <DIR>          ansel
17-02-2023  20:32             5,154 bar_code_creater.ipynb
03-03-2023  00:19       113,434,106 butterfly_swin_transformer.pt
18-01-2023  17:23    <DIR>          code
10-02-2023  02:42    <DIR>          Contacts
10-02-2023  02:42    <DIR>