In [2]:
import torch 
from torchvision import transforms, models
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
import torch.nn as nn 
from PIL import Image, ImageFile
from torchsummary import summary
import warnings
from torchmetrics.classification import Accuracy, Precision, Recall, ConfusionMatrix
from tqdm import tqdm

In [3]:
# ignoring user warnings because it just points out the 'P' images
warnings.simplefilter('ignore', UserWarning)

In [4]:
device = torch.device('mps')
if not(torch.backends.mps.is_available()): 
    device = torch.device('cpu')
device

device(type='mps')

In [5]:
swin_transforms = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(), 
        transforms.Normalize(
            mean=[0.485,0.456, 0.406], 
            std = [0.229, 0.224, 0.225]
        )
    ])


In [6]:
ImageFile.LOAD_TRUNCATED_IMAGES = True
class CustomDataset(ImageFolder): 
    def __init__(self, root, transform=None):
        super().__init__(root, transform=None)
        self.transform = transform
    
    def __getitem__(self, idx):
        
        path, label = self.samples[idx]
        with warnings.catch_warnings():
            warnings.simplefilter('ignore', Image.DecompressionBombWarning)

            try:
                image = Image.open(path).convert('RGB')
                
            except Image.DecompressionBombWarning:
                print(f'Decompression Bomb Warning at {path}')
                raise IndexError(f'Skipping {path} because decompression warning')

            except Exception as e:
                print(e)
                print('PATH = ', path)    
                
        if self.transform: 
            image = self.transform(image)
            
        # label manipulation 
        label = torch.tensor(label)
        label = nn.functional.one_hot(label, num_classes=2)
            
        return image, label
        

In [6]:
BATCH_SIZE = 32
train_link = "datasets/30 k datapoints/train"

dataset = CustomDataset(root = train_link, transform=swin_transforms)

In [7]:
TRAIN_SPLIT = int(0.8 * len(dataset)) 
train, val = random_split(dataset= dataset, lengths=[TRAIN_SPLIT, len(dataset) - TRAIN_SPLIT])

In [8]:
TrainLoader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
ValLoader = DataLoader(val, batch_size=BATCH_SIZE, shuffle=False)

In [8]:
# printing shapes and testing 
for images, labels in TrainLoader: 
    print(images.shape, labels)
    break

for image, label in ValLoader: 
    print(image.shape, label)
    break

torch.Size([32, 3, 224, 224]) tensor([[1, 0],
        [1, 0],
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
        [1, 0],
        [1, 0],
        [0, 1],
        [0, 1],
        [1, 0],
        [0, 1],
        [0, 1],
        [1, 0],
        [1, 0],
        [1, 0],
        [0, 1],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0]])
