# Data

In [24]:
import torchvision.transforms as T
import torchvision
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
import random

download_root = './MNIST_DATASET'

mnist_transform = T.Compose([
    T.ToTensor(),
])

train_dataset = torchvision.datasets.MNIST(download_root, transform=mnist_transform, train=True, download=False)
test_dataset = torchvision.datasets.MNIST(download_root, transform=mnist_transform, train=False, download=False) 

total_size = len(train_dataset)
batch_size = 32
batch_num = total_size // batch_size
labeled_batch_idx = random.choices(range(batch_num), k = int(batch_num * 0.1))
batch_idx = -1
def collate_fn(batch):
    global batch_idx
    images, labels = zip(*batch)  # 배치 내 샘플을 이미지와 라벨로 분리

    batch_idx += 1
    if batch_idx in labeled_batch_idx:
        labels = torch.tensor(labels)  # 라벨 리스트를 배치 형태로 쌓음
    else:
        labels = None

    images = torch.stack(images, dim=0)  # 이미지 리스트를 배치 형태로 쌓음
    return images, labels

valid_indices = random.sample(list(range(total_size)), int(total_size * 0.1))

train_dataloader = DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle= True)
shuffled_train_dataloader = DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle= True, collate_fn = collate_fn)
valid_dataloader = DataLoader(dataset = train_dataset, batch_size = batch_size, sampler = SubsetRandomSampler(valid_indices))
test_dataloader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)

# Model

In [6]:
import torch
import torch.nn as nn
from  pytorch_lightning import LightningModule 
import torch.optim as optim
import torchmetrics

In [34]:
class AutoEncoderLinear(LightningModule):
    def __init__(self, symetric_dimensions, lr):
        super().__init__()

        self.lr = lr
        self.criterion = nn.MSELoss()

        encoder_layer = []
        decoder_layer = []

        self.active = nn.ReLU()
        self.sig_act = nn.Sigmoid()
        
        for idx, dim in enumerate(symetric_dimensions[1:]):
            input = symetric_dimensions[idx]
            fc_layer = nn.Linear(input, dim)
            encoder_layer.append(fc_layer)
            encoder_layer.append(self.active)

        reverse_symetric_dimensions = symetric_dimensions[::-1]
        for idx, dim in enumerate(reverse_symetric_dimensions[1:]):
            input = reverse_symetric_dimensions[idx]
            fc_layer = nn.Linear(input, dim)
            decoder_layer.append(fc_layer)

            # 마지막 레이어는 sigmoid(0~1 값으로 출력하기 위해서)
            if idx == len(reverse_symetric_dimensions) - 2:
                decoder_layer.append(self.sig_act)
            else:
                decoder_layer.append(self.active)

        self.encoder = nn.Sequential(*encoder_layer)
        self.decoder = nn.Sequential(*decoder_layer)

    def forward(self, x) :
        '''
        INPUT
            x : [batch_size, 1, 28, 28]
        OUTPUT
            out : [batch_size, 28 * 28]
        '''
        x = x.view(x.size(0), -1) # [batch_size, 28 * 28]
        out = self.encoder(x)
        out = self.decoder(out)

        return out


    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr = self.lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
        return [optimizer], [scheduler]

    def training_step(self, batch):
        data, label = batch
        
        y = data.view(data.size(0), -1)
        y_hat = self(data)

        loss = self.criterion(y_hat, y)

        self.log(f"train_loss", loss, on_step=True, on_epoch=True, logger=True)

        return loss

    def validation_step(self, batch):
        data, label = batch
        
        y = data.view(data.size(0), -1)
        y_hat = self(data)

        loss = self.criterion(y_hat, y)

        self.log(f"valid_loss", loss, on_step=True, on_epoch=True, logger=True)

        return loss

    def test_step(self, batch):
        data, label = batch
        
        y = data.view(data.size(0), -1)
        y_hat = self(data)

        loss = self.criterion(y_hat, y)

        self.log(f"test_loss", loss, on_step=True, on_epoch=True, logger=True)

        return loss

    def predict_step(self, x):
        x = x.view(x.size(0), -1) 
        latent_vec = self.encoder(x)

        return latent_vec # [batch_size, 28]

# Train & Test

In [35]:
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer

In [36]:

symetric_dimensions = [784, 112, 56, 28]
auto_encoder_model = AutoEncoderLinear(symetric_dimensions=symetric_dimensions, lr = 0.001)

early_stopping = EarlyStopping(monitor='valid_loss', mode='min', patience=5)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
wandb_logger = WandbLogger(name = 'AutoEncoderLinear')

