# Evaluating downstream model against pre-trained Self-Supervised models

In [15]:
from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform
from torch.utils.data import Dataset, DataLoader, random_split
from pl_bolts.models.self_supervised import SimCLR, BYOL
import torchvision.transforms as T
from lightly.models import utils
from lightly.models.modules import masked_autoencoder
from sklearn.manifold import TSNE
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
import plotly.express as px
import pytorch_lightning as pl
from time import time
from PIL import Image
import torchmetrics
import torchvision
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.models.mae.model import MAE
import os


In [16]:
def get_best_checkpoint(selected_model, choice=None):
    logs_dir = os.path.join(
        '/home/woody/iwfa/iwfa028h/dev/faps/data/trained_models/',
        selected_model,
        'lightning_logs'
    )

    best_version = max(
        map(
            lambda x: int(x.replace('version_', '')) if 'version' in x else 0,
            os.listdir(logs_dir)
        )
    )

    version_dir = os.path.join(logs_dir, f'version_{best_version if not choice else choice}', 'checkpoints')
    best_checkpoint = os.path.join(version_dir, os.listdir(version_dir)[0])
    print('LATEST CHECKPOINT', best_checkpoint)

    return best_checkpoint

In [17]:
class ICDARDataset(Dataset):

    def __init__(self, csv_filepath, root_dir, transforms=None, convert_rgb=True):
        
        self.transforms = transforms
        self.convert_rgb = convert_rgb
        
        df = pd.read_csv(csv_filepath, sep=';')
        df['img_path'] = root_dir + os.sep + df.FILENAME
        self.data = df.loc[
            (df.img_path.map(os.path.exists)) &
            (df.img_path.str.contains(''))
        ].reset_index(drop=True)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):

        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path = self.data.loc[idx, 'img_path']
        
        try:
            image = Image.open(img_path)
        except Exception as ex:
            return None

        if self.convert_rgb:
            image = image.convert('RGB')
        
        if self.transforms is not None:
            image = self.transforms(image)
        
        return image, self.data.loc[idx, 'SCRIPT_TYPE']


In [18]:
def load_best_checkpoint(model_name, choice=None, **model_kwargs):
    
    checkpoint = get_best_checkpoint(model_name, choice)
    print(model_kwargs)
    
    if model_name == 'SimCLR':
        model = SimCLR.load_from_checkpoint(checkpoint, strict=False, **model_kwargs)
        return model.encoder
    elif model_name == 'BYOL':
        model = BYOL.load_from_checkpoint(checkpoint, strict=False, **model_kwargs)
    elif model_name == 'MAE':
        model = MAE.load_from_checkpoint(checkpoint, strict=False, **model_kwargs)
        return model
    elif model_name in ['SimCLRDownstream', 'MAEDownstream', 'BYOLDownstream', 'DownstreamClassifier']:
        model = DownstreamClassifier.load_from_checkpoint(checkpoint, strict=False, **model_kwargs)
        return model
    else:
        model = SimCLR.load_from_checkpoint(checkpoint, strict=False)
        return embeddings_model

In [19]:
def data_factory(dataset_name, root_dir, label_filepath, transforms, mode, batch_size, collate_fn=None, num_cpus=None):

    if dataset_name.lower() == 'icdar':
        dataset = ICDARDataset(label_filepath, root_dir, transforms=transforms(), convert_rgb=True)
    else:
        raise NotImplementedError(f'Dataset {dataset_name} is not implemented')

    total_count = len(dataset)
    train_count = int(0.7 * total_count)
    val_count = int(0.1 * total_count)
    test_count = total_count - train_count - val_count

    train_dataset, val_dataset, test_dataset = random_split(
        dataset,
        (train_count, val_count, test_count),
        generator=torch.Generator().manual_seed(42)
    )

    if mode in 'train':
        return {
            'train': DataLoader(
                train_dataset,
                batch_size=batch_size,
                shuffle=True,
                drop_last=True,
                pin_memory=True,
#                 persistent_workers=False,
                num_workers=num_cpus or os.cpu_count(),
                collate_fn=collate_fn() if collate_fn else None
            ),
            'val': DataLoader(
                val_dataset,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
                pin_memory=True,
#                 persistent_workers=False,
                num_workers=num_cpus or os.cpu_count(),
                collate_fn=collate_fn() if collate_fn else None
            )
        }
    elif mode == 'test':
        return {
            'test': DataLoader(
                test_dataset,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
                pin_memory=True,
#                 persistent_workers=False,
                num_workers=num_cpus or os.cpu_count(),
                collate_fn=collate_fn() if collate_fn else None
            )
        }
    else:
        raise KeyError(f'Unknown mode: {mode}')


