# Swin Transformers in Pytorch

In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import torch
import torchvision
from torchvision import datasets
from torchvision import transforms as T # for simplifying the transforms
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, sampler, random_split
from torchvision import models

In [2]:
import torch

if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("✅ MPS is available. Using Apple Silicon GPU.")
else:
    device = torch.device("cpu")
    print("❌ MPS not available. Falling back to CPU.")

✅ MPS is available. Using Apple Silicon GPU.


In [3]:
## Now, we import timm, torchvision image models
# !pip install timm
import timm
from timm.loss import LabelSmoothingCrossEntropy # This is better than normal nn.CrossEntropyLoss

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# remove warnings
import warnings
warnings.filterwarnings("ignore")

In [5]:
import matplotlib.pyplot as plt
%matplotlib inline

In [6]:
import sys
from tqdm import tqdm
import time
import copy

In [7]:
def get_classes(data_dir):
    all_data = datasets.ImageFolder(data_dir)
    return all_data.classes

In [8]:
def get_data_loaders(data_dir, batch_size, train = False):
    if train:
        #train
        transform = T.Compose([
            T.Resize((256, 256)),                             # Resize to uniform size
            T.RandomRotation(degrees=10),                     # Small rotation for viewpoint variation
            T.RandomHorizontalFlip(p=0.5),                    # Flip can be safe for lungs (if not labeled left/right)
            T.CenterCrop(224),                                # Crop to match model input
            T.ToTensor(),
            T.Normalize(mean=[0.485], std=[0.229]),           # Normalize for grayscale X-ray images (you may recompute dataset stats)
            T.RandomErasing(p=0.1, scale=(0.02, 0.1), ratio=(0.3, 3.3), value='random')  # Simulate occlusions
                ])
        train_data = datasets.ImageFolder(os.path.join(data_dir, "train/"), transform = transform)
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
        return train_loader, len(train_data)
    else:
        # val/test
        transform = T.Compose([ # We dont need augmentation for test transforms
            T.Resize((256, 256)),                             # Resize to uniform size
            T.RandomRotation(degrees=10),                     # Small rotation for viewpoint variation
            T.RandomHorizontalFlip(p=0.5),                    # Flip can be safe for lungs (if not labeled left/right)
            T.CenterCrop(224),                                # Crop to match model input
            T.ToTensor(),
            T.Normalize(mean=[0.485], std=[0.229]),           # Normalize for grayscale X-ray images (you may recompute dataset stats)
            T.RandomErasing(p=0.1, scale=(0.02, 0.1), ratio=(0.3, 3.3), value='random')  # Simulate occlusions
        ])
        val_data = datasets.ImageFolder(os.path.join(data_dir, "valid/"), transform=transform)
        test_data = datasets.ImageFolder(os.path.join(data_dir, "test/"), transform=transform)
        val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True, num_workers=4)
        test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=4)
        return val_loader, test_loader, len(val_data), len(test_data)

In [9]:
dataset_path = "/Users/zagarsuren/Documents/GitHub/xray-classification/data"

In [10]:
(train_loader, train_data_len) = get_data_loaders(dataset_path, 128, train=True)

In [11]:
(val_loader, test_loader, valid_data_len, test_data_len) = get_data_loaders(dataset_path, 32, train=False)

In [12]:
classes = get_classes("/Users/zagarsuren/Documents/GitHub/xray-classification/data/train")
print(classes, len(classes))

['Atelectasis', 'Cardiomegaly', 'No Finding', 'Nodule', 'Pneumothorax'] 5


In [13]:
dataloaders = {
    "train": train_loader,
    "val": val_loader
}
dataset_sizes = {
    "train": train_data_len,
    "val": valid_data_len
}

In [14]:
print(len(train_loader), len(val_loader), len(test_loader))

35 18 18


In [15]:
print(train_data_len, valid_data_len, test_data_len)

4370 545 550


In [16]:
# now, for the model
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cpu") # torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
device

device(type='cpu')

In [17]:
classes

['Atelectasis', 'Cardiomegaly', 'No Finding', 'Nodule', 'Pneumothorax']

In [18]:
HUB_URL = "SharanSMenon/swin-transformer-hub:main"
MODEL_NAME = "swin_tiny_patch4_window7_224"
# check hubconf for more models.
model = torch.hub.load(HUB_URL, MODEL_NAME, pretrained=True) # load from torch hub

Using cache found in /Users/zagarsuren/.cache/torch/hub/SharanSMenon_swin-transformer-hub_main


In [19]:
for param in model.parameters(): #freeze model
    param.requires_grad = False

n_inputs = model.head.in_features
model.head = nn.Sequential(
    nn.Linear(n_inputs, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, len(classes))
)
model = model.to(device)
print(model.head)

