In [46]:
from functools import partial
from pathlib import Path
from subprocess import check_call
from typing import Any, Callable, Optional, Tuple, Union

import numpy as np
import torch
from lightly.transforms import SimCLRTransform
from numpy.lib.npyio import NpzFile
from numpy.random import Generator
from PIL import Image
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST, VisionDataset
from torchvision.transforms import Compose, Normalize, ToTensor
from torch.utils.data import DataLoader
from lightly.data import LightlyDataset
from torch import nn
from torch.optim import SGD
from itertools import chain
from models import ConvNet
from lightly.models.modules.heads import SimCLRProjectionHead
from lightly.loss import NTXentLoss
from functools import partial
from pathlib import Path
from torchvision.datasets import MNIST

In [47]:
device = "cpu"
dataset = partial(MNIST, root=Path("mnist_data"), download=True)
train_set = dataset(train=True)  # Loads the training set
test_set = dataset(train=False)  # Loads the test set

train_transform = SimCLRTransform(
    input_size=28,
    min_scale=0.5,
    hf_prob=0,
    rr_prob=0,
    vf_prob=0,
    normalize=dict(mean=[0.1307], std=[0.3081]),
)

test_transform = Compose([ToTensor(), Normalize(mean=[0.1307], std=[0.3081])])

# Step 1: Transform the train and test sets with the respective transformations
train_set = LightlyDataset.from_torch_dataset(train_set, transform=train_transform)
test_set = LightlyDataset.from_torch_dataset(test_set, transform=test_transform)

# Step 2: Create DataLoader for both train and test sets
train_loader = DataLoader(train_set, batch_size=1_024, shuffle=True, drop_last=True)
test_loader = DataLoader(test_set, batch_size=1_024, shuffle=False)

# Step 3: Define the encoder (ConvNet)
encoder = ConvNet(input_shape=[28, 28], output_size=128)
encoder = encoder.to(device)

# Step 4: Define the projection head (SimCLRProjectionHead)
proj_head = SimCLRProjectionHead(input_dim=128, hidden_dim=512)
proj_head = proj_head.to(device)

# Step 5: Define the loss function (NTXentLoss)
loss_fn = NTXentLoss()

# Step 6: Set up the optimizer
params = chain(encoder.parameters(), proj_head.parameters())
optimizer = SGD(params, lr=1, weight_decay=1e-4)


In [48]:
from pandas import DataFrame

def get_next(dataloader: DataLoader) -> Union[Tensor, Tuple]:
    try:
        return next(dataloader)
    except:
        dataloader = iter(dataloader)
        return next(dataloader)

def save_table(path: Union[Path, str], table: DataFrame, formatting: dict) -> None:
    for key in formatting:
        if key in table:
            table[key] = table[key].apply(formatting[key])

    table.to_csv(path, index=False)

### Train the Model

In [53]:
import pandas as pd

steps, losses = [], []
n_optim_steps = 10

for step in range(n_optim_steps):
    (inputs_0, inputs_1), _, _ = get_next(train_loader)

    embeddings_0 = proj_head(encoder(inputs_0.to(device)))
    embeddings_1 = proj_head(encoder(inputs_1.to(device)))

    loss = loss_fn(embeddings_0, embeddings_1)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    steps += [step]
    losses += [loss.item()]
    print(f"Step {step}: loss = {loss}")

Step 0: loss = 7.007874011993408
Step 1: loss = 6.918391704559326
Step 2: loss = 7.011937141418457
Step 3: loss = 7.296550750732422
Step 4: loss = 7.077115535736084
Step 5: loss = 6.87874698638916
Step 6: loss = 6.914821624755859
Step 7: loss = 6.848741054534912
Step 8: loss = 6.859347820281982
Step 9: loss = 6.8735151290893555


In [61]:
print("Saving results...")
train_log = pd.DataFrame({"step": steps, "loss": losses})
formatting = {"loss": "{:.4f}".format}
save_table("train_encoders/pretraining.csv", train_log, formatting)

torch.save(encoder.state_dict(), "train_encoders/encoder.pth")
torch.save(proj_head.state_dict(), "train_encoders/projection_head.pth")

Saving results...
