In [30]:
import torch
device = torch.device("mps") if torch.backends.mps.is_available() else  torch.device("cpu")
print('Device:', device)

Device: mps


In [31]:
import os
NUM_WORKERS = os.cpu_count()
print("Number of workers:", NUM_WORKERS)

Number of workers: 8


In [32]:
import ssl
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

ssl._create_default_https_context = ssl._create_unverified_context

data_path = os.getcwd()

train_dataset = MNIST(
    data_path,
    train=True,
    download=False,
    transform=transforms.ToTensor()
)

test_dataset = MNIST(
    data_path,
    train=False,
    download=False,
    transform=transforms.ToTensor()
)


In [6]:
print(test_dataset.data.shape)

torch.Size([10000, 28, 28])


In [7]:
X_test = test_dataset.data.view(len(test_dataset), 1, 28, 28)
print(X_test.shape)

torch.Size([10000, 1, 28, 28])


In [33]:
import torch.utils.data as data

# use 20% of training data for validation
train_set_size = int(len(train_dataset) * 0.8)
valid_set_size = len(train_dataset) - train_set_size

seed = torch.Generator().manual_seed(42)

train_set, valid_set = data.random_split(
    train_dataset,
    [train_set_size,valid_set_size],
    generator=seed
)

print('train_set_size:{}, valid_set_size:{}'.format(train_set_size, valid_set_size))

train_set_size:48000, valid_set_size:12000


In [34]:
batch_size = 100

In [35]:

train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    num_workers=NUM_WORKERS,
    persistent_workers=True
)
print('train_loader || batch_size:{}, batch_count:{}'.format(batch_size, len(train_loader)))

train_loader || batch_size:100, batch_count:480


In [36]:
valid_loader = DataLoader(
    valid_set,
    batch_size=batch_size,
    shuffle=False,
    drop_last=True,
    pin_memory=True,
    num_workers=NUM_WORKERS,
    persistent_workers=True
)
print('valid_loader || batch_size:{}, batch_count:{}'.format(batch_size, len(valid_loader)))

valid_loader || batch_size:100, batch_count:120


In [37]:
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=True,
    pin_memory=True,
    num_workers=NUM_WORKERS
)
print('test_loader || batch_size:{}, batch_count:{}'.format(batch_size, len(test_loader)))

test_loader || batch_size:100, batch_count:100


In [38]:
from torch import nn

class CNN(nn.Module):
    
    def __init__(self):
        super(CNN, self).__init__()
        
        self.l1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1, 
                out_channels=32, 
                kernel_size=3, 
                stride=1, 
                padding=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(
                kernel_size=2, 
                stride=2
            )
        )

        self.l2 = nn.Sequential(
            nn.Conv2d(
                in_channels=32, 
                out_channels=64, 
                kernel_size=3, 
                stride=1, 
                padding=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(
                kernel_size=2, 
                stride=2
            )
        )


        self.fc1 = nn.Linear(                
            in_features=7*7*64, 
            out_features=1024, 
            bias=True
        )
        nn.init.xavier_uniform_(self.fc1.weight)

        self.l3 = nn.Sequential(
            self.fc1,
            nn.ReLU(),
            nn.Dropout(p=0.5)
        )

        self.fc2 = nn.Linear(
            in_features=1024, 
            out_features=10
        )
        nn.init.xavier_uniform_(self.fc2.weight)

        self.flatten = nn.Flatten()
        
        

    def forward(self, x):
        out = self.l1(x)
        out = self.l2(out)
        out = self.flatten(out)
        out = self.l3(out)
        out = self.fc2(out)
        return out

In [108]:
from torchmetrics.classification import MulticlassAccuracy
import lightning as L
import torch.nn.functional as F 

class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, lr=1e-3):
        super().__init__()
        self.encoder = encoder
        
        self.training_metric = MulticlassAccuracy( num_classes=10)
        self.validation_metric = MulticlassAccuracy( num_classes=10)
        self.test_metric = MulticlassAccuracy( num_classes=10)

        self.lr = lr 

        self.loss_fn = nn.CrossEntropyLoss()


    
    def training_step(self, batch, batch_idx):
        x_batch, y_batch = batch
        pred = self.encoder(x_batch)
        acc = self.training_metric( pred, y_batch)
        loss = self.loss_fn( pred, y_batch)
        values = {"loss": loss, "acc": acc}  
        self.log_dict(values)
        
        return {"loss":loss, 'acc': acc}


    def on_train_epoch_end(self):
        acc = self.training_metric.compute()
        self.log('total_training_acc', acc, prog_bar=True)
        self.training_metric.reset()
        return acc
    
    def validation_step(self, batch, batch_idx):
        x_batch, y_batch = batch
        pred = self.encoder(x_batch)
        acc = self.validation_metric( pred , y_batch)
        loss = self.loss_fn( pred, y_batch)
        values = {"loss": loss, "acc": acc}  
        self.log_dict(values, prog_bar=True)
        return {"loss":loss, 'acc': acc} 

    
    def on_validation_epoch_end(self):
        acc = self.validation_metric.compute()
        self.log('total_validation_acc', acc, prog_bar=True)
        self.validation_metric.reset()
        return acc


    def test_step(self, batch, batch_idx):
        x_batch, y_batch = batch
        pred = self.encoder(x_batch)
        acc = self.test_metric( pred , y_batch)
        loss = self.loss_fn(pred, y_batch)
        values = {"loss": loss, "acc": acc}  
        self.log_dict(values)
        return {"loss":loss, 'acc': acc}  

    
    def on_test_epoch_end(self):
        acc = self.test_metric.compute()
        self.log('total_test_acc', acc, prog_bar=True)
        self.test_metric.reset()
        return acc


    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer


    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        self.encoder.eval()
        with torch.no_grad():
            x_batch, y_batch = batch
            pred = self.encoder(x_batch)
        return pred



