# Building a Shifted Window Transformer for Multi-class Classification of Breast Cancer Histopathological Dataset(BreaKHis Dataset)



*   In this notebook, we've built a Swin Transformer, whose parameters are as close to that described in the original Shifted Window Transformers paper titled, "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows".
*   Since the dataset can be used for both binary and multi-class classification, for this project, we've narrowed it down to just multi-class classification.
*   The dataset also captures the tumours at multiple zoom levels(40x, 100x, 200x and 400x), to remove unnecessary complexity, we've just focused on the 400x zoom images.

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### Importing libraries and setting up device agnostic code

In [2]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"

### Setting up transforms and dataloader

In [3]:
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import v2
import torch.nn.functional as F
from torch.nn import init

FLIP_PROBABILITY = 0.1

data_transform = v2.Compose([
    v2.Resize(size=(224,224)),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomRotation(degrees=(-10,10)),
    v2.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05),
    v2.RandomAffine(degrees=0, translate=(0.05, 0.05)),
    v2.ToTensor(),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Add normalization
])



### Getting dataset and class labels

In [4]:
train_data = datasets.ImageFolder(root="/content/drive/MyDrive/dataset_v3/train",
                                  transform=data_transform,
                                  target_transform=None)
test_data = datasets.ImageFolder(root="/content/drive/MyDrive/dataset_v3/test",
                                  transform=data_transform)
class_names = train_data.classes

### Preprocessing the dataset

The dataset is imbalanced, as the training and test examples for 'ductal_carcinoma' type of tumour are ~5x that of other cases.

So, we've randomly reduced the samples to come on line with the numbers from other cases so it doesn't imbalance the model.

In [5]:
from sklearn.utils import resample

def balance_dataset(dataset, size):
    # Separate data by class
    class_data = {i: [] for i in range(8)}  # For 8 classes

    # Group data by class
    for idx in range(len(dataset)):
        _, label = dataset[idx]
        class_data[label].append(idx)

    # Randomly sample indices from each class
    balanced_indices = []
    for class_idx, indices in class_data.items():
        # If class has more samples than min_size, downsample it
        if len(indices) > size:
            balanced_indices.extend(np.random.choice(indices, size=size, replace=False))
        else:
            balanced_indices.extend(indices)

    # Create a subset dataset
    from torch.utils.data import Subset
    balanced_dataset = Subset(dataset, balanced_indices)

    return balanced_dataset

# Use it like this:
balanced_train_data = balance_dataset(train_data, 200)

balanced_test_data = balance_dataset(test_data, 50)

### Loading up the DataLoader

In [6]:
from torch.utils.data import DataLoader

BATCH_SIZE = 16
train_dataloader = DataLoader(balanced_train_data, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(balanced_test_data, batch_size=BATCH_SIZE, shuffle=False)

## Implementing the Swin Transformer

### Defining the Swin Transformer

In [7]:
import torch.nn.functional as F
from torch.nn import init

class PatchEmbed(nn.Module):
    """Split image into patches and then embed them."""
    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96):
        super().__init__()
        self.img_size = (img_size, img_size)
        self.patch_size = (patch_size, patch_size)
        self.patches_resolution = [img_size // patch_size, img_size // patch_size]
        self.num_patches = (img_size // patch_size) ** 2

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        x = self.norm(x)
        return x

class WindowAttention(nn.Module):
    """Window based multi-head self attention."""
    def __init__(self, dim, window_size, num_heads, qkv_bias=True):
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))

        # Get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(window_size[0])
        coords_w = torch.arange(window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += window_size[0] - 1
        relative_coords[:, :, 1] += window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)

        init.trunc_normal_(self.relative_position_bias_table, std=.02)

    def forward(self, x, mask=None):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)

        attn = F.softmax(attn, dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        return x

class SwinTransformerBlock(nn.Module):
    """Swin Transformer Block."""
    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(
            dim, window_size=(window_size, window_size), num_heads=num_heads, qkv_bias=qkv_bias)

        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, dim)
        )

    def forward(self, x, H, W):
        B, L, C = x.shape
        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # pad feature maps to multiples of window size
        pad_l = pad_t = 0
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        if pad_r > 0 or pad_b > 0:
            x = x[:, :H, :W, :].contiguous()

        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + x
        x = x + self.mlp(self.norm2(x))

        return x

class SwinTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True,
                 drop_rate=0., attn_drop_rate=0.):
        super().__init__()
        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_rate, sum(depths))]

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(
                dim=int(embed_dim * 2 ** i_layer),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                window_size=window_size,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                attn_drop=attn_drop_rate,
                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None
            )
            self.layers.append(layer)

        self.norm = nn.LayerNorm(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        x = self.avgpool(x.transpose(1, 2))
        x = torch.flatten(x, 1)
        x = self.head(x)
        return x

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

class PatchMerging(nn.Module):
    """Patch Merging Layer."""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = nn.LayerNorm(4 * dim)

    def forward(self, x, H, W):
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        x = x.view(B, H, W, C)

        # padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

class BasicLayer(nn.Module):
    """A basic Swin Transformer layer for one stage."""
    def __init__(self, dim, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
                 downsample=None):
        super().__init__()
        self.dim = dim
        self.depth = depth
        self.window_size = window_size
        self.shift_size = window_size // 2

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=0 if (i % 2 == 0) else self.shift_size,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop[i] if isinstance(drop, list) else drop,
                attn_drop=attn_drop)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(dim=dim)
        else:
            self.downsample = None

    def forward(self, x):
        B, L, C = x.shape
        H = W = int(L ** 0.5)

        for blk in self.blocks:
            x = blk(x, H, W)

        if self.downsample is not None:
            x = self.downsample(x, H, W)
            H, W = H // 2, W // 2

        return x

## Running the SwinT

In [8]:
import torch
import torch.nn as nn
from tqdm import tqdm
from torchvision import datasets, transforms
import os

# Create the Vision Transformer model
model = SwinTransformer(
    img_size=224,
    patch_size=4,
    in_chans=3,
    num_classes=len(class_names),
    embed_dim=96,
    depths=[2, 2, 6, 2],
    num_heads=[3, 6, 12, 24],
    window_size=7,
    mlp_ratio=4,
    qkv_bias=True,
    drop_rate=0.1,
    attn_drop_rate=0.1
).to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.05)
# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)

# Learning rate scheduler
steps_per_epoch = len(train_dataloader)
total_steps = steps_per_epoch * 100  # 100 epochs

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    progress_bar = tqdm(dataloader, desc='Training')

    for batch_idx, (images, labels) in enumerate(progress_bar):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{total_loss/(batch_idx+1):.4f}',
            'acc': f'{100.*correct/total:.2f}%'
        })

    return total_loss / len(dataloader), 100. * correct / total

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    progress_bar = tqdm(dataloader, desc='Testing')

    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(progress_bar):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            # Update progress bar
            progress_bar.set_postfix({
                'loss': f'{total_loss/(batch_idx+1):.4f}',
                'acc': f'{100.*correct/total:.2f}%'
            })

    return total_loss / len(dataloader), 100. * correct / total

# Training loop
num_epochs = 100
best_acc = 0
best_model_state = None
best_results = None

print("Starting training...")
for epoch in range(num_epochs):
    print(f'\nEpoch: {epoch+1}/{num_epochs}')

    # Train
    train_loss, train_acc = train_one_epoch(model, train_dataloader, criterion, optimizer, device)

    # Evaluate
    test_loss, test_acc = evaluate(model, test_dataloader, criterion, device)
    if epoch % 10 == 0:
      print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
      print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')

    scheduler.step()

    # Save best model
    if test_acc > best_acc:
        best_acc = test_acc
        best_results = (test_loss, test_acc)
        best_model_state = model.state_dict().copy()

        # Save the checkpoint
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'train_acc': train_acc,
            'test_loss': test_loss,
            'test_acc': test_acc,
        }
        torch.save(checkpoint, 'best_swin_model.pth')
        print(f'New best model saved! Accuracy: {best_acc:.2f}%')

# Load the best model before finishing
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print('\nLoaded best model')
    print(f'Best model performance:')
    print(f'Test Loss: {best_results[0]:.4f}, Test Acc: {best_results[1]:.2f}%')

print('\nTraining completed!')

Starting training...

Epoch: 1/100


