In [1]:
from data import sample_data
import torch

import pytorch_lightning as pl
from models import pretrained
import torchmetrics
from data.core import split_dataset
from torch.utils.data import DataLoader, Dataset
from scipy.stats import binom_test
import numpy as np

In [2]:
p_all = sample_data.cifar10(split='test')
q_all = sample_data.cifar10_1()

In [18]:
class DomainClassifier(pl.LightningModule):
    def __init__(self, lr=1e-3):
        super().__init__()
        self.model = pretrained.resnet18_trained_on_cifar10(domain_classifier=True).model
        self.lr = lr
        self.loss = torch.nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.Accuracy()
        self.save_hyperparameters()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        self.log('val_acc', self.accuracy(y_hat, y))
        return loss

    def reset(self):
        self.accuracy.reset()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)


class DomainClassifierDataset(Dataset):
    def __init__(self, p, q):
        self.p = p
        self.q = q
        self.p_labels = torch.zeros(len(p)).long()
        self.q_labels = torch.ones(len(q)).long()
        self.x = self.p + self.q
        self.y = torch.cat([self.p_labels, self.q_labels])

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx][0], self.y[idx]

In [None]:
from tqdm import tqdm
import logging

logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

N = 10

res = []
for seed in tqdm(range(100)):
    p, _ = split_dataset(p_all, N, seed)
    p1, p2 = split_dataset(p, N // 2, seed)
    q, _ = split_dataset(q_all, N, seed)
    q1, q2 = split_dataset(q, N // 2, seed)

    d1 = DomainClassifierDataset(p1, q1)
    d2 = DomainClassifierDataset(p2, q2)

    tr = pl.Trainer(max_epochs=10, gpus=1, auto_select_gpus=True, enable_checkpointing=False,
                    enable_model_summary=False,
                    logger=False)
    model = DomainClassifier(lr=1e-3)
    tr.fit(model, DataLoader(d1, batch_size=N, shuffle=True, num_workers=4))
    n = int(N * tr.test(model, DataLoader(d2, batch_size=N, shuffle=False, num_workers=4), verbose=False)[0]['val_acc'])
    test = binom_test(n - 5, N, 0.5) <= 0.05
    res.append(test)

In [32]:
N = 10
path = f'results/ctst_cam/tests_{N=}.npy'
res = np.load(path) <= 0.05
res = res.mean(), res.std() / np.sqrt(len(res))
print(f'${res[0]:.2f} \pm {res[1]:.2f}$'.replace('0.', '.'))

$.11 \pm .03$


In [20]:
res = np.load(f'results/ctst_cam/tests_{N=}.npy') < 0.05

In [21]:
res.mean(), res.std() / np.sqrt(len(res))

(0.59, 0.04918333050943175)

In [31]:
np.load(f'results/ctst_cam/tests_N=20.npy') == np.load(f'results/ctst_cam/tests_N=50.npy')

array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False])

In [48]:
n = 98
f'{np.std([True] * n + [False] * (100 - n)) / np.sqrt(100):.2f}'

'0.01'