# MISS meets TRAK

MISS for non-linear model: linearizing as TRAK

> The target function is assumed to be the raw logit, i.e., $\phi(x_{\text{test}}) = \theta^\top x_{\text{test}} + b$.

In [1]:
from TRAK.MISS_trak import MISS_TRAK
from IF.MISS_IF import MISS_IF
from model_train import MLP, SubsetSamper, MNISTModelOutput
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from TRAK.projector import CudaProjector, ProjectionType, BasicProjector
from TRAK.grad_calculator import count_parameters, grad_calculator, out_to_loss_grad_calculator
from tqdm import tqdm

# First, check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda
Using device: cuda


In [2]:
seed=0
ensemble=5
k=5

# Load MNIST data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

sampler_train = SubsetSamper([i for i in range(5000)])
sampler_test = SubsetSamper([0])

train_loader = DataLoader(train_dataset, batch_size=1, sampler=sampler_train)
test_loader = DataLoader(test_dataset, batch_size=1, sampler=sampler_test)

checkpoint_files = [f"./checkpoint/seed_{seed}_ensemble_{i}.pt" for i in range(ensemble)]

In [3]:
model=MLP().to(device)
model_checkpoints=checkpoint_files
train_loader=train_loader
test_loader=test_loader
model_output_class=MNISTModelOutput
device=device

all_grads_p_list = []
Q_list = []

for checkpoint_id, checkpoint_file in enumerate(tqdm(model_checkpoints)):
    model.load_state_dict(torch.load(checkpoint_file))
    model.eval()

    print(model)
    print("#Parameters:", count_parameters(model))

    parameters = list(model.parameters())
    normalize_factor = torch.sqrt(torch.tensor(count_parameters(model), dtype=torch.float32))

    # projection of the grads
    # projector = CudaProjector(grad_dim=count_parameters(model), proj_dim=2048, seed=0, proj_type=ProjectionType.rademacher, device="cuda", max_batch_size=8)
    projector = BasicProjector(grad_dim=count_parameters(model), proj_dim=2048, seed=0, proj_type=ProjectionType.rademacher, device="cuda", max_batch_size=8)

    # Go through the training loader to get grads
    # Φ
    all_grads_p = grad_calculator(data_loader=train_loader, model=model, parameters=parameters, func=model_output_class.model_output, normalize_factor=normalize_factor, device=device, projector=projector, checkpoint_id=checkpoint_id)
    out_to_loss_grads = out_to_loss_grad_calculator(data_loader=train_loader, model=model, func=model_output_class.get_out_to_loss_grad)
    # ϕ
    all_grads_test_p = grad_calculator(data_loader=test_loader, model=model, parameters=parameters, func=model_output_class.model_output, normalize_factor=normalize_factor, device=device, projector=projector, checkpoint_id=checkpoint_id)

    # Append to list for later averaging
    all_grads_p_list.append(all_grads_p)
    Q_list.append(out_to_loss_grads)

  0%|          | 0/5 [00:00<?, ?it/s]

MLP(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (layers): ModuleList(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): Linear(in_features=128, out_features=10, bias=True)
  )
  (relu): ReLU()
)
#Parameters: 101770


100%|██████████| 5000/5000 [00:52<00:00, 95.35it/s]
100%|██████████| 5000/5000 [00:03<00:00, 1662.09it/s]
100%|██████████| 1/1 [00:00<00:00, 222.14it/s]
 20%|██        | 1/5 [00:55<03:42, 55.50s/it]

MLP(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (layers): ModuleList(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): Linear(in_features=128, out_features=10, bias=True)
  )
  (relu): ReLU()
)
#Parameters: 101770


100%|██████████| 5000/5000 [00:51<00:00, 96.68it/s]
100%|██████████| 5000/5000 [00:02<00:00, 1679.67it/s]
100%|██████████| 1/1 [00:00<00:00, 262.21it/s]
 40%|████      | 2/5 [01:50<02:45, 55.05s/it]

MLP(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (layers): ModuleList(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): Linear(in_features=128, out_features=10, bias=True)
  )
  (relu): ReLU()
)
#Parameters: 101770


100%|██████████| 5000/5000 [00:52<00:00, 94.91it/s]
100%|██████████| 5000/5000 [00:02<00:00, 1684.86it/s]
100%|██████████| 1/1 [00:00<00:00, 255.28it/s]
 60%|██████    | 3/5 [02:45<01:50, 55.34s/it]

MLP(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (layers): ModuleList(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): Linear(in_features=128, out_features=10, bias=True)
  )
  (relu): ReLU()
)
#Parameters: 101770


100%|██████████| 5000/5000 [00:52<00:00, 95.01it/s]
100%|██████████| 5000/5000 [00:02<00:00, 1705.85it/s]
100%|██████████| 1/1 [00:00<00:00, 245.17it/s]
 80%|████████  | 4/5 [03:41<00:55, 55.44s/it]

MLP(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (layers): ModuleList(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): Linear(in_features=128, out_features=10, bias=True)
  )
  (relu): ReLU()
)
#Parameters: 101770


100%|██████████| 5000/5000 [00:52<00:00, 94.83it/s]
100%|██████████| 5000/5000 [00:02<00:00, 1691.80it/s]
100%|██████████| 1/1 [00:00<00:00, 252.21it/s]
100%|██████████| 5/5 [04:37<00:00, 55.45s/it]


In [46]:
# Convert lists to tensors
all_grads_p_tensor = torch.stack(all_grads_p_list)
Q_tensor = torch.stack(Q_list)

# Initialize MIS tensor
num_test_samples = all_grads_test_p.size(0)
MIS = torch.zeros(num_test_samples, k, dtype=torch.int32)

print(all_grads_p_tensor.shape)
# Iterate over each test sample
for j in range(num_test_samples):
    index = [i for i in range(all_grads_p_tensor.size(1))]
    for i in range(k):
        avg_Q = torch.mean(Q_tensor, dim=0)
        avg_all_grads_p = torch.mean(all_grads_p_tensor, dim=0)
        score = all_grads_test_p[j] @ torch.linalg.inv(avg_all_grads_p.T @ avg_all_grads_p) @ avg_all_grads_p.T @ avg_Q
        # Select the most influential sample
        i_max = score.cpu().detach().numpy().flatten().argsort()[-1]
        MIS[j, i] = index[i_max]

        Q_tensor = torch.cat([torch.cat([Q_tensor[:, :i, :i], Q_tensor[:, :i, i+1:]], dim=2),torch.cat([Q_tensor[:, i+1:, :i], Q_tensor[:, i+1:, i+1:]], dim=2)],dim=1)
        all_grads_p_tensor = torch.cat([all_grads_p_tensor[:, :i_max, :], all_grads_p_tensor[:, i_max+1:, :]], dim=1)
        print(Q_tensor.shape)
        index = index[:i_max] + index[i_max + 1:]

torch.Size([5, 5000, 2048])
torch.Size([5, 4999, 4999])
torch.Size([5, 4998, 4998])
torch.Size([5, 4997, 4997])
torch.Size([5, 4996, 4996])
torch.Size([5, 4995, 4995])
