# Train Vision Transformer for a single audio time slice

We make raw code blocks for more user functionality. Convert these to code blocks, thereby making custom models outside the scope of the command line interface.

### Imports

In [1]:
import os
import joblib

# array data manipulation and plotting
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
from transformers import ViTConfig, ViTForImageClassification, ViTFeatureExtractor

# machine learning
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, \
    accuracy_score, \
    f1_score, \
    auc, \
    recall_score, \
    precision_score, \
    precision_recall_curve, \
    roc_curve
from sklearn.metrics import confusion_matrix

2025-11-13 20:40:53.814038: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-13 20:40:54.223720: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [9]:
# Example for single-channel image
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),  # Resize to pretrained ViT size
    transforms.ToTensor(),          # Converts to [0,1]
])

In [51]:
batch_size = 2**4
X_train_path = os.path.join('data', 'X_train.npy')
X_test_path = os.path.join('data', 'X_test.npy')
Y_train_path = os.path.join('data', 'Y_train.npy')
Y_test_path = os.path.join('data', 'Y_test.npy')

In [52]:
# --- LOAD DATA ---
X_train = np.load(X_train_path)
X_test = np.load(X_test_path)
Y_train = np.load(Y_train_path)
Y_test = np.load(Y_test_path)
num_classes = len(np.unique(Y_train))

# --- MAKE TENSORS ---
Y_train = torch.tensor(Y_train, dtype=torch.long)
X_train = torch.tensor(X_train, dtype=torch.float32)
Y_test = torch.tensor(Y_test, dtype=torch.long)
X_test = torch.tensor(X_test, dtype=torch.float32)

# Apply to a single image (H, W) or (C, H, W)
X_test = torch.stack([transform(img) for img in X_test])  # batch dimension
X_train = torch.stack([transform(img) for img in X_train])  # batch dimension

# --- DATA LOADER ---
train_data = TensorDataset(X_train, Y_train)
test_data = TensorDataset(X_test, Y_test)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# X_height = X_train.shape[1]
# X_width = X_train.shape[2]

del X_train, Y_train, X_test, Y_test

In [53]:
# architecture
hidden_size = 6**3
num_hidden_layers = 6
num_attention_heads = 6
intermediate_size = 1024
hidden_dropout_prob = 0.1
attention_probs_dropout_prob = 0.1

### Set up model

In [54]:
config = ViTConfig(
    patch_size=32,
    num_channels=1,  # Adjust for your input data (e.g., 1 for grayscale, 3 for RGB)
    hidden_size=hidden_size,
    num_hidden_layers=num_hidden_layers,
    num_attention_heads=num_attention_heads,
    intermediate_size=intermediate_size,
    hidden_dropout_prob=hidden_dropout_prob,
    attention_probs_dropout_prob=attention_probs_dropout_prob,
    num_labels=num_classes,
)

model = ViTForImageClassification(config)

# criterion choice
criterion = torch.nn.CrossEntropyLoss()

# choice of optimizer
optimizer = optim.AdamW(model.parameters(), lr=0.001)   

model

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(1, 216, kernel_size=(32, 32), stride=(32, 32))
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-5): 6 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=216, out_features=216, bias=True)
              (key): Linear(in_features=216, out_features=216, bias=True)
              (value): Linear(in_features=216, out_features=216, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=216, out_features=216, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=216, out_features=1024, bias=True)
            (intermedia

### Train the model

In [56]:
# Training loop

model.to(device)

num_epochs = 100
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0

    for inputs, labels in train_loader:

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Pass to GPU
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(inputs)

        # Compute the loss
        loss = criterion(outputs.logits, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Accumulate the loss
        running_loss += loss.item()

    # Print the average loss for this epoch
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_loader):.4f}")

Epoch 1/100, Loss: 2.2452
Epoch 2/100, Loss: 2.2438
Epoch 3/100, Loss: 2.2509


KeyboardInterrupt: 

In [49]:
# Evaluate the model's accuracy on the validation data
model.eval()  # Set the model to evaluation mode
with torch.no_grad():  # Disable gradient computation for evaluation
    for inputs, labels in train_loader:

        inputs = inputs.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(inputs)

        # Get the predicted class (index of the maximum value in the output)
        # For multi-class (one-hot) targets
        predicted = outputs.logits.argmax(dim=1)
        labels = labels.int()
        predicted = predicted.cpu()
        labels = labels.cpu()
        print(accuracy_score(predicted,labels))

0.21484375
0.16796875
0.19140625
0.18359375
0.2265625
0.21484375
0.24609375
0.21484375
0.2578125
0.19140625
0.16796875
0.171875
0.18359375
0.1953125
0.1875
0.19921875
0.203125
0.24609375
0.16796875
0.21875
0.16015625
0.21484375
0.18359375
0.234375
0.15625
0.1640625
0.2109375
0.20703125
0.1796875
0.21875
0.23828125
0.2109375
0.20703125
0.1953125
0.234375
0.1484375
0.22265625
0.2265625
0.22265625
0.13671875
0.1640625
0.1484375
0.171875
0.1640625
0.1796875
0.18359375
0.21875
0.17578125
0.20703125
0.234375
0.21484375
0.19140625
0.16796875
0.1640625
0.16796875
0.1953125
0.2109375
0.17578125
0.15625
0.203125
0.16796875
0.1953125
0.15625
0.1953125
0.14453125
0.16796875
0.2109375
0.1875
0.1796875
0.234375
0.17578125
0.2265625
0.18359375
0.171875
0.23828125
0.18359375
0.20703125
0.1640625
0.1875
0.1796875
0.265625
0.15625
0.20703125
0.23046875
0.19921875
0.25
0.16796875
0.1875
0.171875
0.2109375
0.21484375
0.23046875
0.2109375
0.18359375
0.2109375
0.1640625
0.18359375
0.1640625
0.1640625
0.1796

In [50]:
# Evaluate the model's accuracy on the validation data
model.eval()  # Set the model to evaluation mode
with torch.no_grad():  # Disable gradient computation for evaluation
    for inputs, labels in test_loader:

        inputs = inputs.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(inputs)

        # Get the predicted class (index of the maximum value in the output)
        # For multi-class (one-hot) targets
        predicted = outputs.logits.argmax(dim=1)
        labels = labels.int()
        predicted = predicted.cpu()
        labels = labels.cpu()
        print(accuracy_score(predicted,labels))

0.05859375
0.03515625
0.0
0.58984375
0.484375
0.453125
0.22265625
0.0
0.0
0.1015625
0.08984375
0.03515625
0.0
0.0
