In [46]:
from functools import partial
from pathlib import Path
from subprocess import check_call
from typing import Any, Callable, Optional, Tuple, Union

import numpy as np
import torch
from lightly.transforms import SimCLRTransform
from numpy.lib.npyio import NpzFile
from numpy.random import Generator
from PIL import Image
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST, VisionDataset
from torchvision.transforms import Compose, Normalize, ToTensor
from torch.utils.data import DataLoader
from lightly.data import LightlyDataset
from torch import nn
from torch.optim import SGD
from itertools import chain
from models import ConvNet
from lightly.models.modules.heads import SimCLRProjectionHead
from lightly.loss import NTXentLoss
from functools import partial
from pathlib import Path
from torchvision.datasets import MNIST

In [106]:
device = "cpu"
dataset_class = partial(MNIST, root=Path("mnist_data"), download=True)

train_transform = SimCLRTransform(
    input_size=28,
    min_scale=0.5,
    hf_prob=0,
    rr_prob=0,
    vf_prob=0,
    normalize=dict(mean=[0.1307], std=[0.3081]),
)

test_transform = Compose([ToTensor(), Normalize(mean=[0.1307], std=[0.3081])])

train_dataset = dataset_class(train=True)

# Step 1: Transform the train and test sets with the respective transformations
train_set_for_training = LightlyDataset.from_torch_dataset(train_dataset, transform=train_transform)

# Step 2: Create DataLoader for both train and test sets
train_loader_for_training = DataLoader(train_set_for_training, batch_size=1_024, shuffle=True, drop_last=True)

# Step 3: Define the encoder (ConvNet)
encoder = ConvNet(input_shape=[28, 28], output_size=128)
encoder = encoder.to(device)

# Step 4: Define the projection head (SimCLRProjectionHead)
proj_head = SimCLRProjectionHead(input_dim=128, hidden_dim=512)
proj_head = proj_head.to(device)

# Step 5: Define the loss function (NTXentLoss)
loss_fn = NTXentLoss()

# Step 6: Set up the optimizer
params = chain(encoder.parameters(), proj_head.parameters())
optimizer = SGD(params, lr=1, weight_decay=1e-4)




In [107]:
from pandas import DataFrame

def get_next(dataloader: DataLoader) -> Union[Tensor, Tuple]:
    try:
        return next(dataloader)
    except:
        dataloader = iter(dataloader)
        return next(dataloader)

def save_table(path: Union[Path, str], table: DataFrame, formatting: dict) -> None:
    for key in formatting:
        if key in table:
            table[key] = table[key].apply(formatting[key])

    table.to_csv(path, index=False)

### Train the Model

In [108]:
import pandas as pd

steps, losses = [], []
n_optim_steps = 10

for step in range(n_optim_steps):
    (inputs_0, inputs_1), _, _ = get_next(train_loader_for_training)

    embeddings_0 = proj_head(encoder(inputs_0.to(device)))
    embeddings_1 = proj_head(encoder(inputs_1.to(device)))

    loss = loss_fn(embeddings_0, embeddings_1)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    steps += [step]
    losses += [loss.item()]
    print(f"Step {step}: loss = {loss}")

Step 0: loss = 7.496316432952881
Step 1: loss = 7.665528774261475


KeyboardInterrupt: 

In [61]:
print("Saving results...")
train_log = pd.DataFrame({"step": steps, "loss": losses})
formatting = {"loss": "{:.4f}".format}
save_table("train_encoders/pretraining.csv", train_log, formatting)

torch.save(encoder.state_dict(), "train_encoders/encoder.pth")
torch.save(proj_head.state_dict(), "train_encoders/projection_head.pth")

Saving results...


### Compute Embeddings

In [116]:
import math
@torch.inference_mode()
def encode(loader: DataLoader, encoder: ConvNet, device: str) -> np.ndarray:
    embeddings = []

    for idx, (images_i, _) in enumerate(loader):
        if idx % 10 == 0:
            print(f"encoding step {idx}")
        images_i = images_i.to(device)
        embeddings_i = encoder(images_i)
        embeddings += [embeddings_i.cpu().numpy()]
    return np.concatenate(embeddings)

test_transform = Compose([ToTensor(), Normalize(mean=[0.1307], std=[0.3081])])

train_set_for_embeddings = dataset_class(train=True, transform=test_transform)
test_set_for_embeddings = dataset_class(train=False, transform=test_transform)

# Step 2: Create DataLoader for both train and test sets
train_loader_for_embeddings = DataLoader(train_set_for_embeddings, batch_size=1_024)
test_loader_for_embeddings = DataLoader(test_set_for_embeddings, batch_size=1_024)

max_array_size = int(2e7)
# GitHub has a max file size of 100MB, which translates to 2.5e7 32-bit floats.

for subset in ("train", "test"):
    print(f"\n********** On subset {subset}")
    
    # Create dataset and DataLoader for current subset
    loader = train_loader_for_embeddings if subset == "train" else test_loader_for_embeddings

    # Generate embeddings for the subset
    embeddings = encode(loader, encoder, device)

    # Check if the embeddings size exceeds the max allowed size
    if embeddings.size > max_array_size:
        # Split the embeddings into smaller parts
        n_splits = math.ceil(embeddings.size / max_array_size)

        for i, embeddings_i in enumerate(np.array_split(embeddings, n_splits, axis=0)):
            # Save each part to a separate file
            filepath = f"embeddings/simclr_{subset}_part{i + 1}of{n_splits}.npy"
            np.save(filepath, embeddings_i, allow_pickle=False)
    else:
        # Save the full embeddings if it's under the max size
        np.save(f"embeddings/simclr_{subset}.npy", embeddings, allow_pickle=False)

    # Save the labels
    saved_dataset = train_set_for_embeddings if subset == "train" else test_set_for_embeddings
    np.save(f"embeddings/labels_{subset}.npy", saved_dataset.targets, allow_pickle=False)


********** On subset train
encoding step 0
encoding step 10
encoding step 20
encoding step 30
encoding step 40
encoding step 50

********** On subset test
encoding step 0