In [20]:
# root_dir = os.path.join(
#     '/home/woody/iwfa/iwfa028h/dev/faps', 'data', 'ICDAR2017_CLaMM_Training'
# )

# train_dataloaders = data_factory(
#     dataset_name='icdar',
#     root_dir=root_dir,
#     label_filepath=os.path.join(root_dir, '@ICDAR2017_CLaMM_Training.csv'),
#     transforms=SimCLREvalDataTransform,
#     mode='train',
#     batch_size=256,
#     num_cpus=4
# )

# test_dataloaders = data_factory(
#     dataset_name='icdar',
#     root_dir=root_dir, 
#     label_filepath=os.path.join(root_dir, '@ICDAR2017_CLaMM_Training.csv'),
#     transforms=SimCLREvalDataTransform,
#     mode='test',
#     batch_size=256,
#     num_cpus=4
# )

In [21]:
def plot_features(model, data_loader, num_feats, batch_size, num_samples, perplexity=25):
    num_samples = len(data_loader) if not num_samples else num_samples
    feats = np.array([]).reshape((0, num_feats))
    labels = np.array([])
    model.eval()
    model.cuda()

    processed_samples = 0
    with torch.no_grad():
        for (x1, x2, _), label in data_loader:
            if processed_samples >= num_samples:
                break
            x1 = x1.squeeze().cuda()
            out = model(x1)
            out = out[-1].detach().cpu().numpy()
            print(out.shape)
            feats = np.append(feats, out, axis=0)
            labels = np.append(labels, label, axis=0)
            processed_samples += batch_size

    tsne = TSNE(n_components=3, perplexity=perplexity, init='pca')
    x_feats = tsne.fit_transform(feats)

    dim_red_df = pd.DataFrame(x_feats)
    dim_red_df['labels'] = pd.Categorical(labels)
    fig = px.scatter_3d(dim_red_df, x=0, y=1, z=2, color='labels', size_max=5)
    fig.show()


In [22]:
def generate_from_embeddings(model, dataloader):
    
    X = []
    y = []
    
    for images, labels in dataloader:
        x1, x2, _ = images
        x1 = x1.to('cuda')
        with torch.no_grad():
            embeddings = model(x1)[-1].detach().cpu().numpy()
        X.append(embeddings)
        y.append(labels.numpy())
        
    X = np.concatenate(X)
    y = np.concatenate(y)

    return X, y

In [23]:
def cluster_embeddings():
    
    simclr_encoder = load_best_checkpoint('SimCLR')
    simclr_encoder.eval()
    simclr_encoder.cuda()
    
    X_train, y_train = generate_from_embeddings(simclr_encoder, train_dataloaders.get('train'))
    
    knn = KNeighborsClassifier(n_neighbors=10)
    print(knn)
    knn.fit(X_train, y_train)
    
    X_test, y_test = generate_from_embeddings(simclr_encoder, test_dataloaders.get('test'))
    y_pred = knn.predict(X_test)
    
    print(accuracy_score(y_test, y_pred))
    
    return knn
    
#     plot_features(simclr_encoder, dataloaders.get('train'), 2048, 256, 1000, 1000)

