In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import numpy as np
from torch.optim.lr_scheduler import StepLR
#import torch.nn.functional as F

torch.manual_seed(0)
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
batch_size = 64
test_batch_size = 64

In [3]:
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.1307], [0.3081])
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=img_transform)
train_dataset, val_dataset = random_split(train_dataset, [50000,10000])
train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=test_batch_size, 
                        shuffle=False)

test_dataset = datasets.MNIST('./data', train=False, download=True, transform=img_transform)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size,
                             shuffle=True, num_workers=2)

In [4]:
class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True), 
            nn.Linear(64, 32), 
            nn.ReLU(True), 
            nn.Linear(32, 2))
        
        self.decoder = nn.Sequential(
            nn.Linear(2, 32),
            nn.ReLU(True),
            nn.Linear(32, 64),
            nn.ReLU(True),
            nn.Linear(64, 128),
            nn.ReLU(True), 
            nn.Linear(128, 28 * 28), 
            nn.Tanh())

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

ae = autoencoder().to(device)

In [5]:
ae_criterion = nn.MSELoss()

epochs = 1000
learning_rate = 0.01
#ae_optimizer = torch.optim.Adam(ae.parameters(), lr=0.001, weight_decay=learning_rate)
ae_optimizer = torch.optim.SGD(ae.parameters(), lr=0.5, momentum=0.5)

In [6]:
train_loss  = []
val_loss = []
for epoch in range(epochs):
    ae.train()
    batch_loss  = []
    for data in train_loader:
        img, _ = data
        img = img.view(img.size(0), -1).to(device)
        output = ae(img)
        
        scheduler = StepLR(ae_optimizer, step_size=50, gamma=0.9)
        loss = ae_criterion(output, img)
        ae_optimizer.zero_grad()
        loss.backward()
        ae_optimizer.step()
        # ===================log========================
        batch_loss.append(loss.item())
    train_loss.append(np.mean(batch_loss))
    
    
    ae.eval()
    batch_loss  = []
    with torch.no_grad():
        for img, y in val_loader:
            img = img.view(img.size(0), -1)
            img, y = img.to(device), y.to(device)
            output = ae(img)
            loss = ae_criterion(output, img)
            batch_loss.append(loss.item())
            val_loss.append(np.mean(batch_loss))
        print(f'Epoch [{epoch + 1}/{epochs}], train loss:{train_loss[-1]:.4f}, val loss:{val_loss[-1]:.4f}')

Epoch [1/1000], train loss:0.6733, val loss:0.6215
Epoch [2/1000], train loss:0.6103, val loss:0.6074
Epoch [3/1000], train loss:0.5978, val loss:0.5931
Epoch [4/1000], train loss:0.5896, val loss:0.5872
Epoch [5/1000], train loss:0.5827, val loss:0.5808
Epoch [6/1000], train loss:0.5769, val loss:0.5751
Epoch [7/1000], train loss:0.5724, val loss:0.5721
Epoch [8/1000], train loss:0.5682, val loss:0.5670
Epoch [9/1000], train loss:0.5646, val loss:0.5634
Epoch [10/1000], train loss:0.5609, val loss:0.5605
Epoch [11/1000], train loss:0.5590, val loss:0.5586
Epoch [12/1000], train loss:0.5563, val loss:0.5598
Epoch [13/1000], train loss:0.5546, val loss:0.5564
Epoch [14/1000], train loss:0.5530, val loss:0.5543
Epoch [15/1000], train loss:0.5511, val loss:0.5526
Epoch [16/1000], train loss:0.5505, val loss:0.5532
Epoch [17/1000], train loss:0.5488, val loss:0.5495
Epoch [18/1000], train loss:0.5471, val loss:0.5554
Epoch [19/1000], train loss:0.5465, val loss:0.5497
Epoch [20/1000], trai