trainer = Trainer(
    max_epochs = 50,
    accelerator = 'auto',
    callbacks= [ early_stopping, lr_monitor],
    logger = wandb_logger
)

trainer.fit(auto_encoder_model, 
            train_dataloader,
            valid_dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/anaconda3/lib/python3.11/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.

  | Name      | Type       | Params | Mode 
-------------------------------------------------
0 | criterion | MSELoss    | 0      | train
1 | active    | ReLU       | 0      | train
2 | sig_act   | Sigmoid    | 0      | train
3 | encoder   | Sequential | 95.8 K | train
4 | decoder   | Sequential | 96.6 K | train
-------------------------------------------------
192 K     Trainable params
0         Non-trainable params
192 K     Total params
0.770     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


In [37]:
trainer.test(auto_encoder_model, test_dataloader)

Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss_epoch': 0.0157100111246109}]

# Classifier

In [51]:
class AutoEnocderClassifier(LightningModule):
    def __init__(self, auto_encoder, num_classes = 10, lr = 0.001):
        super().__init__()

        self.lr = lr

        self.criterion = nn.CrossEntropyLoss()
        self.acc = torchmetrics.Accuracy(task='multiclass', num_classes = num_classes)

        self.auto_encoder = auto_encoder
        self.fc = nn.Linear(28, num_classes)

        self.batch_counter = 0
        self.epoch_counter = 0
        
    def forward(self, x): 
        '''
        INPUT
            x : [batch_size, 1 , 28, 28]
        OUTPUT
            out : [batch_size, 10]
        '''
        latent = self.auto_encoder.predict_step(x) # [batch_size, 28]
        out = self.fc(latent)
        
        return out

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr= self.lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer = optimizer, step_size = 5, gamma = 0.5)
        
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        self.batch_counter += 1

        if self.batch_counter % 32 == 0:
            self.epoch_counter += 1

        image, label =  batch

        if label is None:
            if self.epoch_counter % 3 == 0:
                label = self.predict_step(batch)
            else:
                return None

        y_predict = self(image) # [batch_size, 10]

        loss = self.criterion(y_predict, label)

        _, pred = torch.max(y_predict, dim = 1)
        acc = self.acc(pred, label)

        self.log("train loss", loss, on_step = True, on_epoch = True,  logger = True)
        self.log("train Accuracy", acc, on_step = True, on_epoch = True, logger = True)

        return loss
    
    def validation_step(self, batch):
        image, label =  batch

        y_predict = self(image) # [batch_size, 10]

        loss = self.criterion(y_predict, label)

        _, pred = torch.max(y_predict, dim = 1)
        acc = self.acc(pred, label)

        self.log("valid_loss", loss, on_step = True, on_epoch = True,  logger = True)
        self.log("valid_acc", acc, on_step = True, on_epoch = True, logger = True)

        return loss
        
    def test_step(self, batch):
        image, label =  batch

        y_predict = self(image) # [batch_size, 10]

        loss = self.criterion(y_predict, label)

        _, pred = torch.max(y_predict, dim = 1)
        acc = self.acc(pred, label)

        self.log("test loss", loss, on_step = True, on_epoch = True,  logger = True)
        self.log("test Accuracy", acc, on_step = True, on_epoch = True, logger = True)

        return acc

    def predict_step(self, batch):
        data, label = batch

        y_predict = self(data) # [batch_size, 10]

        _, pred = torch.max(y_predict, dim = 1)
        
        return pred

In [52]:
auto_encoder_classifier = AutoEnocderClassifier(auto_encoder = auto_encoder_model)
wandb_logger = WandbLogger(name = 'AutoEncoder Classifier')

trainer = Trainer(
    max_epochs = 30,
    accelerator = 'auto',
    callbacks = [lr_monitor],
    logger = wandb_logger
)

trainer.fit(auto_encoder_classifier, shuffled_train_dataloader, valid_dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/anaconda3/lib/python3.11/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/opt/anaconda3/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:652: Checkpoint directory ./lightning_logs/5m84csym/checkpoints exists and is not empty.

  | Name         | Type               | Params | Mode 
------------------------------------------------------------
0 | criterion    | CrossEntropyLoss   | 0      | train
1 | acc          | MulticlassAccuracy | 0      | train
2 | auto_encoder | AutoEncoderLinear  | 192 K  | eval 
3 | fc           | Linear             | 290    | train
------------------------------------------------------------
192 K     Trainable params
0     

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/opt/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


In [50]:
trainer.test(auto_encoder_classifier, test_dataloader)

/opt/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test loss_epoch': 0.13764595985412598,
  'test Accuracy_epoch': 0.9638000726699829}]