Training: 100%|██████████| 64/64 [00:26<00:00,  2.37it/s, loss=2.1674, acc=22.78%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.18it/s, loss=1.9183, acc=31.27%]


Train Loss: 2.1674, Train Acc: 22.78%
Test Loss: 1.9183, Test Acc: 31.27%
New best model saved! Accuracy: 31.27%

Epoch: 2/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.54it/s, loss=1.8953, acc=30.89%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.14it/s, loss=1.8927, acc=33.20%]


New best model saved! Accuracy: 33.20%

Epoch: 3/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.55it/s, loss=1.9168, acc=30.11%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.10it/s, loss=1.9956, acc=17.37%]



Epoch: 4/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.54it/s, loss=1.9653, acc=28.05%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.77it/s, loss=1.8901, acc=31.27%]



Epoch: 5/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.54it/s, loss=1.8041, acc=33.33%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.51it/s, loss=1.8025, acc=33.20%]



Epoch: 6/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.52it/s, loss=1.8459, acc=32.65%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.69it/s, loss=1.8125, acc=31.66%]



Epoch: 7/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.49it/s, loss=1.8628, acc=32.36%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.16it/s, loss=1.8023, acc=33.98%]


New best model saved! Accuracy: 33.98%

Epoch: 8/100


Training: 100%|██████████| 64/64 [00:26<00:00,  2.41it/s, loss=1.7306, acc=39.30%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.08it/s, loss=1.7452, acc=36.68%]


New best model saved! Accuracy: 36.68%

Epoch: 9/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=1.7699, acc=35.29%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.90it/s, loss=1.7971, acc=34.36%]



Epoch: 10/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.49it/s, loss=1.7850, acc=34.60%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.48it/s, loss=1.7797, acc=34.75%]



Epoch: 11/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.50it/s, loss=1.6956, acc=38.81%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.67it/s, loss=1.7403, acc=39.00%]


Train Loss: 1.6956, Train Acc: 38.81%
Test Loss: 1.7403, Test Acc: 39.00%
New best model saved! Accuracy: 39.00%

Epoch: 12/100


Training: 100%|██████████| 64/64 [00:26<00:00,  2.45it/s, loss=1.6982, acc=39.78%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.18it/s, loss=1.8829, acc=33.98%]



Epoch: 13/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=1.7528, acc=35.48%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.16it/s, loss=1.7973, acc=34.75%]



Epoch: 14/100


Training: 100%|██████████| 64/64 [00:26<00:00,  2.45it/s, loss=1.6192, acc=43.21%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.13it/s, loss=1.6735, acc=42.47%]


New best model saved! Accuracy: 42.47%

Epoch: 15/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.46it/s, loss=1.5832, acc=44.57%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.44it/s, loss=1.7467, acc=35.14%]



Epoch: 16/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.50it/s, loss=1.6613, acc=40.76%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.70it/s, loss=1.7235, acc=40.54%]



Epoch: 17/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.46it/s, loss=1.5169, acc=48.09%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.12it/s, loss=1.5681, acc=46.33%]


New best model saved! Accuracy: 46.33%

Epoch: 18/100


Training: 100%|██████████| 64/64 [00:26<00:00,  2.46it/s, loss=1.4953, acc=49.95%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.14it/s, loss=1.6245, acc=43.63%]



Epoch: 19/100


Training: 100%|██████████| 64/64 [00:26<00:00,  2.46it/s, loss=1.5195, acc=47.90%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.09it/s, loss=1.5697, acc=48.65%]


New best model saved! Accuracy: 48.65%

Epoch: 20/100


Training: 100%|██████████| 64/64 [00:26<00:00,  2.42it/s, loss=1.4318, acc=52.20%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.43it/s, loss=1.4948, acc=51.35%]


New best model saved! Accuracy: 51.35%

Epoch: 21/100


Training: 100%|██████████| 64/64 [00:26<00:00,  2.39it/s, loss=1.3861, acc=56.30%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.11it/s, loss=1.5456, acc=50.19%]


Train Loss: 1.3861, Train Acc: 56.30%
Test Loss: 1.5456, Test Acc: 50.19%

Epoch: 22/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=1.4844, acc=52.30%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.09it/s, loss=1.6774, acc=39.38%]