In [40]:
autoencoder = LitAutoEncoder(CNN()).to(device)

In [41]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

callback_loss=EarlyStopping(
    monitor='loss',
    mode='min'
)

callback_accuracy=EarlyStopping(
    monitor='acc',
    min_delta=0.00, 
    patience=3, 
    verbose=False, 
    mode="max"
)

In [None]:
trainer = L.Trainer(
    fast_dev_run=True,
    accelerator='mps',
    devices=1
)

In [None]:
trainer = L.Trainer(
    default_root_dir=data_path,
    limit_train_batches=0.1, # limit_train_batches=10
    limit_val_batches=0.01, # limit_val_batches=5
  #  callbacks=[callback_loss],
    precision="bf16-mixed",
    max_epochs=10,
    accelerator='mps',
    devices=1
)


In [42]:
trainer = L.Trainer(
    default_root_dir=data_path,
    num_sanity_val_steps=2,
    check_val_every_n_epoch=5,
    callbacks=[callback_accuracy],
    precision="bf16-mixed",
    accelerator='mps',
    devices=1,
    max_epochs=100
)


Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [43]:

trainer.fit(
    autoencoder,
    train_loader, 
    valid_loader
)


Missing logger folder: /Users/unchil/PythonProjects/lightning_logs

  | Name              | Type               | Params
---------------------------------------------------------
0 | encoder           | CNN                | 3.2 M 
1 | training_metric   | MulticlassAccuracy | 0     
2 | validation_metric | MulticlassAccuracy | 0     
3 | test_metric       | MulticlassAccuracy | 0     
4 | loss_fn           | CrossEntropyLoss   | 0     
---------------------------------------------------------
3.2 M     Trainable params
0         Non-trainable params
3.2 M     Total params
12.965    Total estimated model params size (MB)


Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

In [44]:
trainer.test(
    autoencoder,
    test_loader
)

/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'test_dataloader' to speed up the dataloader worker initialization.


Testing: |                                                                                                    …

[{'loss': 0.033184271305799484,
  'acc': 0.9940900206565857,
  'total_test_acc': 0.9943863153457642}]

In [None]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/

In [46]:
checkpoint = '/Users/unchil/PythonProjects/lightning_logs/version_0/checkpoints/epoch=29-step=14400.ckpt'

In [106]:
model = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=CNN()).to(device)
trainer = L.Trainer(accelerator='mps', devices=1)
pred = trainer.predict(model, test_loader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Predicting: |                                                                                                 …