In [1]:
# import libraries
import torchvision
import torch
import cv2
import numpy as np
import torchvision.transforms as transforms
from tqdm import tqdm
import torch.nn as nn

In [2]:
# check device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device("mps" if torch.backends.mps.is_available() else device)

print(device)

mps


In [3]:
# prepare dataset
trainDataset = torchvision.datasets.FashionMNIST(
    root='./data',
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor()
)

testDataset = torchvision.datasets.FashionMNIST(
    root='./data',
    train=False,
    download=True,
    transform=torchvision.transforms.ToTensor()
)

# Confirm they are in grayscale (they should already be).
# should be [1, 28, 28], 1 channel, 28x28 pixels (grayscale)
print(trainDataset[0][0].shape)
print(testDataset[0][0].shape)

torch.Size([1, 28, 28])
torch.Size([1, 28, 28])


In [4]:
def siftFeatureDetector(image):
    sift = cv2.SIFT_create()
    image = (image * 255).astype(np.uint8)
    keypoints, descriptors = sift.detectAndCompute(image, None)
    return keypoints, descriptors

In [5]:
def extraDescriptorsAndLabels(dataSet, batch_size=128):
    train_loader = torch.utils.data.DataLoader(
        dataSet, batch_size=batch_size, shuffle=False)

    all_descriptors = []
    all_labels = []

    # use tqdm to show progress bar
    for batch in tqdm(train_loader, desc="Processing Batches", unit="batch"):
        images, labels = batch
        images = images.squeeze().numpy()
        for image in images:
            keypoints, descriptors = siftFeatureDetector(image)
            all_descriptors.append(descriptors)
            all_labels.append(labels)
    # verify that the descriptors and the labels are being stored correctly
    if len(all_descriptors) == len(all_labels) and len(all_descriptors) == len(dataSet):
        return all_descriptors, all_labels
    else:
        print("Error: descriptors and labels are not the same length")
        return None

In [6]:
# get descriptors and labels
train_descriptors, train_labels = extraDescriptorsAndLabels(trainDataset)
test_descriptors, test_labels = extraDescriptorsAndLabels(testDataset)

Processing Batches: 100%|██████████| 469/469 [00:19<00:00, 23.80batch/s]
Processing Batches: 100%|██████████| 79/79 [00:03<00:00, 24.45batch/s]


In [7]:
class Task1_MLP(nn.Module):
    def __init__(self, input_dim=28 * 28, hidden_dim=128, output_dim=10):
        super(Task1_MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)  # 输入层: 50 -> 128
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)  # 隐藏层: 128 -> 10（类别数）
        
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [8]:
# convert 28*28 image to 784-dim vector
def imageToVector(image):
    return image.view(-1, 28 * 28)


# train the model
model = Task1_MLP().to(device)

# define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# train the model
num_epochs = 50
batch_size = 600
train_loader = torch.utils.data.DataLoader(
    trainDataset, batch_size=batch_size, shuffle=True)

for epoch in range(num_epochs):
    loop = tqdm(enumerate(train_loader), total=len(
        train_loader), desc=f"Epoch {epoch+1}/{num_epochs}")

    for i, (images, labels) in loop:
        images = images.to(device)
        labels = labels.to(device)

        # forward pass
        outputs = model(imageToVector(images))
        loss = criterion(outputs, labels)

        # backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) == loop.total:
            print(
                f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item()}')

Epoch 1/50: 100%|██████████| 100/100 [00:02<00:00, 40.13it/s]


Epoch [1/50], Step [100/100], Loss: 0.5250208973884583


Epoch 2/50: 100%|██████████| 100/100 [00:01<00:00, 52.38it/s]


Epoch [2/50], Step [100/100], Loss: 0.5676847100257874


Epoch 3/50: 100%|██████████| 100/100 [00:01<00:00, 52.58it/s]


Epoch [3/50], Step [100/100], Loss: 0.3656439185142517


Epoch 4/50: 100%|██████████| 100/100 [00:01<00:00, 52.94it/s]


Epoch [4/50], Step [100/100], Loss: 0.41279366612434387


