In [None]:
import os
import numpy as np
import pytorch_lightning as pl
import torch.nn as nn

from contrastive_fs.data import LabeledDataset
from torch.utils.data import DataLoader
from contrastive_fs.models import CFS_SG

In [None]:
batch_size = 128
background = np.load(os.path.join("data", "Grassy_MNIST", "background.npy")).astype(np.float32)
target = np.load(os.path.join("data", "Grassy_MNIST", "target.npy")).astype(np.float32)

In [None]:
# Label here determines target versus background
labels_train = np.concatenate([np.zeros(background.shape[0]), np.ones(target.shape[0])])
data_train = np.concatenate([background, target])

In [None]:
input_size = data_train.shape[1]
output_size = data_train.shape[1]
dataset = LabeledDataset(data_train, labels_train)

model = CFS_SG(
    input_size=input_size,
    output_size=output_size,
    hidden=[512, 512], # Number of units in each hidden layer
    k_prime=20, # Background dimension size
    lam=0.175, # Tuned to select about 10 features
    lr=1e-3,
    loss_fn=nn.MSELoss()
)

loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

trainer = pl.Trainer(max_epochs=10, accelerator='gpu', devices=1)
trainer.fit(model, loader)

In [None]:
indices = model.get_inds(10) # Get indices with top 10 most strongly selected features

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 10))

blank_image = np.zeros(784)
blank_image[indices] = 1

axes[0].imshow(blank_image.reshape(28, 28))
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[0].set_title("CFS Features", fontsize=16)

axes[1].imshow(target[0].reshape(28, 28), cmap='gray')
axes[1].set_title("Example target image", fontsize=16)
axes[1].set_xticks([])
axes[1].set_yticks([])

axes[2].imshow(background[0].reshape(28, 28), cmap='gray')
axes[2].set_title("Example background image", fontsize=16)
axes[2].set_xticks([])
axes[2].set_yticks([])