In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [12]:
import matplotlib.pyplot as plt

In [18]:
import numpy as np
import pywt


In [19]:
batch_size = 64

# Transform: Normalize and Flatten Images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1, 28, 28))  # Flatten the image
])

# Load MNIST Dataset
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)



In [20]:
# Step 2: Define a function for DWT feature extraction
def dwt_extract_features(images, wavelet='db1'):
    """
    Extract features using Discrete Wavelet Transform (DWT).
    Args:
        images: A batch of images as a PyTorch tensor of shape [batch_size, 1, 28, 28].
        wavelet: The type of wavelet to use (default is 'db1', Haar wavelet).
    Returns:
        Flattened DWT coefficients as a PyTorch tensor of shape [batch_size, num_features].
    """
    batch_size = images.shape[0]
    features = []
    for i in range(batch_size):
        # Convert PyTorch tensor to NumPy array
        image = images[i, 0].cpu().numpy()

        # Perform 2D Discrete Wavelet Transform
        coeffs = pywt.dwt2(image, wavelet)
        cA, (cH, cV, cD) = coeffs

        # Flatten and concatenate coefficients (cA, cH, cV, cD)
        flattened_coeffs = np.concatenate([cA.flatten(), cH.flatten(), cV.flatten(), cD.flatten()])
        features.append(flattened_coeffs)

    # Convert back to PyTorch tensor
    return torch.tensor(features, dtype=torch.float32)

In [21]:
# Step 3: Extract features for training and testing
train_features = []
test_features = []
train_labels = []
test_labels = []

wavelet = 'db1'  # Haar wavelet

# Extract features for training data
for images, labels in train_loader:
    features = dwt_extract_features(images, wavelet)
    train_features.append(features)
    train_labels.append(labels)

train_features = torch.cat(train_features, dim=0)
train_labels = torch.cat(train_labels, dim=0)

# Extract features for testing data
for images, labels in test_loader:
    features = dwt_extract_features(images, wavelet)
    test_features.append(features)
    test_labels.append(labels)

test_features = torch.cat(test_features, dim=0)
test_labels = torch.cat(test_labels, dim=0)

  return torch.tensor(features, dtype=torch.float32)


In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Step 4: Define the Linear Classification Model
class LinearClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(LinearClassifier, self).__init__()
        self.linear = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.linear(x)

# Initialize the linear classifier
latent_dim = train_features.shape[1]  # Dimension of the latent representation
num_classes = 10  # Number of MNIST classes
classifier = LinearClassifier(latent_dim, num_classes).to(device)


In [24]:
# Step 5: Train the Linear Classifier
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=1e-2)
num_epochs = 25

for epoch in range(num_epochs):
    classifier.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for i in range(0, len(train_features), 64):  # Batch processing
        batch_data = train_features[i:i+64].to(device)  # Encoded vectors
        batch_labels = train_labels[i:i+64].to(device)      # Corresponding labels

        # Forward pass
        outputs = classifier(batch_data)
        loss = criterion(outputs, batch_labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        predicted = torch.max(outputs.data, 1)[1]
        total += len(batch_labels)
        correct += (predicted == batch_labels).sum()

    avg_loss = running_loss / (len(train_features) // 64)
    accuracy = 100 * correct / float(total)

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy}%")


Epoch [1/25], Loss: 0.3619, Accuracy: 89.65666961669922%
Epoch [2/25], Loss: 0.2991, Accuracy: 91.61333465576172%
Epoch [3/25], Loss: 0.2894, Accuracy: 91.9566650390625%
Epoch [4/25], Loss: 0.2841, Accuracy: 92.14500427246094%
Epoch [5/25], Loss: 0.2808, Accuracy: 92.26499938964844%
Epoch [6/25], Loss: 0.2783, Accuracy: 92.34000396728516%
Epoch [7/25], Loss: 0.2764, Accuracy: 92.41500091552734%
Epoch [8/25], Loss: 0.2749, Accuracy: 92.46666717529297%
Epoch [9/25], Loss: 0.2737, Accuracy: 92.52833557128906%
Epoch [10/25], Loss: 0.2726, Accuracy: 92.56999969482422%
Epoch [11/25], Loss: 0.2717, Accuracy: 92.6066665649414%
Epoch [12/25], Loss: 0.2708, Accuracy: 92.62667083740234%
Epoch [13/25], Loss: 0.2701, Accuracy: 92.65666961669922%
Epoch [14/25], Loss: 0.2695, Accuracy: 92.66333770751953%
Epoch [15/25], Loss: 0.2689, Accuracy: 92.68167114257812%
Epoch [16/25], Loss: 0.2684, Accuracy: 92.68499755859375%
Epoch [17/25], Loss: 0.2680, Accuracy: 92.70166778564453%
Epoch [18/25], Loss: 0.26

In [24]:
# Step 6: Save the trained classifier
# torch.save(classifier.state_dict(), "model/linear_classifier.pth")

In [25]:

# Step 7: Evaluate the Classifier
classifier.eval()
correct = 0
total = 0

with torch.no_grad():
    for i in range(0, len(test_features), 64):
        batch_data = test_features[i:i+64].to(device)
        batch_labels = test_labels[i:i+64].to(device)

        outputs = classifier(batch_data)
        _, predicted = torch.max(outputs, 1)
        total += batch_labels.size(0)
        correct += (predicted == batch_labels).sum().item()

accuracy = 100 * correct / total
print(f"Accuracy of the Linear Classifier: {accuracy:.2f}%")

Accuracy of the Linear Classifier: 91.54%