Epoch 5/50: 100%|██████████| 100/100 [00:01<00:00, 50.57it/s]


Epoch [5/50], Step [100/100], Loss: 0.3919188380241394


Epoch 6/50: 100%|██████████| 100/100 [00:01<00:00, 52.05it/s]


Epoch [6/50], Step [100/100], Loss: 0.3983255624771118


Epoch 7/50: 100%|██████████| 100/100 [00:01<00:00, 52.57it/s]


Epoch [7/50], Step [100/100], Loss: 0.3363197445869446


Epoch 8/50: 100%|██████████| 100/100 [00:01<00:00, 51.50it/s]


Epoch [8/50], Step [100/100], Loss: 0.3290833830833435


Epoch 9/50: 100%|██████████| 100/100 [00:01<00:00, 52.72it/s]


Epoch [9/50], Step [100/100], Loss: 0.3864980638027191


Epoch 10/50: 100%|██████████| 100/100 [00:01<00:00, 52.83it/s]


Epoch [10/50], Step [100/100], Loss: 0.3238428831100464


Epoch 11/50: 100%|██████████| 100/100 [00:01<00:00, 51.50it/s]


Epoch [11/50], Step [100/100], Loss: 0.3280735909938812


Epoch 12/50: 100%|██████████| 100/100 [00:01<00:00, 52.62it/s]


Epoch [12/50], Step [100/100], Loss: 0.32458868622779846


Epoch 13/50: 100%|██████████| 100/100 [00:01<00:00, 52.64it/s]


Epoch [13/50], Step [100/100], Loss: 0.3096567690372467


Epoch 14/50: 100%|██████████| 100/100 [00:01<00:00, 51.70it/s]


Epoch [14/50], Step [100/100], Loss: 0.3147822320461273


Epoch 15/50: 100%|██████████| 100/100 [00:01<00:00, 52.93it/s]


Epoch [15/50], Step [100/100], Loss: 0.28022098541259766


Epoch 16/50: 100%|██████████| 100/100 [00:01<00:00, 52.67it/s]


Epoch [16/50], Step [100/100], Loss: 0.23281265795230865


Epoch 17/50: 100%|██████████| 100/100 [00:01<00:00, 52.62it/s]


Epoch [17/50], Step [100/100], Loss: 0.289583295583725


Epoch 18/50: 100%|██████████| 100/100 [00:01<00:00, 51.44it/s]


Epoch [18/50], Step [100/100], Loss: 0.23393496870994568


Epoch 19/50: 100%|██████████| 100/100 [00:01<00:00, 52.65it/s]


Epoch [19/50], Step [100/100], Loss: 0.2691209018230438


Epoch 20/50: 100%|██████████| 100/100 [00:01<00:00, 51.92it/s]


Epoch [20/50], Step [100/100], Loss: 0.32970160245895386


Epoch 21/50: 100%|██████████| 100/100 [00:01<00:00, 52.72it/s]


Epoch [21/50], Step [100/100], Loss: 0.2576725482940674


Epoch 22/50: 100%|██████████| 100/100 [00:01<00:00, 52.66it/s]


Epoch [22/50], Step [100/100], Loss: 0.3094936013221741


Epoch 23/50: 100%|██████████| 100/100 [00:01<00:00, 52.87it/s]


Epoch [23/50], Step [100/100], Loss: 0.2914637625217438


Epoch 24/50: 100%|██████████| 100/100 [00:01<00:00, 50.57it/s]


Epoch [24/50], Step [100/100], Loss: 0.28212907910346985


Epoch 25/50: 100%|██████████| 100/100 [00:01<00:00, 50.04it/s]


Epoch [25/50], Step [100/100], Loss: 0.22744138538837433


Epoch 26/50: 100%|██████████| 100/100 [00:01<00:00, 52.79it/s]


Epoch [26/50], Step [100/100], Loss: 0.22099167108535767


Epoch 27/50: 100%|██████████| 100/100 [00:01<00:00, 52.84it/s]


Epoch [27/50], Step [100/100], Loss: 0.3278041183948517


