In [1]:
from sklearn.manifold import TSNE
import plotly.express as px
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
import pl_bolts
import os

from torch.utils.data import Dataset, DataLoader, random_split
from lightly.data import LightlyDataset
from glob import glob
from PIL import Image
from sklearn.manifold import TSNE

from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform
from torchvision.transforms import Compose, ToTensor, Normalize
from lightly.transforms.mae_transform import MAETransform
from src.models.mae.model import MAE

import plotly.io as pio
pio.renderers.default = "iframe"

  warn(
  "lr_options": generate_power_seq(LEARNING_RATE_CIFAR, 11),
  contrastive_task: Union[FeatureMapContrastiveTask] = FeatureMapContrastiveTask("01, 02, 11"),
  self.nce_loss = AmdimNCELoss(tclip)
  warn_missing_pkg("gym")


In [2]:
transforms_dict = {
    'simclr': {
        'train': SimCLRTrainDataTransform(),
        'val': SimCLREvalDataTransform(),
        'test': SimCLREvalDataTransform()
    },
    'mae': {
        'train': MAETransform(),
        'val': Compose([ToTensor(), Normalize(0.5, 0.5)]),
        'test': Compose([ToTensor(), Normalize(0.5, 0.5)])
    },
    'byol': {
        'train': SimCLRTrainDataTransform(),
        'val': SimCLREvalDataTransform(),
        'test': SimCLREvalDataTransform()
    },
    'downstream_linear': {
        'train': SimCLREvalDataTransform(),
        'val': SimCLREvalDataTransform(),
        'test': SimCLREvalDataTransform()
    }
}


def transform_factory(model_name, mode):
    try:
        return transforms_dict.get(model_name).get(mode)
    except KeyError:
        raise NotImplementedError(f'{model_name} {mode} transform not implemented')


In [3]:
class ICDARDataset(Dataset):

    def __init__(self, csv_filepath, root_dir, transforms=None, convert_rgb=True, mask_generator=None):
        self.root_dir = root_dir
        self.transforms = transforms
        self.data = pd.read_csv(csv_filepath, sep=';')
        self.convert_rgb = convert_rgb
        self.mask_generator = mask_generator

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path = os.path.join(self.root_dir, self.data.loc[idx, 'FILENAME'])
        
        try:
            image = Image.open(img_path)
        except Exception as ex:
            return None

        if self.convert_rgb:
            image = image.convert('RGB')
        
        if self.transforms is not None:
            image = self.transforms(image)
        
        return image, self.data.loc[idx, 'SCRIPT_TYPE']


In [4]:
def data_factory(dataset_name, root_dir, label_filepath, transforms, mode, batch_size, collate_fn=None, num_cpus=None):

    if dataset_name.lower() == 'icdar':
        dataset = ICDARDataset(label_filepath, root_dir, transforms=transforms, convert_rgb=True)
    elif dataset_name.lower() == 'icdar_lightly':
        dataset = LightlyDataset(input_dir=root_dir, transform=transforms, filenames=glob(root_dir + '/*.tif'))
    else:
        raise NotImplementedError(f'Dataset {dataset_name} is not implemented')

    total_count = len(dataset)
    train_count = int(0.7 * total_count)
    val_count = int(0.1 * total_count)
    test_count = total_count - train_count - val_count

    train_dataset, val_dataset, test_dataset = random_split(
        dataset,
        (train_count, val_count, test_count)
    )

    if mode in 'train':
        return {
            'train': DataLoader(
                train_dataset,
                batch_size=batch_size,
                shuffle=True,
                drop_last=True,
                pin_memory=True,
                num_workers=num_cpus or os.cpu_count(),
                collate_fn=collate_fn if collate_fn else None
            ),
            'val': DataLoader(
                val_dataset,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
                pin_memory=True,
                num_workers=num_cpus or os.cpu_count(),
                collate_fn=collate_fn if collate_fn else None
            )
        }
    elif mode == 'test':
        return {
            'test': DataLoader(
                test_dataset,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
                pin_memory=True,
                num_workers=os.cpu_count(),
                collate_fn=collate_fn if collate_fn else None
            )
        }
    else:
        raise KeyError(f'Unknown mode: {mode}')


In [5]:
def plot_features(model, data_loader, num_feats, batch_size, num_samples, perplexity, is_3d=False):
    num_samples = len(data_loader) if not num_samples else num_samples
    feats = np.array([]).reshape((0, num_feats))
    labels = np.array([])
    model.encoder.eval()
    model.encoder.cuda()

    processed_samples = 0
    with torch.no_grad():
        for (x1, x2, _), label in data_loader:
            if processed_samples >= num_samples:
                break
            x1 = x1.squeeze().cuda()
            out = model.encoder(x1)[-1]
            out = out.cpu().data.numpy()
            feats = np.append(feats, out, axis=0)
            labels = np.append(labels, label, axis=0)
            processed_samples += batch_size

    tsne = TSNE(n_components=3 if is_3d else 2, perplexity=perplexity, n_iter=3000, init='pca')
    x_feats = tsne.fit_transform(feats)

    return x_feats, labels


In [6]:
checkpoints = {
    'simclr': '/home/woody/iwfa/iwfa028h/dev/faps/data/trained_models/SimCLR/lightning_logs/version_605090/checkpoints/epoch=370-step=5565.ckpt',
    'byol': '/home/woody/iwfa/iwfa028h/dev/faps/data/trained_models/BYOL/lightning_logs/version_616672/checkpoints/epoch=497-step=3984.ckpt'
}

MODEL_NAME = 'byol'
MODE = 'test'
BATCH_SIZE = 64
MAX_SAMPLES = 500

# ssm = pl_bolts.models.self_supervised.SimCLR.load_from_checkpoint(checkpoints.get(MODEL_NAME))
ssm = pl_bolts.models.self_supervised.BYOL.load_from_checkpoint(checkpoints.get(MODEL_NAME)).online_network

transforms = transform_factory(MODEL_NAME, MODE)

data = data_factory(
    'icdar',
    '/home/woody/iwfa/iwfa028h/dev/faps/data/ICDAR2017_CLaMM_Training',
    '/home/woody/iwfa/iwfa028h/dev/faps/data/ICDAR2017_CLaMM_Training/@ICDAR2017_CLaMM_Training.csv',
    transforms,
    'train',
    BATCH_SIZE,
    collate_fn=None,
    num_cpus=8
)

perplexities = [20, 30, 40, 50]
# perplexities = [30]
is_3d = True

all_feats = []
all_labels = []

for p in perplexities:
    print('Perplexity: ', p)
    feats, labels = plot_features(
        ssm,
        data.get('train'),
        2048,
        BATCH_SIZE,
        MAX_SAMPLES,
        p,
        is_3d
    )
    
    feats = np.hstack((feats, np.full((feats.shape[0], 1), p)))
    all_feats.append(feats)
    all_labels.append(labels)


dim_red_df = pd.DataFrame(np.concatenate(all_feats), columns=list(range(3 if is_3d else 2)) + ['Perplexity'])
dim_red_df['labels'] = pd.Categorical(np.concatenate(all_labels).tolist())
# dim_red_df = dim_red_df[dim_red_df.labels.isin(np.random.choice(dim_red_df.labels, 5))]

if is_3d:  

    fig = px.scatter_3d(
        dim_red_df,
        x=0, y=1, z=2,
        color='labels',
        symbol='labels',
        size_max=20
    )
else:
    
    fig = px.scatter(
        dim_red_df,
        x=0, y=1,
        facet_col='Perplexity',
        facet_col_wrap=2,
        color='labels',
        color_discrete_sequence=px.colors.qualitative.Light24,
        symbol='labels',
        size_max=12,
        width=1000,
        height=500*(len(perplexities) // 2)
    )
fig.update_xaxes(matches=None)
fig.update_yaxes(matches=None)
fig.show()


















This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.



Perplexity:  20
Perplexity:  30



This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.



Perplexity:  40



This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.



Perplexity:  50



This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.

