In [220]:
# import libraries
import torchvision
import torch
import cv2
import numpy as np
import torchvision.transforms as transforms
from tqdm import tqdm
import torch.nn as nn

In [221]:
# check device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device("mps" if torch.backends.mps.is_available() else device)

print(device)

mps


In [222]:
# prepare dataset
trainDataset = torchvision.datasets.FashionMNIST(
    root='./data',
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor()
)

testDataset = torchvision.datasets.FashionMNIST(
    root='./data',
    train=False,
    download=True,
    transform=torchvision.transforms.ToTensor()
)

# Confirm they are in grayscale (they should already be).
# should be [1, 28, 28], 1 channel, 28x28 pixels (grayscale)
print(trainDataset[0][0].shape)
print(testDataset[0][0].shape)

torch.Size([1, 28, 28])
torch.Size([1, 28, 28])


In [223]:
def siftFeatureDetector(image):
    sift = cv2.SIFT_create()
    image = (image * 255).astype(np.uint8)
    keypoints, descriptors = sift.detectAndCompute(image, None)
    return keypoints, descriptors

In [224]:
def extraDescriptorsAndLabels(dataSet, batch_size=128):
    train_loader = torch.utils.data.DataLoader(
        dataSet, batch_size=batch_size, shuffle=False)

    all_descriptors = []
    all_labels = []

    # use tqdm to show progress bar
    for batch in tqdm(train_loader, desc="Processing Batches", unit="batch"):
        images, labels = batch
        images = images.squeeze().numpy()
        for image in images:
            keypoints, descriptors = siftFeatureDetector(image)
            all_descriptors.append(descriptors)
            all_labels.append(labels)
    # verify that the descriptors and the labels are being stored correctly
    if len(all_descriptors) == len(all_labels) and len(all_descriptors) == len(dataSet):
        return all_descriptors, all_labels
    else:
        print("Error: descriptors and labels are not the same length")
        return None

In [225]:
# get descriptors and labels
train_descriptors, train_labels = extraDescriptorsAndLabels(trainDataset)
test_descriptors, test_labels = extraDescriptorsAndLabels(testDataset)

Processing Batches: 100%|██████████| 469/469 [00:19<00:00, 24.41batch/s]
Processing Batches: 100%|██████████| 79/79 [00:03<00:00, 24.65batch/s]


In [226]:
class Task1_MLP(nn.Module):
    def __init__(self, input_dim=28 * 28, hidden_dim=128, output_dim=10):
        super(Task1_MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)  # 输入层: 50 -> 128
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)  # 隐藏层: 128 -> 10（类别数）
        
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [None]:
# convert 28*28 image to 784-dim vector
def imageToVector(image):
    return image.view(-1, 28 * 28)

# train the model
model = Task1_MLP().to(device)

# define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# train the model
num_epochs = 5
batch_size = 128
train_loader = torch.utils.data.DataLoader(
    trainDataset, batch_size=batch_size, shuffle=True)

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # forward pass
        outputs = model(imageToVector(images))
        loss = criterion(outputs, labels)

        # backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item()}')

Epoch [1/5], Step [100/469], Loss: 0.6045718193054199
Epoch [1/5], Step [200/469], Loss: 0.4898671507835388
Epoch [1/5], Step [300/469], Loss: 0.4991433620452881
Epoch [1/5], Step [400/469], Loss: 0.4031980633735657
Epoch [2/5], Step [100/469], Loss: 0.38540399074554443
Epoch [2/5], Step [200/469], Loss: 0.4107586145401001
Epoch [2/5], Step [300/469], Loss: 0.44177696108818054
Epoch [2/5], Step [400/469], Loss: 0.43348944187164307
Epoch [3/5], Step [100/469], Loss: 0.48607054352760315
Epoch [3/5], Step [200/469], Loss: 0.44146448373794556
Epoch [3/5], Step [300/469], Loss: 0.2834351062774658
Epoch [3/5], Step [400/469], Loss: 0.3967074751853943
Epoch [4/5], Step [100/469], Loss: 0.4138443171977997
Epoch [4/5], Step [200/469], Loss: 0.2710345387458801
Epoch [4/5], Step [300/469], Loss: 0.4316185712814331
Epoch [4/5], Step [400/469], Loss: 0.34148502349853516
Epoch [5/5], Step [100/469], Loss: 0.5268257856369019
Epoch [5/5], Step [200/469], Loss: 0.23402893543243408
Epoch [5/5], Step [30