Epoch: 23/100


Training: 100%|██████████| 64/64 [00:26<00:00,  2.44it/s, loss=1.3858, acc=56.30%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.18it/s, loss=1.4809, acc=52.90%]


New best model saved! Accuracy: 52.90%

Epoch: 24/100


Training: 100%|██████████| 64/64 [00:26<00:00,  2.45it/s, loss=1.3114, acc=59.04%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.71it/s, loss=1.5785, acc=50.58%]



Epoch: 25/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.49it/s, loss=1.4469, acc=52.98%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.44it/s, loss=1.4912, acc=51.35%]



Epoch: 26/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.49it/s, loss=1.3270, acc=59.14%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.91it/s, loss=1.4147, acc=59.46%]


New best model saved! Accuracy: 59.46%

Epoch: 27/100


Training: 100%|██████████| 64/64 [00:26<00:00,  2.40it/s, loss=1.2310, acc=62.95%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.16it/s, loss=1.5003, acc=50.19%]



Epoch: 28/100


Training: 100%|██████████| 64/64 [00:26<00:00,  2.46it/s, loss=1.4008, acc=56.21%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.15it/s, loss=1.4373, acc=52.51%]



Epoch: 29/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=1.2580, acc=60.61%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.82it/s, loss=1.3663, acc=55.98%]



Epoch: 30/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.50it/s, loss=1.1813, acc=65.88%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.43it/s, loss=1.3463, acc=60.62%]


New best model saved! Accuracy: 60.62%

Epoch: 31/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.48it/s, loss=1.3019, acc=60.22%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.04it/s, loss=1.4066, acc=53.28%]


Train Loss: 1.3019, Train Acc: 60.22%
Test Loss: 1.4066, Test Acc: 53.28%

Epoch: 32/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.46it/s, loss=1.2327, acc=64.52%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.14it/s, loss=1.3029, acc=59.07%]



Epoch: 33/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=1.1247, acc=69.40%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.16it/s, loss=1.3148, acc=63.32%]


New best model saved! Accuracy: 63.32%

Epoch: 34/100


Training: 100%|██████████| 64/64 [00:26<00:00,  2.40it/s, loss=1.2068, acc=66.67%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.43it/s, loss=1.3953, acc=58.69%]



Epoch: 35/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.50it/s, loss=1.2316, acc=63.64%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.79it/s, loss=1.3012, acc=58.30%]



Epoch: 36/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.48it/s, loss=1.0787, acc=72.24%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.15it/s, loss=1.2443, acc=62.16%]



Epoch: 37/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.49it/s, loss=1.2189, acc=64.81%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.17it/s, loss=1.7328, acc=48.26%]



Epoch: 38/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.49it/s, loss=1.2687, acc=59.63%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.97it/s, loss=1.2465, acc=62.55%]



Epoch: 39/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.49it/s, loss=1.0310, acc=72.83%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.51it/s, loss=1.2098, acc=64.09%]


New best model saved! Accuracy: 64.09%

Epoch: 40/100


Training: 100%|██████████| 64/64 [00:26<00:00,  2.40it/s, loss=1.1502, acc=67.16%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.05it/s, loss=1.4047, acc=57.53%]



Epoch: 41/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=1.2235, acc=62.37%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.09it/s, loss=1.2891, acc=65.64%]


Train Loss: 1.2235, Train Acc: 62.37%
Test Loss: 1.2891, Test Acc: 65.64%
New best model saved! Accuracy: 65.64%

Epoch: 42/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=1.0177, acc=74.49%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.98it/s, loss=1.1707, acc=69.11%]


New best model saved! Accuracy: 69.11%

Epoch: 43/100


Training: 100%|██████████| 64/64 [00:26<00:00,  2.45it/s, loss=1.1213, acc=70.19%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.68it/s, loss=1.3316, acc=60.23%]



Epoch: 44/100


Training: 100%|██████████| 64/64 [00:26<00:00,  2.46it/s, loss=1.1705, acc=66.37%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.12it/s, loss=1.1901, acc=66.41%]



Epoch: 45/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.49it/s, loss=0.9725, acc=76.83%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.15it/s, loss=1.1259, acc=66.80%]



