In [1]:
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ds = load_dataset("ylecun/mnist")

Downloading readme: 100%|██████████| 6.97k/6.97k [00:00<00:00, 7.05MB/s]
Downloading data: 100%|██████████| 15.6M/15.6M [00:07<00:00, 2.09MB/s]
Downloading data: 100%|██████████| 2.60M/2.60M [00:00<00:00, 2.75MB/s]
Generating train split: 100%|██████████| 60000/60000 [00:00<00:00, 82503.65 examples/s]
Generating test split: 100%|██████████| 10000/10000 [00:00<00:00, 100974.62 examples/s]


In [None]:
def imageAtIndex(index):
    print(ds['train']['image'][index], ds['train']['label'][index])

In [8]:
imageAtIndex(2)

<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x22A861AD310> 4


In [9]:
import torch
import torch.nn as nn

In [10]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder,self).__init__()

        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784,128),
            nn.ReLU(),
            nn.Linear(128,64),
            nn.ReLU(),
            nn.Linear(64,32),
        )

        self.decoder = nn.Sequential(
            nn.Linear(32,64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 784),
            nn.Sigmoid()
        )
    
    def forward(self,x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [11]:
model = AutoEncoder()

In [12]:
criterion = nn.MSELoss()

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [16]:
model.to(device)

AutoEncoder(
  (encoder): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=784, out_features=128, bias=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=64, bias=True)
    (4): ReLU()
    (5): Linear(in_features=64, out_features=32, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=32, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=784, bias=True)
    (5): Sigmoid()
  )
)

In [17]:
num_epochs = 20

In [19]:
from torchvision import transforms

In [20]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

In [22]:
for epoch in range(num_epochs):
    model.train() 
    train_loss = 0.0 
    
    for batch in ds["train"]:  
        inputs = transform(batch["image"]).unsqueeze(0).float() 
        inputs = inputs.to(device) 
        
        inputs_flat = inputs.view(inputs.size(0), -1)
        
        outputs = model(inputs_flat)
        
        outputs_reshaped = outputs.view(inputs.size(0), 1, 28, 28)
        
        loss = criterion(outputs_reshaped, inputs)
        
        optimizer.zero_grad()  
        loss.backward()        
        optimizer.step()
        
        train_loss += loss.item() 

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss / len(ds['train'])}")

Epoch [1/20], Loss: 0.025724028729721127
Epoch [2/20], Loss: 0.018425951996253572
Epoch [3/20], Loss: 0.016289797009700367
Epoch [4/20], Loss: 0.015077385172684444
Epoch [5/20], Loss: 0.014452993810085657
Epoch [6/20], Loss: 0.014072945572167131
Epoch [7/20], Loss: 0.013792522682270889
Epoch [8/20], Loss: 0.013620988921096432
Epoch [9/20], Loss: 0.013479727258562344
Epoch [10/20], Loss: 0.013343840716884006
Epoch [11/20], Loss: 0.01325978291054489
Epoch [12/20], Loss: 0.013176993361789695
Epoch [13/20], Loss: 0.01310378706776731
Epoch [14/20], Loss: 0.013064102615062924
Epoch [15/20], Loss: 0.012985477360788112
Epoch [16/20], Loss: 0.012945383848281927
Epoch [17/20], Loss: 0.012892594518507636
Epoch [18/20], Loss: 0.0128848624984336
Epoch [19/20], Loss: 0.012869797686257517
Epoch [20/20], Loss: 0.012820577094336235


In [24]:
torch.save(model.state_dict(), "autoencoder.pth")

In [23]:
import matplotlib.pyplot as plt

def visualize_reconstruction(model, dataset):
    model.eval() 
    with torch.no_grad():
        for batch in dataset["test"]:
            inputs = transform(batch["image"]).unsqueeze(0).float().to(device)
            inputs_flat = inputs.view(inputs.size(0), -1)
            
            outputs = model(inputs_flat)
            outputs_reshaped = outputs.view(inputs.size(0), 1, 28, 28)
            
            fig, axs = plt.subplots(1, 2)
            axs[0].imshow(inputs.cpu().squeeze(), cmap="gray")
            axs[0].set_title("Original")
            
            axs[1].imshow(outputs_reshaped.cpu().squeeze(), cmap="gray")
            axs[1].set_title("Reconstructed")
            
            plt.show()
            break

In [25]:
model.load_state_dict(torch.load("autoencoder.pth"))
model.eval()

AutoEncoder(
  (encoder): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=784, out_features=128, bias=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=64, bias=True)
    (4): ReLU()
    (5): Linear(in_features=64, out_features=32, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=32, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=784, bias=True)
    (5): Sigmoid()
  )
)