Sequential(
  (0): Linear(in_features=768, out_features=512, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.3, inplace=False)
  (3): Linear(in_features=512, out_features=5, bias=True)
)


In [20]:
criterion = LabelSmoothingCrossEntropy()
criterion = criterion.to(device)
optimizer = optim.AdamW(model.head.parameters(), lr=0.001)

In [21]:
# lr scheduler
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.97)

In [23]:
import time
import copy
from tqdm import tqdm
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt

def train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, device, num_epochs=10, class_names=None):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    # Lists for tracking metrics
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print("-"*10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            all_preds = []
            all_labels = []

            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            if phase == 'train':
                train_losses.append(epoch_loss)
                train_accuracies.append(epoch_acc.item())
            else:
                val_losses.append(epoch_loss)
                val_accuracies.append(epoch_acc.item())

            print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

            # Print classification report
            if class_names is not None:
                print(f"{phase} Classification Report:")
                print(classification_report(all_labels, all_preds, target_names=class_names))
            else:
                print(f"{phase} Classification Report:")
                print(classification_report(all_labels, all_preds))

            # Save best model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print("Best Val Acc: {:.4f}".format(best_acc))

    # Plot Loss
    plt.figure()
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.title('Loss per Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('loss_curve.png')
    plt.close()

    # Plot Accuracy
    plt.figure()
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(val_accuracies, label='Val Accuracy')
    plt.title('Accuracy per Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.savefig('accuracy_curve.png')
    plt.close()

    # Load best model
    model.load_state_dict(best_model_wts)
    return model

In [24]:
model_ft = train_model(model, criterion, optimizer, exp_lr_scheduler, dataloaders, dataset_sizes, device, num_epochs=50)

Epoch 0/49
----------


100%|██████████| 35/35 [03:41<00:00,  6.33s/it]


train Loss: 1.5464 Acc: 0.3181
train Classification Report:
              precision    recall  f1-score   support

           0       0.31      0.38      0.34       874
           1       0.32      0.28      0.30       874
           2       0.27      0.24      0.26       874
           3       0.30      0.26      0.28       874
           4       0.38      0.43      0.40       874

    accuracy                           0.32      4370
   macro avg       0.32      0.32      0.32      4370
weighted avg       0.32      0.32      0.32      4370



100%|██████████| 18/18 [00:38<00:00,  2.16s/it]


val Loss: 1.4686 Acc: 0.3835
val Classification Report:
              precision    recall  f1-score   support

           0       0.37      0.39      0.38       109
           1       0.36      0.61      0.46       109
           2       0.17      0.04      0.06       109
           3       0.33      0.45      0.38       109
           4       0.63      0.43      0.51       109

    accuracy                           0.38       545
   macro avg       0.37      0.38      0.36       545
weighted avg       0.37      0.38      0.36       545


Epoch 1/49
----------


100%|██████████| 35/35 [03:17<00:00,  5.65s/it]


train Loss: 1.4705 Acc: 0.3950
train Classification Report:
              precision    recall  f1-score   support

           0       0.38      0.45      0.41       874
           1       0.44      0.45      0.45       874
           2       0.33      0.24      0.28       874
           3       0.35      0.33      0.34       874
           4       0.45      0.50      0.47       874

    accuracy                           0.39      4370
   macro avg       0.39      0.39      0.39      4370
weighted avg       0.39      0.39      0.39      4370



100%|██████████| 18/18 [00:40<00:00,  2.24s/it]


val Loss: 1.4643 Acc: 0.4239
val Classification Report:
              precision    recall  f1-score   support

           0       0.46      0.24      0.32       109
           1       0.48      0.54      0.51       109
           2       0.33      0.28      0.30       109
           3       0.34      0.51      0.41       109
           4       0.56      0.54      0.55       109

    accuracy                           0.42       545
   macro avg       0.43      0.42      0.42       545
weighted avg       0.43      0.42      0.42       545


Epoch 2/49
----------


100%|██████████| 35/35 [03:14<00:00,  5.56s/it]


train Loss: 1.4566 Acc: 0.4023
train Classification Report:
              precision    recall  f1-score   support

           0       0.39      0.41      0.40       874
           1       0.44      0.48      0.46       874
           2       0.33      0.26      0.29       874
           3       0.36      0.33      0.34       874
           4       0.46      0.53      0.49       874

    accuracy                           0.40      4370
   macro avg       0.40      0.40      0.40      4370
weighted avg       0.40      0.40      0.40      4370



100%|██████████| 18/18 [00:39<00:00,  2.19s/it]


val Loss: 1.4330 Acc: 0.4220
val Classification Report:
              precision    recall  f1-score   support

           0       0.41      0.38      0.39       109
           1       0.51      0.43      0.47       109
           2       0.33      0.43      0.37       109
           3       0.40      0.23      0.29       109
           4       0.48      0.64      0.55       109

    accuracy                           0.42       545
   macro avg       0.42      0.42      0.41       545
weighted avg       0.42      0.42      0.41       545


Epoch 3/49
----------


100%|██████████| 35/35 [03:21<00:00,  5.77s/it]


train Loss: 1.4296 Acc: 0.4167
train Classification Report:
              precision    recall  f1-score   support

           0       0.39      0.41      0.40       874
           1       0.46      0.52      0.49       874
           2       0.33      0.27      0.30       874
           3       0.38      0.31      0.34       874
           4       0.47      0.58      0.52       874

    accuracy                           0.42      4370
   macro avg       0.41      0.42      0.41      4370
weighted avg       0.41      0.42      0.41      4370



100%|██████████| 18/18 [00:40<00:00,  2.27s/it]


val Loss: 1.4309 Acc: 0.4349
val Classification Report:
              precision    recall  f1-score   support

           0       0.44      0.36      0.39       109
           1       0.41      0.67      0.51       109
           2       0.42      0.17      0.25       109
           3       0.34      0.55      0.42       109
           4       0.77      0.42      0.54       109

    accuracy                           0.43       545
   macro avg       0.48      0.43      0.42       545
weighted avg       0.48      0.43      0.42       545


Epoch 4/49
----------


100%|██████████| 35/35 [03:46<00:00,  6.47s/it]


train Loss: 1.4350 Acc: 0.4142
train Classification Report:
              precision    recall  f1-score   support

           0       0.41      0.38      0.39       874
           1       0.46      0.54      0.50       874
           2       0.35      0.31      0.33       874
           3       0.34      0.33      0.34       874
           4       0.48      0.51      0.49       874

    accuracy                           0.41      4370
   macro avg       0.41      0.41      0.41      4370
weighted avg       0.41      0.41      0.41      4370



100%|██████████| 18/18 [00:39<00:00,  2.21s/it]


val Loss: 1.4182 Acc: 0.4147
val Classification Report:
              precision    recall  f1-score   support

           0       0.37      0.47      0.41       109
           1       0.40      0.52      0.45       109
           2       0.35      0.11      0.17       109
           3       0.39      0.34      0.36       109
           4       0.51      0.63      0.57       109

    accuracy                           0.41       545
   macro avg       0.40      0.41      0.39       545
weighted avg       0.40      0.41      0.39       545


Epoch 5/49
----------


100%|██████████| 35/35 [03:31<00:00,  6.03s/it]


train Loss: 1.4192 Acc: 0.4215
train Classification Report:
              precision    recall  f1-score   support

           0       0.39      0.42      0.41       874
           1       0.45      0.55      0.50       874
           2       0.38      0.23      0.29       874
           3       0.37      0.33      0.35       874
           4       0.47      0.57      0.52       874

    accuracy                           0.42      4370
   macro avg       0.41      0.42      0.41      4370
weighted avg       0.41      0.42      0.41      4370



100%|██████████| 18/18 [00:40<00:00,  2.24s/it]


val Loss: 1.4163 Acc: 0.4275
val Classification Report:
              precision    recall  f1-score   support

           0       0.40      0.25      0.31       109
           1       0.42      0.65      0.51       109
           2       0.33      0.36      0.34       109
           3       0.46      0.21      0.29       109
           4       0.53      0.67      0.59       109

    accuracy                           0.43       545
   macro avg       0.43      0.43      0.41       545
weighted avg       0.43      0.43      0.41       545


Epoch 6/49
----------


100%|██████████| 35/35 [03:12<00:00,  5.49s/it]


train Loss: 1.4152 Acc: 0.4279
train Classification Report:
              precision    recall  f1-score   support

           0       0.41      0.44      0.42       874
           1       0.46      0.55      0.50       874
           2       0.35      0.26      0.30       874
           3       0.38      0.35      0.36       874
           4       0.50      0.55      0.52       874

    accuracy                           0.43      4370
   macro avg       0.42      0.43      0.42      4370
weighted avg       0.42      0.43      0.42      4370



100%|██████████| 18/18 [00:38<00:00,  2.12s/it]


val Loss: 1.4281 Acc: 0.3945
val Classification Report:
              precision    recall  f1-score   support

           0       0.37      0.28      0.32       109
           1       0.50      0.46      0.48       109
           2       0.34      0.17      0.23       109
           3       0.34      0.29      0.32       109
           4       0.39      0.76      0.52       109

    accuracy                           0.39       545
   macro avg       0.39      0.39      0.37       545
weighted avg       0.39      0.39      0.37       545


Epoch 7/49
----------


100%|██████████| 35/35 [03:20<00:00,  5.73s/it]


train Loss: 1.4017 Acc: 0.4366
train Classification Report:
              precision    recall  f1-score   support

           0       0.42      0.44      0.43       874
           1       0.47      0.55      0.51       874
           2       0.36      0.26      0.30       874
           3       0.40      0.36      0.38       874
           4       0.49      0.58      0.53       874

    accuracy                           0.44      4370
   macro avg       0.43      0.44      0.43      4370
weighted avg       0.43      0.44      0.43      4370



100%|██████████| 18/18 [00:39<00:00,  2.17s/it]


val Loss: 1.4107 Acc: 0.4495
val Classification Report:
              precision    recall  f1-score   support

           0       0.42      0.50      0.45       109
           1       0.52      0.55      0.54       109
           2       0.35      0.27      0.30       109
           3       0.37      0.32      0.34       109
           4       0.54      0.61      0.58       109

    accuracy                           0.45       545
   macro avg       0.44      0.45      0.44       545
weighted avg       0.44      0.45      0.44       545


Epoch 8/49
----------


100%|██████████| 35/35 [03:21<00:00,  5.76s/it]


train Loss: 1.3998 Acc: 0.4419
train Classification Report:
              precision    recall  f1-score   support

           0       0.41      0.46      0.43       874
           1       0.49      0.57      0.52       874
           2       0.37      0.27      0.32       874
           3       0.40      0.34      0.37       874
           4       0.51      0.57      0.54       874

    accuracy                           0.44      4370
   macro avg       0.43      0.44      0.44      4370
weighted avg       0.43      0.44      0.44      4370



100%|██████████| 18/18 [00:39<00:00,  2.22s/it]


val Loss: 1.4197 Acc: 0.4422
val Classification Report:
              precision    recall  f1-score   support

           0       0.43      0.29      0.35       109
           1       0.49      0.57      0.53       109
           2       0.30      0.59      0.40       109
           3       0.55      0.19      0.29       109
           4       0.64      0.57      0.60       109

    accuracy                           0.44       545
   macro avg       0.48      0.44      0.43       545
weighted avg       0.48      0.44      0.43       545


Epoch 9/49
----------


100%|██████████| 35/35 [03:23<00:00,  5.81s/it]


train Loss: 1.3916 Acc: 0.4481
train Classification Report:
              precision    recall  f1-score   support

           0       0.43      0.43      0.43       874
           1       0.49      0.58      0.53       874
           2       0.37      0.30      0.33       874
           3       0.41      0.36      0.38       874
           4       0.50      0.57      0.53       874

    accuracy                           0.45      4370
   macro avg       0.44      0.45      0.44      4370
weighted avg       0.44      0.45      0.44      4370



100%|██████████| 18/18 [00:38<00:00,  2.14s/it]


val Loss: 1.4157 Acc: 0.4495
val Classification Report:
              precision    recall  f1-score   support

           0       0.37      0.41      0.39       109
           1       0.59      0.46      0.52       109
           2       0.37      0.30      0.33       109
           3       0.35      0.51      0.41       109
           4       0.71      0.56      0.63       109

    accuracy                           0.45       545
   macro avg       0.48      0.45      0.46       545
weighted avg       0.48      0.45      0.46       545


Epoch 10/49
----------


100%|██████████| 35/35 [03:11<00:00,  5.48s/it]


train Loss: 1.3946 Acc: 0.4458
train Classification Report:
              precision    recall  f1-score   support

           0       0.43      0.43      0.43       874
           1       0.50      0.59      0.54       874
           2       0.37      0.25      0.30       874
           3       0.38      0.38      0.38       874
           4       0.50      0.58      0.53       874

    accuracy                           0.45      4370
   macro avg       0.44      0.45      0.44      4370
weighted avg       0.44      0.45      0.44      4370



100%|██████████| 18/18 [00:38<00:00,  2.12s/it]


val Loss: 1.4080 Acc: 0.4514
val Classification Report:
              precision    recall  f1-score   support

           0       0.34      0.62      0.44       109
           1       0.47      0.65      0.54       109
           2       0.42      0.28      0.34       109
           3       0.51      0.26      0.34       109
           4       0.75      0.44      0.55       109

    accuracy                           0.45       545
   macro avg       0.50      0.45      0.44       545
weighted avg       0.50      0.45      0.44       545


Epoch 11/49
----------


100%|██████████| 35/35 [03:07<00:00,  5.35s/it]


train Loss: 1.3837 Acc: 0.4540
train Classification Report:
              precision    recall  f1-score   support

           0       0.43      0.46      0.44       874
           1       0.50      0.60      0.55       874
           2       0.40      0.32      0.35       874
           3       0.41      0.34      0.37       874
           4       0.50      0.55      0.52       874

    accuracy                           0.45      4370
   macro avg       0.45      0.45      0.45      4370
weighted avg       0.45      0.45      0.45      4370



100%|██████████| 18/18 [00:38<00:00,  2.14s/it]


val Loss: 1.3856 Acc: 0.4532
val Classification Report:
              precision    recall  f1-score   support

           0       0.44      0.33      0.38       109
           1       0.50      0.57      0.53       109
           2       0.39      0.29      0.33       109
           3       0.36      0.47      0.41       109
           4       0.57      0.61      0.59       109

    accuracy                           0.45       545
   macro avg       0.45      0.45      0.45       545
weighted avg       0.45      0.45      0.45       545


Epoch 12/49
----------


100%|██████████| 35/35 [03:08<00:00,  5.38s/it]


train Loss: 1.3766 Acc: 0.4556
train Classification Report:
              precision    recall  f1-score   support

           0       0.45      0.46      0.45       874
           1       0.50      0.58      0.54       874
           2       0.38      0.31      0.34       874
           3       0.39      0.36      0.38       874
           4       0.53      0.57      0.55       874

    accuracy                           0.46      4370
   macro avg       0.45      0.46      0.45      4370
weighted avg       0.45      0.46      0.45      4370



100%|██████████| 18/18 [00:37<00:00,  2.11s/it]


val Loss: 1.4106 Acc: 0.4440
val Classification Report:
              precision    recall  f1-score   support

           0       0.35      0.41      0.38       109
           1       0.47      0.61      0.53       109
           2       0.41      0.28      0.34       109
           3       0.57      0.22      0.32       109
           4       0.48      0.69      0.57       109

    accuracy                           0.44       545
   macro avg       0.46      0.44      0.43       545
weighted avg       0.46      0.44      0.43       545


Epoch 13/49
----------


100%|██████████| 35/35 [03:04<00:00,  5.28s/it]


train Loss: 1.3707 Acc: 0.4568
train Classification Report:
              precision    recall  f1-score   support

           0       0.44      0.48      0.46       874
           1       0.51      0.56      0.53       874
           2       0.39      0.31      0.34       874
           3       0.41      0.35      0.37       874
           4       0.51      0.58      0.54       874

    accuracy                           0.46      4370
   macro avg       0.45      0.46      0.45      4370
weighted avg       0.45      0.46      0.45      4370



100%|██████████| 18/18 [00:38<00:00,  2.16s/it]


val Loss: 1.3889 Acc: 0.4569
val Classification Report:
              precision    recall  f1-score   support

           0       0.35      0.47      0.40       109
           1       0.53      0.48      0.50       109
           2       0.36      0.39      0.38       109
           3       0.47      0.37      0.41       109
           4       0.63      0.58      0.60       109

    accuracy                           0.46       545
   macro avg       0.47      0.46      0.46       545
weighted avg       0.47      0.46      0.46       545


Epoch 14/49
----------


100%|██████████| 35/35 [03:21<00:00,  5.76s/it]


train Loss: 1.3665 Acc: 0.4648
train Classification Report:
              precision    recall  f1-score   support

           0       0.45      0.47      0.46       874
           1       0.50      0.61      0.55       874
           2       0.39      0.29      0.33       874
           3       0.42      0.37      0.39       874
           4       0.52      0.58      0.55       874

    accuracy                           0.46      4370
   macro avg       0.46      0.46      0.46      4370
weighted avg       0.46      0.46      0.46      4370



100%|██████████| 18/18 [00:38<00:00,  2.13s/it]


val Loss: 1.4032 Acc: 0.4477
val Classification Report:
              precision    recall  f1-score   support

           0       0.33      0.38      0.35       109
           1       0.56      0.48      0.51       109
           2       0.36      0.35      0.35       109
           3       0.41      0.48      0.44       109
           4       0.63      0.56      0.59       109

    accuracy                           0.45       545
   macro avg       0.46      0.45      0.45       545
weighted avg       0.46      0.45      0.45       545


Epoch 15/49
----------


100%|██████████| 35/35 [03:06<00:00,  5.32s/it]


train Loss: 1.3604 Acc: 0.4634
train Classification Report:
              precision    recall  f1-score   support

           0       0.44      0.46      0.45       874
           1       0.52      0.60      0.56       874
           2       0.39      0.35      0.37       874
           3       0.41      0.34      0.37       874
           4       0.52      0.56      0.54       874

    accuracy                           0.46      4370
   macro avg       0.46      0.46      0.46      4370
weighted avg       0.46      0.46      0.46      4370



100%|██████████| 18/18 [00:39<00:00,  2.17s/it]


val Loss: 1.3984 Acc: 0.4477
val Classification Report:
              precision    recall  f1-score   support

           0       0.35      0.48      0.40       109
           1       0.47      0.59      0.52       109
           2       0.41      0.22      0.29       109
           3       0.44      0.35      0.39       109
           4       0.57      0.61      0.59       109

    accuracy                           0.45       545
   macro avg       0.45      0.45      0.44       545
weighted avg       0.45      0.45      0.44       545


Epoch 16/49
----------


100%|██████████| 35/35 [03:18<00:00,  5.67s/it]


train Loss: 1.3539 Acc: 0.4730
train Classification Report:
              precision    recall  f1-score   support

           0       0.45      0.47      0.46       874
           1       0.53      0.60      0.56       874
           2       0.40      0.29      0.34       874
           3       0.42      0.39      0.41       874
           4       0.52      0.61      0.56       874

    accuracy                           0.47      4370
   macro avg       0.47      0.47      0.47      4370
weighted avg       0.47      0.47      0.47      4370



100%|██████████| 18/18 [00:41<00:00,  2.32s/it]


val Loss: 1.3723 Acc: 0.4734
val Classification Report:
              precision    recall  f1-score   support

           0       0.39      0.54      0.45       109
           1       0.50      0.68      0.58       109
           2       0.36      0.27      0.31       109
           3       0.54      0.31      0.40       109
           4       0.60      0.57      0.58       109

    accuracy                           0.47       545
   macro avg       0.48      0.47      0.46       545
weighted avg       0.48      0.47      0.46       545


Epoch 17/49
----------


100%|██████████| 35/35 [03:09<00:00,  5.42s/it]


train Loss: 1.3567 Acc: 0.4714
train Classification Report:
              precision    recall  f1-score   support

           0       0.45      0.48      0.47       874
           1       0.51      0.62      0.56       874
           2       0.41      0.31      0.35       874
           3       0.42      0.36      0.39       874
           4       0.53      0.59      0.56       874

    accuracy                           0.47      4370
   macro avg       0.46      0.47      0.46      4370
weighted avg       0.46      0.47      0.46      4370



100%|██████████| 18/18 [00:38<00:00,  2.14s/it]


val Loss: 1.3920 Acc: 0.4587
val Classification Report:
              precision    recall  f1-score   support

           0       0.37      0.54      0.44       109
           1       0.62      0.55      0.59       109
           2       0.36      0.28      0.32       109
           3       0.40      0.40      0.40       109
           4       0.59      0.51      0.55       109

    accuracy                           0.46       545
   macro avg       0.47      0.46      0.46       545
weighted avg       0.47      0.46      0.46       545


Epoch 18/49
----------


100%|██████████| 35/35 [03:17<00:00,  5.65s/it]


train Loss: 1.3494 Acc: 0.4737
train Classification Report:
              precision    recall  f1-score   support

           0       0.47      0.51      0.48       874
           1       0.52      0.62      0.57       874
           2       0.40      0.29      0.33       874
           3       0.40      0.36      0.38       874
           4       0.54      0.59      0.57       874

    accuracy                           0.47      4370
   macro avg       0.46      0.47      0.47      4370
weighted avg       0.46      0.47      0.47      4370



100%|██████████| 18/18 [00:39<00:00,  2.20s/it]


val Loss: 1.4119 Acc: 0.4459
val Classification Report:
              precision    recall  f1-score   support

           0       0.44      0.40      0.42       109
           1       0.53      0.50      0.52       109
           2       0.36      0.42      0.39       109
           3       0.47      0.21      0.29       109
           4       0.45      0.69      0.55       109

    accuracy                           0.45       545
   macro avg       0.45      0.45      0.43       545
weighted avg       0.45      0.45      0.43       545


Epoch 19/49
----------


100%|██████████| 35/35 [03:18<00:00,  5.67s/it]


train Loss: 1.3494 Acc: 0.4725
train Classification Report:
              precision    recall  f1-score   support

           0       0.44      0.47      0.46       874
           1       0.53      0.60      0.56       874
           2       0.39      0.32      0.35       874
           3       0.45      0.36      0.40       874
           4       0.52      0.61      0.56       874

    accuracy                           0.47      4370
   macro avg       0.47      0.47      0.47      4370
weighted avg       0.47      0.47      0.47      4370



100%|██████████| 18/18 [00:40<00:00,  2.26s/it]


val Loss: 1.3886 Acc: 0.4459
val Classification Report:
              precision    recall  f1-score   support

           0       0.35      0.46      0.40       109
           1       0.50      0.55      0.52       109
           2       0.37      0.23      0.28       109
           3       0.39      0.39      0.39       109
           4       0.61      0.61      0.61       109

    accuracy                           0.45       545
   macro avg       0.45      0.45      0.44       545
weighted avg       0.45      0.45      0.44       545


Epoch 20/49
----------


100%|██████████| 35/35 [03:15<00:00,  5.59s/it]


train Loss: 1.3367 Acc: 0.4805
train Classification Report:
              precision    recall  f1-score   support

           0       0.45      0.49      0.47       874
           1       0.52      0.64      0.58       874
           2       0.40      0.29      0.34       874
           3       0.44      0.39      0.41       874
           4       0.55      0.59      0.57       874

    accuracy                           0.48      4370
   macro avg       0.47      0.48      0.47      4370
weighted avg       0.47      0.48      0.47      4370



100%|██████████| 18/18 [00:39<00:00,  2.18s/it]


val Loss: 1.3933 Acc: 0.4899
val Classification Report:
              precision    recall  f1-score   support

           0       0.49      0.29      0.37       109
           1       0.64      0.59      0.61       109
           2       0.36      0.42      0.39       109
           3       0.46      0.48      0.47       109
           4       0.53      0.67      0.59       109

    accuracy                           0.49       545
   macro avg       0.50      0.49      0.49       545
weighted avg       0.50      0.49      0.49       545


Epoch 21/49
----------


100%|██████████| 35/35 [03:23<00:00,  5.80s/it]


train Loss: 1.3437 Acc: 0.4728
train Classification Report:
              precision    recall  f1-score   support

           0       0.46      0.46      0.46       874
           1       0.53      0.61      0.57       874
           2       0.39      0.37      0.38       874
           3       0.41      0.34      0.37       874
           4       0.54      0.58      0.56       874

    accuracy                           0.47      4370
   macro avg       0.47      0.47      0.47      4370
weighted avg       0.47      0.47      0.47      4370



100%|██████████| 18/18 [00:40<00:00,  2.23s/it]


val Loss: 1.3636 Acc: 0.4606
val Classification Report:
              precision    recall  f1-score   support

           0       0.41      0.48      0.44       109
           1       0.47      0.69      0.56       109
           2       0.35      0.25      0.29       109
           3       0.53      0.28      0.37       109
           4       0.53      0.61      0.56       109

    accuracy                           0.46       545
   macro avg       0.46      0.46      0.45       545
weighted avg       0.46      0.46      0.45       545


Epoch 22/49
----------


100%|██████████| 35/35 [03:20<00:00,  5.72s/it]


train Loss: 1.3340 Acc: 0.4785
train Classification Report:
              precision    recall  f1-score   support

           0       0.45      0.50      0.47       874
           1       0.53      0.62      0.57       874
           2       0.42      0.32      0.36       874
           3       0.43      0.37      0.40       874
           4       0.54      0.59      0.56       874

    accuracy                           0.48      4370
   macro avg       0.47      0.48      0.47      4370
weighted avg       0.47      0.48      0.47      4370



100%|██████████| 18/18 [00:39<00:00,  2.21s/it]


val Loss: 1.3782 Acc: 0.4532
val Classification Report:
              precision    recall  f1-score   support

           0       0.35      0.58      0.43       109
           1       0.57      0.58      0.57       109
           2       0.39      0.28      0.33       109
           3       0.44      0.27      0.33       109
           4       0.58      0.56      0.57       109

    accuracy                           0.45       545
   macro avg       0.46      0.45      0.45       545
weighted avg       0.46      0.45      0.45       545


Epoch 23/49
----------


100%|██████████| 35/35 [03:16<00:00,  5.61s/it]


train Loss: 1.3332 Acc: 0.4817
train Classification Report:
              precision    recall  f1-score   support

           0       0.46      0.48      0.47       874
           1       0.52      0.61      0.56       874
           2       0.41      0.35      0.38       874
           3       0.42      0.38      0.40       874
           4       0.55      0.59      0.57       874

    accuracy                           0.48      4370
   macro avg       0.48      0.48      0.48      4370
weighted avg       0.48      0.48      0.48      4370



100%|██████████| 18/18 [00:38<00:00,  2.15s/it]


val Loss: 1.3848 Acc: 0.4679
val Classification Report:
              precision    recall  f1-score   support

           0       0.38      0.41      0.39       109
           1       0.54      0.57      0.55       109
           2       0.43      0.38      0.40       109
           3       0.49      0.36      0.41       109
           4       0.50      0.62      0.56       109

    accuracy                           0.47       545
   macro avg       0.47      0.47      0.46       545
weighted avg       0.47      0.47      0.46       545


Epoch 24/49
----------


100%|██████████| 35/35 [03:21<00:00,  5.75s/it]


train Loss: 1.3312 Acc: 0.4776
train Classification Report:
              precision    recall  f1-score   support

           0       0.46      0.53      0.49       874
           1       0.52      0.59      0.55       874
           2       0.40      0.31      0.35       874
           3       0.43      0.35      0.39       874
           4       0.53      0.62      0.57       874

    accuracy                           0.48      4370
   macro avg       0.47      0.48      0.47      4370
weighted avg       0.47      0.48      0.47      4370



100%|██████████| 18/18 [00:40<00:00,  2.26s/it]


val Loss: 1.3792 Acc: 0.4642
val Classification Report:
              precision    recall  f1-score   support

           0       0.37      0.51      0.43       109
           1       0.59      0.55      0.57       109
           2       0.39      0.25      0.30       109
           3       0.42      0.38      0.40       109
           4       0.54      0.63      0.58       109

    accuracy                           0.46       545
   macro avg       0.46      0.46      0.46       545
weighted avg       0.46      0.46      0.46       545


Epoch 25/49
----------


100%|██████████| 35/35 [03:22<00:00,  5.80s/it]


train Loss: 1.3148 Acc: 0.4968
train Classification Report:
              precision    recall  f1-score   support

           0       0.46      0.49      0.47       874
           1       0.54      0.65      0.59       874
           2       0.45      0.33      0.38       874
           3       0.44      0.41      0.43       874
           4       0.56      0.60      0.58       874

    accuracy                           0.50      4370
   macro avg       0.49      0.50      0.49      4370
weighted avg       0.49      0.50      0.49      4370



100%|██████████| 18/18 [00:39<00:00,  2.20s/it]


val Loss: 1.3701 Acc: 0.4661
val Classification Report:
              precision    recall  f1-score   support

           0       0.37      0.46      0.41       109
           1       0.59      0.47      0.52       109
           2       0.33      0.38      0.35       109
           3       0.55      0.40      0.47       109
           4       0.56      0.62      0.59       109

    accuracy                           0.47       545
   macro avg       0.48      0.47      0.47       545
weighted avg       0.48      0.47      0.47       545


Epoch 26/49
----------


100%|██████████| 35/35 [03:34<00:00,  6.12s/it]


train Loss: 1.3309 Acc: 0.4886
train Classification Report:
              precision    recall  f1-score   support

           0       0.47      0.50      0.48       874
           1       0.54      0.60      0.57       874
           2       0.42      0.33      0.37       874
           3       0.43      0.40      0.41       874
           4       0.55      0.61      0.58       874

    accuracy                           0.49      4370
   macro avg       0.48      0.49      0.48      4370
weighted avg       0.48      0.49      0.48      4370



100%|██████████| 18/18 [00:39<00:00,  2.21s/it]


val Loss: 1.3818 Acc: 0.4697
val Classification Report:
              precision    recall  f1-score   support

           0       0.39      0.49      0.43       109
           1       0.49      0.57      0.53       109
           2       0.37      0.43      0.40       109
           3       0.46      0.34      0.39       109
           4       0.75      0.52      0.62       109

    accuracy                           0.47       545
   macro avg       0.49      0.47      0.47       545
weighted avg       0.49      0.47      0.47       545


Epoch 27/49
----------


 80%|████████  | 28/35 [02:33<00:38,  5.48s/it]


KeyboardInterrupt: 

## Testing

Ok, now we finished training. Lets run the dataset on the test loader and calculate accuracy

In [None]:
test_loss = 0.0
class_correct = list(0 for i in range(len(classes)))
class_total = list(0 for i in range(len(classes)))
model_ft.eval()

for data, target in tqdm(test_loader):
    data, target = data.to(device), target.to(device)
    with torch.no_grad(): # turn off autograd for faster testing
        output = model_ft(data)
        loss = criterion(output, target)
    test_loss = loss.item() * data.size(0)
    _, pred = torch.max(output, 1)
    correct_tensor = pred.eq(target.data.view_as(pred))
    correct = np.squeeze(correct_tensor.cpu().numpy())
    if len(target) == 32:
        for i in range(32):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1

test_loss = test_loss / test_data_len
print('Test Loss: {:.4f}'.format(test_loss))
for i in range(len(classes)):
    if class_total[i] > 0:
        print("Test Accuracy of %5s: %2d%% (%2d/%2d)" % (
            classes[i], 100*class_correct[i]/class_total[i], np.sum(class_correct[i]), np.sum(class_total[i])
        ))
    else:
        print("Test accuracy of %5s: NA" % (classes[i]))
print("Test Accuracy of %2d%% (%2d/%2d)" % (
            100*np.sum(class_correct)/np.sum(class_total), np.sum(class_correct), np.sum(class_total)
        ))

In [None]:
# Save model
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model.cpu(), example)
traced_script_module.save("/Users/zagarsuren/Documents/GitHub/swin-transformer-xray/models/xray_swin_transformer.pt")