Epoch [158/1000], train loss:0.5119, val loss:0.5275
Epoch [159/1000], train loss:0.5107, val loss:0.5306
Epoch [160/1000], train loss:0.5145, val loss:0.5171
Epoch [161/1000], train loss:0.5120, val loss:0.5189
Epoch [162/1000], train loss:0.5137, val loss:0.5322
Epoch [163/1000], train loss:0.5106, val loss:0.5112
Epoch [164/1000], train loss:0.5089, val loss:0.5188
Epoch [165/1000], train loss:0.5081, val loss:0.5135
Epoch [166/1000], train loss:0.5079, val loss:0.5124
Epoch [167/1000], train loss:0.5082, val loss:0.5192
Epoch [168/1000], train loss:0.5104, val loss:0.5120
Epoch [169/1000], train loss:0.5115, val loss:0.5170
Epoch [170/1000], train loss:0.5099, val loss:0.5121
Epoch [171/1000], train loss:0.5083, val loss:0.5105
Epoch [172/1000], train loss:0.5078, val loss:0.5110
Epoch [173/1000], train loss:0.5084, val loss:0.5301
Epoch [174/1000], train loss:0.5124, val loss:0.5142
Epoch [175/1000], train loss:0.5089, val loss:0.5130
Epoch [176/1000], train loss:0.5094, val loss:

Epoch [313/1000], train loss:0.5015, val loss:0.5024
Epoch [314/1000], train loss:0.4986, val loss:0.5121
Epoch [315/1000], train loss:0.4993, val loss:0.5057
Epoch [316/1000], train loss:0.4982, val loss:0.5009
Epoch [317/1000], train loss:0.4983, val loss:0.5171
Epoch [318/1000], train loss:0.4966, val loss:0.5161
Epoch [319/1000], train loss:0.4966, val loss:0.5016
Epoch [320/1000], train loss:0.4979, val loss:0.5007
Epoch [321/1000], train loss:0.4972, val loss:0.5031
Epoch [322/1000], train loss:0.4983, val loss:0.5113
Epoch [323/1000], train loss:0.5019, val loss:0.5166
Epoch [324/1000], train loss:0.5007, val loss:0.5099
Epoch [325/1000], train loss:0.5031, val loss:0.5073
Epoch [326/1000], train loss:0.5006, val loss:0.5081
Epoch [327/1000], train loss:0.5009, val loss:0.5133
Epoch [328/1000], train loss:0.5011, val loss:0.5254
Epoch [329/1000], train loss:0.4992, val loss:0.5042
Epoch [330/1000], train loss:0.4981, val loss:0.5019
Epoch [331/1000], train loss:0.4987, val loss:

Epoch [468/1000], train loss:0.4925, val loss:0.4991
Epoch [469/1000], train loss:0.4931, val loss:0.4998
Epoch [470/1000], train loss:0.4956, val loss:0.5077
Epoch [471/1000], train loss:0.4940, val loss:0.4988
Epoch [472/1000], train loss:0.4922, val loss:0.4989
Epoch [473/1000], train loss:0.4923, val loss:0.5254
Epoch [474/1000], train loss:0.4921, val loss:0.5015
Epoch [475/1000], train loss:0.4926, val loss:0.5043
Epoch [476/1000], train loss:0.4948, val loss:0.4983
Epoch [477/1000], train loss:0.4937, val loss:0.4956
Epoch [478/1000], train loss:0.4919, val loss:0.5037
Epoch [479/1000], train loss:0.4936, val loss:0.5079
Epoch [480/1000], train loss:0.4915, val loss:0.4968
Epoch [481/1000], train loss:0.4938, val loss:0.4997
Epoch [482/1000], train loss:0.4943, val loss:0.5024
Epoch [483/1000], train loss:0.4919, val loss:0.4959
Epoch [484/1000], train loss:0.4929, val loss:0.5121
Epoch [485/1000], train loss:0.4947, val loss:0.5004
Epoch [486/1000], train loss:0.4924, val loss:

Epoch [623/1000], train loss:0.4917, val loss:0.4923
Epoch [624/1000], train loss:0.4927, val loss:0.5063
Epoch [625/1000], train loss:0.4929, val loss:0.4954
Epoch [626/1000], train loss:0.4913, val loss:0.4990
Epoch [627/1000], train loss:0.4909, val loss:0.4955
Epoch [628/1000], train loss:0.4912, val loss:0.5127
Epoch [629/1000], train loss:0.4900, val loss:0.4930
Epoch [630/1000], train loss:0.4914, val loss:0.4973
Epoch [631/1000], train loss:0.4894, val loss:0.4967
Epoch [632/1000], train loss:0.4895, val loss:0.5048
Epoch [633/1000], train loss:0.4900, val loss:0.4929
Epoch [634/1000], train loss:0.4891, val loss:0.4913
Epoch [635/1000], train loss:0.4944, val loss:0.4960
Epoch [636/1000], train loss:0.4907, val loss:0.5023
Epoch [637/1000], train loss:0.4920, val loss:0.4978
Epoch [638/1000], train loss:0.4924, val loss:0.4970
Epoch [639/1000], train loss:0.4914, val loss:0.4973
Epoch [640/1000], train loss:0.4908, val loss:0.4945
Epoch [641/1000], train loss:0.4913, val loss:

Epoch [778/1000], train loss:0.4901, val loss:0.5014
Epoch [779/1000], train loss:0.4902, val loss:0.4997
Epoch [780/1000], train loss:0.4907, val loss:0.4963
Epoch [781/1000], train loss:0.4900, val loss:0.5001
Epoch [782/1000], train loss:0.4910, val loss:0.5002
Epoch [783/1000], train loss:0.4906, val loss:0.4980
Epoch [784/1000], train loss:0.4894, val loss:0.4935
Epoch [785/1000], train loss:0.4902, val loss:0.4992
Epoch [786/1000], train loss:0.4909, val loss:0.4967
Epoch [787/1000], train loss:0.4907, val loss:0.4998
Epoch [788/1000], train loss:0.4927, val loss:0.4944
Epoch [789/1000], train loss:0.4906, val loss:0.4985
Epoch [790/1000], train loss:0.4894, val loss:0.4943
Epoch [791/1000], train loss:0.4898, val loss:0.4949
Epoch [792/1000], train loss:0.4894, val loss:0.4990
Epoch [793/1000], train loss:0.4894, val loss:0.4957
Epoch [794/1000], train loss:0.4922, val loss:0.5130
Epoch [795/1000], train loss:0.4938, val loss:0.4950
Epoch [796/1000], train loss:0.4900, val loss:

Epoch [933/1000], train loss:0.4935, val loss:0.5364
Epoch [934/1000], train loss:0.4930, val loss:0.4958
Epoch [935/1000], train loss:0.4929, val loss:0.4990
Epoch [936/1000], train loss:0.4941, val loss:0.5169
Epoch [937/1000], train loss:0.4922, val loss:0.5216
Epoch [938/1000], train loss:0.4930, val loss:0.4987
Epoch [939/1000], train loss:0.4942, val loss:0.5065
Epoch [940/1000], train loss:0.4915, val loss:0.4965
Epoch [941/1000], train loss:0.4923, val loss:0.5035
Epoch [942/1000], train loss:0.4932, val loss:0.4940
Epoch [943/1000], train loss:0.4920, val loss:0.4987
Epoch [944/1000], train loss:0.4927, val loss:0.5010
Epoch [945/1000], train loss:0.4927, val loss:0.5008
Epoch [946/1000], train loss:0.4918, val loss:0.5047
Epoch [947/1000], train loss:0.4917, val loss:0.4962
Epoch [948/1000], train loss:0.4915, val loss:0.5142
Epoch [949/1000], train loss:0.4907, val loss:0.4947
Epoch [950/1000], train loss:0.4933, val loss:0.5091
Epoch [951/1000], train loss:0.4922, val loss:

In [7]:
ae.eval()
batch_loss = []
with torch.no_grad():
    for img, y in test_loader:
        img = img.view(img.size(0), -1)
        img, y = img.cuda(), y.cuda()
        output = ae(img)
        loss = ae_criterion(output, img)
        batch_loss.append(loss.item())
    print(f"Test loss:{np.mean(batch_loss)}")

Test loss:0.5085482958016122


In [8]:
classifier = nn.Sequential(
                nn.Linear(2, 64),
                nn.ReLU(True),
                nn.Linear(64, 128),
                nn.ReLU(True),
                nn.Linear(128, 10)).to(device)

In [9]:
classifier_criterion = nn.CrossEntropyLoss()
classifier_optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate)

In [10]:
for epoch in range(epochs):
    classifier.train()
    running_loss = 0.0
    running_acc = 0
    for img, y in train_loader:
        img = img.view(img.size(0), -1)
        img, y = img.to(device), y.to(device)
        # ===================forward=====================
        code_vector = ae.encoder(img)
        y_hat = classifier(code_vector)
        loss = classifier_criterion(y_hat, y)
        # ===================backward====================
        classifier_optimizer.zero_grad()
        loss.backward()
        classifier_optimizer.step()

        # print statistics
        running_loss += loss.item() * img.size(0)
        out = torch.argmax(y_hat.detach(),dim=1)
        assert out.shape==y.shape
        running_acc += (y==out).sum().item()
    print(f"Train Acc:{running_acc*100/len(train_dataset)}%")

    correct = 0
    classifier.eval()
    with torch.no_grad():
        for img, y in val_loader:
            img = img.view(img.size(0), -1)
            img, y = img.cuda(), y.cuda()
            code_vector = ae.encoder(img)
            y_hat = classifier(code_vector)
            y_hat = torch.argmax(y_hat, dim=1)
            acc = (y_hat==y).sum().item()
            correct += acc
    print(f"Val accuracy:{correct*100/len(val_dataset)}%")


