In [11]:
import sys
sys.path.insert(0,'../')
from libs.ssl_task import Classification
from libs.ssl_data import SSLHBNDataModule
from torchmetrics.functional import f1_score, accuracy
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [29]:
import wandb
run = wandb.init(project="hbn-regression", job_type="training")

[34m[1mwandb[0m: Currently logged in as: [33mdt-young112[0m ([33msccn[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [30]:
import os
model = "FGVsQMup" 
artifact_dir = f'artifacts/model-{model}'
version = 0
if os.path.exists(f'{artifact_dir}:v{version}'):
    print(f"Artifact directory {artifact_dir}:v{version} already exists. Skipping download.")
else:
    artifact = run.use_artifact(f'sccn/hbn-regression/model-{model}:v{version}', type='model')
    artifact_dir = artifact.download()

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [52]:
import yaml
import lightning as L
with open('../runs/config_Classification.yaml', 'r') as f:
    config = yaml.safe_load(f)

ssl_task = Classification()
config['data']['ssl_task'] = ssl_task
config['data']['cache_dir'] = "data"
config['data']['num_workers'] = 2
config['data']['mapping'] = {'F': 1, 'M': 0}

config['model']['init_args']['emb_size'] = 100
config['model']['init_args']['encoder_emb_size'] = 100
config['model']['init_args']['encoder_path'] = "braindecode.models.Deep4Net"
config['model']['init_args']['encoder_kwargs']['n_chans'] = 128
config['model']['init_args']['window_norm'] = 'channel_wise'

config['trainer']['callbacks'] = None
config['trainer']['logger'] = None

trainer = L.Trainer(**config['trainer'])
mode = 'validate'
print('Loading data module...')
litDataModule = SSLHBNDataModule(**config['data'])
litDataModule.setup(stage=mode)
val_dataloader = litDataModule.val_dataloader()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Loading data module...
Number of subjects in balanced dataset: 92
Gender distribution in balanced dataset: (array(['F', 'M'], dtype='<U1'), array([46, 46]))


In [48]:
import torch
def rankme(embeddings):
        # parse inputs
        # print('RankMe self.embs', self.embs)
        embs = embeddings
        if len(embs.shape) > 2:
            raise ValueError('Expect 2D embeddings of shape (N, K)')
        print('RankMe embs shape', embs.shape)
        if embs.shape[0] < embs.shape[1]:
            raise ValueError(f'Expect N >= K but received ({embs.shape})')
        # subselect 25600 embeddings randomly
        # embs = embs[torch.randperm(embs.shape[0])[:25600]]
        _, S, _ = torch.linalg.svd(embs)
        eps = 1e-7
        p = S/torch.linalg.norm(S, ord=1) + eps
        rank_z = torch.exp(-torch.sum(p*torch.log(p)))

        return rank_z

In [59]:
def normalize_data(x):
    center, _ = x.median(dim=-1, keepdim=True)
    variance = x.quantile(0.75, dim=-1, keepdim=True) - x.quantile(0.25, dim=-1, keepdim=True)
    x = (x - center) / variance # normalize preserving batch dim
    return x

In [69]:
import copy
config['trainer']['callbacks'] = None
config['trainer']['logger'] = None

trainer = L.Trainer(**config['trainer'])
model = Classification.ClassificationLit(**config['model']['init_args'])
scores = trainer.validate(model=model, datamodule=litDataModule)
embeddings= []
preds = []
labels = []
model = model.to('cpu')
model.eval()
final_layer = copy.deepcopy(model.encoder.final_layer)
del model.encoder.final_layer
with torch.no_grad():
    for batch in val_dataloader:
        X, Y = batch[0], batch[1]
        Y = Y.to(torch.long)
        X = model.remove_chan(X)
        X = normalize_data(X)
        Z = model.encoder(X)
        embeddings.append(Z.squeeze().cpu())

        Z = final_layer(Z)
        _, pred = Z.max(1)

        preds.append(pred.cpu())
        labels.append(Y.cpu())
        
embeddings= torch.cat(embeddings, dim=0)
preds = torch.cat(preds, dim=0)
Y = torch.cat(labels, dim=0)
print('accuracy', accuracy(preds, Y, task='binary', num_classes=2))
print('RankMe score', rankme(embeddings))
# print(calculate_rankme(embeddings_best.detach().numpy()))

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Train releases: ['ds005506', 'ds005507', 'ds005508', 'ds005509', 'ds005511', 'ds005512', 'ds005514', 'ds005515', 'ds005516']
Validation release: ds005505
Test release: ds005510


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Number of subjects in balanced dataset: 92
Gender distribution in balanced dataset: (array(['F', 'M'], dtype='<U1'), array([46, 46]))


Validation: |          | 0/? [00:00<?, ?it/s]

accuracy tensor(0.4939)
RankMe embs shape torch.Size([18787, 200])
RankMe score tensor(129.6006)


In [70]:
import copy
version = 0
artifact_path = f'artifacts/model-FGVsQMup:v{version}/model.ckpt'
print(f'Loading model from {artifact_path}...')
config['trainer']['callbacks'] = None
config['trainer']['logger'] = None

trainer = L.Trainer(**config['trainer'])
model = Classification.ClassificationLit(**config['model']['init_args'])
scores = trainer.validate(model=model, ckpt_path=artifact_path, datamodule=litDataModule)
model_best = Classification.ClassificationLit.load_from_checkpoint(artifact_path, **config['model']['init_args'])
embeddings_best = []
preds = []
labels = []
model_best = model_best.to('cpu')
model_best.eval()
final_layer = copy.deepcopy(model_best.encoder.final_layer)
del model_best.encoder.final_layer
with torch.no_grad():
    for batch in val_dataloader:
        X, Y = batch[0], batch[1]
        Y = Y.to(torch.long)
        X = model_best.remove_chan(X)
        X = normalize_data(X)
        Z = model_best.encoder(X)
        embeddings_best.append(Z.squeeze().cpu())

        Z = final_layer(Z)
        _, pred = Z.max(1)

        preds.append(pred.cpu())
        labels.append(Y.cpu())
        
embeddings_best= torch.cat(embeddings_best, dim=0)
preds = torch.cat(preds, dim=0)
print('preds shape:', preds.shape)
Y = torch.cat(labels, dim=0)
print('accuracy', accuracy(preds, Y, task='binary', num_classes=2))
print('RankMe score', rankme(embeddings_best))
# print(calculate_rankme(embeddings_best.detach().numpy()))

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Loading model from artifacts/model-FGVsQMup:v0/model.ckpt...
Train releases: ['ds005506', 'ds005507', 'ds005508', 'ds005509', 'ds005511', 'ds005512', 'ds005514', 'ds005515', 'ds005516']
Validation release: ds005505
Test release: ds005510


Restoring states from the checkpoint path at artifacts/model-FGVsQMup:v0/model.ckpt
/home/dung/eeg-ssl/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:277: Be aware that when using `ckpt_path`, callbacks used to create the checkpoint need to be provided during `Trainer` instantiation. Please add the following callbacks: ["ModelCheckpoint{'monitor': 'val_Classifier/accuracy', 'mode': 'max', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}"].
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at artifacts/model-FGVsQMup:v0/model.ckpt


Number of subjects in balanced dataset: 92
Gender distribution in balanced dataset: (array(['F', 'M'], dtype='<U1'), array([46, 46]))


Validation: |          | 0/? [00:00<?, ?it/s]

preds shape: torch.Size([18787])
accuracy tensor(0.8585)
RankMe embs shape torch.Size([18787, 200])
RankMe score tensor(12.1231)
