In [1]:
%load_ext autoreload
%autoreload 2

import torch
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from dinopl.probing import Prober, LinearAnalysis, KNNAnalysis, ToySet

from datasets import *
from models import *

import dinopl.utils as U
device = torch.device('cpu')
device = torch.device(U.pick_single_gpu())
device

device(type='cuda', index=1)

#### Toy Data Set

In [None]:
trfm = transforms.Normalize(ToySet.mean, ToySet.std)
toy = ToySet(transform=trfm, n_samples=10000)

# iterate over dataset with transform 
data = torch.Tensor([(s[0], s[1]) for s, l in toy]) # get data
lbls = torch.Tensor([l for s, l in toy]) # det labels

plt.scatter(data[:, 0], data[:, 1], c=lbls, s=1)
data.mean(dim=0), data.std(dim=0)

In [3]:
train_set = ToySet(train=True, n_samples=100)
valid_set = ToySet(train=False, n_samples=100)

# prepare dataloaders
train_dl = DataLoader(dataset=train_set, shuffle=True, batch_size=10)
valid_dl = DataLoader(dataset=valid_set, batch_size=10)

# prepare prober
prober = Prober(encoders = {'':nn.Identity()}, 
                analyses = {'lin': LinearAnalysis(n_epochs=100),
                            'knn': KNNAnalysis(k=20)},
                train_dl=train_dl,
                valid_dl=valid_dl,
                n_classes=2)

# train and validate
prober.probe()


Starting analyses ['lin', 'knn'] of ..

Training:   0%|          | 0/100 [00:00<?, ?it/s, loss=0.337]        

                                                                        

 .. took 00:00min => {'probe//lin': '0.98', 'probe//knn': '0.98'}


{'probe//lin': 0.9800000190734863, 'probe//knn': 0.9800000190734863}

#### CIFAR10

In [4]:
trfm = transforms.Compose([
            transforms.Lambda(lambda img: img.convert('RGB')), 
            transforms.ToTensor(),
            transforms.Normalize(CIFAR10.mean, CIFAR10.std),
        ])
train_set = CIFAR10(train=True, transform=trfm)
valid_set = CIFAR10(train=False, transform=trfm)
train_dl = DataLoader(dataset=train_set, batch_size=512, shuffle=True, num_workers=4, pin_memory=True)
valid_dl = DataLoader(dataset=valid_set, batch_size=512, shuffle=False, num_workers=4, pin_memory=True)


Files already downloaded and verified


Files already downloaded and verified


In [5]:
model = flatten(n_pixels=train_set.ds_pixels, n_channels=train_set.ds_channels)

prober = Prober(encoders = {'flatten':model}, 
                analyses = {'lin': LinearAnalysis(n_epochs=20),
                            'knn': KNNAnalysis(k=20)},
                train_dl=train_dl,
                valid_dl=valid_dl,
                n_classes=train_set.ds_classes)

prober.probe(device)


Starting analyses ['lin', 'knn'] of flatten..

Loading embeddings:   0%|          | 0/98 [00:00<?, ?it/s]

                                                                                

 ..flatten took 00:14min => {'probe/flatten/lin': '0.374', 'probe/flatten/knn': '0.339'}


{'probe/flatten/lin': 0.374099999666214,
 'probe/flatten/knn': 0.3386000096797943}

In [6]:
model = flatten(n_pixels=train_set.ds_pixels, n_channels=train_set.ds_channels)

prober = Prober(encoders = {'flatten':model}, 
                analyses = {'lin': LinearAnalysis(n_epochs=20),
                            'knn': KNNAnalysis(k=20)},
                train_dl=train_dl,
                valid_dl=valid_dl,
                n_classes=train_set.ds_classes)

prober.probe(torch.device('cpu'))


Starting analyses ['lin', 'knn'] of flatten..

                                                                                

 ..flatten took 01:04min => {'probe/flatten/lin': '0.389', 'probe/flatten/knn': '0.336'}




{'probe/flatten/lin': 0.3894999921321869,
 'probe/flatten/knn': 0.3361000120639801}

In [7]:
model = vgg11().to(device)

prober = Prober(encoders = {'vgg11':model}, 
                analyses = {'lin': LinearAnalysis(n_epochs=20),
                            'knn': KNNAnalysis(k=20)},
                train_dl=train_dl,
                valid_dl=valid_dl,
                n_classes=train_set.ds_classes)

prober.probe(device)


Starting analyses ['lin', 'knn'] of vgg11..

Loading embeddings:   0%|          | 0/98 [00:00<?, ?it/s]

                                                                                

 ..vgg11 took 00:10min => {'probe/vgg11/lin': '0.369', 'probe/vgg11/knn': '0.449'}


{'probe/vgg11/lin': 0.36880001425743103, 'probe/vgg11/knn': 0.4489000141620636}

In [8]:
model = vgg11().to(device)

prober = Prober(encoders = {'vgg11':model}, 
                analyses = {'lin': LinearAnalysis(n_epochs=20),
                            'knn': KNNAnalysis(k=20)},
                train_dl=train_dl,
                valid_dl=valid_dl,
                n_classes=train_set.ds_classes)

# Encode images on GPU, but analyse embeddings on CPU..
prober.probe(device_enc=device, device_emb=torch.device('cpu'))


Starting analyses ['lin', 'knn'] of vgg11..

                                                                                

 ..vgg11 took 00:16min => {'probe/vgg11/lin': '0.367', 'probe/vgg11/knn': '0.433'}




{'probe/vgg11/lin': 0.3668999969959259, 'probe/vgg11/knn': 0.43309998512268066}