# MISS meets TRAK

MISS for non-linear model: linearizing as TRAK

> The target function is assumed to be the raw logit, i.e., $\phi(x_{\text{test}}) = \theta^\top x_{\text{test}} + b$.

In [2]:
from TRAK.MISS_trak import MISS_TRAK
from IF.MISS_IF import MISS_IF
from model_train import MLP, SubsetSamper, MNISTModelOutput
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from TRAK.projector import CudaProjector, ProjectionType, BasicProjector
from TRAK.grad_calculator import count_parameters, grad_calculator, out_to_loss_grad_calculator
from tqdm import tqdm

# First, check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda
Using device: cuda


In [7]:
seed=0
ensemble_size=5
k=5

In [3]:
# Load MNIST data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

sampler_train = SubsetSamper([i for i in range(5000)])
sampler_test = SubsetSamper([i for i in range(3)])

train_loader = DataLoader(train_dataset, batch_size=1, sampler=sampler_train)
test_loader = DataLoader(test_dataset, batch_size=1, sampler=sampler_test)

checkpoint_files = [f"./checkpoint/seed_{0}_ensemble_{i}.pt" for i in range(ensemble_size)]

from TRAK.MISS_trak import MISS_TRAK

trak = MISS_TRAK(model=MLP().to(device),
                    model_checkpoints=checkpoint_files,
                    train_loader=train_loader,
                    test_loader=test_loader,
                    model_output_class=MNISTModelOutput,
                    device=device)


In [4]:
adaptive_MIS=trak.adaptive_most_k(3)
MIS=trak.most_k(3)

  0%|          | 0/1 [00:00<?, ?it/s]

MLP(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (layers): ModuleList(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): Linear(in_features=128, out_features=10, bias=True)
  )
  (relu): ReLU()
)
#Parameters: 101770


100%|██████████| 5000/5000 [02:04<00:00, 40.23it/s]
100%|██████████| 5000/5000 [00:16<00:00, 305.33it/s]
100%|██████████| 3/3 [00:00<00:00, 39.66it/s]
100%|██████████| 1/1 [02:20<00:00, 140.81s/it]


tensor([1088], device='cuda:0')
tensor([3717], device='cuda:0')
tensor([1460], device='cuda:0')
tensor([2184], device='cuda:0')
tensor([1412], device='cuda:0')
tensor([444], device='cuda:0')
tensor([1412], device='cuda:0')
tensor([370], device='cuda:0')
tensor([1918], device='cuda:0')


  0%|          | 0/1 [00:00<?, ?it/s]

MLP(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (layers): ModuleList(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): Linear(in_features=128, out_features=10, bias=True)
  )
  (relu): ReLU()
)
#Parameters: 101770


100%|██████████| 5000/5000 [02:19<00:00, 35.72it/s]
100%|██████████| 5000/5000 [00:16<00:00, 309.77it/s]
100%|██████████| 3/3 [00:00<00:00, 40.77it/s]
100%|██████████| 1/1 [02:36<00:00, 156.24s/it]

torch.Size([3, 5000])
tensor([1088, 3718, 1461], device='cuda:0')
tensor([2184, 1412,  444], device='cuda:0')
tensor([1412,  370, 1920], device='cuda:0')





In [8]:

# for k in [1, 2, 5, 10]:
TRAK_result = torch.load(f"./TRAK/results/seed_{seed}_k_{k}_ensemble_{ensemble_size}.pt")
adaptive_TRAK_result = torch.load(f"./TRAK/results/seed_{seed}_k_{k}_ensemble_{ensemble_size}_adaptive.pt")
print(TRAK_result[0])
print(adaptive_TRAK_result[0])

tensor([2676, 1604, 1744, 2980,  991], dtype=torch.int32)
tensor([2980, 1604, 3692, 2676,  966], dtype=torch.int32)