Train Acc:77.206%
Val accuracy:83.34%
Train Acc:83.182%
Val accuracy:81.6%
Train Acc:81.982%
Val accuracy:83.57%
Train Acc:82.588%
Val accuracy:82.95%
Train Acc:82.994%
Val accuracy:80.84%
Train Acc:84.086%
Val accuracy:82.8%
Train Acc:83.51%
Val accuracy:83.81%
Train Acc:83.728%
Val accuracy:81.37%
Train Acc:83.798%
Val accuracy:83.11%
Train Acc:84.132%
Val accuracy:83.48%
Train Acc:83.448%
Val accuracy:85.29%
Train Acc:84.444%
Val accuracy:84.82%
Train Acc:84.314%
Val accuracy:82.7%
Train Acc:84.43%
Val accuracy:79.93%
Train Acc:84.442%
Val accuracy:84.8%
Train Acc:84.644%
Val accuracy:81.72%
Train Acc:84.188%
Val accuracy:84.17%
Train Acc:84.592%
Val accuracy:84.91%
Train Acc:84.816%
Val accuracy:83.78%
Train Acc:84.65%
Val accuracy:82.87%
Train Acc:84.8%
Val accuracy:84.03%
Train Acc:84.876%
Val accuracy:85.41%
Train Acc:84.478%
Val accuracy:81.91%
Train Acc:84.732%
Val accuracy:84.6%
Train Acc:84.78%
Val accuracy:84.53%
Train Acc:84.856%
Val accuracy:83.4%
Train Acc:84.872%
Val ac

Val accuracy:84.79%
Train Acc:85.504%
Val accuracy:85.59%
Train Acc:85.402%
Val accuracy:85.45%
Train Acc:85.462%
Val accuracy:85.3%
Train Acc:85.302%
Val accuracy:85.01%
Train Acc:85.48%
Val accuracy:84.66%
Train Acc:85.218%
Val accuracy:84.92%
Train Acc:85.354%
Val accuracy:83.66%
Train Acc:85.368%
Val accuracy:83.89%
Train Acc:85.24%
Val accuracy:84.27%
Train Acc:85.102%
Val accuracy:85.48%
Train Acc:85.448%
Val accuracy:85.43%
Train Acc:85.712%
Val accuracy:85.13%
Train Acc:85.078%
Val accuracy:85.64%
Train Acc:85.41%
Val accuracy:85.19%
Train Acc:84.864%
Val accuracy:85.09%
Train Acc:85.166%
Val accuracy:85.14%
Train Acc:85.316%
Val accuracy:85.35%
Train Acc:85.466%
Val accuracy:85.59%
Train Acc:85.54%
Val accuracy:85.52%
Train Acc:85.448%
Val accuracy:82.91%
Train Acc:85.296%
Val accuracy:84.52%
Train Acc:85.44%
Val accuracy:85.16%
Train Acc:85.326%
Val accuracy:84.68%
Train Acc:85.452%
Val accuracy:85.43%
Train Acc:85.404%
Val accuracy:84.62%
Train Acc:85.648%
Val accuracy:83.99

Val accuracy:85.3%
Train Acc:85.306%
Val accuracy:84.13%
Train Acc:85.118%
Val accuracy:85.59%
Train Acc:85.434%
Val accuracy:85.09%
Train Acc:85.492%
Val accuracy:85.43%
Train Acc:85.342%
Val accuracy:85.11%
Train Acc:85.298%
Val accuracy:84.87%
Train Acc:85.304%
Val accuracy:84.97%
Train Acc:85.454%
Val accuracy:85.45%
Train Acc:85.628%
Val accuracy:85.53%
Train Acc:85.024%
Val accuracy:85.47%
Train Acc:85.418%
Val accuracy:84.82%
Train Acc:85.034%
Val accuracy:84.37%
Train Acc:85.67%
Val accuracy:85.44%
Train Acc:84.898%
Val accuracy:84.72%
Train Acc:85.506%
Val accuracy:85.69%
Train Acc:85.254%
Val accuracy:85.73%
Train Acc:85.286%
Val accuracy:85.61%
Train Acc:85.548%
Val accuracy:84.73%
Train Acc:85.47%
Val accuracy:85.56%
Train Acc:85.144%
Val accuracy:85.19%
Train Acc:85.342%
Val accuracy:84.84%
Train Acc:85.394%
Val accuracy:85.61%
Train Acc:85.602%
Val accuracy:85.44%
Train Acc:84.832%
Val accuracy:85.42%
Train Acc:85.0%
Val accuracy:84.91%
Train Acc:85.398%
Val accuracy:84.4