Epoch: 46/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.48it/s, loss=1.0204, acc=74.49%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.02it/s, loss=1.1860, acc=66.41%]



Epoch: 47/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.49it/s, loss=1.1249, acc=70.58%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.52it/s, loss=1.2983, acc=57.14%]



Epoch: 48/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.50it/s, loss=0.9599, acc=77.81%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.80it/s, loss=1.0861, acc=73.36%]


New best model saved! Accuracy: 73.36%

Epoch: 49/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.46it/s, loss=0.9929, acc=74.29%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.13it/s, loss=1.2342, acc=61.00%]



Epoch: 50/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.48it/s, loss=1.0561, acc=71.95%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.20it/s, loss=1.2405, acc=62.93%]



Epoch: 51/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=0.9190, acc=77.91%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.65it/s, loss=1.1146, acc=70.66%]


Train Loss: 0.9190, Train Acc: 77.91%
Test Loss: 1.1146, Test Acc: 70.66%

Epoch: 52/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.49it/s, loss=0.9329, acc=78.10%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.66it/s, loss=1.3927, acc=57.53%]



Epoch: 53/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=1.1226, acc=70.19%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.19it/s, loss=1.2566, acc=62.55%]



Epoch: 54/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=0.9219, acc=78.20%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.10it/s, loss=1.0610, acc=74.52%]


New best model saved! Accuracy: 74.52%

Epoch: 55/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.48it/s, loss=0.9004, acc=80.25%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.89it/s, loss=1.1163, acc=65.64%]



Epoch: 56/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.49it/s, loss=1.0641, acc=71.46%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.49it/s, loss=1.2065, acc=65.64%]



Epoch: 57/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=0.8686, acc=82.31%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.12it/s, loss=1.0194, acc=74.13%]



Epoch: 58/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.48it/s, loss=0.8629, acc=82.40%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.22it/s, loss=1.2245, acc=61.78%]



Epoch: 59/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.48it/s, loss=1.0906, acc=71.36%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.15it/s, loss=1.2022, acc=64.86%]



Epoch: 60/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=0.8696, acc=81.33%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.60it/s, loss=1.0513, acc=70.66%]



Epoch: 61/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.51it/s, loss=0.8218, acc=84.56%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.58it/s, loss=1.1986, acc=65.64%]


Train Loss: 0.8218, Train Acc: 84.56%
Test Loss: 1.1986, Test Acc: 65.64%

Epoch: 62/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.49it/s, loss=0.9497, acc=77.52%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.14it/s, loss=1.1482, acc=65.25%]



Epoch: 63/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.48it/s, loss=0.8702, acc=81.43%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.07it/s, loss=1.0018, acc=74.90%]


New best model saved! Accuracy: 74.90%

Epoch: 64/100


Training: 100%|██████████| 64/64 [00:26<00:00,  2.44it/s, loss=0.8165, acc=83.97%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.90it/s, loss=1.0395, acc=72.97%]



Epoch: 65/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=1.0060, acc=75.17%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.48it/s, loss=1.1822, acc=71.04%]



Epoch: 66/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.48it/s, loss=0.9002, acc=78.20%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.99it/s, loss=0.9810, acc=74.52%]



Epoch: 67/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.48it/s, loss=0.7816, acc=86.22%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.11it/s, loss=0.9964, acc=72.97%]



Epoch: 68/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=0.9705, acc=76.74%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.13it/s, loss=1.1588, acc=67.95%]



Epoch: 69/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=0.8873, acc=79.77%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.76it/s, loss=1.0165, acc=72.97%]



Epoch: 70/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.49it/s, loss=0.7602, acc=87.39%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.54it/s, loss=1.0344, acc=74.52%]



Epoch: 71/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.48it/s, loss=0.9369, acc=77.71%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.18it/s, loss=1.2578, acc=63.71%]


Train Loss: 0.9369, Train Acc: 77.71%
Test Loss: 1.2578, Test Acc: 63.71%

Epoch: 72/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=0.8860, acc=79.77%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.13it/s, loss=1.0014, acc=74.13%]



Epoch: 73/100


Training: 100%|██████████| 64/64 [00:26<00:00,  2.46it/s, loss=0.7448, acc=87.29%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.16it/s, loss=1.0162, acc=73.36%]