In [24]:
class DownstreamClassifier(pl.LightningModule):
    
    def __init__(self, base_model_name='SimCLR', base_model_version=None, features=2048, num_classes=13, learning_rate=3e-4):
        print(base_model_name, features, num_classes)
        super().__init__()
        
        self.save_hyperparameters()
                
        self.learning_rate = learning_rate
        
        self.num_classes = num_classes
        
        s = time()
        self.base_model = load_best_checkpoint(base_model_name, choice=base_model_version, num_classes=num_classes)
        self.base_model.eval()
        self.base_model.cuda()
        print('Base model load time: ', time() - s)

        self.classifier = nn.Linear(features, num_classes)
        
        self.loss_fn = torch.nn.CrossEntropyLoss()
        
        self.accuracy_fn = torchmetrics.Accuracy(task="multiclass", num_classes=self.num_classes).to(self.device)
        
    def forward(self, x):
        with torch.no_grad():
            x = self.base_model.forward_encoder(x, 0)
            if isinstance(x, list):
                x = x[-1]
            elif isinstance(x, tuple):
                x = x[0]
        x = self.classifier(x)
        return x

    def training_step(self, batch, batch_idx):
        (x1, x2, _), label = batch
        y_hat = self(x1)
        y_hat, _ = torch.max(y_hat, dim=1)
        loss = self.loss_fn(y_hat, label)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        (x1, x2, _), label = batch
        y_hat = self(x1)
        y_hat, _ = torch.max(y_hat, dim=1)
        loss = self.loss_fn(y_hat, label)
        self.log('val_loss', loss)
        return loss
    
    def test_step(self, batch, batch_idx):
        (x1, x2, _), label = batch
        y_hat = self(x1)
        y_hat, _ = torch.max(y_hat, dim=1)
        acc = self.accuracy_fn(y_hat, label)
        self.log('test_acc', acc, on_step=False, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=0.0008)


In [25]:
def train_downstream_model(model_name, downstream_model_name, model_version=None, feats=2048, max_epochs=10):
    
    root_dir = os.path.join(
        '/home/woody/iwfa/iwfa028h/dev/faps', 'data', 'ICDAR2017_CLaMM_Large'
    )
    dataloaders = data_factory(
        dataset_name='icdar',
        root_dir=root_dir, 
        label_filepath=os.path.join(root_dir, '@ICDAR2017_CLaMM_Large.csv'),
        transforms=SimCLREvalDataTransform,
        mode='train',
        batch_size=64,
        num_cpus=8
    )
    
    trainer = pl.Trainer(
        default_root_dir=os.path.abspath(os.path.join(root_dir, '..', 'trained_models', downstream_model_name)),
        accelerator='gpu',
        devices=-1,
        max_epochs=max_epochs,
        enable_progress_bar=True,
        precision=16,
        callbacks=[
            pl.callbacks.ModelCheckpoint(mode="min", monitor="val_loss"),
            pl.callbacks.RichProgressBar()
        ]
    )
    
    downstream_classifier = DownstreamClassifier(model_name, model_version, feats, 13, 3e-3)
    
    trainer.fit(downstream_classifier, dataloaders.get('train'), dataloaders.get('val'))
    
    return downstream_classifier

In [None]:
# downstream_model = train_downstream_model('SimCLR', 'SimCLRDownstream', None, 2048, 100)
downstream_model = train_downstream_model('MAE', 'MAEDownstream', '614645', 1024, 100)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


MAE 1024 13


FileNotFoundError: [Errno 2] No such file or directory: '/home/woody/iwfa/iwfa028h/dev/faps/data/trained_models/MAE/lightning_logs/version_614805/checkpoints'

In [26]:
def test_downstream_model(base_model_name, model_name, num_features, base_model_version):
    
    root_dir = os.path.join(
        '/home/woody/iwfa/iwfa028h/dev/faps', 'data', 'ICDAR2017_CLaMM_Large'
    )
    dataloaders = data_factory(
        dataset_name='icdar',
        root_dir=root_dir, 
        label_filepath=os.path.join(root_dir, '@ICDAR2017_CLaMM_Large.csv'),
        transforms=SimCLREvalDataTransform,
        mode='test',
        batch_size=64,
        num_cpus=8
    )    
    
    trainer = pl.Trainer(
        accelerator='gpu',
        devices=-1,
        max_epochs=1,
        enable_progress_bar=True,
        precision=16,
        enable_checkpointing=False,
        callbacks=[pl.callbacks.RichProgressBar()]
    )
    
    downstream_classifier = load_best_checkpoint(
        model_name,
        # choice='605092',
        base_model_name=base_model_name,
        base_model_version=base_model_version,
        features=num_features,
        num_classes=13,
    )
    
    trainer.test(downstream_classifier, dataloaders.get('test'))


In [27]:
# test_downstream_model('SimCLR', 'SimCLRDownstream')
test_downstream_model('MAE', 'MAEDownstream', 1024, '614645')