In [1]:
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import time

import matplotlib.pyplot as plt
import mmcv
import numpy as np
import torch


from mmcv import Config, DictAction
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from sklearn.manifold import TSNE

import sys
sys.path.append('..')
from mmselfsup.apis import set_random_seed
from mmselfsup.datasets import build_dataloader, build_dataset
from mmselfsup.models import build_algorithm
from mmselfsup.models.utils import ExtractProcess
from mmselfsup.utils import get_root_logger

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cfg = Config.fromfile('/gpfs/data/geraslab/Nan/mmselfsup/configs/benchmarks/classification/nyubreast/us_tsne.py')

# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
    torch.backends.cudnn.benchmark = True

distributed = False

# create work_dir and init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
tsne_work_dir = osp.join(cfg.work_dir, f'tsne_{timestamp}/')
mmcv.mkdir_or_exist(osp.abspath(tsne_work_dir))
log_file = osp.join(tsne_work_dir, 'extract.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

In [3]:
# build the dataloader
dataset = build_dataset(cfg.data.extract)
# compress dataset, select that the label is less then max_num_class
tmp_infos = []
for i in range(len(dataset)):
    tmp_infos.append(dataset.data_source.data_infos[i]['noisy_token_indicies'])
# dataset.data_source.data_infos = tmp_infos
logger.info(f'Apply t-SNE to visualize {len(dataset)} samples.')
    
data_loader = build_dataloader(
    dataset,
    samples_per_gpu=cfg.data.samples_per_gpu,
    workers_per_gpu=cfg.data.workers_per_gpu,
    dist=distributed,
    breast=True,
    shuffle=False)


2022-04-26 18:05:57,605 - mmselfsup - INFO - Apply t-SNE to visualize 269 samples.


In [4]:
# for i, data in enumerate(data_loader):
#     print(data['img'].shape, data['us_counts'])

In [5]:
# build the model
model = build_algorithm(cfg.model)
model.init_weights()

if not distributed:
    model = MMDataParallel(model, device_ids=[0])
else:
    model = MMDistributedDataParallel(
        model.cuda(),
        device_ids=[torch.cuda.current_device()],
        broadcast_buffers=False)

2022-04-26 18:05:57,898 - mmselfsup - INFO - initialize ResNet with init_cfg {'type': 'Pretrained', 'checkpoint': '/gpfs/data/geraslab/Nan/saves/selfsup/swav_breast/data_20220111_full/swav_resnet18_avgpool_coslr-100e_largebatch_skynet-gpu32/20220411_ffdm_latest.pth'}
2022-04-26 18:05:57,899 - mmcv - INFO - load model from: /gpfs/data/geraslab/Nan/saves/selfsup/swav_breast/data_20220111_full/swav_resnet18_avgpool_coslr-100e_largebatch_skynet-gpu32/20220411_ffdm_latest.pth
2022-04-26 18:05:57,901 - mmcv - INFO - load checkpoint from local path: /gpfs/data/geraslab/Nan/saves/selfsup/swav_breast/data_20220111_full/swav_resnet18_avgpool_coslr-100e_largebatch_skynet-gpu32/20220411_ffdm_latest.pth
2022-04-26 18:05:57,995 - mmselfsup - INFO - initialize ClsHead with init_cfg [{'type': 'Normal', 'std': 0.01, 'layer': 'Linear'}, {'type': 'Constant', 'val': 1, 'layer': ['_BatchNorm', 'GroupNorm']}]


In [6]:
# build extraction processor and run
extractor = ExtractProcess()
features = extractor.extract(model, data_loader, distributed=distributed)
# labels = dataset.data_source.get_gt_labels()

[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 68/68, 1.2 task/s, elapsed: 58s, ETA:     0s

In [None]:
# save features
mmcv.mkdir_or_exist(f'{tsne_work_dir}features/')
logger.info(f'Save features to {tsne_work_dir}features/')
for key, val in features.items():
    output_file = \
        f'{tsne_work_dir}features/{dataset_cfg.name}_{key}.npy'
    np.save(output_file, val)

# build t-SNE model
tsne_model = TSNE(
    n_components=2,
    perplexity=20,
    early_exaggeration=12,
    learning_rate=200,
    n_iter=1000,
    n_iter_without_progress=300,
    init='random')

# run and get results
mmcv.mkdir_or_exist(f'{tsne_work_dir}saved_pictures/')
logger.info('Running t-SNE......')
for key, val in features.items():
    result = tsne_model.fit_transform(val)
    res_min, res_max = result.min(0), result.max(0)
    res_norm = (result - res_min) / (res_max - res_min)
    plt.figure(figsize=(10, 10))
    plt.scatter(
        res_norm[:, 0],
        res_norm[:, 1],
        alpha=1.0,
        s=15,
        c=labels,
        cmap='tab20')
    plt.savefig(f'{tsne_work_dir}saved_pictures/{key}.png')