Epoch: 74/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.49it/s, loss=0.8430, acc=83.77%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.62it/s, loss=1.3218, acc=60.62%]



Epoch: 75/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.49it/s, loss=0.9054, acc=78.59%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.67it/s, loss=1.0608, acc=67.95%]



Epoch: 76/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.49it/s, loss=0.7476, acc=87.39%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.16it/s, loss=0.9721, acc=77.61%]


New best model saved! Accuracy: 77.61%

Epoch: 77/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=0.8350, acc=83.77%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.14it/s, loss=1.2777, acc=64.09%]



Epoch: 78/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=0.8692, acc=81.43%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.94it/s, loss=0.9986, acc=76.06%]



Epoch: 79/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.48it/s, loss=0.7313, acc=87.49%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.47it/s, loss=0.9995, acc=72.20%]



Epoch: 80/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.50it/s, loss=0.8279, acc=82.40%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.99it/s, loss=1.2654, acc=59.85%]



Epoch: 81/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.49it/s, loss=0.8674, acc=81.62%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.19it/s, loss=0.9726, acc=76.45%]


Train Loss: 0.8674, Train Acc: 81.62%
Test Loss: 0.9726, Test Acc: 76.45%

Epoch: 82/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=0.7381, acc=87.39%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.15it/s, loss=0.9770, acc=76.45%]



Epoch: 83/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.48it/s, loss=0.7857, acc=85.34%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.87it/s, loss=1.0694, acc=70.27%]



Epoch: 84/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.50it/s, loss=0.8553, acc=81.92%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.50it/s, loss=0.9944, acc=76.06%]



Epoch: 85/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.49it/s, loss=0.6894, acc=91.01%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.93it/s, loss=0.9612, acc=74.52%]



Epoch: 86/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.48it/s, loss=0.7453, acc=86.80%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.12it/s, loss=1.0944, acc=74.13%]



Epoch: 87/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.48it/s, loss=0.8329, acc=82.21%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.22it/s, loss=1.1158, acc=69.50%]



Epoch: 88/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.49it/s, loss=0.7013, acc=90.32%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.97it/s, loss=0.9282, acc=78.38%]


New best model saved! Accuracy: 78.38%

Epoch: 89/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.51it/s, loss=0.7523, acc=86.02%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.50it/s, loss=1.2677, acc=66.02%]



Epoch: 90/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.48it/s, loss=0.7893, acc=85.34%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.16it/s, loss=1.0319, acc=76.45%]



Epoch: 91/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.48it/s, loss=0.6931, acc=91.50%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.13it/s, loss=0.9603, acc=78.76%]


Train Loss: 0.6931, Train Acc: 91.50%
Test Loss: 0.9603, Test Acc: 78.76%
New best model saved! Accuracy: 78.76%

Epoch: 92/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=0.6697, acc=91.50%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.16it/s, loss=1.1195, acc=70.27%]



Epoch: 93/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.48it/s, loss=0.8575, acc=81.52%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.55it/s, loss=1.0276, acc=72.97%]



Epoch: 94/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.50it/s, loss=0.6860, acc=90.32%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.76it/s, loss=0.9270, acc=80.31%]


New best model saved! Accuracy: 80.31%

Epoch: 95/100


Training: 100%|██████████| 64/64 [00:26<00:00,  2.39it/s, loss=0.6759, acc=91.20%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.09it/s, loss=1.1726, acc=64.86%]



Epoch: 96/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=0.8639, acc=82.40%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.20it/s, loss=1.0879, acc=69.11%]



Epoch: 97/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=0.7230, acc=88.17%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.65it/s, loss=0.9788, acc=77.61%]



Epoch: 98/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.48it/s, loss=0.6722, acc=91.50%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.66it/s, loss=0.9749, acc=76.45%]



Epoch: 99/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.49it/s, loss=0.8275, acc=83.77%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.15it/s, loss=1.0887, acc=74.13%]



Epoch: 100/100


Training: 100%|██████████| 64/64 [00:25<00:00,  2.47it/s, loss=0.6831, acc=90.42%]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.16it/s, loss=0.9508, acc=78.38%]


Loaded best model
Best model performance:
Test Loss: 0.9270, Test Acc: 80.31%

Training completed!



