In [85]:
import collections

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
import numpy as np

import curiosidade


%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Torch setup

In [57]:
loss_fn = torch.nn.CrossEntropyLoss()

rnd_state = np.random.RandomState(8)

X = torch.from_numpy(rnd_state.randn(150, 2)).float()
y = X.mean(axis=1).round().abs().long()

probing_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(X, y),
    batch_size=10,
    shuffle=True,
)

num_cls = y.unique().numel()

In [39]:
X.shape, y.shape

(torch.Size([150, 3]), torch.Size([150]))

In [67]:
class BaseModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.params = torch.nn.Sequential(
            collections.OrderedDict((
                ("lin1", torch.nn.Linear(2, 25, bias=True)),
                ("relu1", torch.nn.ReLU(inplace=True)),
                ("lin2", torch.nn.Linear(25, 25, bias=True)),
                ("relu2", torch.nn.ReLU(inplace=True)),
                ("lin3", torch.nn.Linear(25, num_cls)),
            )),
        )
    
    def forward(self, X):
        return self.params(X)
    
    
class ProbingModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.params = torch.nn.Sequential(
            torch.nn.Linear(25, num_cls, bias=True),
        )
    
    def forward(self, X):
        return self.params(X)

In [71]:
base_model = BaseModel()
optim_base = torch.optim.Adam(base_model.parameters(), lr=0.01)

beta = 0.9
mov_avg_loss = 0.0

base_model = base_model.train()
for _ in range(50):
    for X_batch, y_batch in probing_dataloader:
        optim_base.zero_grad()
        y_logits = base_model(X_batch)
        loss = loss_fn(input=y_logits, target=y_batch)
        loss.backward()
        optim_base.step()
        mov_avg_loss = beta * mov_avg_loss + (1.0 - beta) * float(loss.cpu().detach().item())
    
    print(f"{mov_avg_loss=:.3f}")

probing_model = ProbingModel()

mov_avg_loss=0.655
mov_avg_loss=0.695
mov_avg_loss=0.508
mov_avg_loss=0.365
mov_avg_loss=0.306
mov_avg_loss=0.222
mov_avg_loss=0.164
mov_avg_loss=0.130
mov_avg_loss=0.101
mov_avg_loss=0.101
mov_avg_loss=0.071
mov_avg_loss=0.063
mov_avg_loss=0.054
mov_avg_loss=0.051
mov_avg_loss=0.036
mov_avg_loss=0.036
mov_avg_loss=0.037
mov_avg_loss=0.031
mov_avg_loss=0.032
mov_avg_loss=0.032
mov_avg_loss=0.033
mov_avg_loss=0.020
mov_avg_loss=0.023
mov_avg_loss=0.023
mov_avg_loss=0.019
mov_avg_loss=0.026
mov_avg_loss=0.051
mov_avg_loss=0.069
mov_avg_loss=0.098
mov_avg_loss=0.038
mov_avg_loss=0.017
mov_avg_loss=0.013
mov_avg_loss=0.012
mov_avg_loss=0.008
mov_avg_loss=0.007
mov_avg_loss=0.007
mov_avg_loss=0.006
mov_avg_loss=0.007
mov_avg_loss=0.007
mov_avg_loss=0.005
mov_avg_loss=0.005
mov_avg_loss=0.004
mov_avg_loss=0.005
mov_avg_loss=0.005
mov_avg_loss=0.005
mov_avg_loss=0.005
mov_avg_loss=0.004
mov_avg_loss=0.004
mov_avg_loss=0.004
mov_avg_loss=0.004


In [5]:
import transformers

transformer_model = transformers.BertForTokenClassification.from_pretrained(
    "../../segmentador/pretrained_segmenter_model/2_6000_layer_model/"
)

## Curiosidade setup

In [79]:
import functools

task = curiosidade.TaskCustom(
    probing_dataloader=probing_dataloader,
    loss_fn=loss_fn,
    task_name="debug_task",
)

prober = curiosidade.core.attach_probers(
    base_model=base_model,
    probing_model=probing_model,
    task=task,
    optim_fn=functools.partial(torch.optim.Adam, lr=0.001),
    layers_to_attach=["params.relu1", "params.relu2"],
)

prober

ProberPack:
(a): Base model     : BaseModel(
  (params): Sequential(
    (lin1): Linear(in_features=2, out_features=25, bias=True)
    (relu1): ReLU(inplace=True)
    (lin2): Linear(in_features=25, out_features=25, bias=True)
    (relu2): ReLU(inplace=True)
    (lin3): Linear(in_features=25, out_features=3, bias=True)
  )
)
(b): Task           : debug_task
(c): Probed modules :
  (0): params.relu1
  (1): params.relu2

In [83]:
probing_res = prober.train(num_epochs=10)

In [87]:
probing_res

{0: defaultdict(list,
             {'params.relu1': [0.1525202989578247,
               0.12062942981719971,
               0.1292928010225296,
               0.18931007385253906,
               0.16209308803081512,
               0.15137222409248352,
               0.26500827074050903,
               0.1470971256494522,
               0.14856980741024017,
               0.23882833123207092,
               0.23954267799854279,
               0.236884206533432,
               0.34331849217414856,
               0.19775231182575226,
               0.20574763417243958],
              'params.relu2': [0.01922135427594185,
               0.018652724102139473,
               0.014263021759688854,
               0.0274181067943573,
               0.023412588983774185,
               0.01422050315886736,
               0.044416021555662155,
               0.008822912350296974,
               0.005839099176228046,
               0.03657924383878708,
               0.057273704558610916,
        

In [86]:
pd.DataFrame.from_dict(probing_res)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
params.relu1,"[0.1525202989578247, 0.12062942981719971, 0.12...","[0.15209797024726868, 0.20792396366596222, 0.3...","[0.3009136915206909, 0.1465167999267578, 0.176...","[0.14140690863132477, 0.17509359121322632, 0.2...","[0.19007666409015656, 0.25576549768447876, 0.1...","[0.27120277285575867, 0.1974187195301056, 0.18...","[0.1547449827194214, 0.21993950009346008, 0.16...","[0.3226400315761566, 0.14968356490135193, 0.27...","[0.1409863531589508, 0.23320817947387695, 0.18...","[0.24033460021018982, 0.1607469767332077, 0.22..."
params.relu2,"[0.01922135427594185, 0.018652724102139473, 0....","[0.010111590847373009, 0.028941934928297997, 0...","[0.04917234927415848, 0.026193806901574135, 0....","[0.013638722710311413, 0.02252885140478611, 0....","[0.015215988270938396, 0.036623138934373856, 0...","[0.04773978143930435, 0.01920115388929844, 0.0...","[0.011275229044258595, 0.05170627683401108, 0....","[0.06579801440238953, 0.014671983197331429, 0....","[0.008323609828948975, 0.0599777027964592, 0.0...","[0.028094002977013588, 0.018543189391493797, 0..."


In [None]:
fig, ax = plt.subplots(1, figsize=(10, 10))
sns.