torch.Size([32, 3, 224, 224]) tensor([[1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
        [1, 0],
        [1, 0],
        [1, 0],
        [0, 1],
        [0, 1],
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
        [0, 1],
        [1, 0],
        [0, 1],
        [0, 1],
        [1, 0],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [1,

In [7]:
class SWIN(nn.Module): 
    def __init__(self, fine_tune = False):
        super(SWIN, self).__init__()
        
        # load pretrained model 
        self.swin = models.swin_v2_t(weights = 'DEFAULT')
        
        # freeze the vgg 16 
        if not(fine_tune): 
            for params in self.swin.parameters(): 
                params.requires_grad = False
        
        self.swin.head = nn.Sequential(
                nn.Linear(self.swin.head.in_features, 256),
                nn.BatchNorm1d(num_features=256), 
                nn.Dropout(p=0.6),
                nn.ReLU(), 
                nn.Linear(256, 2), 
        )
        
        # making classifier segment trainable 
        for params in self.swin.head.parameters(): 
            params.requires_grad = True
        
    def forward(self, x): 
        return self.swin(x)
    
    def get_prediction(self, x):
        outputs = self.forward(x)
        outputs = torch.softmax(outputs)
        return torch.argmax(outputs)
        
        
    def StartFineTuning(self, blocks_to_unfreeze=4): 
        ''' unfreeze the last block and train that as well 
        do this only when you train the classifier model 
        '''
        
        for idx in range(blocks_to_unfreeze, 4):
            for param in self.swin.features[idx].parameters():
                param.requires_grad = True
        

In [10]:
model = SWIN().to(device=device)

Downloading: "https://download.pytorch.org/models/swin_v2_t-b137f0e2.pth" to /Users/manedge/.cache/torch/hub/checkpoints/swin_v2_t-b137f0e2.pth
100%|██████████| 109M/109M [00:03<00:00, 37.9MB/s] 


In [11]:
model

SWIN(
  (swin): SwinTransformer(
    (features): Sequential(
      (0): Sequential(
        (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
        (1): Permute()
        (2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      )
      (1): Sequential(
        (0): SwinTransformerBlockV2(
          (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (attn): ShiftedWindowAttentionV2(
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (proj): Linear(in_features=96, out_features=96, bias=True)
            (cpb_mlp): Sequential(
              (0): Linear(in_features=2, out_features=512, bias=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=512, out_features=3, bias=False)
            )
          )
          (stochastic_depth): StochasticDepth(p=0.0, mode=row)
          (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (0): Linear(in_features=96, ou

In [12]:
summary(model);

Layer (type:depth-idx)                             Param #
├─SwinTransformer: 1-1                             --
|    └─Sequential: 2-1                             --
|    |    └─Sequential: 3-1                        (4,896)
|    |    └─Sequential: 3-2                        (229,830)
|    |    └─PatchMergingV2: 3-3                    (74,112)
|    |    └─Sequential: 3-4                        (898,956)
|    |    └─PatchMergingV2: 3-5                    (295,680)
|    |    └─Sequential: 3-6                        (10,692,936)
|    |    └─PatchMergingV2: 3-7                    (1,181,184)
|    |    └─Sequential: 3-8                        (14,203,440)
|    └─LayerNorm: 2-2                              (1,536)
|    └─Permute: 2-3                                --
|    └─AdaptiveAvgPool2d: 2-4                      --
|    └─Flatten: 2-5                                --
|    └─Sequential: 2-6                             --
|    |    └─Linear: 3-9                            196,864
|    |

In [13]:
def train_swin(TrainLoader, ValLoader, model, EPOCHS = 5):
    
    # model parameters definition
    lossfn = torch.nn.CrossEntropyLoss()
    LEARNING_RATE = 1E-3
    opt = torch.optim.Adam(
        model.parameters(),
        lr=LEARNING_RATE,
        weight_decay=1e-5
    )
    
    # accuracy 
    accuracy = Accuracy(task='multiclass', num_classes=2).to(device=device)
    
    # training segment
    training_loss_list = []
    val_loss_list = []
    for epoch in range(EPOCHS): 
        running_loss = 0
        TrainLoader_tqdm = tqdm(TrainLoader)
        for image, label in TrainLoader_tqdm:
            
            
            # moving labels and images to GPU
            image = image.to(device=device)
            label = label.to(device=device)
        
            opt.zero_grad()
        
            # predicting and training 
            output = model(image)
            loss = lossfn(output.squeeze(1), label.float())
            loss.backward()
            running_loss += loss.item() / len(TrainLoader)
            
            accuracy.update(output.squeeze(1), label.argmax(dim=1))
            TrainLoader_tqdm.set_postfix({"Training Loss": running_loss})
            opt.step()
            
        train_accuracy = accuracy.compute().item()
        accuracy.reset()
        
        # storing train loss 
        training_loss_list.append(running_loss)
        
        # print progress bar 
        
        
        
        
        
        # validation segment
        val_running_loss = 0
        for image, label in ValLoader: 
            image = image.to(device = device)
            label = label.to(device = device)
            
            # model output
            output = model(image)
            
            # loss computation
            loss = lossfn(output.squeeze(1), label.float())
            val_running_loss += loss.item() / len(ValLoader)
            
            # accuracy computation
            accuracy.update(output.squeeze(1), label.argmax(dim=1))
        
        # validation loss storing 
        val_loss_list.append(val_running_loss)
    
        # final accuracy calculations 
        val_accuracy = accuracy.compute().item()
        accuracy.reset()
        
        # printing metrics
        print(f'''epoch [{epoch+1}/{EPOCHS}]
        \t training loss: {running_loss},
        \t validation loss: {val_running_loss},
        \t Train Accuracy: {train_accuracy},
        \t Val acc: {val_accuracy},
            ''')
    
    return model


In [14]:
EPOCHS = 5
model = train_swin(TrainLoader, ValLoader, model, EPOCHS=EPOCHS)

100%|██████████| 1200/1200 [42:20<00:00,  2.12s/it, Training Loss=0.266]  


epoch [1/5]
        	 training loss: 0.26564220673094224,
        	 validation loss: 0.22959856534997627,
        	 Train Accuracy: 0.8905468583106995,
        	 Val acc: 0.9078124761581421,
            


100%|██████████| 1200/1200 [42:36<00:00,  2.13s/it, Training Loss=0.173]


epoch [2/5]
        	 training loss: 0.1729116088128648,
        	 validation loss: 0.20733012164632492,
        	 Train Accuracy: 0.9341145753860474,
        	 Val acc: 0.9190624952316284,
            


100%|██████████| 1200/1200 [42:49<00:00,  2.14s/it, Training Loss=0.11]   


epoch [3/5]
        	 training loss: 0.10988528079353259,
        	 validation loss: 0.21978169227639843,
        	 Train Accuracy: 0.9595833420753479,
        	 Val acc: 0.918749988079071,
            


100%|██████████| 1200/1200 [42:51<00:00,  2.14s/it, Training Loss=0.0644]


epoch [4/5]
        	 training loss: 0.06441883566944556,
        	 validation loss: 0.25049035825145727,
        	 Train Accuracy: 0.9780208468437195,
        	 Val acc: 0.9201041460037231,
            


100%|██████████| 1200/1200 [42:48<00:00,  2.14s/it, Training Loss=0.042] 


epoch [5/5]
        	 training loss: 0.042004853563630595,
        	 validation loss: 0.31358435168707127,
        	 Train Accuracy: 0.9860156178474426,
        	 Val acc: 0.910729169845581,
            


In [15]:
torch.save(model, './saved models/swin_T_no_finetune.pt')

In [37]:
# fine tuning segment 
model.StartFineTuning(blocks_to_unfreeze=3)
summary(model);

Layer (type:depth-idx)                             Param #
├─SwinTransformer: 1-1                             --
|    └─Sequential: 2-1                             --
|    |    └─Sequential: 3-1                        (4,896)
|    |    └─Sequential: 3-2                        (224,694)
|    |    └─PatchMerging: 3-3                      (74,496)
|    |    └─Sequential: 3-4                        891,756
|    |    └─PatchMerging: 3-5                      (296,448)
|    |    └─Sequential: 3-6                        (31,976,856)
|    |    └─PatchMerging: 3-7                      (1,182,720)
|    |    └─Sequential: 3-8                        (14,183,856)
|    └─LayerNorm: 2-2                              (1,536)
|    └─Permute: 2-3                                --
|    └─AdaptiveAvgPool2d: 2-4                      --
|    └─Flatten: 2-5                                --
|    └─Sequential: 2-6                             --
|    |    └─Linear: 3-9                            196,864
|    |  

In [38]:
model = train_swin(TrainLoader, ValLoader, model, EPOCHS=3)

100%|██████████| 1200/1200 [56:03<00:00,  2.80s/it, Training Loss=0.194]  


epoch [1/3]
        	 training loss: 0.1939756562568555,
        	 validation loss: 0.2718391049901644,
        	 Train Accuracy: 0.9234114289283752,
        	 Val acc: 0.8854166865348816,
            


 26%|██▌       | 311/1200 [20:07<57:32,  3.88s/it, Training Loss=0.0623]  


KeyboardInterrupt: 

In [28]:
torch.save(model, './saved models/swin_finetuned.pt')

# Testing

In [8]:
model = torch.load('./saved models/swin_no_finetune.pt')

In [15]:
testlink = "datasets/30 k datapoints/test"
testset = CustomDataset(root= testlink, transform=swin_transforms)
TestLoader = DataLoader(testset, shuffle=False, batch_size=BATCH_SIZE)

In [17]:

running_loss = 0
BATCH_SIZE = 32
# for metrics
accuracy = Accuracy(task='multiclass', num_classes=2).to(device=device)
precision = Precision(task='multiclass', num_classes=2).to(device=device)
recall = Recall(task='multiclass', num_classes=2).to(device=device)
ConMat = ConfusionMatrix(task='multiclass', num_classes=2).to(device=device)

# loss function 
lossfn = torch.nn.CrossEntropyLoss()

# testing phase
TestLoader_tqdm = tqdm(TestLoader)
model.eval()
for image, label in TestLoader_tqdm:
    image = image.to(device= device)
    label = label.to(device=device)

    output = model(image)

    loss = lossfn(output, label.float())
    running_loss += loss.item() / len(TestLoader)

    TestLoader_tqdm.set_postfix({'Loss':running_loss})

    accuracy.update(output, label.argmax(dim=1))
    precision.update(output, label.argmax(dim=1))
    recall.update(output, label.argmax(dim=1))
    ConMat.update(output, label.argmax(dim=1))
    
    

100%|██████████| 375/375 [13:28<00:00,  2.16s/it, Loss=0.306] 


In [18]:
print(f'''Accuracy [{accuracy.compute()}]
    \t Test loss: {running_loss},
    \t Precision: {precision.compute()},
    \t Recall: {recall.compute()},
    \t 
        ''')


Accuracy [0.921999990940094]
    	 Test loss: 0.30621550742816206,
    	 Precision: 0.921999990940094,
    	 Recall: 0.921999990940094,
    	 
        


In [19]:
ConMat.compute()

tensor([[5549,  451],
        [ 485, 5515]], device='mps:0')

# Random test set

In [37]:
link = './datasets/real life'
BATCH_SIZE = 32
randomTestSet = CustomDataset(root = link, transform=swin_transforms)
TestLoader = DataLoader(randomTestSet, batch_size=BATCH_SIZE, shuffle=True)

In [38]:

running_loss = 0

# for metrics
accuracy = Accuracy(task='multiclass', num_classes=2).to(device=device)
precision = Precision(task='multiclass', num_classes=2).to(device=device)
recall = Recall(task='multiclass', num_classes=2).to(device=device)
ConMat = ConfusionMatrix(task='multiclass', num_classes=2).to(device=device)

accuracy.reset()
precision.reset()
recall.reset()
ConMat.reset()
# loss function 
lossfn = torch.nn.CrossEntropyLoss()

# testing phase
TestLoader_tqdm = tqdm(TestLoader)
model.eval()
for idx, (image, label) in enumerate(TestLoader_tqdm):
    image = image.to(device= device)
    label = label.to(device=device)

    output = model(image)

    loss = lossfn(output, label.float())
    running_loss += loss.item() / len(TestLoader)

    TestLoader_tqdm.set_postfix({'Loss':running_loss})

    accuracy.update(output, label.argmax(dim=1))
    precision.update(output, label.argmax(dim=1))
    recall.update(output, label.argmax(dim=1))
    ConMat.update(output, label.argmax(dim=1))

    if idx == 100:
        break

100%|██████████| 1/1 [00:01<00:00,  1.23s/it, Loss=10.2]


In [39]:
print(f'''Accuracy [{accuracy.compute()}]
    \t Test loss: {running_loss},
    \t Precision: {precision.compute()},
    \t Recall: {recall.compute()},
    \t 
        ''')


Accuracy [0.2857142984867096]
    	 Test loss: 10.234702110290527,
    	 Precision: 0.2857142984867096,
    	 Recall: 0.2857142984867096,
    	 
        


In [40]:
ConMat.compute()

tensor([[2, 0],
        [5, 0]], device='mps:0')