Train Acc:85.552%
Val accuracy:85.5%
Train Acc:85.468%
Val accuracy:82.77%
Train Acc:85.178%
Val accuracy:85.33%
Train Acc:85.23%
Val accuracy:84.29%
Train Acc:85.362%
Val accuracy:85.0%
Train Acc:85.358%
Val accuracy:84.97%
Train Acc:85.19%
Val accuracy:85.05%
Train Acc:85.256%
Val accuracy:85.2%
Train Acc:84.86%
Val accuracy:84.75%
Train Acc:85.224%
Val accuracy:84.83%
Train Acc:85.3%
Val accuracy:83.54%
Train Acc:85.21%
Val accuracy:82.02%
Train Acc:84.902%
Val accuracy:85.11%
Train Acc:85.41%
Val accuracy:83.73%
Train Acc:85.286%
Val accuracy:85.39%
Train Acc:85.388%
Val accuracy:84.89%
Train Acc:85.39%
Val accuracy:84.97%
Train Acc:85.524%
Val accuracy:84.54%
Train Acc:85.364%
Val accuracy:84.74%
Train Acc:85.562%
Val accuracy:84.92%
Train Acc:85.052%
Val accuracy:80.58%
Train Acc:85.224%
Val accuracy:84.5%
Train Acc:85.316%
Val accuracy:85.27%
Train Acc:85.074%
Val accuracy:85.3%
Train Acc:85.122%
Val accuracy:84.6%
Train Acc:84.304%
Val accuracy:82.89%
Train Acc:85.194%
Val accu

Train Acc:84.874%
Val accuracy:83.77%
Train Acc:84.842%
Val accuracy:84.18%
Train Acc:84.71%
Val accuracy:84.57%
Train Acc:85.122%
Val accuracy:84.75%
Train Acc:84.932%
Val accuracy:84.87%
Train Acc:84.94%
Val accuracy:84.24%
Train Acc:85.14%
Val accuracy:83.86%
Train Acc:84.954%
Val accuracy:85.15%
Train Acc:84.414%
Val accuracy:84.62%
Train Acc:84.658%
Val accuracy:82.32%
Train Acc:84.748%
Val accuracy:84.47%
Train Acc:84.75%
Val accuracy:84.01%
Train Acc:84.78%
Val accuracy:83.82%
Train Acc:84.73%
Val accuracy:85.25%
Train Acc:84.234%
Val accuracy:84.55%
Train Acc:84.74%
Val accuracy:84.81%
Train Acc:84.758%
Val accuracy:84.6%
Train Acc:84.682%
Val accuracy:84.58%
Train Acc:84.806%
Val accuracy:84.69%
Train Acc:84.69%
Val accuracy:84.0%
Train Acc:84.624%
Val accuracy:84.59%
Train Acc:84.466%
Val accuracy:84.26%
Train Acc:84.642%
Val accuracy:85.25%
Train Acc:84.934%
Val accuracy:84.47%
Train Acc:83.978%
Val accuracy:84.55%
Train Acc:84.73%
Val accuracy:83.36%
Train Acc:84.61%
Val ac

In [11]:
# evaluate your model on the never-seen-before test data
classifier.eval()
correct = 0
with torch.no_grad():
    for img, y in test_loader:
        img = img.view(img.size(0), -1)
        img, y = img.cuda(), y.cuda()
        code_vector = ae.encoder(img)
        y_hat = classifier(code_vector)
        y_hat = torch.argmax(y_hat, dim=1)
        acc = (y_hat==y).sum().item()
        correct += acc
    print(f"Test accuracy:{correct*100/len(test_dataset)}%")

Test accuracy:84.75%
