In [None]:
# Import necessary libraries
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import torch
from vit_pytorch.efficient import ViT
from linformer import Linformer
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR


In [3]:
# Define a custom dataset class to load the dataset with some necessary preprocess 
class GTZANDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = os.listdir(root_dir)
        self.files = [os.path.join(root, file) for root, _, files in os.walk(root_dir) for file in files]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        # Load image and label for given idx
        img_path = self.files[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.classes.index(os.path.basename(os.path.dirname(img_path)))
        #image = image.rotate(90)       # Change the sequence for comparison (Mentioned in paper)
        # Apply transformation (preprocess)
        if self.transform:
            image = self.transform(image)
        return image, label

# Define data transforms
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [2]:
# Set root directory for dataset
root = "path/of/dataset"

In [4]:
# Load and create dataset
dataset = GTZANDataset(root, transform=data_transforms)

In [5]:
# Split into train and test
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [6]:
# Set gpu for training
device = "cuda"

In [7]:
# Define a pre-implemented transformer
efficient_transformer = Linformer(
    dim=128,
    seq_len=49+1,  # 7x7 patches + 1 cls-token
    depth=12,
    heads=8,
    k=64
)

# Create a pre-implemented vision transformer model using the transformer above
model = ViT(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=10,
    transformer=efficient_transformer,
    channels=3,
).to(device)

In [9]:
# Set learning rate and num of epochs 
lr = 3e-5
epochs = 20

In [10]:
# Define loss function, optimizer and scheduler

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

In [12]:
# Training loop
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    # Iterate over batches in the training set
    for data, label in train_loader:
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    # Validation loop
    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in test_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(test_loader)
            epoch_val_loss += val_loss / len(test_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

Epoch : 1 - loss : 2.3428 - acc: 0.0983 - val_loss : 2.2958 - val_acc: 0.0986

Epoch : 2 - loss : 2.2716 - acc: 0.1646 - val_loss : 2.2678 - val_acc: 0.1466

Epoch : 3 - loss : 2.1799 - acc: 0.2247 - val_loss : 2.1107 - val_acc: 0.2428

Epoch : 4 - loss : 1.9654 - acc: 0.2995 - val_loss : 1.9152 - val_acc: 0.3077

Epoch : 5 - loss : 1.7482 - acc: 0.3851 - val_loss : 1.6578 - val_acc: 0.3774

Epoch : 6 - loss : 1.5509 - acc: 0.4561 - val_loss : 1.5092 - val_acc: 0.4543

Epoch : 7 - loss : 1.3226 - acc: 0.5392 - val_loss : 1.5152 - val_acc: 0.4784

Epoch : 8 - loss : 1.1793 - acc: 0.5857 - val_loss : 1.3772 - val_acc: 0.5024

Epoch : 9 - loss : 1.0617 - acc: 0.6368 - val_loss : 1.3129 - val_acc: 0.5457

Epoch : 10 - loss : 0.9136 - acc: 0.6940 - val_loss : 1.2570 - val_acc: 0.5577

Epoch : 11 - loss : 0.8282 - acc: 0.7272 - val_loss : 1.2983 - val_acc: 0.5481

Epoch : 12 - loss : 0.7070 - acc: 0.7728 - val_loss : 1.2747 - val_acc: 0.5577

Epoch : 13 - loss : 0.5427 - acc: 0.8484 - val_lo

KeyboardInterrupt: 