Ensure latest version of package is installed

In [1]:
%pip install sas-pip/

Processing ./sas-pip
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: sas
  Building wheel for sas (setup.py) ... [?25ldone
[?25h  Created wheel for sas: filename=sas-1.0-py3-none-any.whl size=6289 sha256=6e8f8d3141702ae426b4a9635e99beaa3da3ecf8c32cedb7dcc76cad8522aca4
  Stored in directory: /home/sjoshi/.cache/pip/wheels/4e/07/53/a089817b38c15451794418a74eb8812ee557a2982d04e9d60a
Successfully built sas
Installing collected packages: sas
  Attempting uninstall: sas
    Found existing installation: sas 1.0
    Uninstalling sas-1.0:
      Successfully uninstalled sas-1.0
Successfully installed sas-1.0
Note: you may need to restart the kernel to use updated packages.


Load Data

In [2]:
import torchvision
from torchvision import transforms
import os
root = os.path.expanduser("~/.cache")
cifar10 = torchvision.datasets.CIFAR10(root, transform=transforms.ToTensor())
device = "cuda"

Partition into approximate latent classes

In [4]:
from sas.approx_latent_classes import clip_approx
from sas.subset_dataset import SASSubsetDataset
import random 

rand_labeled_examples_indices = random.sample(range(len(cifar10)), 500)
rand_labeled_examples_labels = [cifar10[i][1] for i in rand_labeled_examples_indices]

partition = clip_approx(
    img_trainset=cifar10,
    labeled_example_indices=rand_labeled_examples_indices, 
    labeled_examples_labels=rand_labeled_examples_labels,
    num_classes=10,
    device=device
)

Load proxy model

In [5]:
from torch import nn 

class ProxyModel(nn.Module):
    def __init__(self, net, critic):
        super().__init__()
        self.net = net
        self.critic = critic
    def forward(self, x):
        return self.critic.project(self.net(x))

Determine subset

In [7]:
import torch 

net = torch.load("2023-12-0111:47:53.610549-cifar100-resnet18-19-net.pt")
critic = torch.load("2023-12-0111:47:53.610549-cifar100-resnet18-19-critic.pt")
proxy_model = ProxyModel(net, critic)
     
subset_dataset = SASSubsetDataset(
    dataset=cifar10,
    subset_fraction=0.2,
    num_downstream_classes=100,
    device=device,
    proxy_model=proxy_model,
    approx_latent_class_partition=partition,
    verbose=True
)

Subset Selection:: 100%|██████████| 10/10 [00:08<00:00,  1.22it/s]

Subset Size: 10000
Discarded 40000 examples





Save subset to file

In [8]:
subset_dataset.save_to_file("cifar10-0.2-sas-indices.pkl")