In [1]:
import os 

os.chdir('/app/dpl/')
print(os.listdir())

from data.dataset import MNIST, TorchDataset
from torch.utils.data import DataLoader

mnist = MNIST()
mnist.train.drop_samples(0.5)
mnist.test.drop_samples(0.8)
mnist_train_set = mnist.train
mnist_test_set = mnist.test

mnist_train_dataloader = DataLoader(mnist_train_set, batch_size=20)
mnist_test_dataloader = DataLoader(mnist_test_set, batch_size=1)

['snapshot', 'data', 'results.ipynb', 'logs', 'neural_baseline', 'dpl']


In [2]:
import numpy as np
import torch

from neural_baseline.lenet import LeNet, train, run_test

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

seed = 42

torch.manual_seed(seed)
g = torch.Generator()
g.manual_seed(seed)

encoder = LeNet(n=10)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

num_epochs = 10
current_epoch = 0

for _ in range(num_epochs):
    print(f"Epoch {current_epoch + 1}\n-------------------------------")
    train(encoder, mnist_train_dataloader, loss_fn, optimizer, device, current_epoch, r'logs/tmp.csv')
    run_test(encoder, mnist_test_dataloader, loss_fn, device, current_epoch, r'logs/tmp.csv')
    current_epoch += 1
print("Done!")

Using cpu device
Epoch 1
-------------------------------
loss: 2.292399  [   10/ 1500]
loss: 2.251404  [   20/ 1500]
loss: 2.174598  [   30/ 1500]
loss: 2.100462  [   40/ 1500]
loss: 1.972739  [   50/ 1500]
loss: 1.857429  [   60/ 1500]
loss: 1.816437  [   70/ 1500]
loss: 1.767038  [   80/ 1500]
loss: 1.782081  [   90/ 1500]
loss: 1.748628  [  100/ 1500]
loss: 1.728826  [  110/ 1500]
loss: 1.727902  [  120/ 1500]
loss: 1.737071  [  130/ 1500]
loss: 1.674755  [  140/ 1500]
loss: 1.673000  [  150/ 1500]
loss: 1.665255  [  160/ 1500]
loss: 1.669287  [  170/ 1500]
loss: 1.642205  [  180/ 1500]
loss: 1.632200  [  190/ 1500]
loss: 1.612558  [  200/ 1500]
loss: 1.619659  [  210/ 1500]
loss: 1.602475  [  220/ 1500]
loss: 1.602597  [  230/ 1500]
loss: 1.612812  [  240/ 1500]
loss: 1.628534  [  250/ 1500]
loss: 1.598486  [  260/ 1500]
loss: 1.618499  [  270/ 1500]
loss: 1.598336  [  280/ 1500]
loss: 1.590416  [  290/ 1500]
loss: 1.627841  [  300/ 1500]
loss: 1.600454  [  310/ 1500]
loss: 1.59531

In [3]:
from data.dataset import TorchDataset

mnist_sub = TorchDataset(r'data/metadata/full_trainset.csv', r'data/metadata/test_80.csv')

mnist_sub_train_dataloader = DataLoader(mnist_sub.train_set, batch_size=20)
mnist_sub_test_dataloader = DataLoader(mnist_sub.test_set, batch_size=1)

In [4]:
from neural_baseline.mlp import *

torch.manual_seed(seed)
g = torch.Generator()
g.manual_seed(seed)

encoder = MLP(encoder, n=19)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

num_epochs = 10
current_epoch = 0

for t in range(num_epochs):
    print(f"Epoch {current_epoch + 1}\n-------------------------------")
    train(encoder, mnist_sub_train_dataloader, loss_fn, optimizer, device, current_epoch, r'logs/MLP_train_full.csv')
    run_test(encoder, mnist_sub_train_dataloader, loss_fn, device, current_epoch, r'logs/MLP_test_full.csv')
    current_epoch += 1
print("Done!")

Epoch 1
-------------------------------
loss: 2.948983  [   10/ 2999]
loss: 2.929049  [   20/ 2999]
loss: 2.944313  [   30/ 2999]
loss: 2.881559  [   40/ 2999]
loss: 2.888810  [   50/ 2999]
loss: 2.900779  [   60/ 2999]
loss: 2.874713  [   70/ 2999]
loss: 2.857065  [   80/ 2999]
loss: 2.849841  [   90/ 2999]
loss: 2.849207  [  100/ 2999]
loss: 2.818681  [  110/ 2999]
loss: 2.834379  [  120/ 2999]
loss: 2.813931  [  130/ 2999]
loss: 2.762838  [  140/ 2999]
loss: 2.788597  [  150/ 2999]
loss: 2.764090  [  160/ 2999]
loss: 2.760105  [  170/ 2999]
loss: 2.745203  [  180/ 2999]
loss: 2.723538  [  190/ 2999]
loss: 2.690132  [  200/ 2999]
loss: 2.738022  [  210/ 2999]
loss: 2.722848  [  220/ 2999]
loss: 2.700193  [  230/ 2999]
loss: 2.682294  [  240/ 2999]
loss: 2.673448  [  250/ 2999]
loss: 2.642087  [  260/ 2999]
loss: 2.649398  [  270/ 2999]
loss: 2.620746  [  280/ 2999]
loss: 2.609026  [  290/ 2999]
loss: 2.583212  [  300/ 2999]
loss: 2.530556  [  310/ 2999]
loss: 2.574799  [  320/ 2999]
