In [1]:
# %%capture
# ! pip install pytorch-lightning
# ! pip install pytorch-lightning-bolts

In [2]:
# !pip install torchvision

In [3]:
import os

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from torchmetrics.functional import accuracy
from pl_bolts.datasets import DummyDataset

In [4]:
train = DummyDataset((1, 28, 28), (1,))
train = DataLoader(train, batch_size=32)

In [5]:
val = DummyDataset((1, 28, 28), (1,))
val = DataLoader(val, batch_size=32)

In [6]:
test = DummyDataset((1, 28, 28), (1,))
test = DataLoader(test, batch_size=32)

In [7]:
class LitAutoEncoder(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

    def training_step(self, batch, batch_idx, optimizer_idx=None):
        # --------------------------
        # REPLACE WITH YOUR OWN
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('train_loss', loss)
        return loss
        # --------------------------

    def validation_step(self, batch, batch_idx):
        # --------------------------
        # REPLACE WITH YOUR OWN
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('val_loss', loss)
        # --------------------------

    def test_step(self, batch, batch_idx):
        # --------------------------
        # REPLACE WITH YOUR OWN
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('test_loss', loss)
        # --------------------------

    def configure_optimizers(self):
        optimizer_1 = torch.optim.Adam(self.parameters(), lr=1e-3)
        optimizer_2 = torch.optim.Adam(self.parameters(), lr=1e-3)
        return [optimizer_1, optimizer_2]

In [11]:
# init model
ae = LitAutoEncoder()

# Initialize a trainer
trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)

# Train the model ⚡
trainer.fit(ae, train, val)

  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5]

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 100 K 
1 | decoder | Sequential | 101 K 
---------------------------------------
202 K     Trainable params
0         Non-trainable params
202 K     Total params
0.810     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [9]:
trainer.test(dataloaders=test)

  rank_zero_warn(
Restoring states from the checkpoint path at /data/users/cting3/CodeNest/code-style-probing/lightning_logs/version_1/checkpoints/epoch=2-step=1878.ckpt
Loaded model weights from checkpoint at /data/users/cting3/CodeNest/code-style-probing/lightning_logs/version_1/checkpoints/epoch=2-step=1878.ckpt
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.08358001708984375
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.08358001708984375}]

In [10]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/