Epoch 28/50: 100%|██████████| 100/100 [00:01<00:00, 52.54it/s]


Epoch [28/50], Step [100/100], Loss: 0.2452799528837204


Epoch 29/50: 100%|██████████| 100/100 [00:01<00:00, 52.34it/s]


Epoch [29/50], Step [100/100], Loss: 0.22180898487567902


Epoch 30/50: 100%|██████████| 100/100 [00:01<00:00, 52.77it/s]


Epoch [30/50], Step [100/100], Loss: 0.207909494638443


Epoch 31/50: 100%|██████████| 100/100 [00:01<00:00, 50.08it/s]


Epoch [31/50], Step [100/100], Loss: 0.247975155711174


Epoch 32/50: 100%|██████████| 100/100 [00:01<00:00, 52.54it/s]


Epoch [32/50], Step [100/100], Loss: 0.272438645362854


Epoch 33/50: 100%|██████████| 100/100 [00:01<00:00, 52.57it/s]


Epoch [33/50], Step [100/100], Loss: 0.2543664276599884


Epoch 34/50: 100%|██████████| 100/100 [00:01<00:00, 50.22it/s]


Epoch [34/50], Step [100/100], Loss: 0.22874465584754944


Epoch 35/50: 100%|██████████| 100/100 [00:02<00:00, 49.30it/s]


Epoch [35/50], Step [100/100], Loss: 0.26947200298309326


Epoch 36/50: 100%|██████████| 100/100 [00:01<00:00, 52.27it/s]


Epoch [36/50], Step [100/100], Loss: 0.21244977414608002


Epoch 37/50: 100%|██████████| 100/100 [00:01<00:00, 52.48it/s]


Epoch [37/50], Step [100/100], Loss: 0.209865540266037


Epoch 38/50: 100%|██████████| 100/100 [00:01<00:00, 51.09it/s]


Epoch [38/50], Step [100/100], Loss: 0.21432387828826904


Epoch 39/50: 100%|██████████| 100/100 [00:01<00:00, 52.45it/s]


Epoch [39/50], Step [100/100], Loss: 0.20792488753795624


Epoch 40/50: 100%|██████████| 100/100 [00:01<00:00, 51.31it/s]


Epoch [40/50], Step [100/100], Loss: 0.21940411627292633


Epoch 41/50: 100%|██████████| 100/100 [00:02<00:00, 49.28it/s]


Epoch [41/50], Step [100/100], Loss: 0.20176605880260468


Epoch 42/50: 100%|██████████| 100/100 [00:01<00:00, 51.18it/s]


Epoch [42/50], Step [100/100], Loss: 0.20604363083839417


Epoch 43/50: 100%|██████████| 100/100 [00:02<00:00, 49.73it/s]


Epoch [43/50], Step [100/100], Loss: 0.16763564944267273


Epoch 44/50: 100%|██████████| 100/100 [00:01<00:00, 50.39it/s]


Epoch [44/50], Step [100/100], Loss: 0.2230391502380371


Epoch 45/50: 100%|██████████| 100/100 [00:01<00:00, 52.48it/s]


Epoch [45/50], Step [100/100], Loss: 0.2106865495443344


Epoch 46/50: 100%|██████████| 100/100 [00:01<00:00, 52.68it/s]


Epoch [46/50], Step [100/100], Loss: 0.2014211267232895


Epoch 47/50: 100%|██████████| 100/100 [00:01<00:00, 52.57it/s]


Epoch [47/50], Step [100/100], Loss: 0.17731967568397522


Epoch 48/50: 100%|██████████| 100/100 [00:01<00:00, 52.30it/s]


Epoch [48/50], Step [100/100], Loss: 0.1795942485332489


Epoch 49/50: 100%|██████████| 100/100 [00:01<00:00, 51.81it/s]


Epoch [49/50], Step [100/100], Loss: 0.18676866590976715


Epoch 50/50: 100%|██████████| 100/100 [00:01<00:00, 52.81it/s]

Epoch [50/50], Step [100/100], Loss: 0.21617922186851501



